import os

import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import GridSearchCV, train_test_split

from dataset_utils import apply_noise

# model_init should take in noise, n, and t and output an initialized sklearn model
def run_experiment(
    X, 
    y, 
    noises, 
    model_init, 
    parameter_name, 
    parameter_list, 
    num_draws = 5, 
    num_splits = 5,
):
    params = []
    chosen_params = [[] for i in range(num_draws)]
    for noise in noises:
        for n in range(num_draws):
            _, y_p, _, _ = apply_noise(
                X, 
                y, 
                noise, 
                y.shape[0], 
                seed=int((n+1) * ((noise+3.3) ** 2) * 1010), 
                exact=False, 
                replace=False
            )
            for t in range(num_splits):
                print(noise, n, t)
                X_train, X_val, y_train, y_val = train_test_split(X, y_p, test_size=0.2, random_state=int(((t+3.3) ** 2) * 1010))
                
                model = model_init(noise, n, t)
                parameters = {parameter_name: parameter_list}
        
                clf = GridSearchCV(model, parameters, return_train_score=True)
                clf.fit(X_train, y_train)

                best = clf.best_params_[parameter_name]

                params.append([best, noise])
                if noise == 0:
                    chosen_params[n].append(best)

    params = pd.DataFrame(params, columns=['Parameter', 'Noise'])
    return params, chosen_params

# model_init should take in noise, n, t, and the tuned parameter and output an initialized sklearn model
def find_diffs(
    X, 
    y, 
    chosen_params,
    noises, 
    model_init, 
    num_draws = 5, 
    num_splits = 5,
    
):
    scores = []
    for noise in noises:
        for n in range(num_draws):
            _, y_p, _, _ = apply_noise(
                X, 
                y, 
                noise, 
                y.shape[0], 
                seed=int((n+1) * ((noise+3.3) ** 2) * 1010), 
                exact=False, 
                replace=False
            )
            for t in range(num_splits):
                X_train, X_val, y_train, y_val = train_test_split(X, y_p, test_size=0.2, random_state=int(((t+3.3) ** 2) * 1010)) #only depend on trial

                param = chosen_params[n][t]

                clf = model_init(noise, n, t, parameter=param)

                clf.fit(X_train, y_train)

                train = clf.score(X_train, y_train)
                val = clf.score(X_val, y_val)

                scores.append([train, val, train-val, noise, n, t]) 
    
    scores = pd.DataFrame(scores, columns=['Train', 'Val', 'Train-Val', 'Noise', 'n', 't'])
    return scores

def savefig(path, fig):
    fig.savefig(f'{path}.png', bbox_inches = 'tight')
    fig.savefig(f'{path}.pdf', bbox_inches = 'tight')

#graphs one line
def graph_single(path, data, y_name, y_label, legend_label):
    fig, ax = plt.subplots()

    sns.lineplot(data=data, x='Noise', y=y_name, ax=ax, label=legend_label)

    ax.set_xlabel(r"Label noise, $\rho$", size = 22)
    ax.set_ylabel(y_label, size = 22)  
    ax.tick_params(axis='both', labelsize=18)
    ax.legend(fontsize=18)

    savefig(path, fig)

#graphs multiple lines on the same axis
def graph_together(path, data_list, y_name, y_label, legend_labels):
    fig, ax = plt.subplots()

    for data, label in zip(data_list, legend_labels):
        sns.lineplot(data=data, x='Noise', y=y_name, ax=ax, label=label)

    ax.set_xlabel(r"Label noise, $\rho$", size = 22)
    ax.set_ylabel(y_label, size = 22)  
    ax.tick_params(axis='both', labelsize=18)
    ax.legend(fontsize=18)

    savefig(path, fig)

#graphs multiple lines on different axes
def graph_separate(path, data_list, y_name, y_label, legend_labels):
    fig = plt.figure()
    gs = fig.add_gridspec(2, 2)
    axs = gs.subplots(sharex='col')

    cmap = matplotlib.cm.get_cmap('tab10')

    for j, (data, label) in enumerate(zip(data_list, legend_labels)):
        a = j // 2
        b = j % 2
        color = cmap(0.05 + j * 0.1)
        sns.lineplot(data=data, x='Noise', y=y_name, ax=axs[a][b], label=label, color=color)

    for i, axs2 in enumerate(axs):
        for ax in axs2:
            ax.set_xlabel('')
            ax.set_ylabel('')
            if i == 1:
                ax.set_xlabel(r"Label noise, $\rho$", size = 22)
            ax.tick_params(axis='both', labelsize=18)
            ax.legend(fontsize=18, loc='upper left')

            fig.supylabel(y_label, size = 22)

    fig.set_size_inches(13.8, 4.8)

    savefig(path, fig)


