import numpy as np
import pandas as pd
from utils import *
import argparse, os

varying_budgets = False

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--seed', '-sd', type=int, default=2333)
    parser.add_argument('--rank', '-rk', type=int, default=3)
    parser.add_argument('--n', '-n', type=int, default=100)
    parser.add_argument('--m', '-m', type=int, default=1000)
    args = parser.parse_args()
    seed, n, m, d = args.seed, args.n, args.m, args.rank
except:
    print('not parsing command line inputs. use given parameters.')
    T, seed, dataset = int(5e5), 23333, 'Random'
    n, m, d = 100, 1000, 10

dataset = 'Random'
np.random.seed(1)
alpha, theta = np.random.exponential(size=(n, d)) + 0.05, np.abs(np.random.normal(size=(m, d))) + 1
alpha = (alpha.T / (alpha @ np.sum(theta, 0))).T
v = alpha @ theta.T

dataset = 'Random'
n, m = v.shape
T = n * 200 # scale with #(buyers)
print('seed = {}, dataset = {}, n = {}, m = {}'.format(seed, dataset, n, m))
B = np.ones(n) / n

v = m * (v.T / np.sum(v, 1)).T

# print('load offline eq prices and allocations or compute it if not yet...')
fpath = os.path.join('results', 'random-' + 'n-{}-m-{}-d-{}'.format(n, m, d), 'offline-eq')
os.makedirs(fpath, exist_ok=True)
try:
    x_opt, p_opt = np.loadtxt(os.path.join(fpath, 'x')), np.loadtxt(os.path.join(fpath, 'p'))
    print('load offline equilibrium eq. x and p...')
except:
    print('compute offline equilibrium x and p')
    p_opt, x_opt = compute_me_mosek(v, B)
    np.savetxt(os.path.join(fpath, 'x'), x_opt), np.savetxt(os.path.join(fpath, 'p'), p_opt)

u_opt = np.sum(v * x_opt, axis = 1) # eq. utilities
beta_opt = B / u_opt
pobj, dobj = eg_primal_obj_val(x_opt, v, B), eg_dual_obj_val(beta_opt, v, B)
print('offline (x,p) has dobj = {:.4e}, EG duality gap = {:.4e}'.format(dobj, dobj - pobj))

np.random.seed(seed)
################################################################
# experiment parameters
do_stochastic = True
budget_cap = False

if np.sum(v[0]) == 1:
    v = m * v # rescale, now supply of each item is 1/m, not 1

delta0 = 0.05
beta = np.ones(n)
beta_ave = np.zeros(n)
g_ave = np.zeros(n)
# u_running = np.zeros(n) # to collect utilities along the way
# record stuff across time
items_all_t = np.zeros(T, dtype=np.int) # j(t) sampled uniformly at random from {0, 1, ..., m-1}
winners_all_t = np.zeros(T, dtype=np.int) # i(t) = min of argmax over i of beta[i] * v[i, j(t)]
# spending[i] := cumulative spending of buyer i
# it gets incremented by beta[t,i] * v[i,j] if j = j(t) is sampled at time t and i = i(t) is the winner

# some are logged in every t (error norms in variables)
# some are only periodically (dgap and envy gap)
spending = np.zeros(n)
inf_norm_to_u_eq, inf_norm_to_beta_eq, inf_norm_to_B = [], [], []
ave_one_norm_to_u_eq, ave_one_norm_to_beta_eq, ave_one_norm_to_B = [], [], []
# log_interval = int(T//num_logs)
# duality_gap, max_envy, ave_envy = [], [], []

x_cumulative = np.zeros((n, m))
x_proportional = (B * np.ones(shape=(n, m)).T).T / m

# x_cumulative = x_proportional
# np.sum(x_feas, 0)

