import json
import re
from dataclasses import dataclass
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import sentence_transformers
import torch
from fire import Fire
from IPython import embed
from mteb import MTEB
from nltk import word_tokenize
from scipy.stats import spearmanr
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Pooling, WordEmbeddings
from torchtyping import TensorType as TT
from tqdm import tqdm

from all_but_the_top import AllButTheTop
from modeling import CustomPooling
from SIF import SIF
from zipfian_whitening import UniformWhitening, ZipfianWhitening

PATH_ENWIKI_VOCAB_MIN200 = "data/enwiki_vocab_min200/enwiki vocab min200.txt"


@dataclass
class UnigramProb:
    """
    Dataclass to store the unigram probabilities of the words and the unused vocabulary ids. (will be used for removing the unused words from the embeddings before whitening)
    """

    prob: TT["model_vocab_size"]
    unused_vocab_ids: set[int]


class WrappedTokenizer(object):
    """
    Given a instanciated WhiteSpaceTokenizer, warp as a new tokenizer class to integrate nltk tokenizer
    """

    def __init__(
        self,
        original_tokenizer: sentence_transformers.models.tokenizer.WhitespaceTokenizer,
    ):
        self.original_tokenizer = original_tokenizer

    def tokenize(self, text: str) -> list[str]:
        """
        Override the tokenize method to use nltk_tokenize instead of the original whitespace_tokenize
        """
        text = self.nltk_tokenize(text)
        return self.original_tokenizer.tokenize(text)

    # Modified from https://github.com/kawine/usif/blob/71ffef5b6d7295c36354136bfc6728a10bd25d32/usif.py#L107-L137
    def nltk_tokenize(self, sentence: str) -> str:
        """
        Given a sentence, tokenize it using nltk's word_tokenize function.
        Then, preprocess the tokens by converting them to lowercase and removing punctuation.
        Finally, return the concatenated tokens as a string, ready to be passed to whitespace tokenization of the model.
        """
        # regex for non-punctuation
        not_punc = re.compile(".*[A-Za-z0-9].*")

        # preprocess a given token
        def preprocess(t):
            t = t.lower().strip("';.:()").strip('"')
            t = "not" if t == "n't" else t
            return re.split(r"[-]", t)

        tokens = []

        for token in word_tokenize(sentence):
            if not_punc.match(token):
                tokens = tokens + preprocess(token)

        return " ".join(tokens)

    def __getattr__(self, name):
        """
        Forward all other method calls to the original tokenizer
        """
        return getattr(self.original_tokenizer, name)


def load_unigram_prob(model_name: str, model_vocab_size: int) -> UnigramProb:
    unigram_prob_path = Path(
        f"data/wikipedia/{Path(model_name).name}/unigram_prob.json"
    )
    with unigram_prob_path.open("r") as f:
        unigram_prob = json.load(f)
    unigram_prob_tensor = torch.zeros(model_vocab_size)
    for word_id, prob in unigram_prob.items():
        unigram_prob_tensor[int(word_id)] = float(prob)

    # ensure unigram_prob_tensor is a valid probability distribution
    assert torch.allclose(unigram_prob_tensor.sum(), torch.tensor(1.0))
    assert torch.all(unigram_prob_tensor >= 0)
    assert torch.all(unigram_prob_tensor <= 1)
    assert unigram_prob_tensor.shape[0] == model_vocab_size

    unused_vocab_ids = set(range(model_vocab_size)) - set(unigram_prob.keys())
    return UnigramProb(prob=unigram_prob_tensor, unused_vocab_ids=unused_vocab_ids)


def load_unigram_prob_enwiki_vocab_min200(
    tokenizer: Union[
        sentence_transformers.models.tokenizer.WhitespaceTokenizer, WrappedTokenizer
    ],
    model_vocab_size: int,
    path: str = PATH_ENWIKI_VOCAB_MIN200,
    topk: Optional[int] = None,
) -> UnigramProb:
    """
    Load the unigram probabilities of the words in the vocabulary from the enwiki_vocab_min200.txt file.
    Only available for glove/word2vec. (Could be used for BERT-based models as well, but subword tokenization cause very sparse unigram probabilities without doing alignment)
    """
    frequency_dict: Dict[int, int] = {}
    # load the frequency of the words in the vocabulary
    with open(path, "r") as f:
        for count, line in enumerate(f):
            word_and_freq = line.rstrip().split(" ")
            assert (
                len(word_and_freq) == 2
            )  # ensuring that the line has only two elements, otherwise the file is not formatted correctly or the line is corrupted
            word, freq = word_and_freq
            freq = int(freq)
            word_id = tokenizer.word2idx[word] if word in tokenizer.word2idx else None
            if word_id is not None:
                frequency_dict[word_id] = freq
            if topk is not None and (count + 1) >= topk:
                break

    # create a tensor of the unigram probabilities
    unigram_prob = torch.zeros(model_vocab_size)
    for word_id, freq in frequency_dict.items():
        unigram_prob[word_id] = freq

    # normalize the unigram probabilities
    unigram_prob = unigram_prob / unigram_prob.sum()
    assert torch.allclose(unigram_prob.sum(), torch.tensor(1.0))
    assert torch.all(unigram_prob >= 0)
    assert torch.all(unigram_prob <= 1)
    assert unigram_prob.shape[0] == model_vocab_size

    # check the top k most frequent words
    assert topk is None or len(frequency_dict) == topk

    unused_vocab_ids = set(range(model_vocab_size - 1)) - set(frequency_dict.keys())
    return UnigramProb(prob=unigram_prob, unused_vocab_ids=unused_vocab_ids)


def remove_unused_words(
    unused_vocab_ids: set[int],
    W: TT["num_words", "hidden_dim"],
    p: Optional[TT["num_words"]] = None,
) -> tuple[TT["num_words", "hidden_dim"], Optional[TT["num_words"]]]:
    """
    Remove the unused words from the input embeddings and unigram probabilities.
    If the topk is provided, only the topk most frequent words are kept.
    """
    W = W.clone()
    p = p.clone() if p is not None else None

    mask = torch.ones(W.shape[0], dtype=bool)
    mask[list(unused_vocab_ids)] = False
    W = W[mask]
    p = p[mask] if p is not None else None
    if p is None:
        return W, None
    else:
        p = p / p.sum()
        assert W.shape[0] == p.shape[0]
        assert p.sum() == 1
        return W, p


def load_word2vec_model(model_name: str) -> SentenceTransformer:
    # load the word2vec model from the text file
    model_name += ".txt"
    embedding = WordEmbeddings.from_text_file(model_name)
    pooling = Pooling(embedding.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[embedding, pooling])
    return model
