import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import networkx as nx
import scipy
import os

from graphons import graphon_families as gf
from graphons import graphon_est_ell2 as ge 
from graphons import utils
from itertools import product
from datetime import datetime
from graphons.utils import truncate_sparse_matrix
from graphons.graphon_experiments import get_boolean_sample
import graphons.bionetworks_utils as bu 


import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=FutureWarning) 


plt.rcParams['figure.figsize'] = [15, 10]
sns.set_style('whitegrid')
plt.rcParams['font.size'] = 28.0
plt.rcParams['xtick.labelsize'] = 28.0
plt.rcParams['ytick.labelsize'] = 28.0


################################################################
################################
### ESTIMATION FUNCTIONS
################################
################################################################

def iter_exp_graphon_all_methods(n_vals, 
                                 alpha_vals, beta_vals, 
                                 p_flip_vals, 
                                 np_fixed, num_trials):
    np.random.seed(0)
    df_col_names = ['alpha', 'beta', 'np', 'nq', 'our_algo', 'spectral_algo',
       'bitflip_algo', 'p_flip']
    results = pd.DataFrame(
        columns = df_col_names
    )
    kept_graphs = {}

    for (nq, alpha, beta_val) in product(n_vals, alpha_vals, beta_vals): 
        graphon_fn_alpha = lambda x, y: gf.graphon_alpha(x, y, smoothness = alpha, scale=1, bias=0)
        graphon_fn_beta = lambda x, y: gf.graphon_alpha(x, y, smoothness = beta, scale=1, bias=0)
        for trial_number in range(num_trials): 
            alpha = np.round(alpha, 2)
            beta = np.round(beta_val, 2)
            gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, 
                                    q_graphon_fn=graphon_fn_beta,
                                    n_p = np_fixed, 
                                    n_q = nq)

            p_sample = gp.P.g1_sample.toarray()
            q_sample = gp.Q.g1_sample.toarray()
            subset_indices = gp.subset_indices
            q_matrix_full = gp.Q_extended.g1

            q_hat_ours = ge.matrix_completion_from_matrices(p_sample, q_sample, 
                                                subset_indices)
            qhat_ours_err = ge.frob_error(q_hat_ours, q_matrix_full)

            kq = int(np.power(nq, 1 / (1 + beta)))
            kp = int(np.power(np_fixed, 1 / (1 + alpha)))

            q_hat_eigvec = ge.eigenvec_feature_transfer_baseline(p_sample, q_sample, subset_indices, kp = kp, kq = kq)
            qhat_eigvec_err = ge.frob_error(q_hat_eigvec, q_matrix_full)
            

            for p_flip in p_flip_vals: 
                qhat_bitflip, _, _, _ = ge.get_usvt_bitflipped_estimator(q_matrix_full, 
                                                                        subset_indices, p_flip = p_flip)
                qhat_bitflip_err = ge.frob_error(qhat_bitflip, q_matrix_full)
                # df_col_names = ['alpha', 'beta', 'np', 'nq', 'our_algo', 'spectral_algo',
                #     'bitflip_algo', 'p_flip']
                results.loc[len(results)] = [
                    alpha, beta, np_fixed, nq, qhat_ours_err, qhat_eigvec_err, 
                    qhat_bitflip_err, p_flip
                ]
                if nq == max(n_vals) and trial_number == num_trials - 1: 
                    k = int(100 * p_flip)
                    kept_graphs[f'oracle_{k}_{nq}'] = qhat_bitflip
            if nq == max(n_vals) and trial_number == num_trials - 1: 
                kept_graphs[f'true_{nq}'] = q_matrix_full
                kept_graphs[f'ours_{nq}'] = q_hat_ours
                kept_graphs[f'spectral_{nq}'] = q_hat_eigvec
                  
        print(f'{nq=}, {p_flip=}, {alpha=}, {beta_val=}')
    return results, kept_graphs

