import time
import os
from typing import Any, Dict, List

import openai
import logging

from llm_interface.large_language_model import LargeLanguageModel
from prompt_compiler.data_structs.llm_response import LLMResponse

logger = logging.getLogger("global_logger")

class ChatGPT(LargeLanguageModel):
    def __init__(self, model_name: str) -> None:
        self._model_name = model_name
        self.OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
        openai.api_key = self.OPENAI_API_KEY

    def get_id(self) -> str:
        return f"chatgpt_{self._model_name}"

    def _sample_completions(
            self,
            prompt: str,
            temperature: float,
            stop_token: str,
            max_tokens: int,
            freq_penalty: float,
            num_completions: int = 1,
            sleep_time: int = 10) -> List[LLMResponse]:
        """
        Note that sys and user prompt are assumed to be separated by a newline.
        """

        chunks = prompt.split("\n")
        sys_prompt = chunks[0]
        user_prompt = "\n".join(chunks[1:])

        response = None
        for _ in range(6):
            try:
                response = openai.ChatCompletion.create(
                    model=self._model_name,
                    messages=[
                        {"role": "system", "content": sys_prompt},
                        {"role": "user", "content": user_prompt},
                    ],
                    temperature=temperature,
                    stop=stop_token,
                    max_tokens=max_tokens,
                    frequency_penalty=freq_penalty,
                    n=num_completions)
                # time.sleep(sleep_time)
                # Successfully queried, so break.
                break
            except (openai.error.RateLimitError,
                    openai.error.APIConnectionError, openai.error.APIError):
                # Wait for 60 seconds if this limit is reached. Hopefully rare.
                time.sleep(1)

        if response is None:
            raise RuntimeError("Failed to query OpenAI API.")

        assert len(response["choices"]) == num_completions
        return [
            self._raw_to_llm_response(r, prompt, temperature, stop_token, num_completions)
            for r in response["choices"]
        ]

    @staticmethod
    def _raw_to_llm_response(raw_response: Dict[str, Any],
                             prompt: str,
                             temperature: float,
                             stop_token: str,
                             num_completions: int) -> LLMResponse:
        text = raw_response["message"]["content"]
        prompt_info = {
            "temperature": temperature,
            "num_completions": num_completions,
            "stop_token": stop_token,
        }
        return LLMResponse(prompt,
                           text,
                           prompt_info=prompt_info,
                           other_info=raw_response.copy())