import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import pickle
import sys
import pandas as pd


'''
This file creates figure 4 and table 1 of the paper. Use it with the following arguments:
    seed for fig. 4 (a, b)
    lin, non, nonlin or noncomp: string for type of figure. lin and non show 4a,b for
        linear/non-linear system, respectively. nonlin shows them for both system.
        lincomp/noncomp show 4a/b with linear or nonlinear system, but also create 4c/d.
    as many seeds as wished to produce table 1
    
    Example:
    python ECG_evaluation 1234 noncomp 1234 1235 1236 1237 4321
    creates fig 4 for network seed 1234 and table 1 over the seeds mentioned after noncomp
'''


#%%

seed = int(sys.argv[1])
print('figure', sys.argv[2], 'at seed', seed)

# Set up figure 4
if sys.argv[2] in ['lin', 'non']:
    # Histograms like fig. 4 (a,b) for linear or non-linear reservoir
    systems = [sys.argv[2]]
    
    fig1 = plt.figure(1, figsize=(20,8))
    plt.clf()
    gs = GridSpec(1,2, hspace=0.2, wspace=0.2, bottom=0.15, left=0.07, right=0.99)
    plt.rcParams.update({'font.size': 24})
    
    ax1a = fig1.add_subplot(gs[0,0])
    ax1a.set_title('(a)',loc='left')
    ax1a.set_xlabel('distance from separating hyperplane')
    ax1a.set_ylabel('frequency')
    
    ax1b = fig1.add_subplot(gs[0,1])
    ax1b.set_title('(b)',loc='left')
    ax1b.set_xlabel('distance from separating hyperplane')
    ax1b.set_ylabel('frequency')

elif sys.argv[2] == 'nonlin':
    # histograms like fig. 4 (a,b), but for both linear and non-linear reservoirs
    systems = ['lin', 'non']
    
    fig1 = plt.figure(1, figsize=(20,16))
    plt.clf()
    gs = GridSpec(2,2, hspace=0.2, wspace=0.2, top=0.95, bottom=0.05, left=0.07, right=0.99)
    plt.rcParams.update({'font.size': 24})
    
    ax1a = fig1.add_subplot(gs[0,0])
    ax1a.set_title('(a)',loc='left')
    ax1a.set_xlabel('distance from separating hyperplane')
    ax1a.set_ylabel('frequency')
    
    ax1b = fig1.add_subplot(gs[0,1])
    ax1b.set_title('(b)',loc='left')
    ax1b.set_xlabel('distance from separating hyperplane')
    ax1b.set_ylabel('frequency')
    
    ax1c = fig1.add_subplot(gs[1,0])
    ax1c.set_title('(c)',loc='left')
    ax1c.set_xlabel('distance from separating hyperplane')
    ax1c.set_ylabel('frequency')
    
    ax1d = fig1.add_subplot(gs[1,1])
    ax1d.set_title('(d)',loc='left')
    ax1d.set_xlabel('distance from separating hyperplane')
    ax1d.set_ylabel('frequency')

elif sys.argv[2] in ['lincomp', 'noncomp']:
    # fig. 4 (a-d), using linear ('lincomp') or non-linear ('noncomp') reservoir
    systems = [sys.argv[2][:3]]
    
    fig1 = plt.figure(1, figsize=(20,12))
    plt.clf()
    gs = GridSpec(2,2, hspace=0.4, wspace=0.2, top=0.95, bottom=0.1, left=0.07, right=0.9, height_ratios=[1.0, 0.4])
    plt.rcParams.update({'font.size': 24})
    
    ax1a = fig1.add_subplot(gs[0,0])
    ax1a.set_title('(a)',loc='left')
    ax1a.set_xlabel('distance from separating hyperplane')
    ax1a.set_ylabel('frequency')
    ax1a.yaxis.set_major_locator(plt.MultipleLocator(50))
    
    ax1b = fig1.add_subplot(gs[0,1])
    ax1b.set_title('(b)',loc='left')
    ax1b.set_xlabel('distance from separating hyperplane')
    ax1b.set_ylabel('frequency')
    ax1b.yaxis.set_major_locator(plt.MultipleLocator(50))
    
    ax1c = fig1.add_subplot(gs[1,0])
    ax1c.set_title('(c)',loc='left')
    ax1c.set_xlabel('optimization step')
    ax1c.set_ylabel(r'$\kappa_{\eta}$')
    
    ax1d = fig1.add_subplot(gs[1,1])
    ax1d.set_title('(d)',loc='left')
    ax1d.set_xlabel(r'time $t$')
    ax1d.set_ylabel(r'$\kappa_{\eta}$')




