import json
import os
import re
import argparse
import random
import time
import inspect

from utils.config import Config
from utils.logger import create_logger, display_exp_setting
from utils.data_loader import load_data
from utils.algorithm_loader import load_algorithm
from utils.evaluation import nlp_evaluation

from symbolic_compiler.compiler import SymbolicCompiler

parser = argparse.ArgumentParser()
parser.add_argument('--algorithm', default='Pure-LLM') # Ours, RAG-LLM, Pure-LLM
parser.add_argument("--seed", type=int, default=0)

parser.add_argument("--dataset", type=str, default='BioEng')
parser.add_argument("--query_scope", type=int, default=2)
parser.add_argument("--context_len", type=int, default=5) # context_len = query_scope * 2 + 1
parser.add_argument("--context_num", type=int, default=4)

parser.add_argument("--embedding_model", type=str, default="text-embedding-ada-002")

# compiler
parser.add_argument("--dsl", type=str, default="autodsl")

parser.add_argument("--alpha", type=float, default=0.05)
parser.add_argument("--beta", type=float, default=0.45)
parser.add_argument("--gamma", type=float, default=0.50)

parser.add_argument("--engine", type=str, default="openai/gpt-3.5-turbo")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--freq_penalty", type=float, default=0.0)
parser.add_argument("--max_tokens", type=int, default=2048)
parser.add_argument("--llm_cache_dir", type=str, default="llm_cache")
args = parser.parse_args()

if __name__ == "__main__":
    random.seed(args.seed)
    cfg = Config(args)
    metrics = nlp_evaluation()
    logger = create_logger(os.path.join(cfg.log_dir, 'log.txt'))
    display_exp_setting(logger, cfg)

    train_dataset, test_examples, avg_trainexample_len = load_data(cfg.algorithm, cfg.dataset, cfg.context_len, cfg.query_scope) 
    if cfg.dataset != "Synthesis":
        compiler = SymbolicCompiler(cfg.dataset, cfg.dsl, cfg.engine, cfg.temperature, cfg.freq_penalty, cfg.max_tokens, cfg.llm_cache_dir, cfg.alpha, cfg.beta, cfg.gamma)
    else:
        compiler = None
    completion_algorithm = load_algorithm(cfg.algorithm, cfg.dataset, train_dataset, cfg.engine, cfg.embedding_model, cfg.context_num, compiler, chunk_size=avg_trainexample_len*cfg.context_len)

    query_list, answer_list, result_list, evaluation_list, context_list, time_list = [], [], [], [], [], []

    for query, answer in list(zip(test_examples["query"], test_examples["answer"])):
        start_time = time.time()

        logger.info("query: " + query)
        result, context = completion_algorithm.invoke(query, answer)
        logger.info("result: " + result)
        logger.info("answer: " + answer)

        end_time = time.time()
        query_list.append(query)
        answer_list.append(answer)
        result_list.append(result)
        context_list.append(context)
        time_list.append(end_time - start_time)
        evaluation_list.append(json.dumps(metrics.evaluation(result, answer)))

        json_results = {
            "query_list": query_list, 
            "answer_list": answer_list, 
            "result_list": result_list, 
            "evaluation_list": evaluation_list, 
            "context_list": context_list, 
            "time_list": time_list,
            "result": metrics.mean([json.loads(a) for a in evaluation_list])
        }

        with open(f"{cfg.result_dir}/results.json", "w") as f:
            logger.info(f"dumping results to {cfg.result_dir}/results.json")
            json.dump(json_results, f, indent=2)

    logger.info(metrics.mean([json.loads(a) for a in evaluation_list]))