import numpy
import random
import time
import os
import matplotlib.pyplot as plt
import datasets
import models
import sys
import pickle

# Figure 2b and 2c
def plot_rate(output, rates):
    name = 'sensitivity-rho'
    current_time = time.strftime('%Y-%m-%d_%H%M%S')
    path_png = os.path.join('output', name, current_time, 'png')
    path_eps = os.path.join('output', name, current_time, 'eps')
    os.makedirs(path_png)
    os.makedirs(path_eps)
    
    k = 3
    msize = 7
    
    plt.rcParams.update({'font.size': 18})

    
    plt.figure(1)
    plt.clf()
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['sample'][24] for rate in rates],
                    'mo-', markersize=msize, label='i=25',
                    basex=2
                )
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['sample'][49] for rate in rates],
                    'gD-', markersize=msize, label='i=50',
                    basex=2
                )
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['sample'][74] for rate in rates],
                    'rs-', markersize=msize, label='i=75',
                    basex=2
                )
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['sample'][99] for rate in rates],
                    'bD-', markersize=msize, label='i=100',
                    basex=2
                )
    plt.xlabel(r'$\rho/\lambda$')
    plt.ylabel(r'$t^{STR}/t^D$')
    plt.title(r'Sample-competitive ratio ($M = 2^{0}\lambda$)'.format(k))
    plt.gcf().subplots_adjust(left = 0.17, bottom=0.15)    
    plt.legend()
    plt.savefig(os.path.join(path_eps,'{0}_sample.eps'.format(name)), format='eps')
    plt.savefig(os.path.join(path_png,'{0}_sample.png'.format(name)), format='png', dpi=200)
    
    plt.figure(2)
    plt.clf()
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['Inc_subopt'][24] for rate in rates],
                    'mo-', markersize=msize, label='i=25',
                    basex=2
                )
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['Inc_subopt'][49] for rate in rates],
                    'gD-', markersize=msize, label='i=50',
                    basex=2
                )
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['Inc_subopt'][74] for rate in rates],
                    'rs-', markersize=msize, label='i=75',
                    basex=2
                )
    plt.semilogx(   
                    rates, 
                    [output[(rate, k)]['Inc_subopt'][99] for rate in rates],
                    'bD-', markersize=msize, label='i=100',
                    basex=2
                )
    plt.xlabel(r'$\rho/\lambda$')
    plt.ylabel(r'Suboptimality on $S_i$')
    plt.title(r'Suboptimality ($M = 2^{0}\lambda$)'.format(k))
    plt.gcf().subplots_adjust(left = 0.17, bottom=0.15)
    plt.legend()
    plt.savefig(os.path.join(path_eps,'{0}_subopt.eps'.format(name)), format='eps')
    plt.savefig(os.path.join(path_png,'{0}_subopt.png'.format(name)), format='png', dpi=200)


# Figure 2a
def plot_m(output, ks):
    name = 'sensitivity-m'
    current_time = time.strftime('%Y-%m-%d_%H%M%S')
    path_png = os.path.join('output', name, current_time, 'png')
    path_eps = os.path.join('output', name, current_time, 'eps')
    os.makedirs(path_png)
    os.makedirs(path_eps)
    
    plt.rcParams.update({'font.size': 18})
    
    rate = 1.0
    Ms = [1 << k for k in ks]
    msize = 7
    
    plt.figure(1)
    plt.clf()
    plt.semilogx(   
                    Ms, 
                    [output[(rate, k)]['sample'][24] for k in ks],
                    'mo-', markersize=msize, label='i=25',
                    basex=2
                )
    plt.semilogx(   
                    Ms, 
                    [output[(rate, k)]['sample'][49] for k in ks],
                    'gD-', markersize=msize, label='i=50',
                    basex=2
                )
    plt.semilogx(   
                    Ms, 
                    [output[(rate, k)]['sample'][74] for k in ks],
                    'rs-', markersize=msize, label='i=75',
                    basex=2
                )
    plt.semilogx(   
                    Ms, 
                    [output[(rate, k)]['sample'][99] for k in ks],
                    'bD-', markersize=msize, label='i=100',
                    basex=2
                )
    plt.xlabel(r'$M/\lambda$')
    plt.ylabel(r'$t^{STR}/t^D$')
    plt.title(r'Sample-competitive ratio ($\rho = {0}\lambda$)'.format(1))
    plt.gcf().subplots_adjust(left = 0.17, bottom=0.15)
    plt.legend()
    plt.savefig(os.path.join(path_eps,'{0}_sample.eps'.format(name)), format='eps')
    plt.savefig(os.path.join(path_png,'{0}_sample.png'.format(name)), format='png', dpi=200)


def median_outputs(output_list, b):
    output = {
                'Inc_test': [0]*b,
                'Inc_subopt': [0]*b,
                'sample': [0]*b
             }
             
    for t in xrange(b):
        output['Inc_test'][t] = numpy.median([ o['Inc']['test'][t] for o in output_list ])
        output['Inc_subopt'][t] = numpy.median([ o['Inc']['subopt'][t] for o in output_list ])
        output['sample'][t] = numpy.median([ o['sample'][t] for o in output_list ])
    
    return output
    

if __name__ == "__main__":
    b = 100    
    rates = [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]
    ks = [0, 1, 2, 3, 4]
    
    data = {}
    
    for rr in rates:
        for kk in ks:
            with open('output/output_data/rcv-r{0}_m{1}.pkl'.format(rr, kk), 'r') as f:
                outputs = pickle.load(f)
            data[(rr, kk)] = median_outputs(outputs, b)
    
    plot_rate(data, rates)
    plot_m(data, ks)
