import pytest
from causal_graphs.graph_generation import generate_categorical_graph, get_graph_func

from causal_discovery import enco, intervention_strategies, new_approach

NUM_VARS = 5
NUM_CATEGS = 10

graph = generate_categorical_graph(
    num_vars=NUM_VARS,
    min_categs=NUM_CATEGS,
    max_categs=NUM_CATEGS,
    edge_prob=0.2,
    connected=True,
    use_nn=True,
    graph_func=get_graph_func("random"),
    seed=1,
    embed_dim=4,
)


@pytest.mark.parametrize("strategy", intervention_strategies.strategies.keys())
def test_standard_strategy_initialization(strategy):
    def test():
        cd_method = enco.ENCO(graph=graph, interventions_policy=strategy)
        cd_method.graph_fitting_module.sample_next_var_idx(
            cd_method.gamma, cd_method.theta
        )

    # Some strategies require specific initialization and will fail in this test.
    if strategy == "nonempty_round_robin":
        with pytest.raises(ValueError):
            test()
    elif "trained" in strategy:
        with pytest.raises(AssertionError):
            test()
    else:
        test()


def test_nonempty_round_robin_initialization():
    cd_method = new_approach.NewApproach(
        graph=graph,
        int_data_collection_policy="nonempty_round_robin",
        interventions_policy="nonempty_round_robin",
    )
    cd_method.int_dataset.add_batch(var_idx=0)
    cd_method.graph_fitting_module.sample_next_var_idx(cd_method.gamma, cd_method.theta)
