from collections import defaultdict
from matplotlib import pyplot as plt
import numpy as np

n, m = 100, 300
all_plen = (5,10,20,50,100,200)

u_hseq_all = []
u_ave_baseline_all = []
ave_one_norm_to_beta_hseq_all = []
inf_norm_to_u_hseq_all = []
ave_one_norm_to_u_hseq_all = []
max_rel_buyer_regret_all = []
ave_rel_buyer_regret_all = []
inf_norm_u_ave_baseline_all = []

plen = 10

# compare plen = 5, 50, 200

all_data = {}

# get all beta convergence data form all seeds and all plen
for plen in (5, 50, 200):
    inf_norm_to_beta_hseq_all = []
    for sd in range(10):
        (
            u_hseq, B, 
            inf_norm_u_ave_baseline,
            inf_norm_to_beta_hseq, #ave_one_norm_to_beta_hseq, 
            inf_norm_to_u_hseq, #ave_one_norm_to_u_hseq, 
            inf_norm_to_B, #ave_one_norm_to_B, 
        ) = np.load(
            'results/iid_vs_periodic_100_300_{}_{}.npz'.format(plen, sd)
        ).values()
        inf_norm_to_beta_hseq_all.append(inf_norm_to_beta_hseq)
    
    inf_norm_to_beta_hseq_all = np.array(inf_norm_to_beta_hseq_all)
    all_data[plen] = inf_norm_to_beta_hseq_all

T = len(inf_norm_to_beta_hseq)
fig = plt.figure(figsize=(6, 4))
t0 = int(T//50)
skip_size = max(int(T//2000), 5)
num_dp = (T - t0) // skip_size

import seaborn as sns
plt.clf()
sns.set_theme()

all_data_arrays = (
    all_data[5], all_data[50], all_data[200] 
)

all_labels = tuple(
    'period length = {}'.format(plen) for plen in (5, 50, 200)
)

for data_array, label in zip(all_data_arrays, all_labels):
    plt.errorbar(
        np.arange(t0+1, T+1, skip_size), 
        np.mean(data_array[:, np.arange(t0, T, skip_size)], axis=0), 
        (1/np.sqrt(10)) * np.std(data_array[:, np.arange(t0, T, skip_size)], axis=0), 
        errorevery=num_dp//10,
        linestyle='solid', 
        label=label, 
    )
# horizontal bars
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*20) == 0]
plt.xticks(range(0, T+1, T//5))
plt.xlabel('t')
plt.legend( prop={'size': 12}, loc='center right' )
plt.title(r'max$_i$ $|\beta_i^t - \beta^{\rm HS}_i|/\beta^{\rm HS}_i$, Periodic Input')
plt.savefig(f'../plots/iid_vs_periodic.pdf')
plt.show()