from typing import Dict, Callable

try:
    import transformers
except ValueError:
    import transformers
import logging
from typing import Optional
import torch


def tokenizer_and_embedding_resize(
    special_tokens_dict: Dict[str, str],
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
    other_tokens_dict: Optional[Dict[str, str]] = None,
    other_tokens_are_special_tokens: bool = True,
    embed_fn: str = "smart",
):
    """Wrapper function to perform tokenizer and embedding resizing."""
    _EMBED_FNS: Dict[str, Callable] = {
        "vipi": vipi_tokenizer_and_embedding_resize,
        "smart": smart_tokenizer_and_embedding_resize,
        "hf": hf_init_tokenizer_and_embedding_resize,
    }
    _embed_fn = _EMBED_FNS[embed_fn]
    return _embed_fn(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
        other_tokens_dict=other_tokens_dict,
        other_tokens_are_special_tokens=other_tokens_are_special_tokens,
    )


def hf_init_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict[str, str],
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
    other_tokens_dict: Optional[Dict[str, str]] = None,
    other_tokens_are_special_tokens: bool = True,
):
    """Use the HF default method.

    Note that this is the method used by e.g. LLaVA; see https://github.com/haotian-liu/LLaVA/blob/7775b12d6b20cd69089be7a18ea02615a59621cd/llava/model/builder.py#L134
    """
    # use smart_tokenizer_and_embedding_resize for the special tokens
    smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model)
    tokenizer.add_tokens(
        list(other_tokens_dict.values()), special_tokens=other_tokens_are_special_tokens
    )
    model.resize_token_embeddings(len(tokenizer))


def vipi_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict[str, str],
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
    other_tokens_dict: Optional[Dict[str, str]] = None,
    other_tokens_are_special_tokens: bool = True,
):
    """A form of the VIPI tokenizer, applied only to 'other tokens dict'."""
    # use smart_tokenizer_and_embedding_resize for the special tokens
    smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model)

    # for the 'other' tokens, use VIPI initialization
    new_tokens = [
        x for x in other_tokens_dict.values() if x not in tokenizer.get_vocab()
    ]
    new_tokens_prev_ids = [
        tokenizer(x, add_special_tokens=False).input_ids for x in new_tokens
    ]
    logging.warning(
        f"adding tokens {other_tokens_dict} to vocab (as special tokens={other_tokens_are_special_tokens}"
    )
    num_new_tokens = tokenizer.add_tokens(
        list(other_tokens_dict.values()), special_tokens=other_tokens_are_special_tokens
    )

    logging.info(f"adding {num_new_tokens} to vocab")
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        for token, prev_token_ids in zip(new_tokens, new_tokens_prev_ids):
            new_token_id = tokenizer.convert_tokens_to_ids(token)

            # Sanity check that the prev_token_ids exactly reconstruct the token
            assert tokenizer.decode(prev_token_ids) == token

            input_embeds_mean = torch.stack(
                [input_embeddings[i] for i in prev_token_ids]
            ).mean(dim=0)
            output_embeds_mean = torch.stack(
                [output_embeddings[i] for i in prev_token_ids]
            ).mean(dim=0)

            input_embeddings[new_token_id, :] = input_embeds_mean
            output_embeddings[new_token_id, :] = output_embeds_mean
    logging.debug(f"len(tokenizer) after resize is {len(tokenizer)}")


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict[str, str],
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
    other_tokens_dict: Optional[Dict[str, str]] = None,
    other_tokens_are_special_tokens: bool = True,
):
    """Resize tokenizer and embedding matrix, adding both special_tokens_dict and other_tokens_dict.

    :param special_tokens_dict: special tokens that can be added with tokenizer.add_special_tokens().
        Typically this only includes tokens like bos_token, eos_token, pad_token.
        See transformers.tokenization_utils method .add_special_tokens() for more info.
    :param other_tokens_dict: tokens that cannot be added with tokenizer.add_special_tokens().
        This is where most tokens should be added.
    :param tokenizer: the tokenizer to modify.
    :param model: the model to be used with the tokenizer; its embedding matrix will be resized accordinly.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    logging.debug(f"len(tokenizer) before resize is {len(tokenizer)}")
    logging.warning(f"adding special tokens {special_tokens_dict} to vocab")
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    if other_tokens_dict:
        logging.warning(
            f"adding tokens {other_tokens_dict} to vocab (as special tokens={other_tokens_are_special_tokens}"
        )
        num_new_tokens += tokenizer.add_tokens(
            list(other_tokens_dict.values()),
            special_tokens=other_tokens_are_special_tokens,
        )
    logging.info(f"adding {num_new_tokens} to vocab")
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

    logging.debug(f"len(tokenizer) after resize is {len(tokenizer)}")


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk.

    via https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py"""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa
