from typing import List, Type, Optional, Callable, Any
import time
from functools import wraps
import requests as _requests
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import json

def retry_on_specific_exceptions(
    on_exceptions: List[Type[Exception]],
    max_retries: Optional[int] = None,
    backoff_time: float = 3.0,
    backoff_multiplier: float = 1.5,
    on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
    """Retry on an LLM Provider's rate limit error with exponential backoff
    For example, to use for OpenAI, do the following:
    ```
    from openai import RateLimitError

    # Recommend specifying max_retries to avoid infinite loops!
    @retry_on_specific_exceptions([RateLimitError], max_retries=3)
    def completion(...):
        # Wrap OpenAI completion function here
        ...
    ```
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            sleep_time = backoff_time
            attempt = 0
            while max_retries is None or attempt < max_retries:
                try:
                    return func(*args, **kwargs)
                except tuple(on_exceptions) as e:
                    if on_exception_callback is not None:
                        on_exception_callback(e, sleep_time)
                    time.sleep(sleep_time)
                    sleep_time *= backoff_multiplier
                    attempt += 1

        return wrapper

    return decorator

def together_completion(**kwargs):
    """Query TextSynth API for completion.
    Retry with back-off until they respond.
    """

    def _exception_callback(e: Exception, sleep_time: float) -> None:
        import traceback

        traceback.print_exc()

    @retry_on_specific_exceptions(
        on_exceptions=[_requests.exceptions.RequestException],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        return _requests.post(**kwargs)

    return completion()

def fix_acc_norm(model, location, location2, benchmark):
    df = pd.read_json(location, lines=False)
    df = df.drop_duplicates(subset=['doc_id'])
    TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
    for i in tqdm(range(len(df))):
        for j in range(100):
            start_argument = df.iloc[i].arguments[0][0]
            response = together_completion(
                url="https://api.together.xyz/v1/completions/",
                headers={"Authorization": "Bearer " + TOGETHER_API_KEY},
                json={"prompt": start_argument, "model": model, "logprobs": 1, 'echo': True, "max_tokens": 1},
            )
            try:
                resp = response.json()
                logprob = sum(resp['prompt'][0]['logprobs']["token_logprobs"][1:])
                n_responses = len(df.loc[i]['resps'])
                logprobs_responses = []
                for k in range(n_responses):
                    df.loc[i, 'resps'][k][0][0] -= logprob
                    df.loc[i, 'filtered_resps'][k][0] -= logprob
                    logprobs_responses.append(df.loc[i, 'filtered_resps'][k][0])
                logprobs_responses = [logprob / len(argument[1]) for logprob, argument in zip(logprobs_responses, df.loc[i]['arguments'])]
                df.loc[i, 'acc_norm'] = float(np.argmax(logprobs_responses) == df.loc[i, 'target'])
                break
            except Exception as e:
                print(e)

    df.to_json(location, lines=False, indent=4)
    load_second = json.load(open(location2))
    load_second['results'][benchmark]['acc_norm,none'] = np.mean(df['acc_norm'])
    with open(location2, 'w') as f:
        json.dump(load_second, f, indent=4)


base_path = "output"

import argparse
arguments = argparse.ArgumentParser()
arguments.add_argument("--model", type=str)
arguments.add_argument("--benchmark", type=str)

args = arguments.parse_args()

model_not_slash = args.model.replace('/', '__')
location = f"{base_path}/{args.model}/{args.benchmark}/model__{model_not_slash}_{args.benchmark}.jsonl"
if not os.path.exists(location):
    location = f"{base_path}/{args.model}/{args.benchmark}/pretrained__{model_not_slash}_{args.benchmark}.jsonl"

location2 = f"{base_path}/{args.model}/{args.benchmark}/results.json"


fix_file_location = f"{base_path}/{args.model}/{args.benchmark}/is_fixed.txt"

if not os.path.exists(fix_file_location) and os.path.exists(location):
    fix_acc_norm(args.model, location, location2, args.benchmark)
    with open(fix_file_location, 'w') as f:
        f.write("Fixed")
else:
    print("Already fixed or file does not exist")