for idx, system in enumerate(systems):
    # load data for figure of seed and type named
    with open('data/ECG/ECG5000_'+str(seed)+'_'+system+'.txt', 'rb') as handle:
        full_dictionary = pickle.loads(handle.read())
    
    opt_dictionary = full_dictionary['opt_dictionary']
    rand_dictionary = full_dictionary['rand_dictionary']
    figure_dictionary = full_dictionary['figure_dictionary']
    composition_dictionary = full_dictionary['composition_dict']
    
    
    soft_margins_train = opt_dictionary['soft_margins_train']
    accs_test = opt_dictionary['accs_test']
    
    
    
    soft_margins_train_rand = rand_dictionary['soft_margins_train_rand']
    accs_test_rand = rand_dictionary['accs_test_rand']
    
    
    
    train_labels = figure_dictionary['train_labels']
    test_labels = figure_dictionary['test_labels']
    
    opt_distances = figure_dictionary['opt_distances']
    rand_distances = figure_dictionary['rand_distances']
    

    
    composition_optimization = composition_dictionary['composition_optimization']
    composition_time = composition_dictionary['composition_temporal']
    times = composition_dictionary['times']
    
    
    

    #%%
    
    # create figure
    if idx == 0:
        axrand = ax1a
        axopt = ax1b
    if idx == 1:
        axrand = ax1c
        axopt = ax1d
    
    
    labels = np.append(train_labels, test_labels)
    
    # find random input vector with best and worst accuracy
    min_rand_idx = np.argmin(np.abs(accs_test_rand - np.min(accs_test_rand)))
    max_rand_idx = np.argmin(np.abs(accs_test_rand - np.max(accs_test_rand)))
    
    
    #axrand
    bins = np.linspace(
            np.min(np.append(rand_distances.flatten(), opt_distances)),
            np.max(np.append(rand_distances.flatten(), opt_distances)),
            150)

    
    axrand.hist(np.mean(rand_distances, axis=0)[labels > 0], color='deepskyblue', bins=bins, alpha=0.5, histtype='stepfilled')
    axrand.hist(np.mean(rand_distances, axis=0)[labels < 0], color='crimson', bins=bins, alpha=0.5, histtype='stepfilled')
    
    axrand.hist(rand_distances[min_rand_idx, labels > 0], edgecolor='navy', facecolor='None', bins=bins, histtype='stepfilled')
    axrand.hist(rand_distances[min_rand_idx, labels < 0], edgecolor='darkred', facecolor='None', bins=bins, histtype='stepfilled')
    
    axrand.hist(rand_distances[max_rand_idx, labels > 0], edgecolor='navy', facecolor='None', bins=bins, histtype='stepfilled', linestyle='dotted')
    axrand.hist(rand_distances[max_rand_idx, labels < 0], edgecolor='darkred', facecolor='None', bins=bins, histtype='stepfilled', linestyle='dotted')
    
    
    
    #axopt
    axopt.hist(opt_distances[labels > 0], color='deepskyblue', bins=bins, alpha=0.5, histtype='stepfilled')
    axopt.hist(opt_distances[labels < 0], color='crimson', bins=bins, alpha=0.5, histtype='stepfilled')
    
    
    # use the same xlim for comparison
    axlima = axrand.get_xlim()
    axlimb = axopt.get_xlim()
    
    axrand.set_xlim([np.min([axlima, axlimb]), np.max([axlima, axlimb])])
    axopt.set_xlim([np.min([axlima, axlimb]), np.max([axlima, axlimb])])
    
    aylima = axrand.get_ylim()
    aylimb = axopt.get_ylim()
    
    axrand.set_ylim([np.min([aylima, aylimb]), np.max([aylima, aylimb])])
    axopt.set_ylim([np.min([aylima, aylimb]), np.max([aylima, aylimb])])
    
    
    
    
    
    # fig. 4 c,d
    if sys.argv[2] in ['lincomp', 'noncomp']:

        end_step = 10 # look at the first end_step optimizaton steps in fig. 4c
        ax1c.plot(composition_optimization[:end_step, 0], c='deepskyblue', label=r'$v^\mathrm{T}M^{u}$')
        ax1c.plot(composition_optimization[:end_step, 1], c='crimson', label=r'$\frac{\eta}{2}v^\mathrm{T}\Sigma^{u} v$')
        ax1c.fill_between(range(end_step), composition_optimization[:end_step, 0], composition_optimization[:end_step, 1], alpha=0.2, facecolor='cornflowerblue')
        
        ax1c.legend(loc='lower center')
        ax1c.set_ylim([np.min(composition_time * 1.05), np.max(composition_time)*1.05])
        

        
        ax1d.plot(times, composition_time[:, 0], c='deepskyblue', label=r'$v^\mathrm{T}M^{u}$')
        ax1d.plot(times, composition_time[:, 1], c='crimson', label=r'$\frac{\eta}{2}v^\mathrm{T}\Sigma^{u} v$')
        ax1d.fill_between(times, composition_time[:, 0], composition_time[:, 1], alpha=0.2, facecolor='cornflowerblue')
        
        ax1d.legend(loc='upper left')
        ax1d.set_ylim([np.min(composition_time * 1.05), np.max(composition_time)*1.05])
        
    