################################################################
################################
### PLOTTING FUNCTIONS
################################
################################################################
def plot_res_df(res_subset, savepath=None): 
    plt.figure(figsize=(20, 8))
    np_fixed = res_subset['np'].unique()[0]
    alpha_val = np.round(res_subset['alpha'].unique()[0], 2)
    beta_val = np.round(res_subset['beta'].unique()[0], 2)
    for idx, p_flip in enumerate(res_subset['p_flip'].unique()): 
        res_flip_subset = res_subset[res_subset['p_flip'] == p_flip].groupby('nq')

        if idx == 0: 
            plt.plot(res_subset['nq'].unique(), 
                res_flip_subset.mean()['our_algo'],
                label='Our Algorithm',
                marker='o',
            )
            plt.fill_between(res_subset['nq'].unique(), 
                y1=res_flip_subset.quantile(.05)['our_algo'],
                y2=res_flip_subset.quantile(.95)['our_algo'], 
                alpha=0.1, 
            )


            plt.plot(res_subset['nq'].unique(), 
                res_flip_subset.median()['spectral_algo'],
                label='Spectral Baseline', 
                marker='p'
            )
            plt.fill_between(res_subset['nq'].unique(), 
                y1=res_flip_subset.quantile(.05)['spectral_algo'],
                y2=res_flip_subset.quantile(.95)['spectral_algo'], 
                alpha=0.1
            )

        pflip = float(p_flip)
        plt.plot(res_subset['nq'].unique(), 
            res_flip_subset.median()['bitflip_algo'],
            label=f'Oracle, $p={pflip}$',
            marker='o',
        )
        plt.fill_between(res_subset['nq'].unique(), 
            y1=res_flip_subset.quantile(.05)['bitflip_algo'],
            y2=res_flip_subset.quantile(.95)['bitflip_algo'], 
            alpha=0.1, 
        )
    # plt.legend(loc='upper center', ncol=3, shadow=True, 
    #            bbox_to_anchor=(0, 1, 1, 1))
    plt.ylabel('MSE (Log Scale)')
    plt.xlabel(f'$n_Q$ Value, with $n_P = {np_fixed}$')
    aval = float(alpha_val)
    bval = float(beta_val)
    # plt.title(f'Graphon Transfer with $\\alpha={aval}, \\beta={bval}$')
    plt.yscale('log')

    plt.legend(loc='upper center', 
            shadow=True,
            ncol=3, bbox_to_anchor=(-0.2, 1.25))
    plt.gcf().subplots_adjust(top=0.8)
    plt.gcf().subplots_adjust(bottom=0.1)

    # plt.tight_layout()
    if savepath is not None: 
        plt.savefig(savepath, dpi=700.0)

def plot_heatmaps_with_source(out_d, plot_titles=False, cmap='Spectral', savepath=None):
    nq = list(out_d.keys())[0].split('_')[-1]
    true_mat = np.triu(out_d[f'true_{nq}'], k = 1).copy().T
    source_mat = np.triu(out_d[f'source_{nq}'], k = 1).copy()
    ours_mat = np.triu(out_d[f'ours_{nq}'], k = 1).copy()
    sbm_mat = np.triu(out_d[f'sbm_{nq}'], k = 1).copy()
    oracle_10 = np.triu(out_d[f'oracle_10_{nq}'], k = 1).copy()
    # oracle_30 = np.triu(out_d[f'oracle_30_{nq}'], k = 1).copy()
    # oracle_50 = np.triu(out_d[f'oracle_50_{nq}'], k = 1).copy()

    fig, axs = plt.subplots(ncols=5, 
                            gridspec_kw=dict(width_ratios=[2, 2, 2, 2, 0.2]), 
                            figsize=(15, 4))


    leftmost = source_mat + true_mat
    left = true_mat + ours_mat 
    middle = true_mat + sbm_mat 
    right = true_mat + oracle_10 

    vmin_value = 0.0
    vmax_value = 1.0

    sns.heatmap(leftmost, cmap=cmap, square=True, 
                xticklabels=False, yticklabels=False, cbar=False, ax=axs[0],
                vmin=vmin_value, vmax=vmax_value)

    sns.heatmap(left, cmap=cmap, square=True, 
                xticklabels=False, yticklabels=False, cbar=False, ax=axs[1],
                vmin=vmin_value, vmax=vmax_value)
    sns.heatmap(middle, cmap=cmap, square=True, 
                xticklabels=False, yticklabels=False, cbar=False, ax=axs[2],
                vmin=vmin_value, vmax=vmax_value)
    im = sns.heatmap(right, cmap=cmap, square=True, cbar=False,
                    xticklabels=False, yticklabels=False, ax=axs[3],
                    vmin=vmin_value, vmax=vmax_value)

    fig.colorbar(axs[2].collections[0], cax=axs[-1])
    plt.tight_layout()


    if plot_titles: 
        axs[0].set_title('Source')
        axs[1].set_title('Our Algorithm')
        axs[2].set_title('Spectral')
        axs[3].set_title('Oracle')
    plt.tight_layout()

    if savepath:
        plt.savefig(savepath, dpi=700.0)

