# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import json
import os
import sys
import time
from pathlib import Path
from typing import Any, List, Literal, Optional, Tuple, TypedDict

import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
    get_model_parallel_rank,
    initialize_model_parallel,
    model_parallel_is_initialized,
)

from llama.model import ModelArgs
from llama.mixed_model import MixedTransformer
from llama.tokenizer import Tokenizer
from llama.beam import Beam
from llama.utils import *
from ngrams.ngram_models import make_models

class MixedLlama:
    @staticmethod
    def build(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        device = None,
        model_parallel_size: Optional[int] = None,
        seed: int = 1,
    ) -> "Llama":
        """
        Build a Llama instance by initializing and loading a pre-trained model.

        Args:
            ckpt_dir (str): Path to the directory containing checkpoint files.
            tokenizer_path (str): Path to the tokenizer file.
            max_seq_len (int): Maximum sequence length for input text.
            max_batch_size (int): Maximum batch size for inference.
            mixed (bool): Whether to mix embeddings or not
            model_parallel_size (Optional[int], optional): Number of model parallel processes.
                If not provided, it's determined from the environment. Defaults to None.

        Returns:
            Llama: An instance of the Llama class with the loaded model and tokenizer.

        Raises:
            AssertionError: If there are no checkpoint files in the specified directory,
                or if the model parallel size does not match the number of checkpoint files.

        Note:
            This method initializes the distributed process group, sets the device to CUDA,
            and loads the pre-trained model and tokenizer.

        """
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group("nccl")
        if not model_parallel_is_initialized():
            if model_parallel_size is None:
                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
            initialize_model_parallel(model_parallel_size)

        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        print(local_rank)
        # torch.cuda.set_device(local_rank)
        # seed must be the same in all processes
        if device == None:
            torch.cuda.set_device(local_rank)
            device = torch.cuda.current_device()
        torch.manual_seed(seed)

        if local_rank > 0:
            sys.stdout = open(os.devnull, "w")

        start_time = time.time()
        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
        assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
        assert model_parallel_size == len(
            checkpoints
        ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
        ckpt_path = checkpoints[get_model_parallel_rank()]
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        with open(Path(ckpt_dir) / "params.json", "r") as f:
            params = json.loads(f.read())

        model_args: ModelArgs = ModelArgs(
            max_seq_len=max_seq_len,
            max_batch_size=max_batch_size,
            **params,
        )
        tokenizer = Tokenizer(model_path=tokenizer_path)
        model_args.vocab_size = tokenizer.n_words
        torch.set_default_tensor_type(torch.cuda.HalfTensor)
        # Choose between mixed and normal embeddings
        model = MixedTransformer(model_args)
        model.load_state_dict(checkpoint, strict=False)
        print(f"Loaded in {time.time() - start_time:.2f} seconds")
        return MixedLlama(model, tokenizer, device)

    def __init__(self, model: MixedTransformer, tokenizer: Tokenizer, device):
        print(device)
        self.model = model.to(device).eval()
        self.tokenizer = tokenizer
        self.device = device
        
    @torch.inference_mode()
    def beam_generate(
        self,
        prompt_tokens: List[List[int]],
        max_gen_len: int,
        mixing_method: str,
        smoothing: str,
        n_token_consider: int,
        n_token_sample: int,
        alpha: int, # weight on bigram probs
        temp: int,
        n_drafts: int = 1, # number of beams
        debug: bool = False,
        verbose: bool = False,
        i_weights = None,
        i_length = None,
        ngrams = None,
        sample_beams: bool = False,
        sample_tokens: bool = False,
        get_time: bool = False,
        penalty = 200
    ):
        """
        Run multi-sequence generation using mixed embeddings.
        Args:
            prompt_tokens (List[List[int]]): Initial tokenized prompts
            max_gen_len (int): Max generation length
            mixing_method (str): Mixing method
            smoothing (str): Smoothing method
            ngram_length (int): Length of ngrams for smoothing
            n_token_consider (int): Number of tokens to normalize for before running beam search
            n_token_sample (int): Number of tokens to consider from n_token_consider
            alpha (float): Weight for N-Gram probabilities
            temp (float): Temperature
            n_drafts (int): Number of drafts
            debug (bool): Whether to print outputs
            verbose (bool): Whether to store and return model hidden states
            i_weights (list): List of weights corresponding to ngrams in i_length
            i_length (list): Ngram lengths to use in interpolation
            sample_tokens (bool): Whether to sample next tokens passed into the beam search algorithm
            sample_beams (bool): Whether to sample beams
            get_time (bool): Return time spent
            penalty (int): Penalty on uninterpolated drafts
        Returns:
            (alive_seq, alive_ppl), (fin_seq, fin_ppl)
        """
        # check batch size and prompt lengths
        params = self.model.params
        bsz = len(prompt_tokens)
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        min_prompt_len = min(len(t) for t in prompt_tokens)
        max_prompt_len = max(len(t) for t in prompt_tokens)
        assert min_prompt_len == max_prompt_len, "Prompt lenghts must be equal"
        prompt_len = min_prompt_len
        assert max_prompt_len <= params.max_seq_len
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
        pad_id = self.tokenizer.pad_id
        
        # initialize token tensor
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=self.device)
        for k, t in enumerate(prompt_tokens):
            tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=self.device)
        
        # if no generation possible
        if min_prompt_len == total_len:
            raise RuntimeError("no generation possible")

        ### INTIALIZATION ###
        initial_tokens = tokens.unsqueeze(1).repeat(1, n_drafts, 1)
        beam_search = Beam(initial_tokens, 
                           tokenizer=self.tokenizer,
                           vocab_size=params.vocab_size,
                           mixing_method=mixing_method,
                           smoothing=smoothing,
                           alpha=alpha,
                           verbose=debug,
                           i_weights=i_weights,
                           i_length=i_length,
                           ngrams=ngrams,
                           sample_beams=sample_beams,
                           sample_tokens=sample_tokens,
                           get_time=get_time,
                           penalty=penalty)
        unseen_first = torch.ones(bsz) # 1 if still parsing prompt
        token_weights = torch.zeros(bsz, self.model.vocab_size)
        if verbose:
            state_list = []
        prev_pos = 0
        ### INFERENCE ###
        for cur_pos in range(min_prompt_len, total_len):
            input_text_mask = tokens != pad_id
            
            # Model step
            if cur_pos == min_prompt_len:
                token_weights = None
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], 
                                        start_pos=prev_pos, 
                                        token_weights=token_weights, 
                                        verbose=verbose)
            if verbose:
                logits, states = logits
            
            # Softmax
            if temp > 0:
                probs = torch.softmax(logits[:, -1] / temp, dim=-1)
            else:
                raise RuntimeError("Temperature must be greater than 0 while mixing")
            if verbose:
                states["end_probs"] = probs
                state_list.append(states)

            # Flag prompts on first generation
            is_first = torch.mul(tokens[:, cur_pos] == pad_id, unseen_first)
            unseen_first[is_first.nonzero(as_tuple=True)[0]] = 0
            
            # Flag prompts not yet generating
            still_prompt = input_text_mask[:, cur_pos]
            
            # Beam pass
            token_weights = beam_search(probs, still_prompt, is_first, cur_pos, n_token_consider, n_token_sample, use_mix=True)
            
            # Do not mix for prompts not yet generating
            keep_idx = input_text_mask[:, cur_pos].ravel().nonzero()
            keep_token_weights = torch.zeros_like(token_weights)
            keep_token_weights[keep_idx, tokens[keep_idx, cur_pos]] = 1
            token_weights = torch.where(input_text_mask[:, cur_pos].unsqueeze(1).expand(-1, self.model.vocab_size), 
                                        keep_token_weights, token_weights)
            prev_pos = cur_pos
            
        ### RETURN ###
        results = beam_search.return_results(prompt_len)
        if verbose:
            return results, state_list
        else:
            return results