import json
import time
import os
import random
from typing import List, Union, Dict, Tuple
from dataclasses import dataclass
from dataclasses_json import dataclass_json
import numpy as np
from tqdm import trange, tqdm
from scipy import stats
from propose_d5_descriptions import (
    propose_descriptions,
    D5ProposerResponse as ProposerResponse,
)
from transformers import GPT2Tokenizer
from utils import get_context_length, get_avg_length
from validate_descriptions import (
    validate_descriptions,
    Validator,
    GPTValidator,
    D5Validator,
    get_validator_by_name,
)
from d5_problem import D5Problem as Problem
import numpy as np
from sklearn.linear_model import LogisticRegression
from copy import deepcopy


def train_logistic_regression_with_l1_and_non_negative_coef(X, Y, K):
    """
    Trains a logistic regression model with L1 penalty and non-negative coefficients.

    Parameters:
    X (numpy array): The feature matrix.
    Y (numpy array): The target vector.
    K (int): The desired number of non-negative coefficients.

    Returns:
    Tuple[LogisticRegression, non-negative_indices]: The trained logistic regression model.
    """

    C_start = 1e-2
    C_end = 100
    tol = 1e-4
    max_iter = 10

    for iteration in range(max_iter):
        C = (C_start * C_end) ** 0.5
        model = LogisticRegression(penalty="l1", C=C, solver="saga", max_iter=5000)
        model.fit(X, Y)

        coef = model.coef_
        coef_positive = coef * (coef > tol)
        coef_count = np.count_nonzero(coef_positive)
        if coef_count == K:
            break
        elif coef_count < K:
            C_start = C
        else:
            C_end = C

        if abs(C_end - C_start) < tol:
            break
    selected_idxes = [i for i, c in enumerate(coef[0]) if c > 0]
    sorted_idxes = sorted(selected_idxes, key=lambda i: coef[0][i], reverse=True)
    return model, sorted_idxes[:K]


def add_noise_to_scores(scores: List[float], noise: float = 1e-4) -> List[float]:
    """
    Add noise to a list of scores.

    Parameters
    ----------
    scores : List[float]
        A list of scores.
    noise : float
        The amount of noise to add.

    Returns
    -------
    List[float]
        The list of scores with noise.
    """
    return [score + random.uniform(-noise, noise) for score in scores]


CORPUS_PAIR_OVERHEAD = 1024
CORPUS_BUFFER_FRACTION = 0.25


@dataclass_json
@dataclass
class HypothesisInfo:
    hypothesis: str
    V_prime: float
    text2score: Dict[str, float]
    scores_a: List[float]
    scores_b: List[float]
    statistical_significance: float


@dataclass_json
@dataclass
class D5Result:
    proposer_results: List[ProposerResponse]
    hypotheses_info: List[HypothesisInfo]


@dataclass_json
@dataclass
class IterativeD5Result:
    iteration: int
    d5_results: List[D5Result]
    full_hypotheses_info: List[HypothesisInfo]
    texts_a_residuals: List[float]
    texts_b_residuals: List[float]


def get_max_num_samples_in_proposer(problem: Problem, proposer_model: str) -> int:
    """
    Get the maximal number of in-context samples based on the context length.Leave a buffer of 25% of the relative context length and 1024 tokens for the absolute context length

    Parameters
    ----------
    problem : Problem
        The D5 problem to solve.
    proposer_model : str
        The model used to propose descriptions.

    Returns
    -------
    int
        The maximal number of in-context samples.
    """
    max_corpus_pair_length = (
        get_context_length(proposer_model) - CORPUS_PAIR_OVERHEAD
    ) * (1 - CORPUS_BUFFER_FRACTION)
    texts_a, texts_b = problem.texts_a, problem.texts_b
    avg_a_length, avg_b_length = get_avg_length(texts_a), get_avg_length(texts_b)
    max_num_samples = int(max_corpus_pair_length / (avg_a_length + avg_b_length))
    return max_num_samples


def calculate_diff_w_significance(
    scores_a: List[float], scores_b: List[float]
) -> Tuple[float, float]:
    """
    Calculate the difference in scores between two lists of scores and the statistical significance of the difference.

    Parameters
    ----------
    scores_a : List[float]
        A list of scores from the validator
    scores_b : List[float]
        A list of scores from the validator

    Returns
    -------
    Tuple[float, float]
        The difference in scores and the statistical significance of the difference.
    """
    mu = np.mean(scores_a) - np.mean(scores_b)
    p_value = stats.ttest_ind(scores_a, scores_b)[1]
    return mu, p_value