def plot_heatmaps_side_by_side(out_d, plot_titles=False, cmap='Spectral', savepath=None): 
    nq = list(out_d.keys())[0].split('_')[-1]
    true_mat = np.triu(out_d[f'true_{nq}'], k = 1).copy().T
    ours_mat = np.triu(out_d[f'ours_{nq}'], k = 1).copy()
    spectral_mat = np.triu(out_d[f'spectral_{nq}'], k = 1).copy()
    oracle_10 = np.triu(out_d[f'oracle_10_{nq}'], k = 1).copy()
    # oracle_30 = np.triu(out_d[f'oracle_30_{nq}'], k = 1).copy()
    # oracle_50 = np.triu(out_d[f'oracle_50_{nq}'], k = 1).copy()
    
    fig, axs = plt.subplots(ncols=4, gridspec_kw=dict(width_ratios=[2, 2, 2, 0.2]), figsize=(15, 6))

    left = true_mat + ours_mat 
    middle = true_mat + spectral_mat 
    right = true_mat + oracle_10 

    sns.heatmap(left, cmap=cmap, square=True, 
                xticklabels=False, yticklabels=False, cbar=False, ax=axs[0])
    sns.heatmap(middle, cmap=cmap, square=True, 
                xticklabels=False, yticklabels=False, cbar=False, ax=axs[1])

    # Plot right heatmap
    vmin_value = 0.0
    vmax_value = 1.0
    im = sns.heatmap(right, cmap=cmap, square=True, cbar=False,
                        xticklabels=False, yticklabels=False, ax=axs[2],
                        vmin=vmin_value, vmax=vmax_value)


    fig.colorbar(axs[2].collections[0], cax=axs[3])
    if plot_titles: 
        axs[0].set_title('Our Algorithm')
        axs[1].set_title('Spectral')
        axs[2].set_title('Oracle')
    plt.tight_layout()

    if savepath:
        plt.savefig(savepath, dpi=700.0)


# def iter_exp_email_all_methods(P_full_np, Q_full_np, n_vals, 
#                                    p_flip_vals,
#                                 np_fixed, 
#                                 num_trials,
#                                 seed=42):
#     np.random.seed(seed)
#     df_col_names = ['np', 'nq', 'our_algo', 'spectral_algo',
#        'bitflip_algo', 'p_flip']
#     results = pd.DataFrame(
#         columns = df_col_names
#     )
#     P_full_sample = get_boolean_sample(P_full_np).toarray()
#     d_est_p = ge.slice_distances_ell2_est(P_full_sample)
    
#     for nq in n_vals: 
#         h_quantile = 10 * np.sqrt(np.log(nq) / nq)
#         n_p = P_full_np.shape[0]
#         n_q = nq
#         for _ in range(num_trials): 
            
#             subset_context = np.random.choice(n_p, size=n_q, replace=False)
#             subset_indices = subset_context
#             Q_subset = truncate_sparse_matrix(
#                 Q_full_np[:, subset_context][subset_context].tocsr()
#             )
#             Q_subset_sample = get_boolean_sample(Q_subset).toarray()
            
#             q_hat_ours, q_hat_sbm, q_hat_bitflip = ge.est_q_three_methods(
#                 p_sample, q_sample, 
#                 q_full_ground_truth, 
#                 subset_indices, 
#                 p_flip=0.1, 
#                 kp=None, 
#                 kq=None,
#             )

#             d_est_subset = d_est_p[:, subset_indices].copy()
#             # print(d_est_subset.shape)
#             mask = ge.bottom_percentile_mask(d_est_subset, 
#                                             h_quantile).astype(np.float32)
#             row_sums = np.sum(mask, axis=1)
#             row_norms = np.divide(1.0, row_sums)
#             mask_normalized = np.diag(row_norms) @ mask
            
