import numpy as np
# import matplotlib
from matplotlib import pyplot as plt
from numpy.core.fromnumeric import size
from numpy.lib.function_base import average
import pandas as pd
import seaborn as sns
import os, argparse

# font = {'size'   : 40}
# matplotlib.rc('font', **font)
plt.rcParams.update({'font.size': 100})

print('plot inf-dim instance 10 seeds...')

# load offline eq
n = 100
# load offline equilibrium
u_opt = np.loadtxt(os.path.join('results', 'inf-dim', 'offline-eq', 'u'))

sns.set_theme()

os.makedirs('plots', exist_ok=True)

# average across seeds
from matplotlib import pyplot as plt
import seaborn as sns
# sns.set_theme()
import os, json

inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = [], [], []
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = [], [], []

for seed in range(1, 11):
    fpath = os.path.join('results', 'inf-dim', 'sd-{}'.format(seed))
    with open(os.path.join(fpath, 'meta_data'), 'r') as ff:
        meta_data = json.load(ff)
    T = meta_data['T']
    inf_norm_to_B = np.loadtxt(os.path.join(fpath, 'inf_norm_to_B.gz'))
    ave_one_norm_to_B = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'))
    inf_norm_to_B_all_seeds.append(inf_norm_to_B)
    ave_one_norm_to_B_all_seeds.append(ave_one_norm_to_B)

# convert them into numpy arrays and compute quartiles
inf_norm_to_B_all_seeds, ave_one_norm_to_B_all_seeds = np.array(inf_norm_to_B_all_seeds), np.array(ave_one_norm_to_B_all_seeds)
inf_norm_to_B_quartiles, ave_one_norm_to_B_quartiles = np.percentile(inf_norm_to_B_all_seeds, q=(0, 25, 50, 75, 100), axis = 0), np.percentile(ave_one_norm_to_B_all_seeds, q=(0, 25, 50, 75, 100), axis = 0)

# np.std(inf_norm_to_u_eq_all_seed, axis=0)
t0 = 5*n
T = n * 100
skip_size = 1
num_dp = (T - t0) // skip_size

###### plot inf_norm_to_B & ave_norm_to_B stuff (same as in plot_budget_diff.py) ######
linestyles = ((0, (3, 1, 1, 1, 1, 1)), 'dotted', 'solid', 'dashdot', 'dashed')
labels = (r'min', r'$Q_1$', r'$Q_2$ (median)', r'$Q_3$', r'max')
fig = plt.figure(figsize=(6, 4))
for idx in range(5):
    plt.plot(range(t0+1, T+1, skip_size), inf_norm_to_B_quartiles[idx, range(t0, T, skip_size)], label = labels[idx], linestyle = linestyles[idx])
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0] # epoch lines
plt.legend()
plt.legend(prop={'size': 15})
plt.title(r'Inf-Dim, Quartiles of $||(\bar{{b}}^t - B)/B||_\infty$', fontsize=15)
plt.savefig(os.path.join('plots', 'quartiles-max-relative-error-spending-inf-dim-n-{}.pdf'.format(n)))
plt.clf()

fig = plt.figure(figsize=(6, 4))
for idx in range(5):
    plt.plot(range(t0+1, T+1, skip_size), ave_one_norm_to_B_quartiles[idx, range(t0, T, skip_size)], label = labels[idx], linestyle = linestyles[idx])
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0] # epoch lines
plt.legend(prop={'size': 15})
plt.title(r'Inf-Dim, Quartiles of $||(\bar{{b}}^t - B)/B||_1/n$', fontsize=15)
plt.savefig(os.path.join('plots', 'quartiles-ave-relative-error-spending-inf-dim-n-{}.pdf'.format(n)))
plt.clf()