import json
from typing import Any, List

import torch
from tqdm import tqdm

def prepare_encodings(batch, tokenizer, length):
    """
    Create encodings of a text batch truncated to a certain length.
    Args:
        batch (List[str]): Batch of text to encode
        tokenizer: Llama tokenizer
        length (int): Prompt length
    """
    tokens = tokenizer.encode(batch, True, False)
    new_encodings = []
    for i, encoded_text in enumerate(tokens):
        new_encodings.append(encoded_text[:length])
    return new_encodings

def evaluate_mixed_losses(data: List[List[str]],
                          model: Any,
                          tokenizer: Any,
                          prompt_len: int,
                          max_gen_len: int,
                          alpha: float,
                          temp: float,
                          n_drafts: int,
                          n_token_consider: int,
                          n_token_sample: int,
                          mixing_method: str,
                          smoothing: str,
                          debug: bool = False,
                          bsz=16,
                          i_weights = None,
                          i_length = None,
                          ngrams = None,
                          sample_beams = False,
                          sample_tokens=False,
                          get_time=False,
                          penalty=200,
                          marker=True):
    """
    Evaluate perplexity for mixed embeddings.
    Args:
        data (List[List[String]]): Input data
        model: Llama model
        tokenizer: Llama tokenizer
        prompt_len (int): Number of tokens in starting prompt
        max_gen_len (int): Maximum numbers of tokens to generate
        alpha (float): Alpha value
        temp (float): Temperature
        n_drafts (int): Number of drafts
        mixing_method (str): Mixing method
        smoothing (str): Smoothing strategy
        debug (bool): Control whether to print debugging information (default False)
        bsz (int): Batch size (default = 16)
        i_weights (List[float], Optional): List of weights corresponding to each ngram model
        i_length (List[int], Optional): List of ngram models to consider (1 for bigram, 2 for trigram, etc.)
        ngrams (Tuple, Optional): Tuple of ngram models 
        sample_beams (bool, Optional): Whether to sample beams
        diversity_boost (Tuple(int, float), Optional): Diversity penalty with # of tokens to penalize and penalty ratio 
        sample_tokens (bool): Whether to sample tokens
        marker (bool): Progress bar toggle
        
    Return:
        sequences (torch.Tensor): Generated sequences (n_prompts, n_drafts, prompt_len+max_gen_len)
        ppl (torch.Tensor): Perplexity (n_prompts, n_drafts)
    """
    if debug:
        print("### DEBUG MODE ON ###")
    
    it = range(0, len(data), bsz)
    if marker:
        it = tqdm(it)
    sequences = torch.zeros(len(data), n_drafts, prompt_len+max_gen_len, dtype=torch.long)
    ppl = torch.zeros(len(data), n_drafts)
    ovr_time = None
    for b_start in it:
        b_end = b_start + bsz
        # preprocessing
        batch = data[b_start : b_end]
        truncated_tokens = prepare_encodings(batch, tokenizer, prompt_len)
        
        # inference
        k = model.beam_generate(prompt_tokens=truncated_tokens, 
                                            max_gen_len=max_gen_len, 
                                            mixing_method=mixing_method,
                                            smoothing=smoothing,
                                            n_token_consider=n_token_consider,
                                            n_token_sample=n_token_sample,
                                            alpha=alpha, 
                                            temp=temp,
                                            n_drafts=n_drafts,
                                            debug=debug,
                                            i_weights=i_weights,
                                            i_length=i_length,
                                            ngrams=ngrams,
                                            sample_beams=sample_beams,
                                            sample_tokens=sample_tokens,
                                            get_time=get_time,
                                            penalty=penalty)
        # timing options
        if not get_time:
            (alive_seq, alive_ppl), (fin_seq, fin_ppl) = k
        else:
            (alive_seq, alive_ppl), (fin_seq, fin_ppl), ngram_time = k
            ovr_time = ngram_time if ovr_time is None else ovr_time + ngram_time
        # seq: n_prompts, n_drafts, prompt_len+max_gen_len
        # ppl: n_prompts, n_drafts
        combined_ppl = torch.cat([alive_ppl, fin_ppl], dim=1) # n_prompts, 2*n_drafts
        combined_seq = torch.cat([alive_seq, fin_seq], dim=1) # n_prompts, 2*n_drafts, prompt_len+max_gen_len
        top_ppl, top_idx = torch.topk(combined_ppl, n_drafts, dim=-1, largest=False)
        top_seq = torch.take_along_dim(combined_seq, top_idx.unsqueeze(dim=2), dim=1) # n_prompts, n_drafts, prompt_len+max_gen_len
        ppl[b_start : b_end, :] = top_ppl
        sequences[b_start : b_end, :, :] = top_seq
    if not get_time:    
        return sequences, ppl
    else:
        return sequences, ppl, ovr_time

def evaluate_nucleus_losses(data,
                            model,
                            tokenizer,
                            prompt_len,
                            max_gen_len,
                            temp,
                            bsz=16,
                            marker=True):
    """
    Evaluate perplexity for nucleus sampling.
    Args:
        data (List[List[String]]): Input data
        model (Any): Model
        tokenizer (Any): Llama tokenizer
        prompt_len (int): Number of tokens in starting prompt
        max_gen_len (int): Maximum numbers of tokens to generate
        temp (float): Temperature
        bsz (int): Batch size (default = 16)
        marker (bool): Progress bar toggle
    Return:
        sequences (torch.Tensor): Generated sequences (n_prompts, prompt_len+max_gen_len)
        ppl (torch.Tensor): Perplexity (n_prompts)
    """
    it = range(0, len(data), bsz)
    if marker:
        it = tqdm(it)
    sequences = torch.zeros(len(data), prompt_len+max_gen_len, dtype=torch.long)
    ppl = torch.zeros(len(data), dtype=torch.float32)
    for b_start in it:
        b_end = b_start + bsz
        # preprocessing
        batch = data[b_start : b_end]
        truncated_tokens = prepare_encodings(batch, tokenizer, prompt_len)
        
        # inference
        curr_seq, curr_ppl = model.generate(prompt_tokens=truncated_tokens,
                                  max_gen_len=max_gen_len,
                                  temperature=temp,
                                  top_p=0.9,
                                  logprobs=True)
        sequences[b_start : b_end, :] = curr_seq
        ppl[b_start : b_end] = curr_ppl
    return sequences, ppl
        
            
def parse_params(param_path):
    with open(param_path, "r") as f:
        p = json.load(f)
    return p["alpha"], p["temp"], p["prompt_len"], p["mixing_method"], p["smoothing"], p["sample_tokens"], p["sample_beams"], p["i_weights"], p["i_length"], p["ckpt_path"]