#             q_hat_ours = mask_normalized @ Q_subset_sample @ mask_normalized.T

#             qhat_ours_err = ge.frob_error(q_hat_ours, Q_full_np.toarray())

#             q_hat_eigvec = ge.eigenvec_feature_transfer_baseline(P_full_sample, Q_subset_sample, 
#                                                                  subset_indices)
#             qhat_eigvec_err = ge.frob_error(q_hat_eigvec, Q_full_np.toarray())
            

#             for p_flip in p_flip_vals: 
#                 qhat_bitflip, _, _, _ = ge.get_usvt_bitflipped_estimator(Q_full_np.toarray(), 
#                                                                         subset_indices, p_flip = p_flip)
#                 qhat_bitflip_err = ge.frob_error(qhat_bitflip, Q_full_np.toarray())
#                 results.loc[len(results)] = [
#                     np_fixed, nq, qhat_ours_err, qhat_eigvec_err, 
#                     qhat_bitflip_err, p_flip
#                 ]
#         print(f'{nq=}')
#     return results

def iter_exp_three_methods(P_full_np_list, Q_full_np_list, 
                           P_labels, Q_labels, 
                           n_vals, 
                                   p_flip_vals,
                                num_trials,
                                seed=42):
    np.random.seed(seed)
    assert len(P_full_np_list) == len(Q_full_np_list)
    df_col_names = ['np', 'nq', 'p_label', 'q_label', 'our_algo', 'sbm_algo',
       'bitflip_algo', 'p_flip']
    results = pd.DataFrame(
        columns = df_col_names
    )

    for idx in range(len(Q_full_np_list)): 
        
        P_full_np = scipy.sparse.csr_matrix(P_full_np_list[idx])
        Q_full_np = scipy.sparse.csr_matrix(Q_full_np_list[idx])
        np_fixed = int(P_full_np.shape[0])
        P_full_sample = get_boolean_sample(P_full_np).toarray()
        d_est_p = ge.slice_distances_ell2_est(P_full_sample)
        
        for nq in n_vals: 
            h_quantile = 100 * np.sqrt(np.log(nq) / nq)
            n_p = P_full_np.shape[0]
            n_q = nq
            for _ in range(num_trials): 
                
                subset_context = np.random.choice(n_p, size=n_q, replace=False)
                subset_indices = subset_context
                Q_subset = truncate_sparse_matrix(
                    Q_full_np[:, subset_context][subset_context].tocsr()
                )
                Q_subset_sample = get_boolean_sample(Q_subset).toarray()
                
                q_hat_ours, q_hat_sbm, q_hat_bitflips_list = ge.est_q_three_methods(
                    P_full_sample, Q_subset_sample, 
                    Q_full_np.toarray(),
                    subset_indices, 
                    d_est_large=d_est_p,
                    p_flip_list=p_flip_vals, 
                    kp=None, 
                    kq=None,
                )

                qhat_ours_err = ge.frob_error(q_hat_ours, Q_full_np.toarray())

                qhat_sbm_err = ge.frob_error(q_hat_sbm, Q_full_np.toarray())
                

                for indx2, p_flip in enumerate(p_flip_vals): 
                    qhat_bitflip = q_hat_bitflips_list[indx2]

                    qhat_bitflip_err = ge.frob_error(qhat_bitflip, Q_full_np.toarray())
                    results.loc[len(results)] = [
                        np_fixed, nq, 
                        P_labels[idx], Q_labels[idx],
                        qhat_ours_err, qhat_sbm_err, 
                        qhat_bitflip_err, p_flip
                    ]
            print(f'{nq=}')
    return results

def load_email_dataset(): 
    email_df = pd.read_csv('data/email-eu/email_data_binned_10_bins.csv')

    adjacency_matrices = {}
    for ts_category in email_df['time_category'].unique():
        key = int(ts_category)

        # Filter DataFrame for rows with the current unique value
        filtered_df = email_df[email_df['time_category'] == ts_category]

        
        # Extract 'u' and 'v' values from filtered DataFrame
        u_values = filtered_df['u'].astype(int).values
        v_values = filtered_df['v'].astype(int).values
        
        # Determine matrix dimensions based on max 'u' and 'v' values
        max_index = 1004
        adj_matrix = np.zeros((max_index + 1, max_index + 1))  # Initialize adjacency matrix
        
        # Populate adjacency matrix
        for u, v in zip(u_values, v_values):
            adj_matrix[int(u), int(v)] = 1
        
        # Store adjacency matrix in dictionary
        adjacency_matrices[ts_category] = adj_matrix
    return adjacency_matrices


    
