import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.lines import Line2D
import numpy as np
import argparse

#second version of code for plots

colors = rcParams['axes.prop_cycle'].by_key()["color"]

parser = argparse.ArgumentParser(description='ReLU Experiment Plotter')
parser.add_argument('--nx', type=int, default=20, metavar='N')
parser.add_argument('--ny', type=int, default=20, metavar='N')
parser.add_argument('--convex', type=int, default=0, metavar='N')

args = parser.parse_args()

nx = args.nx
ny = args.ny
convex = args.convex
if convex:
    filestring = "nx={:04d}_ny={:04d}_cv".format(nx, ny)
else:
    filestring = "nx={:04d}_ny={:04d}".format(nx, ny)

n_col = 1
n_row = 1
fig, axs = plt.subplots(n_row, n_col, figsize=(n_col * 5, n_row * 4))

dim_list = [1,2,4,8,16,32]

log_data = np.loadtxt("log_" + filestring + ".dat")
time_steps = 6
nruns = 20
log_data_list = [] 
for i in range(time_steps):
    log_data_list.append(log_data[(10*3*20*i):(10*3*20*(i+1)),:]) 
log_data_LGV = []
log_data_LGV_min = []
log_data_LGV_max = []
log_data_LGV_dev = []
log_data_MD = []
log_data_MD_min = []
log_data_MD_max = []
log_data_MD_dev = []
log_data_WFR = []
log_data_WFR_min = []
log_data_WFR_max = []
log_data_WFR_dev = [] 

for row in range(time_steps):
    log_data_LGV_complete = log_data_list[row][:10,2] 
    log_data_MD_complete = log_data_list[row][10:20,2]
    log_data_WFR_complete = log_data_list[row][20:30,2]
    for nrun in range(1,nruns):
        log_data_LGV_complete = np.concatenate((log_data_LGV_complete,log_data_list[row][(30*nrun):(30*nrun)+10,2]))
        log_data_MD_complete = np.concatenate((log_data_MD_complete,log_data_list[row][(30*nrun)+10:(30*nrun)+20,2]))
        log_data_WFR_complete = np.concatenate((log_data_WFR_complete,log_data_list[row][(30*nrun)+20:(30*nrun)+30,2]))

    log_data_LGV.append(np.mean(log_data_LGV_complete))
    log_data_MD.append(np.mean(log_data_MD_complete))
    log_data_WFR.append(np.mean(log_data_WFR_complete))
    log_data_LGV_max.append(np.max(log_data_LGV_complete))
    log_data_MD_max.append(np.max(log_data_MD_complete))
    log_data_WFR_max.append(np.max(log_data_WFR_complete))
    log_data_LGV_min.append(np.min(log_data_LGV_complete))
    log_data_MD_min.append(np.min(log_data_MD_complete))
    log_data_WFR_min.append(np.min(log_data_WFR_complete))
    log_data_LGV_dev.append(np.std(log_data_LGV_complete))
    log_data_MD_dev.append(np.std(log_data_MD_complete))
    log_data_WFR_dev.append(np.std(log_data_WFR_complete))

row_vec = range(6)
axs.errorbar(row_vec, np.log(np.maximum(log_data_LGV,np.exp(-10))), 
        yerr=[np.log(np.maximum(log_data_LGV,np.exp(-10)))-np.log(np.maximum(np.array(log_data_LGV) - np.array(log_data_LGV_dev),np.exp(-10))),
            np.log(np.maximum(np.array(log_data_LGV) + np.array(log_data_LGV_dev),np.exp(-10)))-np.log(np.maximum(log_data_LGV,np.exp(-10)))],
        marker='o',color=colors[0]) 
axs.errorbar(row_vec, np.log(np.maximum(log_data_MD,np.exp(-10))), 
        yerr=[np.log(np.maximum(log_data_MD,np.exp(-10)))-np.log(np.maximum(np.array(log_data_MD) - np.array(log_data_MD_dev),np.exp(-10))),
            np.log(np.maximum(np.array(log_data_MD) + np.array(log_data_MD_dev),np.exp(-10)))-np.log(np.maximum(log_data_MD,np.exp(-10)))],
        marker='o',color=colors[1])
axs.errorbar(row_vec, np.log(np.maximum(log_data_WFR,np.exp(-10))), 
        yerr=[np.log(np.maximum(log_data_WFR,np.exp(-10)))-np.log(np.maximum(np.array(log_data_WFR) - np.array(log_data_WFR_dev),np.exp(-10))),
            np.log(np.maximum(np.array(log_data_WFR) + np.array(log_data_WFR_dev),np.exp(-10)))-np.log(np.maximum(log_data_WFR,np.exp(-10)))],
        marker='o',color=colors[2])

axs.set_ylabel(r"log(NI)")
axs.set_xlabel(r"$\log_2(d)$")
axs.legend(fontsize=8)

legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[0], label='Scatter', markersize=8),
                   Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[1], label='Scatter', markersize=8),
                   Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[2], label='Scatter', markersize=8)]
axs.legend(legend_elements, ['LGV', 'MD', 'WFR'])

fig.tight_layout()
fig.savefig("plots_" + filestring + ".pdf")
