""" 
Show alignment between alternative sortability definitions that count paths differently.
"""
import numpy as np
import CDExperimentSuite_DEV as CDES
import os
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kendalltau


def visualize(opt, vsb_type):
    """Visualization"""
    if vsb_type == "Original":
        vsb_function = CDES.utils.order_alignment
    elif vsb_type == "Path-existence":
        vsb_function = CDES.utils.order_alignment_CountOne
    elif vsb_type == "All-paths":
        vsb_function = CDES.utils.order_alignment_CountAll
    else:
        raise KeyError("vsb type not found")

    viz = CDES.Visualizer(opt)
    for acc_measure in ["sid", "shd"]:
        viz.varsortability_kde(
            acc_measure=acc_measure,
            filters={"algorithm": ["sortnregressIC", "randomregressIC"]},
            vsb_fun=vsb_function,
            vsb_name=f"{vsb_type.capitalize()} var-sortability",
            custom_name=f"_{vsb_type}",
        )


for exp_name in ["ER", "SF"]:
    opt = {
        "overwrite": True,
        "base_dir": f"src/results/SortabilityDefinitions/{exp_name}",
        # ---
        "MEC": False,
        "thres": 0,
        "thres_type": "standard",
        "vsb_function": CDES.utils.var_sortability,
        "R2sb_function": CDES.utils.r2_sortability,
        "CEVsb_function": CDES.utils.cev_sortability,
        # ---
        "n_repetitions": 50,
        "graphs": [exp_name],
        "edges": [2],
        "edge_types": ["fixed"],
        "noise_distributions": [
            CDES.utils.NoiseDistribution("gauss", "uniform", (0.5, 2.0)),
        ],
        "edge_weights": [(0.1, 0.5)],
        "n_nodes": [20],
        "n_obs": [1000],
    }

    opt["exp_name"] = "_raw"
    opt["scaler"] = CDES.Scalers.Identity()
    opt = CDES.utils.Options(**opt)

    # --------- generate new data if it doesn't exist already -----------

    if not os.path.exists(opt.base_dir):
        ## run
        CDES.DataGenerator().generate_and_save(opt)
        expR = CDES.ExperimentRunner(opt)
        expR.randomregressIC()
        CDES.Evaluator(opt).evaluate(thresholding=opt.thres_type)

    CDES.utils.create_folder(
        f"src/results/SortabilityDefinitions/{exp_name}/{exp_name}_raw/_viz/"
    )

    # --------- visualization and computation of rank correlation ---------

    df_path = f"src/results/SortabilityDefinitions/{exp_name}/{exp_name}_raw/_eval/standard_0.csv"
    df_raw = CDES.utils.load_results(df_path)
    df_raw.drop_duplicates(subset="dataset_hash", keep="first", inplace=True)

    # original definition
    df_raw["orig_Var"] = df_raw.apply(
        lambda x: CDES.utils.order_alignment(x.W_true[0], x.vars[0]), axis=1
    )
    df_raw["orig_R2"] = df_raw.apply(
        lambda x: CDES.utils.order_alignment(x.W_true[0], x.R2[0]), axis=1
    )

    # path-existence
    df_raw["pe_Var"] = df_raw.apply(
        lambda x: CDES.utils.order_alignment_CountOne(x.W_true[0], x.vars[0]), axis=1
    )
    df_raw["pe_R2"] = df_raw.apply(
        lambda x: CDES.utils.order_alignment_CountOne(x.W_true[0], x.R2[0]), axis=1
    )

    # all-paths
    df_raw["pc_Var"] = df_raw.apply(
        lambda x: CDES.utils.order_alignment_CountAll(x.W_true[0], x.vars[0]), axis=1
    )
    df_raw["pc_R2"] = df_raw.apply(
        lambda x: CDES.utils.order_alignment_CountAll(x.W_true[0], x.R2[0]), axis=1
    )

    plt.rcParams["text.usetex"] = True
    plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}"
    plt.rcParams["axes.labelsize"] = 28
    plt.rcParams["xtick.labelsize"] = 20
    plt.rcParams["ytick.labelsize"] = 20
    plt.rcParams["legend.fontsize"] = 20
    plt.rcParams["lines.linewidth"] = 2
    plt.rcParams["lines.markersize"] = 14

    comparisons = [("orig", "pc"), ("orig", "pe"), ("pe", "pc")]
    for a, b in comparisons:
        df = df_raw.loc[
            :, ["dataset_hash", f"{a}_Var", f"{a}_R2", f"{b}_Var", f"{b}_R2"]
        ].copy()
        # rank correlation
        tau_var = kendalltau(df[f"{a}_Var"].to_numpy(), df[f"{b}_Var"].to_numpy())
        tau_r2 = kendalltau(df[f"{a}_R2"].to_numpy(), df[f"{b}_R2"].to_numpy())
        with open(
            f"src/results/SortabilityDefinitions/{exp_name}/{exp_name}_raw/_viz/significance_{a}_{b}.txt",
            "w",
        ) as f:
            f.write("var:\t" + str(tau_var))
            f.write("R2:\t" + str(tau_r2))

        names = {"orig": "Original", "pe": "Path-existence", "pc": "Path-count"}
        # wrangle for plotting
        df = df.melt(
            id_vars="dataset_hash",
            value_vars=[f"{a}_Var", f"{b}_Var", f"{a}_R2", f"{b}_R2"],
            var_name="measure",
            value_name="value",
        )
        df["type"] = df.measure.apply(lambda x: names[x.split("_")[0]])
        df["measure"] = df.measure.apply(lambda x: x.split("_")[1])
        df = df.pivot(
            index=["dataset_hash", "measure"], columns="type", values="value"
        ).reset_index()
        df.measure.replace(
            {
                "R2": r"$R^2$-sortability"
                + f";  Kendall rank correlation coefficient: {tau_r2.statistic:.2f}"
            },
            inplace=True,
        )
        df.measure.replace(
            {
                "Var": f"Var-sortability; Kendall rank correlation coefficient: {tau_var.statistic:.2f}"
            },
            inplace=True,
        )

        fig, ax = plt.subplots(figsize=(9, 6))
        # draw diagonal line
        plt.plot(
            np.linspace(df[names[a]].min(), df[names[a]].max(), 100),
            np.linspace(df[names[b]].min(), df[names[b]].max(), 100),
            color="black",
            linestyle="--",
            label="diagonal",
        )
        # scatters
        sns.scatterplot(
            data=df, x=names[a], y=names[b], hue="measure", ax=ax, s=50, alpha=0.5
        )
        legend = plt.legend(handlelength=1.5)
        # adjust legend label size for point markers
        for handle in legend.legend_handles:
            try:
                handle.set_sizes([30])  # Adjust the marker size as needed
            except AttributeError:
                pass
        plt.tight_layout()
        plt.savefig(
            f"src/results/SortabilityDefinitions/{exp_name}/{exp_name}_raw/_viz/{exp_name}_sortability_{a}_vs_{b}.pdf"
        )