### metabolic stuff
def load_metabolic_networks(): 
    path_to_saved = 'data/BiGG_data/metabolite_networks_processed/'
    fnames = os.listdir(path_to_saved)
    uq_species = [x.split('_sparse_')[0] for x in fnames if 'mat' in x]
    df_dict = {}
    for k in uq_species:
        df_name = path_to_saved + k + '_sparse_mat.npz'
        txt_name = path_to_saved + k + '_metabolite_names.txt'
        comp_names = path_to_saved + k + '_compartment_names.txt'

        sparse_mat = scipy.sparse.load_npz(df_name)
        metabolite_names = [str(x) for x in np.loadtxt(txt_name, dtype=str)]
        compartment_names = [str(x) for x in np.loadtxt(comp_names, dtype=str)]
        df = bu.create_df(sparse_mat, metabolite_names, compartment_names)
        df_dict[k] = df
    shared_nodes_all = set(list(df_dict['Recon3D'].index))
    for k in df_dict.keys(): 
        cur_nodes = set(list(df_dict[k].index))
        # print(len(shared_nodes_all))
        shared_nodes_all = shared_nodes_all.intersection(cur_nodes)
    

    shared_df_dict = {}
    shared_nodes = list(shared_nodes_all)
    for k in df_dict.keys(): 
        df_filtered, _ = bu.filter_df_by_metabolites(df_dict[k], shared_nodes)
        shared_df_dict[k] = df_filtered
        # print(df_filtered.shape)
    return shared_df_dict

def iter_exp_metabolic_all_methods(P_df, Q_df, n_vals, 
                                   p_flip_vals,
                                np_fixed, 
                                num_trials, seed=11):
    np.random.seed(seed)
    df_col_names = ['np', 'nq', 'our_algo', 'spectral_algo',
       'bitflip_algo', 'p_flip']
    results = pd.DataFrame(
        columns = df_col_names
    )
    P_full = scipy.sparse.csr_matrix(P_df.iloc[:, :-1])
    Q_full = scipy.sparse.csr_matrix(Q_df.iloc[:, :-1])
    
    Q_full_np = truncate_sparse_matrix(Q_full.copy()).toarray()
    P_full_np = truncate_sparse_matrix(P_full.copy()).toarray()

    P_full_sample = get_boolean_sample(truncate_sparse_matrix(
        P_full)
    ).toarray()
    d_est_p = ge.slice_distances_ell2_est(P_full_sample)
    
    for nq in n_vals: 
        h_quantile = 50 * np.sqrt(np.log(nq) / nq)
        n_p = len(P_df)
        n_q = nq
        for _ in range(num_trials): 
            subset_context = np.random.choice(n_p, size=n_q, replace=False)
            subset_indices = subset_context
            Q_subset = truncate_sparse_matrix(
                Q_full[:, subset_context][subset_context].tocsr()
            )
            Q_subset_sample = get_boolean_sample(Q_subset).toarray()
            
            d_est_subset = d_est_p[:, subset_indices].copy()
            # print(d_est_subset.shape)
            mask = ge.bottom_percentile_mask(d_est_subset, 
                                            h_quantile).astype(np.float32)
            row_sums = np.sum(mask, axis=1)
            row_norms = np.divide(1.0, row_sums)
            mask_normalized = np.diag(row_norms) @ mask
            # print(mask_normalized.shape)
            # print(Q_subset_sample.shape)
            
            q_hat_ours = mask_normalized @ Q_subset_sample @ mask_normalized.T

            qhat_ours_err = ge.frob_error(q_hat_ours, Q_full_np)

            q_hat_eigvec = ge.eigenvec_feature_transfer_baseline(P_full_sample, Q_subset_sample, 
                                                                 subset_indices)
            qhat_eigvec_err = ge.frob_error(q_hat_eigvec, Q_full_np)
            

            for p_flip in p_flip_vals: 
                qhat_bitflip, _, _, _ = ge.get_usvt_bitflipped_estimator(Q_full_np, 
                                                                        subset_indices, p_flip = p_flip)
                qhat_bitflip_err = ge.frob_error(qhat_bitflip, Q_full_np)
                # df_col_names = ['alpha', 'beta', 'np', 'nq', 'our_algo', 'spectral_algo',
                #     'bitflip_algo', 'p_flip']
                results.loc[len(results)] = [
                    np_fixed, nq, qhat_ours_err, qhat_eigvec_err, 
                    qhat_bitflip_err, p_flip
                ]
        print(f'{nq=}, {p_flip=}')
    return results 
    