def d5(
    problem: Problem,
    num_descriptions_per_prompt: int,
    validator: Validator,
    d5_problem_name: Union[str, None] = None,
    proposer_num_rounds: int = 3,
    proposer_model: str = "gpt-4",
    early_stopping_significance_threshold: float = 5e-4,
    minimal_samples_to_validate: int = 128,
    max_samples_to_validate: int = 1024,
    template_name: str = "orig",
) -> D5Result:
    """
    The D5 algorithm.

    Parameters
    ----------
    problem : Problem
        The D5 problem to solve.
    num_descriptions_per_prompt : int
        The number of descriptions to propose per prompt.
    validator : Validator
        The validator to use.
    d5_problem_name : Union[str, None], optional
        The name of the D5 problem, by default None. Used for saving the results.
    proposer_num_rounds : int, optional
        The number of rounds to run the proposer, by default 3
    proposer_model : str, optional
        The model used to propose descriptions, by default "gpt-4"
    early_stopping_significance_threshold : float, optional
        The significance threshold for early stopping, by default 5e-4. I.e. if the p-value of the t-test is below this threshold, the algorithm will accept the hypothesis.
    minimal_samples_to_validate : int, optional
        The minimal number of samples to validate, by default 128. The algorithm will validate at least this number of samples.
    max_samples_to_validate : int, optional
        The maximal number of samples to validate, by default 1024. The algorithm will validate at most this number of samples. This ensures that the validation process does not take too long.
    template_name : str, optional
        The name of the template used to generate the prompts, by default "orig" but can also be "detailed" to propose more detailed descriptions.

    Returns
    -------
    D5Result
        The results of the D5 algorithm. Contains the proposer results and the hypotheses info.
    """

    # Propose descriptions
    max_num_samples_in_proposer = get_max_num_samples_in_proposer(
        problem, proposer_model
    )

    d5_ouput_dir = f"experiments/d5-{d5_problem_name}-{validator.model_name}-prop_n_samples={max_num_samples_in_proposer}-prop_model={proposer_model}-n_desc_per_prompt={num_descriptions_per_prompt}-n_rounds={proposer_num_rounds}-template_name={template_name}-time={int(time.time())}"

    if not os.path.exists(d5_ouput_dir):
        os.makedirs(d5_ouput_dir)

    saved_argument_dict = {
        "problem": problem.to_dict(),
        "num_descriptions_per_prompt": num_descriptions_per_prompt,
        "validator_name": validator.model_name,
        "proposer_num_rounds": proposer_num_rounds,
        "proposer_model": proposer_model,
        "early_stopping_significance_threshold": early_stopping_significance_threshold,
        "minimal_samples_to_validate": minimal_samples_to_validate,
        "max_samples_to_validate": max_samples_to_validate,
        "template_name": template_name,
    }

    with open(f"{d5_ouput_dir}/arguments.json", "w") as f:
        f.write(json.dumps(saved_argument_dict))

    all_proposer_results = []
    for _ in trange(proposer_num_rounds, desc="Proposer Rounds"):
        proposer_result = propose_descriptions(
            problem=problem,
            num_descriptions_per_prompt=num_descriptions_per_prompt,
            model=proposer_model,
            num_samples=max_num_samples_in_proposer,
            template_name=template_name,
            example_descriptions=problem.example_descriptions,
        )
        all_proposer_results.append(proposer_result)

    # save the descriptions
    with open(f"{d5_ouput_dir}/proposer_results.json", "w") as f:
        f.write(
            json.dumps(
                [proposer_result.to_dict() for proposer_result in all_proposer_results]
            )
        )

    all_hypotheses_info = [
        HypothesisInfo(
            hypothesis=d,
            V_prime=0.0,
            text2score={},
            scores_a=[],
            scores_b=[],
            statistical_significance=1.0,
        )
        for proposer_result in all_proposer_results
        for d in proposer_result.descriptions
    ]

    print([x.hypothesis for x in all_hypotheses_info])

    text_random_order = list(range(max(len(problem.texts_a), len(problem.texts_b))))
    random.shuffle(text_random_order)
    # validate descriptions by validating blocks of 32 text samples from each corpus
    # and then calculating the difference in scores between the two corpora
    # and the statistical significance of the difference

    result_save_path = f"{d5_ouput_dir}/results.json"
    for hypothesis_info in tqdm(all_hypotheses_info, desc="Validating Descriptions"):
        hypothesis = hypothesis_info.hypothesis
        for j in range(0, max_samples_to_validate, 32):
            text_index_to_validate = text_random_order[j : j + 32]

            # early stopping if the statistical significance is below the threshold
            if (
                hypothesis_info.statistical_significance
                < early_stopping_significance_threshold
                and j >= minimal_samples_to_validate
            ):
                print("Early stopping for hypothesis", hypothesis)
                break

            # validate descriptions for corpus a and b
            for corpus_name in ["a", "b"]:
                texts = getattr(problem, f"texts_{corpus_name}")
                text_batch = [texts[i] for i in text_index_to_validate]

                if j <= len(texts):
                    scores = [
                        float(x[0])
                        for x in validate_descriptions(
                            descriptions=[hypothesis],
                            texts=text_batch,
                            validator=validator,
                        )
                    ]
                    for k, score in enumerate(scores):
                        text = text_batch[k]
                        hypothesis_info.text2score[text] = score

                    if corpus_name == "a":
                        hypothesis_info.scores_a.extend(scores)
                    else:
                        hypothesis_info.scores_b.extend(scores)

            # calculate the difference in scores between the two corpora
            (
                hypothesis_info.V_prime,
                hypothesis_info.statistical_significance,
            ) = calculate_diff_w_significance(
                hypothesis_info.scores_a, hypothesis_info.scores_b
            )

            d5_result = D5Result(
                proposer_results=all_proposer_results,
                hypotheses_info=all_hypotheses_info,
            )
            with open(result_save_path, "w") as f:
                f.write(d5_result.to_json())

    return d5_result


