import json
from dataclasses import dataclass

from tqdm import tqdm

from adapter import GPT3Adapter, HuggingfaceAdapter, DelphiAdapter
from data import CausalDataset, MoralDataset, Example, AbstractDataset, JsonSerializable
from evaluator import AccuracyEvaluatorWithAmbiguity, CorrelationEvaluator, RMSEEvaluator, AuROCEvaluator
from prompt import CausalJudgmentPrompt, MoralJudgmentPrompt, JudgmentPrompt


@dataclass
class ExperimentResult(JsonSerializable):
    acc: float
    conf_interval: tuple[float, float]
    r: float
    p: float
    rmse: float
    auroc: float

def run_template_for_huggingface(cd: AbstractDataset, adapter: HuggingfaceAdapter,
                                 jp: JudgmentPrompt, method: str='yesno', batch_size: int=4):
    batched_instances = []
    all_choice_scores, all_label_indices = [], []
    ex: Example
    for ex in tqdm(cd):
        instance = jp.apply(ex)
        batched_instances.append(instance)

        all_label_indices.append(ex.answer_dist)

        if len(batched_instances) == batch_size:
            choice_scores = adapter.adapt(batched_instances, method=method)
            all_choice_scores.extend(choice_scores)
            batched_instances = []

    if len(batched_instances) > 0:
        choice_scores = adapter.adapt(batched_instances, method=method)
        all_choice_scores.extend(choice_scores)

    return all_choice_scores, all_label_indices