def plot_results_dfs(results_df_list, titles_list, savepath=None):
    num_exps = len(results_df_list)
    fig, axs = plt.subplots(figsize=(20, 8), ncols=num_exps, sharey=True)

    np_fixed = results_df_list[0]['np'].unique()[0]
    for idx in range(num_exps): 
        results_df = results_df_list[idx]
        for idx2, p_flip in enumerate(results_df['p_flip'].unique()): 
            res_flip_subset = results_df[results_df['p_flip'] == p_flip].groupby('nq')
            if idx2 == 0: 
                axs[idx].plot(results_df['nq'].unique(), 
                    res_flip_subset.median()['our_algo'],
                    label='Our Algorithm',
                    marker='o',
                )
                axs[idx].fill_between(results_df['nq'].unique(), 
                    y1=res_flip_subset.quantile(.05)['our_algo'],
                    y2=res_flip_subset.quantile(.95)['our_algo'], 
                    alpha=0.2, 
                )


                axs[idx].plot(results_df['nq'].unique(), 
                    res_flip_subset.median()['spectral_algo'],
                    label='Spectral Baseline', 
                    marker='p'
                )
                axs[idx].fill_between(results_df['nq'].unique(), 
                    y1=res_flip_subset.quantile(.05)['spectral_algo'],
                    y2=res_flip_subset.quantile(.95)['spectral_algo'], 
                    alpha=0.2
                )
            pflip=float(p_flip)
            axs[idx].plot(results_df['nq'].unique(), 
                res_flip_subset.median()['bitflip_algo'],
                label=f'Oracle, $p={pflip}$',
                marker='o',
            )
            axs[idx].fill_between(results_df['nq'].unique(), 
                y1=res_flip_subset.quantile(.05)['bitflip_algo'],
                y2=res_flip_subset.quantile(.95)['bitflip_algo'], 
                alpha=0.2, 
            )
        axs[idx].set_yscale('log')
        if titles_list is not None: 
            axs[idx].set_title(titles_list[idx])
        # means = res_subset.groupby(['nq', 'p_flip']).mean()['bitflip_algo']
        # mean_subset = means[means['p_flip'] == p_flip]
        # print(p_flip, mean_subset)

    axs[0].set_ylabel('MSE (Log Scale)')
    axs[0].set_xlabel(f'$n_Q$ Value, with $n_P = {int(np_fixed)}$')
    axs[1].set_xlabel(f'$n_Q$ Value, with $n_P = {int(np_fixed)}$')
    # axs[0].set_yscale('log')
    # axs[1].set_yscale('log')
    handles, labels = axs[0].get_legend_handles_labels()
    plt.legend(handles, labels, loc='upper center', 
               shadow=True,
               ncol=3, bbox_to_anchor=(-0.2, 1.25))
    plt.gcf().subplots_adjust(top=0.8)
    plt.gcf().subplots_adjust(bottom=0.1)

    # plt.tight_layout()
    if savepath is not None: 
        plt.savefig(savepath, dpi=700.0)


