import pytest
import torch as t

from hypo_interp.tasks import InductionTask

DEVICE = "cpu"


@pytest.fixture(scope="module")
def induction_task():
    """
    We use this fixture to create a single instance of the induction
    task for all tests in this module. Should speed things up a bit.
    """
    task = InductionTask(device=DEVICE, seq_len=50, num_examples=40, zero_ablation=True)
    yield task


def test_eval_metric(induction_task):
    task = induction_task
    default_score, _ = task.score(per_prompt=True)
    eval_default = task.eval_metric(default_score, default_score, use_mean=True)

    assert abs(eval_default).mean() < 1e-15


def test_l1_distance(induction_task):

    task = induction_task
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 5, 6])
    expected = t.tensor(3)
    eval_val = task.eval_metric(candidate, original, distance="l1", use_mean=False)
    assert t.equal(eval_val, expected)


def test_l2_distance(induction_task):
    task = induction_task
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 5, 6])
    expected = t.tensor(9)
    eval_val = task.eval_metric(candidate, original, distance="l2", use_mean=False)
    assert t.equal(eval_val, expected)


def test_eval_metric_use_mean_true(induction_task):
    task = induction_task
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 3, 2])
    expected_diff = t.tensor(1)  # L2 distance squared

    actual_diff = task.eval_metric(original, candidate, use_mean=True)
    assert t.equal(
        actual_diff, expected_diff
    ), "use_mean:True should return difference of means"


def test_eval_metric_per_use_mean_false(induction_task):
    original = t.tensor([1, 2, 3])
    candidate = t.tensor([4, 3, 2])
    expected_mean_diff = t.tensor(11 / 3)  # Mean of [9, 1, 1]

    actual_mean_diff = induction_task.eval_metric(original, candidate, use_mean=False)

    assert (
        actual_mean_diff.item() == expected_mean_diff.item()
    ), "use_mean=False should return mean of differences"
