import os
from datasets import load_dataset
import torch
import json
from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM
from tqdm import tqdm
import numpy as np
import random
import argparse
# from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
import torch.distributed as dist
import torch.multiprocessing as mp

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="llama2-7b-chat-4k-lm-infinite")
    parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E", default=True)
    return parser.parse_args(args)

# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
    if "llama2" in model_name:
        prompt = f"<s>[INST]{prompt}[/INST]"
    return prompt

def post_process(response, model_name):
    if "xgen" in model_name:
        response = response.strip().replace("Assistant:", "")
    elif "internlm" in model_name:
        response = response.split("<eoa>")[0]
    return response

def get_pred(rank, world_size, data, max_length, max_gen, prompt_format, dataset, device, model_name, model2path, out_path):
    device = torch.device(f'cuda:{rank}')
    model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device)


    for json_obj in tqdm(data):
        prompt = prompt_format.format(**json_obj)

        if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
            prompt = build_chat(tokenizer, prompt, model_name)

        input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        context_length = input.input_ids.shape[-1]

        input_ids = input.input_ids

        # token_lengths.append(context_length)

        if context_length > 10*1024:
            continue


        if dataset == "samsum" or True: # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
            model_custom_config = {
                "max_new_tokens": max_gen,
                "temperature": 0.0,
                "top_p": 0.9
            }
            with torch.no_grad():
                output = model.generate(
                    input_ids,
                    **model_custom_config
                )[0]

            # output = model.generate(
            #     **input,
            #     max_new_tokens=max_gen,
            #     num_beams=1,
            #     do_sample=False,
            #     temperature=1.0,
            # )[0]
        pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
        pred = post_process(pred, model_name)
        with open(out_path, "a", encoding="utf-8") as f:
            json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": context_length}, f, ensure_ascii=False)
            f.write('\n')


    dist.destroy_process_group()

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_model_and_tokenizer(path, model_name, device):
    if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name:
        tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
    elif "llama2" in model_name:
        # replace_llama_attn_with_flash_attn()
        if "rerope" in model_name and "leaky" not in model_name:
            import methods.rerope
            tokenizer = LlamaTokenizer.from_pretrained(path)
            model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)
        elif "leaky-rerope" in model_name:
            import methods.leaky_rerope_patch
            tokenizer = LlamaTokenizer.from_pretrained(path)
            model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)
        elif "weave" in model_name:
            import methods.weave_v20
            # setting llama2-7b-chat
            from methods.weave_v20 import position_set
            position_set.push_width = 20 # 10
            position_set.last_context = 800 #612 #1024
            position_set.context_window_length = 2048
            position_set.train_length = 2048
            tokenizer = LlamaTokenizer.from_pretrained(path)
            model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)
        elif "lm-infinite" in model_name:
            # hack_args = (args.use_lambda_attention, args.local_branch, args.global_branch,
            #              args.limit_distance, args.triangle_offset)
            use_lambda_attention = True
            local_branch = 2048
            global_branch = 100
            limit_distance = 2048
            triangle_offset = 0.0
            hack_args = (use_lambda_attention, local_branch, global_branch,
                         limit_distance, triangle_offset)
            from methods.lminfinite.lm_infinite import LLAMA_Model
            max_length = 32770
            truncation_side = "right"
            load_in_4bit = False
            model = LLAMA_Model(
                path, path,
                max_length, truncation_side,
                load_in_4bit, device, *hack_args
            )
            tokenizer = model.tokenizer
            from tensor_parallel import tensor_parallel
            tensor_parallel(model.model)

        elif "streaming-llm" in model_name:
            from methods.streaming_llm.enable_streaming_llm import enable_streaming_llm
            from methods.streaming_llm_method import StreamingLLM
            tokenizer = AutoTokenizer.from_pretrained(path)
            model = AutoModelForCausalLM.from_pretrained(path, device_map=device, torch_dtype=torch.float16)
            start_size, recent_size = 4, 2040
            kv_cache = enable_streaming_llm(model, start_size=start_size, recent_size=recent_size)
            model = StreamingLLM(model=model, tokenizer=tokenizer, kv_cache=kv_cache, max_gen_len=50)
        elif "dynamic-ntk" in model_name:
            from transformers import LlamaConfig
            config = LlamaConfig.from_pretrained(path)
            config.rope_scaling = {
                "type": "dynamic",
                "factor": 2.0
            }
            model = LlamaForCausalLM.from_pretrained(path, device_map=device, torch_dtype=torch.float16, config=config)
            tokenizer = LlamaTokenizer.from_pretrained(path)

        else:
            tokenizer = LlamaTokenizer.from_pretrained(path)
            model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)

    model.eval()
    return model, tokenizer

if __name__ == '__main__':
    seed_everything(42)
    args = parse_args()
    world_size = torch.cuda.device_count()
    mp.set_start_method('spawn', force=True)

    model2path = json.load(open("config/model2path.json", "r"))
    model2maxlen = json.load(open("config/model2maxlen.json", "r"))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 单台测试
    world_size = 2
    # device = torch.device('cuda:0')

    model_name = args.model
    # define your model
    max_length = model2maxlen[model_name]
    if args.e:
        datasets = ["qasper", "multifieldqa_en",
                    "hotpotqa", "2wikimqa",
                    "gov_report", "multi_news", \
                    "trec", "triviaqa", "samsum",
                    "passage_count", "passage_retrieval_en",
                    "lcc", "repobench-p"]
        # datasets = ["qasper", "triviaqa", "passage_retrieval_en", "lcc", "repobench-p"]
        # datasets = ["repobench-p"]
    else:
        datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh",
                    "hotpotqa", "2wikimqa", "musique", "dureader",
                    "gov_report", "qmsum", "multi_news", "vcsum",
                    "trec", "triviaqa", "samsum", "lsht", \
                    "passage_count", "passage_retrieval_en", "passage_retrieval_zh",
                    "lcc", "repobench-p"]
    # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
    dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
    dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
    # predict on each dataset
    if not os.path.exists("pred"):
        os.makedirs("pred")
    if not os.path.exists("pred_e"):
        os.makedirs("pred_e")
    for dataset in datasets:
        if args.e:
            data = load_dataset('json', data_files=f"data/{dataset}_e.jsonl")['train']
            if not os.path.exists(f"pred_e/{model_name}"):
                os.makedirs(f"pred_e/{model_name}")
            out_path = f"pred_e/{model_name}/{dataset}.jsonl"
        else:
            data = load_dataset('json', data_files=f"data/{dataset}.jsonl")['train']
            if not os.path.exists(f"pred/{model_name}"):
                os.makedirs(f"pred/{model_name}")
            out_path = f"pred/{model_name}/{dataset}.jsonl"


        prompt_format = dataset2prompt[dataset]
        max_gen = dataset2maxlen[dataset]
        data_all = [data_sample for data_sample in data]
        data_subsets = [data_all[i::world_size] for i in range(world_size)]
        processes = []
        for rank in range(world_size):
            p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_length, \
                        max_gen, prompt_format, dataset, device, model_name, model2path, out_path))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()