import glob
import json
import os
import random
import re
import subprocess
import sys

import numpy as np
import tiktoken
import yaml
from dotenv import load_dotenv
from openai.types.chat.chat_completion_message import (
    ChatCompletionMessage,
)

from data_generation.utils import (
    calc_openai_cost,
    check_overwrite,
    get_step_save_file,
    print_generation,
    reorder_response_file,
)
from data_generation.wikihow.generate_data import parse_step1_result

load_dotenv()
random.seed(42)
np.random.seed(42)


def main(
    data_file: str,
    model: str,
    rate_limit: int,
    token_limit: int,
    max_tokens: int,
    prompt_file: str,
    prompt_version: str,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> None:
    # parse the results from step 1
    step1_result_file = get_step_save_file(data_file, 1, 1, "response")
    api_calls, task_summaries = parse_step1_result(step1_result_file)

    with open(prompt_file, "r") as f:
        prompt = yaml.safe_load(f)[prompt_version]
        user_prompt = prompt["user_message"]

    request_file = get_step_save_file(data_file, 2, 1, "request")
    save_file = get_step_save_file(data_file, 2, 1, "response_er")
    index_file = get_step_save_file(data_file, 2, 1, "index")
    indexes = []

    # construct the request file
    tot = 0
    with open(request_file, "w+") as f:
        for e_idx, (cur_api_call, cur_task) in enumerate(
            zip(api_calls, task_summaries)
        ):
            cur_messages = [
                {"role": "system", "content": prompt["system"]},
                {
                    "role": "user",
                    "content": user_prompt.replace(
                        "__task__", cur_task
                    ).replace("__trajectory__", cur_api_call),
                },
            ]
            # construct the json for the request body
            cur_body = {
                "model": model,
                "messages": cur_messages,
                "temperature": temperature,
                "max_tokens": max_tokens,
                "top_p": top_p,
            }
            f.write(json.dumps(cur_body) + "\n")
            tot += 1
            indexes.append(e_idx)

    # save the indexes to a file
    with open(index_file, "w+") as f:
        f.write(json.dumps(indexes))

    print(f"Number of examples: {tot}")

    check_overwrite(save_file)

    process = subprocess.Popen(
        [
            "python",
            "llms/providers/openai_request_parallel.py",
            "--request_url",
            f"https://{os.environ['VIJAY_RESOURCE_NAME']}.openai.azure.com/openai/deployments/{os.environ['VIJAY_RESOURCE_NAME']}/chat/completions?api-version={os.environ['VIJAY_VERSION']}"
            if model.startswith("vijay")
            else "https://api.openai.com/v1/chat/completions",
            "--api_key",
            os.environ["VIJAY_API_KEY"]
            if model.startswith("vijay")
            else os.environ["OPENAI_API_KEY"],
            "--requests_filepath",
            request_file,
            "--save_filepath",
            save_file,
            "--max_requests_per_minute",
            str(rate_limit),
            "--max_tokens_per_minute",
            str(token_limit),
        ]
    )
    process.wait()
    reorder_response_file(save_file, request_file)

    os.remove(request_file)


if __name__ == "__main__":
    costs = []
    arg = sys.argv[1] if len(sys.argv) > 1 else "0"
    for file_idx in ["18"]:
        data_file = f"data/wikihow/wikihow_digital_sep/wh_{file_idx}.jsonl"
        if arg == "1":
            print_generation(
                get_step_save_file(data_file, 2, 1, "response_er")
            )
        main(
            data_file,
            model="gpt-4-turbo-2024-04-09",
            # model="vijay-gpt-4",
            rate_limit=1500,
            token_limit=200_000,
            temperature=1.0,
            max_tokens=4096,
            prompt_file="data_generation/error_recovery/prompts/prompt.yaml",
            prompt_version="v2",
        )
        costs.append(
            calc_openai_cost(
                get_step_save_file(data_file, 2, 1, "response_er")
            )
        )