if sys.argv[2] == 'nonlin':
    # if a-d are histograms, use the same axlim for comparison
    axlima = ax1a.get_xlim()
    axlimb = ax1c.get_xlim()
    
    ax1a.set_xlim([np.min([axlima, axlimb]), np.max([axlima, axlimb])])
    ax1b.set_xlim([np.min([axlima, axlimb]), np.max([axlima, axlimb])])
    ax1c.set_xlim([np.min([axlima, axlimb]), np.max([axlima, axlimb])])
    ax1d.set_xlim([np.min([axlima, axlimb]), np.max([axlima, axlimb])])
    
    aylima = ax1a.get_ylim()
    aylimb = ax1c.get_ylim()
    
    ax1a.set_ylim([np.min([aylima, aylimb]), np.max([aylima, aylimb])])
    ax1b.set_ylim([np.min([aylima, aylimb]), np.max([aylima, aylimb])])
    ax1c.set_ylim([np.min([aylima, aylimb]), np.max([aylima, aylimb])])
    ax1d.set_ylim([np.min([aylima, aylimb]), np.max([aylima, aylimb])])
    
fig1.savefig('data/ECG/ECG5000_histogram_'+str(seed)+'_'+sys.argv[2]+'.pdf')








#%%

# Evaluate performance gains

