import asyncio
import json
import logging
import os
import tempfile
from typing import Any

import aiolimiter
import openai
from aiohttp import ClientSession
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio

logging.basicConfig(level=logging.INFO)


async def openai_chat_acompletion(
    messages: list[dict[str, str]],
    tools: list[dict[str, Any]],
    model: str,
    max_tokens: int,
    temperature: float,
    limiter: aiolimiter.AsyncLimiter,
    top_p: float = 1.0,
) -> openai.types.chat.chat_completion_message.ChatCompletionMessage:
    client = AsyncOpenAI()
    async with limiter:
        for _ in range(3):
            try:
                if not tools:
                    response = await client.chat.completions.create(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                    )
                else:
                    response = await client.chat.completions.create(
                        model=model,
                        messages=messages,
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                        tools=tools,
                        tool_choice="auto",
                    )

                response = response.choices[0].message
                return response

            except Exception as e:
                import traceback

                print(traceback.format_exc())
                await asyncio.sleep(10)

        return {}


async def openai_chat_acompletion_batch(
    batch: list[list[dict[str, str]]],
    tools: list[dict[str, Any]],
    model: str,
    rate_limit: int,
    max_tokens: int,
    temperature: float,
    top_p: float = 1.0,
    tqdm_disable: bool = True,
) -> list[str]:
    if "OPENAI_API_KEY" not in os.environ:
        raise ValueError(
            "OPENAI_API_KEY environment variable must be set when using OpenAI API."
        )

    openai.api_key = os.environ["OPENAI_API_KEY"]
    openai.organization = os.environ.get("OPENAI_ORGANIZATION", "")

    tasks = []
    limiter = aiolimiter.AsyncLimiter(rate_limit)
    # openai.aiosession.set(ClientSession())
    for messages in batch:
        tasks.append(
            openai_chat_acompletion(
                model=model,
                tools=tools,
                messages=messages,
                limiter=limiter,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        )
    try:
        generations = await tqdm_asyncio.gather(*tasks, disable=tqdm_disable)
    except Exception as e:
        print(f"An error occurred: {e}")
        generations = [None] * len(tasks)
    # await openai.aiosession.get().close()
    return generations


def safe_batch_chat_completion(
    batch: list[list[dict[str, str]]],
    model: str,
    rate_limit: int,
    max_tokens: int,
    temperature: float,
    top_p: float = 1.0,
    tqdm_disable: bool = True,
    minimal_percent: float = 1.0,
    tools: list[dict[str, Any]] = [],
    max_retries=5,
) -> list[str]:
    """Some completion may fail due to various reason, rerun until every example has the response"""
    idx_batch = {idx: element for idx, element in enumerate(batch)}
    responses = {idx: None for idx in range(len(batch))}
    tries = 0
    save_per_n_tries = 2

    while tries < max_retries:
        _idx_batch = {
            idx: element
            for idx, element in idx_batch.items()
            if not responses[idx]
        }
        _idxes = list(_idx_batch.keys())
        _batch = list(_idx_batch.values())

        if 1 - (len(_batch) / len(batch)) >= minimal_percent:
            break

        generations = asyncio.run(
            openai_chat_acompletion_batch(
                batch=_batch,
                model=model,
                rate_limit=rate_limit,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                tqdm_disable=tqdm_disable,
                tools=tools,
            )
        )
        for idx, generation in zip(_idxes, generations):
            if generation:
                responses[idx] = generation

        tries += 1

        if tries % save_per_n_tries == 0:
            with tempfile.NamedTemporaryFile(
                mode="w", delete=False
            ) as temp_file:
                json.dump(responses, temp_file)
                print(f"Save intermediate results to {temp_file.name}")

    return [responses[idx] for idx in range(len(batch))]
