import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

# datasets = ['banknote', 'bcc', 'penguin', 'iran_customer_churn', 'column', 'digits4', 'seeds', 'occupancy_detection', 'fico', 'australian_credit', 'give_credit']
# datasets = ['banknote', 'seeds', 'column', 'digits4', 'penguin', 'give_credit']
# datasets = ['bcc', 'iran_customer_churn', 'occupancy_detection', 'fico', 'australian_credit']

datasets = ['banknote', 'iran_customer_churn', 'seeds', 'column', 'fico', 'australian_credit']
names = ['Banknote', 'Iran', 'Seeds', 'Column', 'Fico', 'Australian']

# datasets = ['bcc']

# dataset = 'iran_customer_churn'
noises = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
num_trials = 100
epochs = 1000

def graph_indiv():
    for dataset in datasets:
        filter = np.load(f'exponential_loss_csvs/datasets/{dataset}/filter.npy')

        expe_weights = np.load(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_expe_weight.npy')
        draw_weights = np.load(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_draw_weight.npy')

        expe_weights = expe_weights * filter[np.newaxis, :]
        draw_weights = draw_weights * filter[np.newaxis, np.newaxis, :]

        noises_arr = np.array(noises)
        expe_norms = (expe_weights ** 2).sum(axis=-1)
        expe_norms = np.stack((noises_arr, expe_norms), axis=1)

        noises_arr = np.repeat(noises_arr, num_trials)
        draw_norms = (draw_weights ** 2).sum(axis=-1)
        draw_norms = draw_norms.flatten()
        draw_norms = np.stack((noises_arr, draw_norms), axis=1)

        expe_df = pd.DataFrame(expe_norms, columns=['sigma', 'squared_norm']) 
        draw_df = pd.DataFrame(draw_norms, columns=['sigma', 'squared_norm']) 

        fig, ax = plt.subplots()

        sns.lineplot(data=expe_df, x='sigma', y='squared_norm', ax=ax)

        ax.set_xlabel(r"Standard Deviation, $\sigma$", size = 22)
        ax.set_ylabel('Expected Squared Norm', size = 22)  
        ax.legend([dataset])

        fig.savefig(f'./figures/{dataset}_{noises}_{num_trials}_{epochs}_expe.png', bbox_inches = 'tight')


        fig, ax = plt.subplots()

        sns.lineplot(data=draw_df, x='sigma', y='squared_norm', ax=ax)

        ax.set_xlabel(r"Standard Deviation, $\sigma$", size = 22)
        ax.set_ylabel('Squared Norm', size = 22)  
        ax.legend([dataset])

        fig.savefig(f'./figures/{dataset}_{noises}_{num_trials}_{epochs}_draw.png', bbox_inches = 'tight')

def graph_together():
    expe_dfs = []
    draw_dfs = []
    for dataset, name in zip(datasets, names):
        filter = np.load(f'exponential_loss_csvs/datasets/{dataset}/filter.npy')

        expe_weights = np.load(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_expe_weight.npy')
        draw_weights = np.load(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_draw_weight.npy')

        expe_weights = expe_weights * filter[np.newaxis, :]
        draw_weights = draw_weights * filter[np.newaxis, np.newaxis, :]

        noises_arr = np.array(noises)
        expe_norms = (expe_weights ** 2).sum(axis=-1)
        expe_norms = np.stack((noises_arr, expe_norms), axis=1)

        noises_arr = np.repeat(noises_arr, num_trials)
        draw_norms = (draw_weights ** 2).sum(axis=-1)
        draw_norms = draw_norms.flatten()
        draw_norms = np.stack((noises_arr, draw_norms), axis=1)

        expe_df = pd.DataFrame(expe_norms, columns=['sigma', 'squared_norm']) 
        draw_df = pd.DataFrame(draw_norms, columns=['sigma', 'squared_norm']) 

        expe_df['Dataset'] = name
        draw_df['Dataset'] = name

        expe_dfs.append(expe_df)
        draw_dfs.append(draw_df)

    expe_df = pd.concat(expe_dfs)
    draw_df = pd.concat(draw_dfs)

    fig, ax = plt.subplots()

    sns.lineplot(data=expe_df, x='sigma', y='squared_norm', hue='Dataset', ax=ax)

    ax.set_xlabel(r"Standard Deviation, $\sigma$", size = 20)
    ax.set_ylabel('Expected Squared Norm', size = 20)  
    ax.set_yscale('log')
    ax.tick_params(axis='x', labelsize=13)
    ax.tick_params(axis='y', labelsize=13)
    
    fig.savefig(f'./figures/all_{noises}_{num_trials}_{epochs}_expe.png', bbox_inches = 'tight')


    fig, ax = plt.subplots()
    fig.set_size_inches(8, 10)

    sns.lineplot(data=draw_df, x='sigma', y='squared_norm', hue='Dataset', ax=ax)

    ax.set_xlabel(r"Standard Deviation, $\sigma$", size = 30)
    ax.set_ylabel('Squared Norm of Relevant Weights', size = 30)
    # ax.set_ylabel('Squared Norm', size = 30)
    ax.set_yscale('log')
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)
    ax.legend(loc='upper center', bbox_to_anchor=(0.5,-0.15), ncol=2, fontsize=30)
    # ax.legend(loc='lower left', ncol=2)
    # fig.legend(ncol=2, fontsize=20)

    # ax.set_title('test', fontsize=30)
    
    fig.savefig(f'./figures/all_{noises}_{num_trials}_{epochs}_draw.pdf', bbox_inches = 'tight')
    fig.savefig(f'./figures/all_{noises}_{num_trials}_{epochs}_draw.png', bbox_inches = 'tight')


# graph_indiv()
graph_together()