def run_all_realworld(): 
    startTime = datetime.now()

    print('Beginning Email EU Run...')
    email_adj_dict = load_email_dataset()
    time_zero_list = [0, 0, 0]
    time_one_list = [1, 4, 7]
    P_full_np_list = [email_adj_dict[time_zero_list[i]] for i in range(len(time_zero_list))]
    Q_full_np_list = [email_adj_dict[time_one_list[i]] for i in range(len(time_one_list))]
    P_labels = [f't{time_zero_list[i]}' for i in range(len(time_zero_list))]
    Q_labels = [f't{time_one_list[i]}' for i in range(len(time_one_list))]
    num_trials = 50
    p_flip_vals = [0.01, 0.05, 0.1]
    seed = 91
    nq_vals = [20, 40, 60, 80, 100]
    res_df = iter_exp_three_methods(P_full_np_list, Q_full_np_list, 
                           P_labels, Q_labels, 
                           n_vals=nq_vals,
                                   p_flip_vals=p_flip_vals,
                                num_trials=num_trials,
                                seed=seed)
    run_label = f'email_eu_seed_{seed}_numtrials_{num_trials}_smaller_p_flips'
    res_df.to_csv(f'exp-results/results_3_methods/{run_label}.csv')    


    # Print total time
    time_diff = datetime.now() - startTime
    print(f'Total time {time_diff}')

