"""
Basic file containing some tests for utilities with circuits
"""
import random
import time

from hypo_interp.tasks import InductionTask, TracrProportionTask
from hypo_interp.types_ import Circuit
from hypo_interp.utils.circuit_utils import _CircuitGraph, sample_circuit_from_circuit

DEVICE = "cpu"


def test_circuit_graph_object_builds_fast():
    """
    Tests that the circuit graph object builds and checks
    that it doesn't take too long (we are using it in a bunch
    of places so it should be fast-ish)

    This a bit of sinful test for having used a private class
    but I think it is worth it.
    """
    task = InductionTask(zero_ablation=True, device=DEVICE)
    circuit: Circuit = task.complete_circuit
    start_time = time.time()
    circuit_graph = _CircuitGraph(circuit, use_pos_embed=task.use_pos_embed)
    end_time = time.time()

    elapsed_time = end_time - start_time
    assert elapsed_time < 0.05


def test_sample_circuit_sanity_check():
    """
    Tests that the paths generated by the circuit graph
    are good enough if we make it long enough.
    """
    # set random seed
    random_seed = 42
    percentage_of_full_circuit = 0.5
    random.seed(random_seed)

    task = InductionTask(zero_ablation=True, device=DEVICE)
    score_default = task.score()
    circuit: Circuit = task.complete_circuit
    minimum_number_of_edges = int(percentage_of_full_circuit * len(circuit))

    sampled_circuit = sample_circuit_from_circuit(
        circuit,
        minimum_number_of_edges,
        use_pos_embed=task.use_pos_embed,
        seed=random_seed,
    )
    task.set_circuit(sampled_circuit)
    score_sampled_circuit = task.score()

    # get a random circuit that is bad of size minimum_number_of_edges
    random_bad_circuit = task.complete_circuit
    random.shuffle(random_bad_circuit)
    random_bad_circuit = random_bad_circuit[:minimum_number_of_edges]

    # reset the task
    task.set_circuit(random_bad_circuit)
    score_random_bad_circuit = task.score()

    assert len(sampled_circuit) >= minimum_number_of_edges
    assert score_random_bad_circuit > score_sampled_circuit
    assert score_sampled_circuit > score_default


def test_sample_circuit_tracr_sanity_check():
    """
    Tests that the paths generated by the circuit graph
    are good enough if we make it long enough.
    """
    # set random seed
    random_seed = 42
    percentage_of_full_circuit = 0.5
    random.seed(random_seed)

    task = TracrProportionTask(device=DEVICE)
    score_default = task.score()
    circuit: Circuit = task.complete_circuit
    minimum_number_of_edges = int(percentage_of_full_circuit * len(circuit))

    sampled_circuit = sample_circuit_from_circuit(
        circuit,
        minimum_number_of_edges,
        use_pos_embed=task.use_pos_embed,
        seed=random_seed,
    )
    task.set_circuit(sampled_circuit)
    score_sampled_circuit = task.score()

    # get a random circuit that is bad of size minimum_number_of_edges
    random_bad_circuit = task.complete_circuit
    random.shuffle(random_bad_circuit)
    random_bad_circuit = random_bad_circuit[:minimum_number_of_edges]

    # reset the task
    task.set_circuit(random_bad_circuit)
    score_random_bad_circuit = task.score()

    sampled_circuit.sort()
    random_bad_circuit.sort()

    assert len(sampled_circuit) >= minimum_number_of_edges
    assert score_random_bad_circuit != score_sampled_circuit
    assert score_sampled_circuit > score_default