def exp1_causal_huggingface(model_name: str="bert-base-uncased", batch_size: int=4):
    if model_name in ['roberta-large', 'albert-xxlarge-v2',
                      'gpt2-xl', "EleutherAI/gpt-neo-1.3B"]:
        adapter = HuggingfaceAdapter(model_name=model_name, device='cpu')
    else:
        adapter = HuggingfaceAdapter(model_name=model_name, device='cpu')

    cd = CausalDataset()

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_huggingface(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt.jinja"),
                                                                method='yesno', batch_size=batch_size)
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_huggingface(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt_2.jinja"),
                                                                method='multiple_choice', batch_size=batch_size)
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    # Causal 3-class Accuracy: 0.3403 (0.2629, 0.4177) (for multiple-choice)
    print()
    print("Model: ", model_name)
    print(f"Causal 3-class Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Causal Correlation: {r:.4f} (p={p:.4f})")
    print(f"Causal RMSE: {rmse:.4f}")
    print(f"Causal AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def exp1_moral_huggingface(model_name: str="bert-base-uncased", batch_size: int=4):
    adapter = HuggingfaceAdapter(model_name=model_name)
    cd = MoralDataset()

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_huggingface(cd, adapter, MoralJudgmentPrompt("./prompts/exp1_moral_prompt.jinja"),
                                                                method='yesno', batch_size=batch_size)
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_huggingface(cd, adapter, MoralJudgmentPrompt("./prompts/exp1_moral_prompt_2.jinja"),
                                                                method='multiple_choice', batch_size=batch_size)
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    # Moral 3-class Accuracy: 0.2742 (0.1631, 0.3852)
    print()
    print("Model: ", model_name)
    print(f"Moral 3-class Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Moral Correlation: {r:.4f} (p={p:.4f})")
    print(f"Moral RMSE: {rmse:.4f}")
    print(f"Moral AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def exp1_moral_delphi():
    # we can't use prompt here because that's not how Delphi works
    # Delphi shouldn't have confidence intervals, because we assigned fake probability

    adapter = DelphiAdapter()
    cd = MoralDataset()

    evaluator = AccuracyEvaluatorWithAmbiguity()

    all_choice_scores, all_label_indices = [], []
    ex: Example
    for ex in tqdm(cd):
        choice_scores = adapter.adapt(ex.story + " " + ex.question, method='yesno')
        all_choice_scores.append(choice_scores)
        all_label_indices.append(ex.answer_dist)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)

    # Moral 3-class Accuracy: 0.2742 (0.1631, 0.3852)
    print()
    print("Delphi")
    print(f"Moral 3-class Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")

    return ExperimentResult(acc, conf_interval, 0, 0, 0, 0)

def run_template_for_gpt3(cd: AbstractDataset, adapter: GPT3Adapter,
                            jp: JudgmentPrompt, method: str='yesno'):
    all_choice_scores, all_label_dist = [], []
    for ex in tqdm(cd):
        instance = jp.apply(ex)
        choice_scores = adapter.adapt(instance, method=method)
        all_choice_scores.append(choice_scores)
        all_label_dist.append(ex.answer_dist)

    return all_choice_scores, all_label_dist

def exp1_causal(engine: str='text-davinci-002'):
    cd = CausalDataset()

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    adapter = GPT3Adapter(engine=engine)

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_gpt3(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt.jinja"),
                                                            method='yesno')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_gpt3(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt_2.jinja"),
                                                            method='multiple_choice')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    # choice_scores, label_indices = run_template_for_gpt3(cd, adapter, CausalJudgmentPrompt("./prompts/exp1_causal_prompt_3.jinja"),
    #                                                         method='multiple_choice')
    # all_choice_scores.extend(choice_scores)
    # all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    # Causal Accuracy: 0.6250 (0.5459, 0.7041) Temp = 0.7, 0.9
    # Causal Accuracy: 0.6181 (0.5387, 0.6974) Temp = 0.5
    print()
    print(f"engine: {engine}")
    print(f"Causal Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Causal Correlation: {r:.4f} (p={p:.4f})")
    print(f"Causal RMSE: {rmse:.4f}")
    print(f"Causal AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)


def exp1_moral(engine: str='text-davinci-002'):
    md = MoralDataset()

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    auroc_evaluator = AuROCEvaluator()

    adapter = GPT3Adapter(engine=engine)

    all_choice_scores, all_label_indices = [], []

    choice_scores, label_indices = run_template_for_gpt3(md, adapter, MoralJudgmentPrompt("./prompts/exp1_moral_prompt.jinja"),
                                                            method='yesno')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    choice_scores, label_indices = run_template_for_gpt3(md, adapter,
                                                         MoralJudgmentPrompt("./prompts/exp1_moral_prompt_2.jinja"),
                                                         method='multiple_choice')
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)

    # davinci-2: Moral Accuracy: 0.5000 (0.3755, 0.6245)
    # davinci-1: Moral Accuracy: 0.4032 (0.2811, 0.5253)

    # davinci-2: Moral Accuracy: 0.3387 (0.2209, 0.4565) prompt2
    print()
    print(f"engine={engine}")
    print(f"Moral Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"Moral Correlation: {r:.4f} (p={p:.4f})")
    print(f"Moral RMSE: {rmse:.4f}")
    print(f"Moral AuROC: {auroc:.4f}")

    return ExperimentResult(acc, conf_interval, r, p, rmse, auroc)

def produce_table1():
    result = {}

    # causal
    for model_name in ['google/electra-large-generator', 'bert-base-uncased', 'bert-large-uncased', 'roberta-large',
              'albert-xxlarge-v2', 'gpt2-xl']:
        er = exp1_causal_huggingface(model_name=model_name, batch_size=32)
        result[model_name] = er.json

    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        er = exp1_causal(engine=engine)
        result[engine] = er.json

    json.dump(result, open('../../results/exp1_causal_full_result.json', 'w'), indent=2)

    # moral
    result = {}
    for model_name in ['roberta-large', 'google/electra-large-generator', 'bert-base-uncased', 'bert-large-uncased',
                'albert-xxlarge-v2', 'gpt2-xl']:
            er = exp1_moral_huggingface(model_name=model_name)
            result[model_name] = er.json

    er = exp1_moral_delphi()
    result['delphi'] = er.json

    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        er = exp1_moral(engine=engine)
        result[engine] = er.json

    json.dump(result, open('../../results/exp1_moral_full_result.json', 'w'), indent=2)

if __name__ == '__main__':
    pass
    # exp1_causal()
    # exp1_moral()
    # exp2_causal() # 54.2%
    # exp2_moral() # 0.7419 (0.6330, 0.8509)

    produce_table1()