import numpy
import matplotlib.pyplot as plt
import os
import time
import sys
import pickle

def plot_subopt(data, arrival_type):
    metric = 'subopt'
    b = 100
    current_time = time.strftime('%Y-%m-%d_%H%M%S')
    name = arrival_type + '-' + metric
    path_png = os.path.join('output', name, current_time, 'png')
    path_eps = os.path.join('output', name, current_time, 'eps')
    os.makedirs(path_png)
    os.makedirs(path_eps)
    
    plt.rcParams.update({'font.size': 10})
    plt.rc('xtick', labelsize=7)
    plt.rc('ytick', labelsize=7)
    
    xx = range(b)
    msize = 4
    
    
    fig, ax0 = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(12, 6))
    
    for i in xrange(1,7):
        ax1 = plt.subplot(2, 3, i)

        if i == 1 or i == 4:
            dataset = 'a9a'
        elif i == 2 or i == 5:
            dataset = 'rcv'
        elif i == 3 or i == 6:
            dataset = 'movielens1m'
            
        if i <= 3:
            rate = 1
        else:
            rate = 5
        
        d = data[(dataset, arrival_type, rate)]
        
        l1, = ax1.semilogy(xx, zero_to_nan(d['Inc'][metric]), 'mo-', markersize=msize, markevery=10)
        if rate == 1:
            l2, = ax1.semilogy(xx, zero_to_nan(d['S'][metric]), 'yD-', markersize=msize, markevery=10)
        l3, = ax1.semilogy(xx, zero_to_nan(d['B'][metric]), 'rs-', markersize=msize, markevery=10)
        l4, = ax1.semilogy(xx, zero_to_nan(d['SGD_pass'][metric]), 'b^-', markersize=msize, markevery=10)
        ax2 = ax1.twinx()
        l0, = ax2.plot(xx, d['sample'], 'g--')
        ax2.set_ylim(0, 1.1)
        ax2.tick_params(axis='y', labelcolor='green')
        plt.title(r'{0}, $\rho = {1}\lambda$'.format(dataset, rate))
        plt.xlim(0, b)
   
    fig.text(0.5, 0.08, 'Time', ha='center')
    fig.text(0.04, 0.5, r'Suboptimality on $S_i$', va='center', rotation='vertical')
    fig.text(0.96, 0.5, r'$t^{STR}/t^D$', va='center', rotation='vertical', color='green')


    handles = [l1, l2, l3, l4, l0]
    labels = ['STRSAGA', 'SSVRG', r'DYNASAGA($\rho$)', 'SGD', r'$t^{STR}/t^D$']
    plt.figlegend( handles, labels, loc = 'lower center', ncol=5, labelspacing=0. )

    plt.subplots_adjust(left=.1, bottom=0.15, right=0.9, top=0.9, wspace=0.3, hspace=0.3)

    plt.savefig(os.path.join(path_eps,'{0}.eps'.format(name)), format='eps')
    plt.savefig(os.path.join(path_png,'{0}.png'.format(name)), format='png', dpi=200)
    
