"""
Show average R2-sortability for different weights and avgerage in-degrees for ER and SF graphs.
"""

from auxiliary.vsb_emergence import vsb_investigation, approx_b, plot_heatmap
from CDExperimentSuite_DEV import *
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}\usepackage{amssymb}"
plt.rcParams["axes.labelsize"] = 25
plt.rcParams["xtick.labelsize"] = 18
plt.rcParams["ytick.labelsize"] = 18
plt.rcParams["legend.fontsize"] = 18
plt.rcParams["lines.linewidth"] = 3
plt.rcParams["lines.markersize"] = 12


def node_limit_graph(opt):
    """
    Visualize as heatmap vsb in the node limit (for a lot of nodes) for different graph types
    """

    # ER graph run
    opt_ER = deepcopy(opt)
    opt_ER["base_dir"] += "graph"
    opt_ER["exp_name"] = "graph_ER"
    opt_ER["graphs"] = ["ER"]
    opt_ER["noise_distributions"] = [
        utils.NoiseDistribution("gauss", "uniform", (0.5, 2))
    ]
    opt_ER = utils.Options(**opt_ER)
    vsb_investigation(
        opt_ER, vsb_functions={"R2": utils.r2_sortability}, plot_fun=plot_heatmap
    )

    ## SF graph run
    opt_SF = deepcopy(opt)
    opt_SF["base_dir"] += "graph"
    opt_SF["exp_name"] = "graph_SF"
    opt_SF["graphs"] = ["SF"]
    opt_SF["noise_distributions"] = [
        utils.NoiseDistribution("gauss", "uniform", (0.5, 2))
    ]
    opt_SF["edge_types"] = ["fixed"]
    opt_SF["edges"] = list(np.arange(1, 9, 1))
    opt_SF = utils.Options(**opt_SF)
    vsb_investigation(
        opt_SF, vsb_functions={"R2": utils.r2_sortability}, plot_fun=plot_heatmap
    )


if __name__ == "__main__":
    w_ranges = []
    for i in np.arange(-1.0, 3.0, 0.5):
        a = 0.2
        res = approx_b(a=a, k=i)
        w_ranges.append((a, res[0]))
    opt = {
        "overwrite": False,
        "base_dir": f"src/results/Heatmap/",
        "scaler": Scalers.Identity(),
        # ---
        "MEC": False,
        "thres": 0,
        "thres_type": "standard",
        "vsb_function": utils.var_sortability,
        "R2sb_function": utils.r2_sortability,
        "CEVsb_function": utils.cev_sortability,
        # ---
        "n_repetitions": 2,
        "edge_types": ["fixed"],
        "edges": [0.4] + list(np.arange(1.5, 8, 1)),
        "edge_weights": w_ranges,
        "n_nodes": [50],
        "n_obs": [1000],
    }

    # run experiments
    node_limit_graph(opt)
