import argparse
import os
import json
import datetime
import logging
import sys
import torch
import time
from sklearn.metrics import f1_score
import torch.optim as optim

from transformers import AutoTokenizer, LlamaForCausalLM
from tqdm import tqdm
from utils import LLMModel
from prompts_qa import Prompt_Loader
import evaluate
from prune_model import PruningEnv, BertPruningModel, PPOAgent, InferenceModel

class MindMap_Execution:
    def __init__(self, args) -> None:
        self.args = args
        self.dataset_name = args.dataset_name
        self.data_path = args.data_path
        self.save_path = args.save_path
        self.bsz = args.batch_size
        self.max_new_tokens = args.max_new_tokens
        self.executor = args.executor
        self.mode = args.mode
        self.sample_k = args.sample_k
        self.prompt = args.prompt_mode
        self.filter = args.filter
        self.claim_type = args.claim_type
        self.paradigm = args.paradigm
        self.model_name = args.model_name
        self.executor = args.executor
        self.folder = args.folder
        self.cot = args.cot
        

        self.llm = LLMModel(args.api_key, args.model_name, args.stop_words, args.max_new_tokens)
        self.prompt_loader = Prompt_Loader(args.prompt_mode, args.paradigm, args.cot)


    def create_logging_path(self):
        log_name = str(datetime.datetime.now()).replace(' ','_')
        self.log_path = os.path.join(self.save_path, self.dataset_name, log_name)
        self.create_directories_dir(self.log_path)

        #logging and console logging
        log_formatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s")
        self._logger = logging.getLogger()
        for handler in self._logger.handlers[:]:
            self._logger.removeHandler(handler)

        for f in self._logger.filters[:]:
            self._logger.removeFilters(f)
        
        file_handler = logging.FileHandler(os.path.join(self.log_path, 'all.log'))
        file_handler.setFormatter(log_formatter)
        self._logger.addHandler(file_handler)
        
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(log_formatter)
        self._logger.addHandler(console_handler)

        self._logger.setLevel(logging.INFO)

        #logging arguments
        #1. as json
        name = 'args'
        dic = self.args
        path = os.path.join(self.log_path, '%s.json' % name)
        f = open(path, 'w')
        json.dump(vars(dic), f, indent = 4)
        f.close()

        # 2. as string
        path = os.path.join(self.log_path, '%s.txt' % name)
        f = open(path, 'w')
        args_str = ["%s = %s" % (key, value) for key, value in vars(dic).items()]
        f.write('\n'.join(args_str))
        f.close()

    def create_directories_dir(self, d):
        if d and not os.path.exists(d):
            os.makedirs(d)
        return d
    
    #mode == sample, get sample data else all data
    def get_datas(self, data_path, folder = 'gpt4', mode = 'sample'):  
        data_file = os.path.join(data_path, self.dataset_name, folder, 'dev_mind_map.json')

        self._logger.info('data file:' + data_file)
        
        full_datas = json.load(open(data_file))
        if mode == 'sample':
            sample_datas_ids = json.load(open(os.path.join(data_path, self.dataset_name, folder, 'sample_ids_' + str(self.sample_k)+ '.json')))
            sample_ids = [sample['id'] for sample in sample_datas_ids]
            sample_datas = [] 
            for data in full_datas:
                if data['id'] in sample_ids:
                    sample_datas.append(data)
            return sample_datas
                 
        return full_datas
    

    def execute(self):
        # predict result
        # self.result_dict = [] 

        #create output dir and logging path
        self.create_logging_path()

        #load dataset
        self._logger.info('dataset name:' + self.dataset_name)
        self._logger.info('original data path:' + self.data_path)
        self._logger.info('save path: ' + self.log_path)
        self._logger.info('datas mode:' + self.mode)
        self._logger.info('prompt mode:' + self.prompt)
        self._logger.info('paradigm mode:' + self.paradigm)
        self._logger.info('executor:' + self.executor)
        self._logger.info('model name:' + self.model_name)
        self._logger.info('use cot:' + str(self.cot))

        original_datas = self.get_datas(self.data_path, self.folder, self.mode)

        self._logger.info('Total: ' + str(len(original_datas)))

        # if self.executor == 'url':
        #     self._logger.info('***start execute with url mode***')
        #     self.execution_with_url(original_datas)
        # elif self.executor == 'open':
        #     self._logger.info('***start execute with open model***')
        #     self._logger.info('model_path:' + self.args.model_path)
        #     self.execute_with_open_model(original_datas)
        # else:
        #     self._logger.info('***start execute with openai api batch mode***')
        #     self.execution_with_openai_batch(original_datas)

        extracted_context = [data['mind_maps'] for data in original_datas]
        questions = [data['claim'] for data in original_datas]
        labels = [data['label'] for data in original_datas]
        evidences = [data['evidence'] for data in original_datas]##
        # Creating an Environment
        env = PruningEnv(extracted_context, questions, labels, evidences, self.prompt_loader, self.llm, self.dataset_name)

        # Create policy model and optimizer
        policy_model = BertPruningModel().cuda()
        optimizer = optim.Adam(policy_model.parameters(), lr=self.args.lr) #lr =3e-5

        # Create PPO agent and start training
        ppo_agent = PPOAgent(env, policy_model, optimizer,self.args.clip_param, self.args.ppo_epochs, mini_batch_size=self.args.batch_size)
        # ppo_agent.train(num_episodes=1000)
        ppo_agent.train(num_episodes=self.args.num_episodes)

            
    def inference(self):
        # predict result
        self.result_dict = [] 

        #create output dir and logging path
        self.create_logging_path()

        #load dataset
        self._logger.info('dataset name:' + self.dataset_name)
        self._logger.info('original data path:' + self.data_path)
        self._logger.info('save path: ' + self.log_path)
        self._logger.info('datas mode:' + self.mode)
        self._logger.info('prompt mode:' + self.prompt)
        self._logger.info('paradigm mode:' + self.paradigm)
        self._logger.info('executor:' + self.executor)
        self._logger.info('prune model path:' + self.args.prune_model_path)
        self._logger.info('model name:' + self.model_name)
        self._logger.info('use cot:' + str(self.cot))

        original_datas = self.get_datas(self.data_path, self.folder, self.mode)

        self._logger.info('Total: ' + str(len(original_datas)))

        # Load the trained BERT pruning model
        trained_pruning_model = BertPruningModel()
        trained_pruning_model.load_state_dict(torch.load(self.args.prune_model_path))
        trained_pruning_model.eval()

        # Creating the inference model
        inference_model = InferenceModel(trained_pruning_model, self.llm)

    
        for data in original_datas:
            predict_example = data
            mind_maps = data['mind_maps']
            question = data['claim']
            evidence = data['evidence']

            # inferences
            pruned_mind_maps = inference_model.prune_text((mind_maps, question) )
            full_prompts = self.prompt_loader.prompt_construction(question, evidence, pruned_mind_maps, self.dataset_name)
            messages = [{"role": "user", "content": full_prompts}]
            response = self.llm.predict(messages)
            response_json = json.loads(response.text) 
            if self.model_name.startswith('llama-2'):
                predict = response_json['output']
            else:
                predict = response_json['choices'][0]['message']['content']
            predict_example['prediction'] = predict
            self.result_dict.append(predict_example)

        ##save result
        json.dump(self.result_dict, open(os.path.join(self.log_path, 'predict.json'), 'w'), indent=4)
        self._logger.info('prediction.json saved in logger path!')

        ## start evaluate the result
        self._logger.info('start to evaluate the result with macro F1 .')
        if self.args.dataset_name == 'FEVEROUS' or self.args.dataset_name == 'SCIFACT':
            macro_report, confusion_report = evaluate.evaluate_feverous(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)
        elif self.args.dataset_name == 'HOVER':
            macro_reports = evaluate.evaluate_hover(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for key, macro_report  in macro_reports.items():
                self._logger.info(f'********************{key}*******************')
                for row_data in macro_report.split('\n'):
                    self._logger.info(row_data)
        elif self.args.dataset_name == 'HOTPOTQA' or self.args.dataset_name =='2WIKIMULTIHOPQA' or self.args.dataset_name == 'MuSiQue' or self.args.dataset_name == 'QANGAROO':
            metrics = evaluate.evaluate_hotpotqa(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info(metrics)
        elif self.args.dataset_name =='STRATEGYQA':
            macro_report, confusion_report = evaluate.evaluate_strategyqa(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)

        self._logger.info('********************End !*******************')
        return 

    def execution_with_openai_batch(self, datas):

        # datas = datas[1260:]
       
        dataset_chunks = [datas[i:i + self.bsz] for i in range(0, len(datas), self.bsz)]
        filter_results = []
        for chunk in tqdm(dataset_chunks):
            full_prompts = [self.prompt_loader.prompt_construction(example['claim'], example['evidence'], str(example['mind_maps'][0]), self.dataset_name) for example in chunk]


            batch_outputs = self.llm.batch_generate_with_openai(full_prompts)

            if len(batch_outputs) > 0:
                for example, output in zip(chunk, batch_outputs):
                    filter_results.append(example)
                    predict_result = example
                    predict_result['prediction'] = output
                    self.result_dict.append(predict_result)
                time.sleep(1)

        ##save result
        json.dump(self.result_dict, open(os.path.join(self.log_path, 'predict.json'), 'w'), indent=4)
        self._logger.info('prediction.json saved in logger path!')

        if self.filter:
            ##save result
            json.dump(filter_results, open(os.path.join(self.data_path, self.dataset_name, 'filter', 'dev_mind_map.json'), 'w'), indent=4)
            self._logger.info('fliter dev_mind_map.json saved in filter path!')
            self._logger.info(str(len(filter_results)))
            

        ## start evaluate the result
        self._logger.info('start to evaluate the result with macro F1 .')
        if self.args.dataset_name == 'FEVEROUS':
            macro_report, confusion_report = evaluate.evaluate_feverous(os.path.join(self.log_path, 'predict.json'))
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)
            
            
    def execute_with_open_model(self, datas):
        
        tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
        tokenizer.pad_token = '[PAD]'
        model = LlamaForCausalLM.from_pretrained(self.args.model_path)

        dataset_chunks = [datas[i:i + self.bsz] for i in range(0, len(datas), self.bsz)]
        for chunk in tqdm(dataset_chunks):
            full_prompts = []
            for example in chunk:
                if self.claim_type == 'text':
                    claim = example['claim']
                elif self.claim_type == 'map':
                    claim = str(example['claim_maps'])
                evidence = example['evidence']
                # mind_maps = str(example['mind_maps'][0])
                mind_maps = example['mind_maps']
                if self.dataset_name == 'SCIFACT':
                    evidence = '\n'.join(example['evidence'])
                    str_mind_maps = [str(map) for map in example['mind_maps']]
                    mind_maps = '\n'.join(str_mind_maps)
                    
                if self.dataset_name == 'QANGAROO':
                    candidates = example['candidates']
                    full_prompt = self.prompt_loader.prompt_construction_mc(claim, evidence, mind_maps, candidates, self.dataset_name)
                else:
                    full_prompt = self.prompt_loader.prompt_construction(claim, evidence, mind_maps, self.dataset_name)

                full_prompts.append(full_prompt)
            
            batch_inputs = tokenizer(full_prompts, padding=True, return_tensors="pt")
            generate_ids = model.generate(batch_inputs.input_ids, max_length=300, ).to('cuda:1')
            batch_predict = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            for sample, predict in zip(chunk, batch_predict):
                predict_example = sample
                predict_example['prediction'] = predict
                self.result_dict.append(predict_example)

        ## save_results
        json.dump(self.result_dict, open(os.path.join(self.log_path, 'predict.json'), 'w'), indent=4)
        self._logger.info('prediction.json saved in logger path!')

        ## start evaluate the result
        self._logger.info('start to evaluate the result with macro F1 .')
        if self.args.dataset_name == 'FEVEROUS' or self.args.dataset_name == 'SCIFACT':
            macro_report, confusion_report = evaluate.evaluate_feverous(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)
            
        elif self.args.dataset_name == 'HOVER':
            macro_reports = evaluate.evaluate_hover(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for key, macro_report  in macro_reports.items():
                self._logger.info(f'********************{key}*******************')
                for row_data in macro_report.split('\n'):
                    self._logger.info(row_data)
            self._logger.info('********************End !*******************')
        
        elif self.args.dataset_name == 'HOTPOTQA' or self.args.dataset_name =='2WIKIMULTIHOPQA' or self.args.dataset_name == 'MuSiQue' or self.args.dataset_name == 'QANGAROO':
            metrics = evaluate.evaluate_hotpotqa(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info(metrics)

        elif self.args.dataset_name =='STRATEGYQA':
            macro_report, confusion_report = evaluate.evaluate_strategyqa(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)

        self._logger.info('********************End !*******************')


    def execution_with_url(self, datas): 
        messages = []


        for example in tqdm(datas):
            predict_example = example
            # claim = example['claim']
            if self.claim_type == 'text':
                claim = example['claim']
            elif self.claim_type == 'map':
                claim = str(example['claim_maps'])
            evidence = example['evidence']

            mind_maps = example['mind_maps']

            if self.dataset_name == 'QANGAROO':
                candidates = example['candidates']
                full_prompts = self.prompt_loader.prompt_construction_mc(claim, evidence, mind_maps, candidates, self.dataset_name)
                messages = [{"role": "user", "content": full_prompts}]
            else:
                full_prompts = self.prompt_loader.prompt_construction(claim, evidence, mind_maps, self.dataset_name)
                messages.append({"role": "user", "content": full_prompts})

            #call the gpt by url
            response = self.llm.call_gpt_with_url(messages, max_tokens = self.max_new_tokens, temperature=0.0, top_p=1.0)
            

            response_json = json.loads(response.text) 
           

            #if result is not true, recall the url
            while ('choices' not in response_json.keys() or \
                                response_json['choices'][0]['finish_reason'] != 'stop' or \
                                'content' not in response_json['choices'][0]['message'].keys()) and \
                                    'output' not in response_json.keys():
                print('try again !')
                print(response_json)
                print(example['id'])
                if 'error' in response_json.keys():
                    if response_json['error']['message'].find("This model's maximum context length is 4097 tokens") != -1:
                        messages = []
                        mind_maps = []
                        full_prompts = self.prompt_loader.prompt_construction(claim, evidence, mind_maps, self.dataset_name)
                        messages.append({"role": "user", "content": full_prompts})
                    elif response_json['error']['message'].find("The response was filtered due to the prompt triggering Azure OpenAI's content management policy") != -1 and \
                            self.prompt == 'map':
                        mind_maps = evidence
                        full_prompts = self.prompt_loader.prompt_construction(claim, evidence, mind_maps, self.dataset_name)
                        messages = [{"role": "user", "content": full_prompts}]

                # response = self.llm.call_gpt_with_url(messages, max_tokens = self.max_new_tokens, temperature=0.0, top_p=1.0)
                response = self.llm.call_gpt_with_url(messages, max_tokens = self.max_new_tokens, temperature=0.7, top_p=0.95)
                response_json = json.loads(response.text) 

            # #pop the user message
            messages.pop()
            # predict = response_json['choices'][0]['message']['content']
            if self.model_name.startswith('llama'):
                predict = response_json['output']
            else:
                predict = response_json['choices'][0]['message']['content']
            print('**********************')
            print(predict)
                
            predict_example['prediction'] = predict
            self.result_dict.append(predict_example)

        ##save result
        json.dump(self.result_dict, open(os.path.join(self.log_path, 'predict.json'), 'w'), indent=4)
        self._logger.info('prediction.json saved in logger path!')

        ## start evaluate the result
        self._logger.info('start to evaluate the result with macro F1 .')
        if self.args.dataset_name == 'FEVEROUS' or self.args.dataset_name == 'SCIFACT':
            macro_report, confusion_report = evaluate.evaluate_feverous(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)
        elif self.args.dataset_name == 'HOVER':
            macro_reports = evaluate.evaluate_hover(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for key, macro_report  in macro_reports.items():
                self._logger.info(f'********************{key}*******************')
                for row_data in macro_report.split('\n'):
                    self._logger.info(row_data)
        elif self.args.dataset_name == 'HOTPOTQA' or self.args.dataset_name =='2WIKIMULTIHOPQA' or self.args.dataset_name == 'MuSiQue' or self.args.dataset_name == 'QANGAROO':
            metrics = evaluate.evaluate_hotpotqa(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info(metrics)
        elif self.args.dataset_name =='STRATEGYQA':
            macro_report, confusion_report = evaluate.evaluate_strategyqa(os.path.join(self.log_path, 'predict.json'), paradigm = self.args.paradigm)
            self._logger.info('********************macro_report*******************')
            for row_data in macro_report.split('\n'):
                self._logger.info(row_data)

        self._logger.info('********************End !*******************')
    
    

                
def parse_args():
    parser = argparse.ArgumentParser()
    # dataset args
    parser.add_argument('--model_mode', choices = ['train','inference'], type=str, default='train')
    parser.add_argument('--dataset_name',choices = ['FEVEROUS', 'HOVER','SCIFACT', 'HOTPOTQA','2WIKIMULTIHOPQA','MuSiQue','STRATEGYQA', 'QANGAROO'], default = 'FEVEROUS', type=str)
    parser.add_argument('--data_path', default = './MindmapFC/datasets', type=str) #dataset path
    parser.add_argument('--save_path', default = './MindmapFC/result', type=str)

    # fact checker args 
    parser.add_argument("--model_name", choices = ['gpt-4','text-davinci-003','llama-2-70b'], type=str, default='gpt-4')
    parser.add_argument("--executor", choices=['url', 'api', 'open'], type=str, default = 'url')
    parser.add_argument('--model_path', type=str, default = 'local model_path')
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--api_key', default='xxx', type=str)
    parser.add_argument('--stop_words', type=str,  default='# The claim is')
    parser.add_argument('--max_new_tokens', type=int, default=1024)
    parser.add_argument('--mode', choices=['sample', 'full'], type=str, default = 'sample')
    parser.add_argument('--sample_k', type=int, default=1000)
    parser.add_argument('--prompt_mode', choices=['text', 'map', 'mix'], type=str, default = 'text')
    parser.add_argument('--filter',action='store_true', default = False)
    parser.add_argument('--cot',action='store_true', default = False)
    parser.add_argument('--paradigm', choices = ['fc', 'qa'], type=str, default='fc')
    parser.add_argument('--claim_type', choices = ['text', 'map'], type=str, default='text')
    parser.add_argument('--folder', choices = ['gpt4', 'gpt35'], type=str, default='gpt4')

    ## for ppo
    parser.add_argument('--clip_param', type = float, default = 0.2, help="clip_param")
    parser.add_argument('--ppo_epochs', type = int, default = 5, help="ppo_epochs")
    parser.add_argument('--batch_size', type = int, default = 4, help="batch_size")
    parser.add_argument('--num_episodes', type = int, default = 1000, help="num_episodes")
    parser.add_argument('--lr', type =float, default=2e-5, help="The intial learning rate for adam.")
    parser.add_argument('--prune_model_path', default = './infore/model', type=str)
    
    
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    map_execution = MindMap_Execution(args)
    if args.model_mode == 'train':
        map_execution.execute()
    else:
        map_execution.inference()
     