def run_all_simul(): 
    startTime = datetime.now()
    print('Beginning Graphons 1 Run...')
    alpha = 0.1
    beta = 0.5 

    graphon_fn_alpha = lambda x, y: gf.graphon_alpha(x, y, smoothness = alpha, scale=1, bias=0)
    graphon_fn_beta = lambda x, y: gf.graphon_alpha(x, y, smoothness = beta, scale=1, bias=0)

    np_fixed = 200 
    nq_vals = [20, 30, 40, 50]
    P_full_np_list = []
    Q_full_np_list = []
    num_trials = 50
    p_flip_vals = [0.1, 0.3, 0.5]
    seed = 91
    

    gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, 
                                q_graphon_fn=graphon_fn_beta,
                                n_p = np_fixed, 
                                n_q = 20)
    p_matrix_full = scipy.sparse.csr_matrix(gp.P.g1.copy())
    q_matrix_full = scipy.sparse.csr_matrix(gp.Q_extended.g1.copy())
    P_full_np_list.append(p_matrix_full)
    Q_full_np_list.append(q_matrix_full)
    P_labels=['alpha_01']
    Q_labels=['beta_05']
    
    res_df_graphon = iter_exp_three_methods(
        P_full_np_list, Q_full_np_list, 
                           P_labels, Q_labels, 
                           n_vals=nq_vals,
                           p_flip_vals=p_flip_vals,
                           num_trials=num_trials,
                           seed=seed  
    )
    run_label = f'graphon_alpha_01_beta_05_seed_{seed}_numtrials_{num_trials}'
    res_df_graphon.to_csv(f'exp-results/results_3_methods/{run_label}.csv')  

    print('Beginning Graphons 2 Run...')
    graphon_fn_alpha = lambda x, y: gf.graphon_liza_rotated(x, y, 
                                                        period=3, bias=0.5, 
                                                        invert=True, phase_shift=False)
    graphon_fn_beta = lambda x, y: gf.graphon_liza_2(x, y, period=3, 
                                                 bias=0.5, invert=False, 
                                                 phase_shift=False)
    np_fixed = 200 
    nq_vals = [20, 30, 40, 50]
    P_full_np_list = []
    Q_full_np_list = []
    num_trials = 50
    p_flip_vals = [0.1, 0.3, 0.5]
    seed = 91
    

    
    gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, 
                                q_graphon_fn=graphon_fn_beta,
                                n_p = np_fixed, 
                                n_q = 20)
    p_matrix_full = scipy.sparse.csr_matrix(gp.P.g1.copy())
    q_matrix_full = scipy.sparse.csr_matrix(gp.Q_extended.g1.copy())
    P_full_np_list.append(p_matrix_full)
    Q_full_np_list.append(q_matrix_full)
    P_labels=['liza_period_3_rotated']
    Q_labels=['liza_period_3_ordinary']
    
    res_df_wavy = iter_exp_three_methods(
        P_full_np_list, Q_full_np_list, 
                           P_labels, Q_labels, 
                           n_vals=nq_vals,
                           p_flip_vals=p_flip_vals,
                           num_trials=num_trials,
                           seed=seed  
    )
    run_label = f'graphon_liza_period_3_rotated_liza_period_3_ordinary_seed_{seed}_numtrials_{num_trials}'
    res_df_wavy.to_csv(f'exp-results/results_3_methods/{run_label}.csv')  

    print('Beginning MMSB 1 Run...')

    
    graphon_fn_alpha = lambda x, y: gf.graphon_liza_rotated(x, y, 
                                                        period=3, bias=0.5, 
                                                        invert=True, phase_shift=False)
    graphon_fn_beta = lambda x, y: gf.graphon_liza_2(x, y, period=3, 
                                                 bias=0.5, invert=False, 
                                                 phase_shift=False)
    np_fixed = 200 
    nq_vals = [20, 30, 40, 50]
    P_full_np_list = []
    Q_full_np_list = []
    num_trials = 50
    p_flip_vals = [0.1, 0.3, 0.5]
    seed = 91
    

    p_matrix_full, q_matrix_full, _, _, _ = gf.get_mmsb_pair_two_param(20, np_fixed,
                kp_power = 0.5, kq_power = 0.5, 
                  ap = 0.7, bp = 0.3,
                 aq = 0.9, bq=0.1,
                 noisy=True, 
                 noise_level=0.01, seed=seed)
    
    
    p_matrix_full = scipy.sparse.csr_matrix(p_matrix_full)
    q_matrix_full = scipy.sparse.csr_matrix(q_matrix_full)
    P_full_np_list.append(p_matrix_full)
    Q_full_np_list.append(q_matrix_full)
    P_labels=['mmsb_a_07_b_03_noise_001']
    Q_labels=['mmsb_a_09_b_01_noise_001']
    
    res_df_mmsb = iter_exp_three_methods(
        P_full_np_list, Q_full_np_list, 
                           P_labels, Q_labels, 
                           n_vals=nq_vals,
                           p_flip_vals=p_flip_vals,
                           num_trials=num_trials,
                           seed=seed  
    )
    run_label = f'mmsb_seed_{seed}_numtrials_{num_trials}'
    res_df_mmsb.to_csv(f'exp-results/results_3_methods/{run_label}.csv')  



    print('Beginning High Dim Run...')

    graphon_fn_alpha = lambda x, y: gf.graphon_d_dim_exp(x, y, 
                                                        scale=2.5)
    graphon_fn_beta = lambda x, y: gf.graphon_d_dim_exp(x, y, 
                                                        scale=1.0)
    np_fixed = 200 
    nq_vals = [20, 30, 40, 50]
    P_full_np_list = []
    Q_full_np_list = []
    num_trials = 50
    p_flip_vals = [0.1, 0.3, 0.5]
    seed = 91

    gp = gf.GraphonPairHighDim(
        dim = 10, 
        p_graphon_fn=graphon_fn_alpha,
        q_graphon_fn=graphon_fn_beta, 
        n_p = np_fixed, 
        n_q = 20
    )
    P_full_np_list = []
    Q_full_np_list = []
    p_matrix_full = scipy.sparse.csr_matrix(gp.P.g1.copy())
    q_matrix_full = scipy.sparse.csr_matrix(gp.Q_extended.g1.copy())
    P_full_np_list.append(p_matrix_full)
    Q_full_np_list.append(q_matrix_full)

    p_matrix_full, q_matrix_full, _, _, _ = gf.get_mmsb_pair_two_param(20, np_fixed,
                kp_power = 0.5, kq_power = 0.5, 
                  ap = 0.7, bp = 0.3,
                 aq = 0.9, bq=0.1,
                 noisy=True, 
                 noise_level=0.01, seed=seed)
    
    P_labels=['graphon_d_dim_scale_25']
    Q_labels=['graphon_d_dim_scale_10']
    
    res_df_d_dim = iter_exp_three_methods(
        P_full_np_list, Q_full_np_list, 
                           P_labels, Q_labels, 
                           n_vals=nq_vals,
                           p_flip_vals=p_flip_vals,
                           num_trials=num_trials,
                           seed=seed  
    )
    run_label = f'd_dim_scale_p_25_q_10_seed_{seed}_numtrials_{num_trials}'
    res_df_d_dim.to_csv(f'exp-results/results_3_methods/{run_label}.csv')  


    time_diff = datetime.now() - startTime
    print(f'Total time {time_diff}')

if __name__ == '__main__': 
    run_all_realworld()
    run_all_simul()