if len(sys.argv) > 3:

    soft_margins_train_lin = []
    accs_test_lin = []

    soft_margins_train_rand_lin = []
    accs_test_rand_lin = []

    soft_margins_train_non = []
    accs_test_non = []

    soft_margins_train_rand_non = []
    accs_test_rand_non = []
    
    for seed in sys.argv[3:]:

        with open('data/ECG/ECG5000_'+seed+'_lin.txt', 'rb') as handle:
            full_dictionary = pickle.loads(handle.read())

        opt_dictionary = full_dictionary['opt_dictionary']
        rand_dictionary = full_dictionary['rand_dictionary']

        soft_margins_train_lin.append(opt_dictionary['soft_margins_train'])
        accs_test_lin.append(opt_dictionary['accs_test'])

        soft_margins_train_rand_lin.append(rand_dictionary['soft_margins_train_rand'])
        accs_test_rand_lin.append(rand_dictionary['accs_test_rand'])



        with open('data/ECG/ECG5000_'+seed+'_non.txt', 'rb') as handle:
            full_dictionary = pickle.loads(handle.read())

        opt_dictionary = full_dictionary['opt_dictionary']
        rand_dictionary = full_dictionary['rand_dictionary']


        soft_margins_train_non.append(opt_dictionary['soft_margins_train'])
        accs_test_non.append(opt_dictionary['accs_test'])

        soft_margins_train_rand_non.append(rand_dictionary['soft_margins_train_rand'])
        accs_test_rand_non.append(rand_dictionary['accs_test_rand'])


    soft_margins_train_lin = np.array(soft_margins_train_lin)
    accs_test_lin = np.array(accs_test_lin)

    soft_margins_train_rand_lin = np.array(soft_margins_train_rand_lin)
    accs_test_rand_lin = np.array(accs_test_rand_lin)

    soft_margins_train_non = np.array(soft_margins_train_non)
    accs_test_non = np.array(accs_test_non)

    soft_margins_train_rand_non = np.array(soft_margins_train_rand_non)
    accs_test_rand_non = np.array(accs_test_rand_non)



    #%%

    # main tables: soft margins and accuracies as in Table 1.
    index = ['random', 'optimized']


    soft_margins_data = {
            'linear': [str(np.mean(soft_margins_train_rand_lin))+'+-'+str(np.std(soft_margins_train_rand_lin)),
                       str(np.mean(soft_margins_train_lin))+'+-'+str(np.std(soft_margins_train_lin))],
            'nonlinear': [str(np.mean(soft_margins_train_rand_non))+'+-'+str(np.std(soft_margins_train_rand_non)),
                       str(np.mean(soft_margins_train_non))+'+-'+str(np.std(soft_margins_train_non))]
            }

    soft_margins_main_table = pd.DataFrame(soft_margins_data, index=index)
    print('Soft margins:')
    print(soft_margins_main_table)
    



    acc_data = {
            'linear': [str(np.mean(accs_test_rand_lin))+'+-'+str(np.std(accs_test_rand_lin)),
                       str(np.mean(accs_test_lin))+'+-'+str(np.std(accs_test_lin))],
            'nonlinear': [str(np.mean(accs_test_rand_non))+'+-'+str(np.std(accs_test_rand_non)),
                       str(np.mean(accs_test_non))+'+-'+str(np.std(accs_test_non))]
            }

    acc_main_table = pd.DataFrame(acc_data, index=index)
    print('Accuracies:')
    print(acc_main_table)



    #%%

    # nonlinearity gain tables
    # Difference of accuracy/soft margin in the same linear and non-linear system
    
    nonlinearity_gain_soft_margin_data = {
            'nonlinearity gain': [str(np.mean(soft_margins_train_rand_non - soft_margins_train_rand_lin))+'+-'+str(np.std(soft_margins_train_rand_non - soft_margins_train_rand_lin)),
                                  str(np.mean(soft_margins_train_non - soft_margins_train_lin))+'+-'+str(np.std(soft_margins_train_non - soft_margins_train_lin))]
            }

    nonlinearity_gain_soft_margin_table = pd.DataFrame(nonlinearity_gain_soft_margin_data, index=['random', 'optimal'])
    print('Nonlinearity gain soft margin:')
    print(nonlinearity_gain_soft_margin_table)

    nonlinearity_gain_soft_margin = soft_margins_train_non - soft_margins_train_lin
    print('Percentage of negative nonlinearity gain in optimized case:', len(nonlinearity_gain_soft_margin[nonlinearity_gain_soft_margin<0])*100./len(nonlinearity_gain_soft_margin), '%')
    




    nonlinearity_gain_acc_data = {
            'nonlinearity gain': [str(np.mean(accs_test_rand_non - accs_test_rand_lin))+'+-'+str(np.std(accs_test_rand_non - accs_test_rand_lin)),
                                  str(np.mean(accs_test_non - accs_test_lin))+'+-'+str(np.std(accs_test_non - accs_test_lin))]
            }

    nonlinearity_gain_acc_table = pd.DataFrame(nonlinearity_gain_acc_data, index=['random', 'optimal'])
    print('Nonlinearity gain accuracy:')
    print(nonlinearity_gain_acc_table)

    nonlinearity_gain_acc = accs_test_non - accs_test_lin
    print('Percentage of negative nonlinearity gain in optimized case:', len(nonlinearity_gain_acc[nonlinearity_gain_acc<0])*100./len(nonlinearity_gain_acc), '%')
    



    #%%

    # optimization gain tables
    # how much soft margin, accuracy is achieved from optimization


    # optimization gain
    optimization_gain_soft_margin_data = {
            'optimization gain': [str(np.mean(soft_margins_train_lin[:, None] - soft_margins_train_rand_lin))+'+-'+str(np.std(soft_margins_train_lin[:, None] - soft_margins_train_rand_lin)),
                                  str(np.mean(soft_margins_train_non[:, None] - soft_margins_train_rand_non))+'+-'+str(np.std(soft_margins_train_non[:, None] - soft_margins_train_rand_non))]
            }

    optimization_gain_soft_margin_table = pd.DataFrame(optimization_gain_soft_margin_data, index=['linear', 'nonlinear'])
    print('optimization gain soft margin:')
    print(optimization_gain_soft_margin_table)

    optimization_gain_soft_margin = (soft_margins_train_lin[:, None] - soft_margins_train_rand_lin).flatten()
    print('Percentage of negative optimization gain in linear case:', len(optimization_gain_soft_margin[optimization_gain_soft_margin<0])*100./len(optimization_gain_soft_margin), '%')

    optimization_gain_soft_margin = (soft_margins_train_non[:, None] - soft_margins_train_rand_non).flatten()
    print('Percentage of negative optimization gain in nonlinear case:', len(optimization_gain_soft_margin[optimization_gain_soft_margin<0])*100./len(optimization_gain_soft_margin), '%')




    optimization_gain_acc_data = {
            'optimization gain': [str(np.mean(accs_test_lin[:, None] - accs_test_rand_lin))+'+-'+str(np.std(accs_test_lin[:, None] - accs_test_rand_lin)),
                                  str(np.mean(accs_test_non[:, None] - accs_test_rand_non))+'+-'+str(np.std(accs_test_non[:, None] - accs_test_rand_non))]
            }

    optimization_gain_acc_table = pd.DataFrame(optimization_gain_acc_data, index=['linear', 'nonlinear'])
    print('optimization gain accuracy:')
    print(optimization_gain_acc_table)

    optimization_gain_acc = (accs_test_lin[:, None] - accs_test_rand_lin).flatten()
    print('Percentage of negative optimization gain in linear case:', len(optimization_gain_acc[optimization_gain_acc<0])*100./len(optimization_gain_acc), '%')

    optimization_gain_acc = (accs_test_non[:, None] - accs_test_rand_non).flatten()
    print('Percentage of negative optimization gain in nonlinear case:', len(optimization_gain_acc[optimization_gain_acc<0])*100./len(optimization_gain_acc), '%')