for t in range(1, T+1):
    if do_stochastic:
        # sample an item
        j = np.random.choice(m)
        items_all_t[t-1] = j
        # remove buyers that have depleted their budgets
        if budget_cap:
            has_budget = [spending[i] + beta[i] * v[i,j] <= B[i] * T for i in range(n)]
        else:
            has_budget = [True] * n
        # find winners for this item (just pick the lex. smallest winner, if tie)
        winner = np.argmax(beta[has_budget] * v[has_budget, j])
        winners_all_t[t-1] = winner
        spending[winner] += beta[winner] * v[winner, j] # option 1: use beta(t) to compute prices
        # u_running[winner] += v[winner, j] # winner gets its v[winner, j]
        # update g_bar: only the winner's entry can potentially be incremented
        g_ave = (t-1) * g_ave / t if t > 1 else np.ones(n) / n
        # note the m: since it is non-averaged sum over j
        g_ave[winner] += v[winner, j] / t
    else: # find the full subgradient
        winners = np.argmax(beta * v.T, 1) # winners[j] wins item j
        g_ave = (t-1) * g_ave / t
        for j, winner in enumerate(winners):
            # u_running[winner] += v[winner, j]/m # winners collect their rewards: same as g_ave
            g_ave[winner] += (v[winner, j]/m) / t
    # update beta
    beta = np.maximum((1-delta0) * B, np.minimum(1 + delta0, B / g_ave)) # spending[winner] += beta[winner] * v[winner, j] # option 2: use beta(t+1) to compute prices
    beta_ave = (t-1) * beta_ave / t + beta / t
    # compute duality gap
    x_cumulative[winner, j] += 1
    # logging
    inf_norm_to_u_eq.append(np.max(np.abs(g_ave - u_opt)/u_opt)) # relative to each u_opt
    inf_norm_to_beta_eq.append(np.max(np.abs(beta - beta_opt)/beta_opt))
    inf_norm_to_B.append(np.max(np.abs(B - spending/t)/B))
    ave_one_norm_to_u_eq.append(np.mean(np.abs(g_ave - u_opt)/u_opt))
    ave_one_norm_to_beta_eq.append(np.mean(np.abs(beta - beta_opt)/beta_opt))
    ave_one_norm_to_B.append(np.mean(np.abs(B - spending/t)/B))
    # np.linalg.norm(beta - beta_opt) / np.linalg.norm(beta_opt)
    # if t % log_interval == 0:
    #     # compute duality gap - using a "feasible" x_feas
    #     x_feas = 1/m * x_cumulative / np.maximum(1, np.sum(x_cumulative, axis=0)) # make it primal (supply-)feasible by normalizing each column
    #     x_feas = (1 - 1/t) * x_feas + 1/t * x_proportional
    #     pobj, dobj = eg_primal_obj_val(x_feas, v, B), eg_dual_obj_val(beta, v, B)
    #     dgap = dobj - pobj
    #     duality_gap.append(dgap)
    #     # compute envy gaps - using a time averaged x (may not be feasible w.r.t. s[j] = 1/m supplies)
    #     x = 1/m * x_cumulative
    #     umat_budget_scaled = np.array([[v[i].T @ x[k] / B[k] for k in range(n)] for i in range(n)]) # umat[2,3] - v[2].T @ x[3] == 0
    #     envy_gap_all_buyer = np.max(umat_budget_scaled, 1) - np.diag(umat_budget_scaled)
    #     max_envy.append(np.max(envy_gap_all_buyer))
    #     ave_envy.append(np.average(envy_gap_all_buyer))
    if t % (int(T//20)) == 0:
        # print('t = {}, dobj = {}, dgap = {:.4f}'.format(t, dobj, dgap))
        print('t = {}, max_beta_error = {:.4f}, max_u_error = {:.4f},  max_b_error = {:.4f}'.format(t, inf_norm_to_beta_eq[-1], inf_norm_to_u_eq[-1], inf_norm_to_B[-1]))
    
res = g_ave / u_opt
print('max and min of g_ave/T divided by u_opt[i]: {:.4f}, {:.4f}'.format(np.min(res), np.max(res)))

x = x_cumulative / T

# plot something
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_theme()

# save results
import pandas as pd
import json
fpath = os.path.join('results', 'random-' + 'n-{}-m-{}-d-{}'.format(n, m, d), 'sd-{}'.format(seed))
print('fpath = {}'.format(fpath))
os.makedirs(fpath, exist_ok=True)
# np.savetxt(os.path.join(fpath, 'duality_gap'), duality_gap, fmt='%.4e') 
np.savetxt(os.path.join(fpath, 'inf_norm_to_beta_eq.gz'), inf_norm_to_beta_eq, fmt='%.4e') 
np.savetxt(os.path.join(fpath, 'ave_one_norm_to_beta_eq.gz'), ave_one_norm_to_beta_eq, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'inf_norm_to_u_eq.gz'), inf_norm_to_u_eq, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'ave_one_norm_to_u_eq.gz'), ave_one_norm_to_u_eq, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'inf_norm_to_B.gz'), inf_norm_to_B, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'), ave_one_norm_to_B, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'ave_envy'), ave_envy, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'max_envy'), max_envy, fmt='%.4e')
meta_data = {'T': T, 'dataset': dataset, 'n': n, 'm': m, 'd': d,  # 'number of duality gap and envy computations (num_logs)': num_logs, 
            'seed': seed, 'delta0': delta0,
            'varying_budgets': varying_budgets}
with open(os.path.join(fpath, 'meta_data'), 'w') as mdff:
    mdff.write(json.dumps(meta_data, indent=4))