from typing import Literal

import pytest
import torch as t

from hypo_interp.tasks import (
    GreaterThanTask,
    InductionTask,
    IoITask,
    TracrProportionTask,
    TracrReverseTask,
)
from hypo_interp.tasks.docstring.task import DocstringTask
from hypo_interp.tasks.mech_interp_task import MechInterpTask

DEVICE = "cpu"


@pytest.fixture(scope="module")
def induction_task():
    """
    We use this fixture to create a single instance of the induction
    task for all tests in this module. Should speed things up a bit.
    """
    task = InductionTask(device=DEVICE, seq_len=50, num_examples=40, zero_ablation=True)
    yield task


def test_sanity_full_circuit():
    """
    A basic test that checks that setting a full circuit doesn't
    modify results in any way.
    """

    task = InductionTask(device=DEVICE)
    default_score, _ = task.score()

    full_circuit = task.complete_circuit
    task.set_circuit(full_circuit)
    score_full_circuit_no_reset, _ = task.score()

    task.reset()
    task.set_circuit(full_circuit)
    score_full_circuit_reset, _ = task.score()

    # Check that all of these values are within machine precision of each other
    assert abs(default_score - score_full_circuit_no_reset) < 1e-15
    assert abs(default_score - score_full_circuit_reset) < 1e-15


def test_induction_canonical_circuit(induction_task):

    task = TracrProportionTask(device=DEVICE)
    default_score, _ = task.score()
    task.set_circuit(task.canonical_circuit)
    score_canonical_circuit, _ = task.score()

    assert abs(default_score - score_canonical_circuit) > 1e-15


def test_eval_metric(induction_task):

    task = induction_task
    default_score, _ = task.score(per_prompt=True)
    eval_default = task.eval_metric(default_score, default_score, use_mean=True)

    assert abs(eval_default).mean() < 1e-15


def test_l1_distance(induction_task):

    task = induction_task
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 5, 6])
    expected = t.tensor(3)
    eval_val = task.eval_metric(candidate, original, distance="l1", use_mean=False)
    assert t.equal(eval_val, expected)


def test_l2_distance(induction_task):

    task = induction_task
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 5, 6])
    expected = t.tensor(9)
    eval_val = task.eval_metric(candidate, original, distance="l2", use_mean=False)
    assert t.equal(eval_val, expected)


def test_eval_metric_use_mean_true(induction_task):
    task = induction_task
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 3, 2])
    expected_diff = t.tensor(1)  # L2 distance squared

    actual_diff = task.eval_metric(original, candidate, use_mean=True)
    assert t.equal(
        actual_diff, expected_diff
    ), "use_mean:True should return difference of means"


def test_eval_metric_per_use_mean_false(induction_task):
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 3, 2])
    expected_mean_diff = t.tensor(11 / 3)  # Mean of [9, 9, 9]

    actual_mean_diff = induction_task.eval_metric(original, candidate, use_mean=False)

    assert (
        actual_mean_diff.item() == expected_mean_diff.item()
    ), "use_mean=False should return mean of differences"


def test_induction_canonical_circuit():
    task = InductionTask(device=DEVICE)
    default_score, _ = task.score()
    task.set_circuit(task.canonical_circuit)
    score_canonical_circuit, _ = task.score()

    assert abs(default_score - score_canonical_circuit) > 1e-15


def _test_full_vs_canonical_vs_random(task):
    scores = dict()
    scores["default"], _ = task.score(per_prompt=False)

    # check full circuit
    full_circuit = task.complete_circuit
    task.set_circuit(full_circuit)
    scores["full"], _ = task.score(per_prompt=False)

    # check canonical circuit
    cc = task.canonical_circuit
    task.set_circuit(cc)
    scores["canonical"], _ = task.score(per_prompt=False)

    # check random circuit
    import random

    random.seed(0)
    random.shuffle(full_circuit)
    random_circuit = full_circuit[: len(cc)]
    task.set_circuit(random_circuit)
    scores["random"], _ = task.score(per_prompt=False)

    return scores


def get_task(
    task_name: str, ablation: Literal["canonical", "zero", "corrupted"]
) -> MechInterpTask:
    device = "cpu"
    if ablation == "zero":
        zero_ablation = True
    elif ablation == "corrupted":
        zero_ablation = False
    else:
        # maybe have this be a property of the task?
        zero_ablation = {
            "InductionTask": True,
            "TracrProportionTask": False,
            "TracrReverseTask": False,
            "GreaterThanTask": False,
            "IoITask": False,
            "DocstringTask": False,
        }[task_name]

    if task_name == "InductionTask":
        return InductionTask(zero_ablation=zero_ablation, device=device, num_examples=10)
    elif task_name == "TracrProportionTask":
        return TracrProportionTask(
            zero_ablation=zero_ablation, device=device, num_examples=6
        )
    elif task_name == "TracrReverseTask":
        return TracrReverseTask(
            zero_ablation=zero_ablation, device=device, num_examples=6
        )
    elif task_name == "GreaterThanTask":
        return GreaterThanTask(zero_ablation=zero_ablation, device=device, num_examples=3)
    elif task_name == "IoITask":
        return IoITask(zero_ablation=zero_ablation, device=device, num_examples=20)
    elif task_name == "DocstringTask":
        return DocstringTask(zero_ablation=zero_ablation, device=device, num_examples=3)
    else:
        raise ValueError(f"Unknown task {task_name}")

    return task


@pytest.mark.slow
@pytest.mark.parametrize(
    "task_name",
    [
        "InductionTask",
        "TracrProportionTask",
        "TracrReverseTask",
        "GreaterThanTask",
        "IoITask",
        "DocstringTask",
    ],
)
def test_full_vs_canonical_vs_random(task_name):
    task = get_task(task_name, ablation="canonical")
    scores = _test_full_vs_canonical_vs_random(task)
    print(scores)
    assert scores["default"] == scores["full"]
    assert scores["random"] > scores["full"]
    # this test is random, so might not always pass.
    assert scores["random"] > scores["canonical"]
