import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from numpy.lib.function_base import average
import pandas as pd
from pandas.core.indexing import is_label_like
import seaborn as sns
import os, argparse

font = {'size': 24}
matplotlib.rc('font', **font)

budget_cap = False
prefix = '' if budget_cap == False else 'budget-cap-'

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--dataset', '-ds', default='MovieLens')
    args = parser.parse_args()
    dataset = args.dataset
except:
    print('not parsing command line inputs. use given parameters.')
    dataset = 'MovieLens'

# load v and normalize it
if dataset == 'MovieLens':
    df = pd.read_csv("../../data/movielens_1500x1500_lr0.1_wd1e-05_dim20_rmse0.88640326.csv", header=None)
    v = df.to_numpy()
    # n, m = 300, 1500
if dataset == 'Jokes':
    df = pd.read_csv("../../data/jokes_7200.csv")
    v = df.to_numpy()
if dataset == 'Household':
    df = pd.read_csv("../../data/household_items_understood.csv")
    v = df.to_numpy()

n, m = v.shape
v = m * (v.T / np.sum(v, 1)).T
B = np.ones(n) / n

# compute proportional share values
# x_proportional = (B * np.ones(shape=(n, m)).T).T / m
# u_proportional = np.sum(v * x_proportional, 1)

# load offline eq
x_opt = np.loadtxt(os.path.join('results', dataset, 'offline-eq', 'x'))
u_opt = np.sum(v * x_opt, 1)
u_proportional = np.ones(n)/n
inf_norm_to_u_eq_baseline, ave_norm_to_u_eq_baseline = np.max(np.abs(u_proportional-u_opt)/u_opt), np.mean(np.abs(u_proportional-u_opt)/u_opt)
sns.set_theme()

os.makedirs('plots', exist_ok=True)

# average across seeds
from matplotlib import pyplot as plt
import seaborn as sns
# sns.set_theme()
import os, json

# for dataset in ('MovieLens', 'Household', 'Jokes'):
# dataset = 'MovieLens'
print('plotting "budget spending deviation from prescribed budget" for [dataset = {}]'.format(dataset))
# duality_gap_all_seeds = []
inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = [], [], []
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = [], [], []

for seed in range(1, 11):
    fpath = os.path.join('results', dataset, prefix + 'sd-{}'.format(seed))
    with open(os.path.join(fpath, 'meta_data'), 'r') as ff:
        meta_data = json.load(ff)
    T = meta_data['T']
    inf_norm_to_B = np.loadtxt(os.path.join(fpath, 'inf_norm_to_B.gz'))
    ave_one_norm_to_B = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'))
    inf_norm_to_B_all_seeds.append(inf_norm_to_B)
    ave_one_norm_to_B_all_seeds.append(ave_one_norm_to_B)

# convert them into numpy arrays and compute quartiles across seeds
inf_norm_to_B_all_seeds, ave_one_norm_to_B_all_seeds = np.array(inf_norm_to_B_all_seeds), np.array(ave_one_norm_to_B_all_seeds)
inf_norm_to_B_quartiles, ave_one_norm_to_B_quartiles = np.percentile(inf_norm_to_B_all_seeds, q=(0, 25, 50, 75, 100), axis = 0), np.percentile(ave_one_norm_to_B_all_seeds, q=(0, 25, 50, 75, 100), axis = 0)

t0 = 20*n
T = n * 100
skip_size = 1
num_dp = (T - t0) // skip_size

###### plot inf_norm_to_B stuff ######
linestyles = ((0, (3, 1, 1, 1, 1, 1)), 'dotted', 'solid', 'dashdot', 'dashed')
labels = (r'min', r'$Q_1$', r'$Q_2$ (median)', r'$Q_3$', r'max')
fig = plt.figure(figsize=(6, 4))
for idx in range(5):
    plt.plot(range(t0+1, T+1, skip_size), inf_norm_to_B_quartiles[idx, range(t0, T, skip_size)], label = labels[idx], linestyle = linestyles[idx])
plt.legend()
plt.legend(prop={'size': 15})
plt.title(r'{}, Quartiles of $||(\bar{{b}}^t - B)/B||_\infty$'.format(dataset), fontsize=15)
plt.savefig(os.path.join('plots',  prefix + 'quartiles-ave-relative-error-spending-{}-n-{}-m-{}.pdf'.format(dataset, n, m)))
plt.clf()

fig = plt.figure(figsize=(6, 4))
for idx in range(5):
    plt.plot(range(t0+1, T+1, skip_size), ave_one_norm_to_B_quartiles[idx, range(t0, T, skip_size)], label = labels[idx], linestyle = linestyles[idx])
plt.legend(prop={'size': 15})
plt.title(r'{}, Quartiles of $||(\bar{{b}}^t - B)/B||_1/n$'.format(dataset), fontsize=15)
plt.savefig(os.path.join('plots',  prefix + 'quartiles-max-relative-error-spending-{}-n-{}-m-{}.pdf'.format(dataset, n, m)))
plt.clf()

# ###### max relative errors ######
# fig = plt.figure(figsize=(6, 4))
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_\infty$', linestyle='solid', errorevery=num_dp//10)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{u}^t - u^*)/u^*||_\infty$', linestyle='dotted', errorevery=num_dp//8)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_\infty$', linestyle='dashdot', errorevery=num_dp//6)
# plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * inf_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_\infty$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# # plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=0.5, linestyles='dotted', linewidth=1.0) #, label=r'multiplies of $n$')
# [plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# # plt.errorbar(range(1, T+1, log_interval), np.mean(duality_gap_all_seeds, axis=0), np.std(duality_gap_all_seeds, axis=0), label = r'${\rm dgap}_t$', linestyle='dashed', errorevery=num_logs//4)
# # plt.yscale('log') #, plt.xscale('log')
# plt.xticks(range(0, T+1, T//5))
# # plt.xlabel('t')
# plt.title('{}, n = {}, m = {} (Max Relative Errors)'.format(dataset, n, m), fontsize=15)
# # if dataset == 'MovieLens':
# plt.legend(prop={'size': 15}, loc='center right')
# # plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
# plt.savefig(os.path.join('plots', prefix + 'max-relative-error-{}-n-{}-m-{}-mean-and-se.pdf'.format(dataset, n, m)))
# plt.clf()

# ###### ave relative errors ######
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_1/n$', linestyle='solid', errorevery=num_dp//10)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{u}^t - u^*)/u^*||_1/n$', linestyle='dotted', errorevery=num_dp//8)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_1/n$', linestyle='dashdot', errorevery=num_dp//6)
# plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * ave_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_1/n$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# # plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=1, linestyles='dotted', linewidth=1.0)
# [plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# # plt.yscale('log') #, plt.xscale('log')
# plt.xticks(range(0, T+1, T//5))
# # plt.xlabel('t')
# plt.title('{}, n = {}, m = {} (Ave. Relative Errors)'.format(dataset, n, m), fontsize=15)
# # if dataset == 'MovieLens'：
# plt.legend(prop={'size': 15}, loc='center right')
# # plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
# plt.savefig(os.path.join('plots', prefix + 'ave-relative-error-{}-n-{}-m-{}-mean-and-se.pdf'.format(dataset, n, m)))
# plt.clf()