from collections import defaultdict
import numpy as np
import pandas as pd
import cvxpy as cp
from pf_optimization import kl, PFOptimization
import ipopt
import sys

def arm_dict_from_groups(groups_to_arms):
    arms_to_groups = defaultdict(list)
    for g, arms in groups_to_arms.items():
        for a in arms:
            arms_to_groups[a].append(g)
    return dict(arms_to_groups)

# erdos renyi
def gen_erdos_renyi(G, K, p=0.5):
    mus = dict(enumerate(sorted(np.random.rand(K))))
    groups_to_arms = dict()
    for g in range(G):
        arms = []
        for k in range(K):
            if np.random.random() < p:
                arms.append(k)
        if not arms:
            arms = list(np.random.choice(range(K), 2))
        groups_to_arms[g] = arms
    return groups_to_arms


def gen_contrived(G):
    # arm 0 is the OPT for group 0
    # arm 1 is the OPT for all other groups
    # arms 2-G is shared with group 0, and one other group and is suboptimal
    groups_to_arms = dict()
    shared_arms = [0]
    current_k = 2
    for g in range(1, G):
        groups_to_arms[g] = [1, current_k]
        shared_arms.append(current_k)
        current_k += 1
    groups_to_arms[0] = shared_arms
    return groups_to_arms


def prune(groups_to_arms, mus):
    while True:
        arms_to_groups = arm_dict_from_groups(groups_to_arms)
        shared_arms = []
        for arm, groups in arms_to_groups.items():
            if len(groups) >= 2:
                shared_arms.append(arm)

        groups_to_remove = []
        # if a group is not improveable, then remove it
        for g in list(groups_to_arms.keys()):
            arms = groups_to_arms[g]
            rewards_for_arms = [mus[a] for a in arms]
            opt = max(rewards_for_arms)
            good = False
            for a in arms:
                if mus[a] < opt and a in shared_arms:
                    good = True
                    break
            if not good:
                # remove the group
                groups_to_remove.append(g)

        if not groups_to_remove:
            break

        for g in groups_to_remove:
            del groups_to_arms[g]
        
    return groups_to_arms


def rename(groups_to_arms, mus):
    arms_to_groups = arm_dict_from_groups(groups_to_arms)
    # rename all groups and arms
    orig_groups = sorted(groups_to_arms.keys())
    g_to_idx = dict((g, i) for (i, g) in enumerate(orig_groups))
    orig_arms = sorted(arms_to_groups.keys())
    a_to_idx = dict((a, i) for (i, a) in enumerate(orig_arms))
    new_mus = dict()

    new_groups_to_arms = dict()
    for g, arms in groups_to_arms.items():
        new_groups_to_arms[g_to_idx[g]] = sorted([a_to_idx[a] for a in arms])

    groups_to_arms = new_groups_to_arms 
    arms_to_groups = arm_dict_from_groups(groups_to_arms)

    for a in orig_arms:
        new_mus[a_to_idx[a]] = mus[a]
    mus = new_mus
    
    return groups_to_arms, arms_to_groups, mus

def do_opt(groups_to_arms, arms_to_groups, mus, pf_opt, all_arms, all_groups):
    constraints = []
    disagreement_point = dict()
    deltas = dict()
    group_to_unavail_actions = defaultdict(list)

    for g in all_groups:
        disagreement_point[g] = pf_opt.disagreement(g)
        deltas[g] = pf_opt.opts[g] - np.array([mus[a] for a in all_arms])
        group_arms = groups_to_arms[g]
        for a in all_arms:
            if a not in group_arms:
                group_to_unavail_actions[g].append(a)


    G = len(all_groups)
    K = len(all_arms)
    constraints = []
    group_alpha = cp.Variable((K, G), nonneg=True)
    regret_decrease = cp.Variable(G, nonneg=True)

    # WARM START: sum of group alpha is alpha
    for a in all_arms:
        if a in pf_opt.suboptimal_arms:
            constraints += [
                # cp.sum(group_alpha[k]) == np.abs(warm_alpha.value[k])
                cp.sum(group_alpha[a]) >= pf_opt.Js[a]
            ]

    for g in all_groups:
        constraints  += [
            # Regret decrease for each group
            regret_decrease[g] == disagreement_point[g] - group_alpha[:, g] @ deltas[g],
        ]
        # actions that are unavailable should be 0.
        if group_to_unavail_actions[g]:
            constraints += [
                group_alpha[group_to_unavail_actions[g], g] == 0,
            ]

    prob = cp.Problem(cp.Maximize(cp.sum([cp.log(regret_decrease[i]) for i in range(G)])), constraints)

    #     prob.solve(solver=cp.SCS, max_iters=10000, eps=1e-7)
    prob.solve(solver=cp.MOSEK, verbose=False)
    return prob, group_alpha.value, regret_decrease.value, disagreement_point


def run_one(seed, structure, G, K, p):
    np.random.seed(seed)
    mus = dict(enumerate(sorted(np.random.rand(K))))
    if structure == 'random':
        groups_to_arms = gen_erdos_renyi(G, K, p=p)
    elif structure == 'worst_case':
        K = G+1
        mus = sorted(np.random.random(3))
        low = mus[0]
        mid = mus[1]
        high = mus[2]
        mus = np.array([high, mid] + [low for _ in range(K-2)])
#         mus = np.array([0.5, 0.21] + [0.2 for _ in range(K-2)])
        groups_to_arms = gen_contrived(G)

    groups_to_arms = prune(groups_to_arms, mus)
    groups_to_arms, arms_to_groups, mus = rename(groups_to_arms, mus)

    all_arms = sorted(arms_to_groups.keys())
    all_groups = sorted(groups_to_arms.keys())
    pf_opt = PFOptimization(all_groups, all_arms, mus, groups_to_arms, arms_to_groups)

    prob, alphas, regret_decrease, disagreement_point = do_opt(groups_to_arms, arms_to_groups, mus, pf_opt, all_arms, all_groups)

    fair_utility_gain = sum(regret_decrease)
    total_disagreement = sum(disagreement_point.values())
    opt_utility_gain =  total_disagreement - pf_opt.optimal_regret
    pof = (opt_utility_gain - fair_utility_gain)/opt_utility_gain
    fair_utility_gain, opt_utility_gain, pof

    info = dict(G=G, K=K, structure=structure, seed=seed, pof=pof, p=p)
    return info


def main(structure, num_runs):
    all_results = []
    p=0.5
    K = 10
    for G in [3, 5, 10, 50]:
        if structure == 'worst_case':
            K = G+1
        for seed in range(0, num_runs):
            try:
                info = run_one(seed, structure, G, K, p)
                all_results.append(info)
            except Exception as e:
                print(e)
                continue

    df = pd.DataFrame(all_results)

    # Results
    print('Median:')
    print(df.groupby(['G']).pof.median())
    print('95th percentile:')
    print(df.groupby(['G']).pof.quantile(0.95))

    return df

if __name__ == '__main__':
    structure = str(sys.argv[1])
    num_runs = int(sys.argv[2])
    main(structure, num_runs)