def fold_design(X):
    half_length = len(X) // 2
    diff_between_folds = X[:half_length] - X[half_length:]
    return np.concatenate([diff_between_folds, -diff_between_folds], axis=0)


def iterative_d5(
    problem: Problem,
    num_descriptions_per_prompt: int,
    validator: Validator,
    sub_problem_size: int,
    d5_problem_name: str,
    num_rounds: int,
    max_samples_to_validate: int = 4096,
    template_name: str = "orig",
    proposer_model: str = "gpt-4",
    paired: bool = False,
) -> D5Result:
    problem = deepcopy(problem)

    # sample a subset of the texts to validate
    problem.texts_a = random.sample(
        problem.texts_a, min(max_samples_to_validate, len(problem.texts_a))
    )
    problem.texts_b = random.sample(
        problem.texts_b, min(max_samples_to_validate, len(problem.texts_b))
    )

    texts_a = problem.texts_a
    texts_b = problem.texts_b

    if paired:
        assert len(texts_a) == len(texts_b)

    # add noise to the scores to avoid ties
    texts_a_residuals = np.array(add_noise_to_scores([0.5 for _ in texts_a]))
    texts_b_residuals = np.array(add_noise_to_scores([-0.5 for _ in texts_b]))

    full_computed_hypotheses_info = {}
    all_d5_results = []
    args = {
        "problem": problem.to_dict(),
        "num_descriptions_per_prompt": num_descriptions_per_prompt,
        "validator_name": validator.model_name,
        "sub_problem_size": sub_problem_size,
        "d5_problem_name": d5_problem_name,
        "num_rounds": num_rounds,
        "max_samples_to_validate": max_samples_to_validate,
        "template_name": template_name,
        "proposer_model": proposer_model,
    }

    experiment_dir = f"experiments/iterative_d5-{d5_problem_name}-{validator.model_name}-prop_n_samples={max_samples_to_validate}-prop_model={proposer_model}-n_desc_per_prompt={num_descriptions_per_prompt}-n_rounds={num_rounds}-template_name={template_name}-time={int(time.time())}"
    os.makedirs(experiment_dir, exist_ok=True)
    with open(f"{experiment_dir}/args.json", "w") as f:
        f.write(json.dumps(args, indent=4))

    result_save_path = f"{experiment_dir}/results.json"

    for round_id in range(num_rounds):
        loss = (
            np.mean(np.abs(texts_a_residuals)) + np.mean(np.abs(texts_b_residuals))
        ) / 2
        print(f"Round {round_id} loss: {loss}")

        # create a sub-problem
        # select the highest scoring text_a and lowest scoring text_b
        # and create a sub-problem with them
        if not paired:
            texts_a_slected_indices = np.argsort(texts_a_residuals)[::-1][
                :sub_problem_size
            ]
            texts_b_slected_indices = np.argsort(texts_b_residuals)[:sub_problem_size]
        else:
            texts_a_slected_indices = np.argsort(texts_a_residuals)[::-1][
                :sub_problem_size
            ]
            texts_b_slected_indices = texts_a_slected_indices

        sub_problem_texts_a = [texts_a[i] for i in texts_a_slected_indices]
        sub_problem_texts_b = [texts_b[i] for i in texts_b_slected_indices]

        sub_problem = Problem(
            texts_a=sub_problem_texts_a,
            texts_b=sub_problem_texts_b,
            goal=problem.goal,
            example_descriptions=problem.example_descriptions,
        )

        # run D5 on the sub-problem
        d5_result = d5(
            problem=sub_problem,
            num_descriptions_per_prompt=num_descriptions_per_prompt,
            validator=validator,
            template_name=template_name,
            d5_problem_name=f"{d5_problem_name}_round_{round_id}",
            proposer_num_rounds=1,
            early_stopping_significance_threshold=0.0,
            proposer_model=proposer_model,
        )
        all_d5_results.append(d5_result)

        new_hypotheses_info = d5_result.hypotheses_info
        new_hypotheses_info_dict = {
            hypothesis_info.hypothesis: hypothesis_info
            for hypothesis_info in new_hypotheses_info
        }

        old_and_new_descriptions = list(full_computed_hypotheses_info.keys()) + list(
            new_hypotheses_info_dict.keys()
        )
        new_and_old_descriptions_design_matrix = []
        sub_problem_texts = sub_problem_texts_a + sub_problem_texts_b
        for description in old_and_new_descriptions:
            if description in full_computed_hypotheses_info:
                hypothesis_info = full_computed_hypotheses_info[description]
            else:
                hypothesis_info = new_hypotheses_info_dict[description]
            scores = np.array(
                [hypothesis_info.text2score[text] for text in sub_problem_texts]
            )
            sign = (
                1
                if np.mean(hypothesis_info.scores_a) > np.mean(hypothesis_info.scores_b)
                else -1
            )
            new_and_old_descriptions_design_matrix.append(scores * sign)

        new_and_old_descriptions_design_matrix = np.array(
            new_and_old_descriptions_design_matrix
        ).T

        gold_label = [1.0] * len(sub_problem_texts_a) + [0.0] * len(sub_problem_texts_b)
        gold_label = np.array(gold_label)
        if paired:
            new_and_old_descriptions_design_matrix = fold_design(
                new_and_old_descriptions_design_matrix
            )

        (
            model,
            selected_indices,
        ) = train_logistic_regression_with_l1_and_non_negative_coef(
            X=new_and_old_descriptions_design_matrix,
            Y=gold_label,
            K=len(full_computed_hypotheses_info) + 4,
        )

        new_descriptions = [
            old_and_new_descriptions[i]
            for i in selected_indices
            if i >= len(full_computed_hypotheses_info)
        ]
        if len(new_descriptions) == 0:
            continue
        print(f"New descriptions: {new_descriptions}")
        all_texts = problem.texts_a + problem.texts_b
        validator_scores = validate_descriptions(
            validator=validator,
            descriptions=new_descriptions,
            texts=all_texts,
            progress_bar=True,
        )

        for i, description in enumerate(new_descriptions):
            validator_score = validator_scores[:, i]
            text2score = {
                all_texts[j]: float(validator_score[j]) for j in range(len(all_texts))
            }
            scores_a = [text2score[text] for text in problem.texts_a]
            scores_b = [text2score[text] for text in problem.texts_b]
            V_prime, p_value = calculate_diff_w_significance(
                scores_a=scores_a, scores_b=scores_b
            )
            hypothesis_info = HypothesisInfo(
                hypothesis=description,
                scores_a=scores_a,
                scores_b=scores_b,
                V_prime=V_prime,
                text2score=text2score,
                statistical_significance=p_value,
            )
            full_computed_hypotheses_info[description] = hypothesis_info

        all_existing_descriptions = list(full_computed_hypotheses_info.keys())
        all_existing_descriptions_design_matrix = []
        for description in all_existing_descriptions:
            hypothesis_info = full_computed_hypotheses_info[description]
            scores = np.array([hypothesis_info.text2score[text] for text in all_texts])
            all_existing_descriptions_design_matrix.append(scores)

        target = np.array([1.0] * len(problem.texts_a) + [0.0] * len(problem.texts_b))
        all_existing_descriptions_design_matrix = np.array(
            all_existing_descriptions_design_matrix
        ).T
        if paired:
            all_existing_descriptions_design_matrix = fold_design(
                all_existing_descriptions_design_matrix
            )

        model = LogisticRegression(fit_intercept=False)
        model.fit(all_existing_descriptions_design_matrix, target)
        y_prob = model.predict_proba(all_existing_descriptions_design_matrix)[:, 1]
        residual = target - y_prob

        texts_a_residuals = residual[: len(problem.texts_a)]
        if not paired:
            texts_b_residuals = residual[len(problem.texts_a) :]
        else:
            texts_b_residuals = -residual[len(problem.texts_a) :]

        result = IterativeD5Result(
            iteration=round_id,
            d5_results=all_d5_results,
            full_hypotheses_info=list(full_computed_hypotheses_info.values()),
            texts_a_residuals=texts_a_residuals.tolist(),
            texts_b_residuals=texts_b_residuals.tolist(),
        )

        with open(result_save_path, "w") as f:
            json.dump(result.to_dict(), f, indent=2)

    return result

