import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
from tqdm import tqdm
import numpy as np
from peft import PeftModel
import os
from sklearn.metrics import f1_score

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="data/babi/data")
    parser.add_argument('--model_name_or_path', type=str, default='EleutherAI/gpt-j-6B')
    parser.add_argument('--lora_name_or_path', type=str)
    args = parser.parse_args()
    return args

def main(args):
    y_pred, y_true = [], []

    # Prepare tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, load_in_8bit=True, device_map={"": 0})
    if args.lora_name_or_path is not None:
        model = PeftModel.from_pretrained(model, args.lora_name_or_path, device_map={"": 0})
    model.eval()

    def compute_prob(inp, contxt_len, answer_tokens):
        inputs = tokenizer(inp, return_tensors='pt')
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        logits = model(**inputs).logits
        logits = logits[:, contxt_len - 1:inputs['attention_mask'].sum()]
        vocab_log_probs = torch.log_softmax(logits, dim=-1)
        token_log_probs = torch.gather(
            vocab_log_probs, dim=2, index=answer_tokens[:, :, None]
        )
        log_prob = token_log_probs.sum()
        return log_prob.cpu().item()

    answer_contxt_len = tokenizer('Answer:', return_tensors="pt").input_ids.size(1)
    with torch.no_grad():
        with open(args.input_path, "r") as f:
            for line in tqdm(list(f)):
                sample = json.loads(line)
                prompt = sample["input"]
                prompt_len = tokenizer(prompt, return_tensors="pt").input_ids.size(1)
                prob_list = [] # list of log prob of each answer
                answer_list = sample['answer_list']

                for answer in answer_list:
                    answer_tokens = tokenizer(f' {answer}', return_tensors='pt').input_ids.to(model.device)
                    if args.norm == 'length':
                        prob = compute_prob(f'{prompt} {answer}', prompt_len, answer_tokens)
                        final_prob = prob / answer_tokens.size(1)
                    else:
                        prob = compute_prob(f'{prompt} {answer}', prompt_len, answer_tokens)
                        uncond_prob = compute_prob(f'Answer: {answer}', answer_contxt_len, answer_tokens)
                        final_prob = prob - uncond_prob
                    prob_list.append(final_prob)
                    
                gt_idx = answer_list.index(sample['answer'])
                gen_idx = np.argmax(prob_list)
                y_pred.append(gen_idx)
                y_true.append(gt_idx)
            
    f1 = f1_score(y_true, y_pred, average='weighted')
    return f1

if __name__ == "__main__":
    args = parse_args()
    files = os.listdir(args.data_dir)
    for file in files:
        file = os.path.join(args.data_dir, file)
        f1s = {}
        args.input_path = file
        for norm in ['length', 'unconditioned']:
            args.norm = norm
            f1 = main(args)
            f1s[norm] = f1
        print(f'{file}: {f1s}')