def plot_test(data, arrival_type):
    metric = 'test'
    b = 100
    current_time = time.strftime('%Y-%m-%d_%H%M%S')
    name = arrival_type + '-' + metric
    path_png = os.path.join('output', name, current_time, 'png')
    path_eps = os.path.join('output', name, current_time, 'eps')
    os.makedirs(path_png)
    os.makedirs(path_eps)
    
    plt.rcParams.update({'font.size': 10})
    plt.rc('xtick', labelsize=7)
    plt.rc('ytick', labelsize=7)
    
    xx = range(b)
    msize = 4
    
    
    fig, ax0 = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(12, 6))
    
    for i in xrange(1,7):
        ax1 = plt.subplot(2, 3, i)

        if i == 1 or i == 4:
            dataset = 'a9a'
        elif i == 2 or i == 5:
            dataset = 'rcv'
        elif i == 3 or i == 6:
            dataset = 'movielens1m'
            
        if i <= 3:
            rate = 1
        else:
            rate = 5
        
        d = data[(dataset, arrival_type, rate)]
        
        l1, = ax1.plot(xx[10:], d['Inc'][metric][10:], 'mo-', markersize=msize, markevery=10)
        if rate == 1:
            l2, = ax1.plot(xx[10:], d['S'][metric][10:], 'yD-', markersize=msize, markevery=10)
        l3, = ax1.plot(xx[10:], d['B'][metric][10:], 'rs-', markersize=msize, markevery=10)
        l4, = ax1.plot(xx[10:], d['SGD_pass'][metric][10:], 'b^-', markersize=msize, markevery=10)
        plt.title(r'{0}, $\rho = {1}\lambda$'.format(dataset, rate))
        plt.xlim(0, b)
   
    fig.text(0.5, 0.08, 'Time', ha='center')
    fig.text(0.04, 0.5, 'Average test loss', va='center', rotation='vertical')


    handles = [l1, l2, l3, l4]
    labels = ['STRSAGA', 'SSVRG', r'DYNASAGA($\rho$)', 'SGD']
    plt.figlegend( handles, labels, loc = 'lower center', ncol=5, labelspacing=0. )

    plt.subplots_adjust(left=.1, bottom=0.15, right=0.9, top=0.9, wspace=0.3, hspace=0.3)

    plt.savefig(os.path.join(path_eps,'{0}.eps'.format(name)), format='eps')
    plt.savefig(os.path.join(path_png,'{0}.png'.format(name)), format='png', dpi=200)
    

def median_outputs(output_list, b):
    output = {
                'Inc': {'test': [0]*b, 'subopt': [0]*b},
                'S': {'test': [0]*b, 'subopt': [0]*b},
                'B': {'test': [0]*b, 'subopt': [0]*b},
                'A': {'test': [0]*b},
                'sample': [0]*b,
                'SGD_pass': {'test': [0]*b, 'subopt': [0]*b},
                'SGD_unif': {'test': [0]*b, 'subopt': [0]*b}
             }
             
    for t in xrange(b):
        output['Inc']['test'][t] = numpy.median([ o['Inc']['test'][t] for o in output_list ])
        output['Inc']['subopt'][t] = numpy.median([ o['Inc']['subopt'][t] for o in output_list ])
        output['S']['test'][t] = numpy.median([ o['S']['test'][t] for o in output_list ])
        output['S']['subopt'][t] = numpy.median([ o['S']['subopt'][t] for o in output_list ])
        output['B']['test'][t] = numpy.median([ o['B']['test'][t] for o in output_list ])
        output['B']['subopt'][t] = numpy.median([ o['B']['subopt'][t] for o in output_list ])
        output['A']['test'][t] = numpy.median([ o['A']['test'][t] for o in output_list ])
        output['sample'][t] = numpy.median([ o['sample'][t] for o in output_list ])
        
        output['SGD_pass']['test'][t] = numpy.median([ o['SGD_pass']['test'][t] for o in output_list ])
        output['SGD_pass']['subopt'][t] = numpy.median([ o['SGD_pass']['subopt'][t] for o in output_list ])
        output['SGD_unif']['test'][t] = numpy.median([ o['SGD_unif']['test'][t] for o in output_list ])
        output['SGD_unif']['subopt'][t] = numpy.median([ o['SGD_unif']['subopt'][t] for o in output_list ])
    
    return output
        
def zero_to_nan(values):
    """Replace every 0 with 'nan' and return a copy."""
    return [float('nan') if x==0 else x for x in values]
        
if __name__ == "__main__":
    if len(sys.argv) < 2:
        print "needs argument arrival_type (skewed, poisson)"
        exit()
    arrival_type = sys.argv[1]
    
    datasets = ['a9a', 'rcv', 'movielens1m']
    arrival_types = [arrival_type]
    rates = [1, 5]
    
    data = {}
    for d in datasets:
        for a in arrival_types:
            for r in rates:
                name = d + '-' + a
                with open('output/output_data/{0}_r{1}.pkl'.format(name, r), 'r') as f:
                    outputs = pickle.load(f)
                data[(d, a, r)] = median_outputs(outputs, 100)
    
    plot_subopt(data, arrival_type=arrival_type)
    plot_test(data, arrival_type=arrival_type)
