import argparse
import json
import sys

import openai
from transformers import LlamaTokenizer
from vllm import LLM, SamplingParams

tokenizer = LlamaTokenizer.from_pretrained(
    
)


def config_llm(args):
    llm = None
    if args.inference_type == "direct":
        llm = LLM(
            model="/home/bo/agent/fine_tune/output_5k/",
            tensor_parallel_size=1,
            max_num_batched_tokens=4096,
        )
        # llm = LLM(model="/home/bo/agent/fine_tune/output_5k_qlora_llama/checkpoint-1725",tensor_parallel_size=4, max_num_batched_tokens = 4096)
    elif args.inference_type == "openai_api":
        openai.api_key = "EMPTY"
        openai.api_base = f"http://{args.api_base}/v1"
        llm = openai.Model.list()["data"][0]["id"]
    return llm


def generate_from_huggingface_completion(
    args,
    llm,
    prompts: str,
    temperature: float,
    top_p: float,
    max_new_tokens: int,
) -> str:
    generated_texts = []
    if args.inference_type == "direct":
        outputs = llm.generate(prompts)
        for output in outputs:
            prompt = output.prompt
            generated_text = output.outputs[0].text
            generated_texts.append(generated_text)
            print(output)
            assert False
    elif args.inference_type == "openai_api":
        for prompt in prompts:
            output = openai.ChatCompletion.create(
                model=llm,
                messages=[
                    # {'role': 'system', 'content': 'You are a helpful assistant.'},
                    {"role": "user", "content": prompt}
                ],
            )
            output = output.choices[0]["message"]["content"]
            print(output)
            generated_texts.append(output)
    return generated_texts


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--eval_set",
        "-i",
        type=str,
        required=True,
        help="Path to the evaluation set",
    )
    parser.add_argument(
        "--inference_type",
        type=str,
        choices=["openai_api", "direct"],
    )
    parser.add_argument(
        "--api_base",
        type=str,
        required=False,
        help="in the format like <host>:<port>",
    )

    args = parser.parse_args()
    llm = config_llm(args)
    filename = args.eval_set
    with open(filename, "r") as file:
        data = json.load(file)
    tot = mat = 0
    res = []
    prompts = []
    ground_truths = []

    skip_cnt = 0
    click_tot = click_cnt = click_full = 0
    type_tot = type_cnt = type_full = 0

    for i, entry in enumerate(data):
        prompt = entry["conversations"][0]["value"]
        prompt_len = len(tokenizer.encode(prompt))
        if prompt_len > 3500:
            print(prompt_len)
            skip_cnt += 1
            continue
        prompts.append(prompt)
        ground_truth = entry["conversations"][1]["value"]
        ground_truths.append(ground_truth)

    preds = generate_from_huggingface_completion(
        args,
        llm,
        prompts=prompts,
        temperature=-1,
        top_p=-1,
        max_new_tokens=-1,
    )

    for i in range(len(preds)):
        if ground_truths[i].replace(" ", "") == preds[i].replace(" ", ""):
            mat += 1
        if "CLICK" in ground_truths[i]:
            click_tot += 1
            if "CLICK" in preds[i]:
                click_cnt += 1
            if ground_truths[i].replace(" ", "") == preds[i].replace(" ", ""):
                click_full += 1
        if "TYPE" in ground_truths[i]:
            type_tot += 1
            if "TYPE" in preds[i]:
                type_cnt += 1
            if ground_truths[i].replace(" ", "") == preds[i].replace(" ", ""):
                type_full += 1
        tot += 1
        print("skip_cnt", skip_cnt)
        print("overall", mat, tot, mat / (tot + 0.0001))
        print("click", click_cnt, click_tot, click_cnt / (click_tot + 0.0001))
        print("type", type_cnt, type_tot, type_cnt / (type_tot + 0.0001))
        print(
            "click_full",
            click_full,
            click_tot,
            click_full / (click_tot + 0.0001),
        )
        print(
            "type_full", type_full, type_tot, type_full / (type_tot + 0.0001)
        )

        res.append(
            {
                "prompt": prompts[i],
                "pred": preds[i],
                "ground_truth": ground_truths[i],
            }
        )
    with open(filename + "_pred", "w") as file:
        json.dump(res, file)
    print(mat, tot, mat / tot)


if __name__ == "__main__":
    main()
