import json
import os
import re
from copy import deepcopy

import faiss
import matplotlib.pyplot as plt
import numpy as np
from get_similar_data import *
from langchain_community.embeddings.huggingface import (
    HuggingFaceEmbeddings,
)
from sklearn.decomposition import PCA


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def process_webarena_data(source_path, component):
    source_dict = json.load(open(source_path, "r"))
    wa_res = []
    for examples_key in source_dict:
        examples = source_dict[examples_key]
        extracted_parts = {"objective": None, "observation": None}

        for idx, message in enumerate(examples["messages"]):
            extracted_parts["objective"] = examples["intent"]
            if "user" in message:
                extracted_parts["observation"] = re.search(
                    r"(observation:.*)", message["user"], re.DOTALL
                ).group()
                extracted_parts["examples"] = examples_key
                extracted_parts["index_in_messages"] = idx
                wa_res.append(deepcopy(extracted_parts[component]))
    return wa_res


def get_embeddings(data: dict[str, list[str]], info: str) -> list[np.ndarray]:
    model_kwargs = {"device": "cuda:0"} if torch.cuda.is_available() else {}
    embed_model = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-mpnet-base-v2",
        model_kwargs=model_kwargs,
    )
    store = Store(embed_model)

    # calculate and save embeddings
    for k in data:
        if (
            "webarena" in k
            or "mind2web" in k
            and os.path.exists(f"./data_vis/.tmp/faiss/{info}/{k}/")
        ):
            continue
        store.load_doc(k, data[k])
        store.docs[k].save_local(f"./data_vis/.tmp/faiss/{info}/{k}/")

    # load embeddings
    for k in data:
        assert os.path.exists(f"./data_vis/.tmp/faiss/{info}/{k}")
        store.load_embeddings(k, f"./data_vis/.tmp/faiss/{info}/{k}")

    embedding_list = []
    for k in data:
        embedding_list.append(
            faiss.rev_swig_ptr(
                store.docs[k].index.get_xb(), len(data[k]) * 768
            ).reshape(len(data[k]), 768)
        )
    return embedding_list


def plot_embedding(data, info, tag: str):
    embedding_list = get_embeddings(data, info)
    data_np = np.vstack(embedding_list)
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(data_np)

    plt.figure(figsize=(8, 6))
    colors = ["r", "r", "r", "r", "cyan", "b", "y", "orange"]
    tmp_list = list(data.keys())
    variable_list = embedding_list

    for i in [0, 1, 2, 3, 5, 4]:
        # for i in [0,1,2,3,4]:
        plt.scatter(
            pca_result[
                sum([var.shape[0] for var in variable_list[:i]]) : sum(
                    [var.shape[0] for var in variable_list[: i + 1]]
                ),
                0,
            ],
            pca_result[
                sum([var.shape[0] for var in variable_list[:i]]) : sum(
                    [var.shape[0] for var in variable_list[: i + 1]]
                ),
                1,
            ],
            color=colors[i],
            label=tmp_list[i],
            marker=".",
        )

    plt.legend()
    plt.savefig(f"./data_vis/.tmp/plots/{tag}_tsne_{info}.svg")
    plt.savefig(
        f"./data_vis/.tmp/plots/{tag}_tsne_{info}.pdf", bbox_inches="tight"
    )
    # plt.show()
