import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from numpy.lib.function_base import average
import pandas as pd
import seaborn as sns
import os, argparse

font = {'size': 24}
matplotlib.rc('font', **font)

budget_cap = False
prefix = '' if budget_cap == False else 'budget-cap-'

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--dataset', '-ds', default='MovieLens')
    args = parser.parse_args()
    dataset = args.dataset
except:
    print('not parsing command line inputs. use given parameters.')
    dataset = 'MovieLens'

# load v and normalize it
if dataset == 'MovieLens':
    df = pd.read_csv("../../data/movielens_1500x1500_lr0.1_wd1e-05_dim20_rmse0.88640326.csv", header=None)
    v = df.to_numpy()
    # n, m = 300, 1500
if dataset == 'Jokes':
    df = pd.read_csv("../../data/jokes_7200.csv")
    v = df.to_numpy()
if dataset == 'Household':
    df = pd.read_csv("../../data/household_items_understood.csv")
    v = df.to_numpy()

n, m = v.shape
v = m * (v.T / np.sum(v, 1)).T
B = np.ones(n) / n

# compute proportional share values
# x_proportional = (B * np.ones(shape=(n, m)).T).T / m
# u_proportional = np.sum(v * x_proportional, 1)

# load offline eq
x_opt = np.loadtxt(os.path.join('results', dataset, 'offline-eq', 'x'))
u_opt = np.sum(v * x_opt, 1)
u_proportional = np.ones(n)/n
inf_norm_to_u_eq_baseline, ave_norm_to_u_eq_baseline = np.max(np.abs(u_proportional-u_opt)/u_opt), np.mean(np.abs(u_proportional-u_opt)/u_opt)
sns.set_theme()

os.makedirs('plots', exist_ok=True)

# average across seeds
from matplotlib import pyplot as plt
import seaborn as sns
# sns.set_theme()
import os, json

# for dataset in ('MovieLens', 'Household', 'Jokes'):
# dataset = 'MovieLens'
print('plotting {}, budget_cap = {}'.format(dataset, budget_cap))
# duality_gap_all_seeds = []
inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = [], [], []
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = [], [], []

for seed in range(1, 11):
    fpath = os.path.join('results', dataset, prefix + 'sd-{}'.format(seed))
    with open(os.path.join(fpath, 'meta_data'), 'r') as ff:
        meta_data = json.load(ff)
    T = meta_data['T'] #, meta_data['number of duality gap and envy computations (num_logs)']
    # n, m = meta_data['n'], meta_data['m']

    # log_interval = int(T//num_logs)
    # duality_gap = np.loadtxt(os.path.join(fpath, 'duality_gap'))
    inf_norm_to_beta_eq = np.loadtxt(os.path.join(fpath, 'inf_norm_to_beta_eq.gz')) 
    ave_one_norm_to_beta_eq = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_beta_eq.gz'))
    inf_norm_to_u_eq = np.loadtxt(os.path.join(fpath, 'inf_norm_to_u_eq.gz'))
    ave_one_norm_to_u_eq = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_u_eq.gz'))
    inf_norm_to_B = np.loadtxt(os.path.join(fpath, 'inf_norm_to_B.gz'))
    ave_one_norm_to_B = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'))
    # ave_envy = np.loadtxt(os.path.join(fpath, 'ave_envy'))
    # max_envy = np.loadtxt(os.path.join(fpath, 'max_envy'))

    # duality_gap_all_seeds.append(duality_gap)
    inf_norm_to_beta_eq_all_seeds.append(inf_norm_to_beta_eq), inf_norm_to_u_eq_all_seeds.append(inf_norm_to_u_eq), inf_norm_to_B_all_seeds.append(inf_norm_to_B)
    ave_one_norm_to_beta_eq_all_seeds.append(ave_one_norm_to_beta_eq), ave_one_norm_to_u_eq_all_seeds.append(ave_one_norm_to_u_eq), ave_one_norm_to_B_all_seeds.append(ave_one_norm_to_B)


# convert them into numpy arrays
# duality_gap_all_seeds = np.array(duality_gap_all_seeds)
inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = np.array(inf_norm_to_beta_eq_all_seeds), np.array(inf_norm_to_u_eq_all_seeds), np.array(inf_norm_to_B_all_seeds)
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = np.array(ave_one_norm_to_beta_eq_all_seeds), np.array(ave_one_norm_to_u_eq_all_seeds), np.array(ave_one_norm_to_B_all_seeds)

# np.std(inf_norm_to_u_eq_all_seed, axis=0)
t0 = 5*n
T = n * 100
skip_size = 1
num_dp = (T - t0) // skip_size

# # top-bottom subplots
# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 8))
# fig.suptitle('{}, n = {}, m = {}'.format(dataset, n, m))
# ax1.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0))

# ax2.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0))
# plt.show()

###### max relative errors ######
fig = plt.figure(figsize=(6, 4))
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_\infty$', linestyle='solid', errorevery=num_dp//10)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{u}^t - u^*)/u^*||_\infty$', linestyle='dashed', errorevery=num_dp//8)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_\infty$', linestyle='dashdot', errorevery=num_dp//6)
plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * inf_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_\infty$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=0.5, linestyles='dotted', linewidth=1.0) #, label=r'multiplies of $n$')
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# plt.errorbar(range(1, T+1, log_interval), np.mean(duality_gap_all_seeds, axis=0), np.std(duality_gap_all_seeds, axis=0), label = r'${\rm dgap}_t$', linestyle='dashed', errorevery=num_logs//4)
# plt.yscale('log') #, plt.xscale('log')
plt.xticks(range(0, T+1, T//5))
# plt.xlabel('t')
plt.title('{}, n = {}, m = {} (Max Relative Errors)'.format(dataset, n, m), fontsize=15)
# if dataset == 'MovieLens':
plt.legend(prop={'size': 15}, loc='center right')
# plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
plt.savefig(os.path.join('plots', prefix + 'max-relative-error-{}-n-{}-m-{}-mean-and-se.pdf'.format(dataset, n, m)))
plt.clf()

###### ave relative errors ######
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_1/n$', linestyle='solid', errorevery=num_dp//10)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{u}^t - u^*)/u^*||_1/n$', linestyle='dashed', errorevery=num_dp//8)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_1/n$', linestyle='dashdot', errorevery=num_dp//6)
plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * ave_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_1/n$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=1, linestyles='dotted', linewidth=1.0)
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# plt.yscale('log') #, plt.xscale('log')
plt.xticks(range(0, T+1, T//5))
# plt.xlabel('t')
plt.title('{}, n = {}, m = {} (Ave. Relative Errors)'.format(dataset, n, m), fontsize=15)
# if dataset == 'MovieLens'：
plt.legend(prop={'size': 15}, loc='center right')
# plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
plt.savefig(os.path.join('plots', prefix + 'ave-relative-error-{}-n-{}-m-{}-mean-and-se.pdf'.format(dataset, n, m)))
plt.clf()