import numpy as np
import cvxpy as cp
import pickle
from functools import *
from itertools import *
import dask.dataframe as dd
import pandas as pd
from ortools.graph.python import min_cost_flow



def apply(df, f, npartitions=7, ref=None):
    if ref is not None:
        filename = f"persistency/{ref}.pkl"
        try:
            return pd.read_pickle(filename)
        except (OSError, IOError) as e:
            print(f"apply({ref})")

    res = df.apply(f, axis=1, result_type="expand")
    df[res.columns] = res

    if ref is not None:
        df2 = df.applymap(lambda x: x.__name__ if callable(x) else x)
        df2.to_pickle(filename)
        df = pd.read_pickle(filename)
    return df

    """
    pdf = dd.from_pandas(df, npartitions=npartitions)
    for _, x in df.iterrows():
        meta = pd.DataFrame([f(x)], index=[0])
        break
    res = pdf.apply(f, axis=1, result_type="expand", meta=meta).compute(scheduler='processes')
    """

def cross(df1, df2):
    return df1.merge(df2, how="cross")

special_names = {"epsilon": "\\epsilon", "n_adv": "k", "n_imp": "n",
        "spread": "\\sigma", "n_types": "\\#\\mathrm{types}", "corr_prob": "p",
        "quota": "B", "nrows": "\\#\\mathrm{rows}"}

def bind(fn, **kwargs):
    def f(*args, **kwargs2):
        return fn(*args, **{**kwargs, **kwargs2})
    f.__name__ = fn.__name__ + " $ " + ",".join(f"{special_names[k] if k in special_names else k}={v}" for k, v in kwargs.items()) + " $"
    return f

def named(name):
    def name_function(f):
        f.__name__ = name
        return f
    return name_function

partial_ = partial
def partial(f, *args, **kwargs):
    g = partial_(f, *args, **kwargs)
    g.__name__ = f.__name__
    return g



class Impressions:
    def __init__(self, imp_types, imp_supply=None, ix=None, is_adwords=False, shrink=True):

        if not isinstance(imp_types, np.ndarray):
            n_adv = max([max(x.keys()) for x in imp_types]) + 1
            imp_types_matrix = np.zeros((len(imp_types), n_adv))
            for i, imp_dict in enumerate(imp_types):
                imp_types_matrix[i, list(imp_dict.keys())] = list(imp_dict.values())
            imp_types = imp_types_matrix

        if ix is None:
            imp_supply = np.ones(len(imp_types), dtype=int) if imp_supply is None else np.array(imp_supply)
            ix = np.empty(np.sum(imp_supply), dtype=int)
            i = 0
            for j, s in enumerate(imp_supply):
                ix[i:i+s] = j
                i += s
        else:
            assert(imp_supply is None)
            types, counts = np.unique(ix, return_counts=True)
            imp_supply = np.zeros(len(imp_types))
            imp_supply[types] = counts

        self.ix = ix
        self.imp_supply = np.array(imp_supply, dtype=int)
        self.imp_types = imp_types
        self.is_adwords = is_adwords
        if shrink: self.shrink()

    def shrink(self):
        nonzero = np.zeros(len(self.imp_types), dtype=bool)
        nonzero[self.ix] = True
        self.ix = np.cumsum(nonzero)[self.ix] - 1
        self.imp_types = self.imp_types[nonzero]
        self.imp_supply = self.imp_supply[nonzero]
        self.n_types, self.n_adv = self.imp_types.shape
        self.len = len(self.ix)

    def copy(self):
        return Impressions(self.imp_types, ix=np.copy(self.ix), is_adwords=self.is_adwords)

    def __getitem__(self, item):
        return Impressions(self.imp_types, ix=self.ix[item], is_adwords=self.is_adwords)

    def __len__(self):
        return self.len

    class IterImpressions:
        def __init__(self, impressions):
            self.i = 0
            self.impressions = impressions

        def __next__(self):
            if self.i >= len(self.impressions.ix):
                raise StopIteration
            imp = self.impressions.imp_types[self.impressions.ix[self.i]]
            self.i += 1
            return imp

    def __iter__(self):
        return Impressions.IterImpressions(self)

    """
    def get_matrix(obj):
        if isinstance(obj, Impressions):
            matrix = np.zeros((obj.n_types, obj.n_adv))
            for i, count in zip(*np.unique(obj.ix, return_counts=True)):
                imp_dict = obj.imp_types[i]
                print(imp_dict)
                matrix[i, list(imp_dict.keys())] = list(imp_dict.values())
            return matrix
        else:
            return obj
    """

    def stats(self):
        m = np.zeros((len(self.ix), self.n_adv))
        m[self.ix] = self.imp_types[self.ix]
        for i, x in enumerate(m.T):
            y = (x / np.sum(x)) * range(len(x))
            nz, = np.where(x > 0)
            print(f"{i}: {np.mean(y):.2f} {len(nz)} {np.var(nz):.2f} {np.mean(x[nz]):.2f}")
        import pdb; pdb.set_trace()



def yahoo_advs_imps(day):
    with open(f"pickles/yahoo-{day}.pickle", 'rb') as handle:
        xs = pickle.load(handle)
    imp_types = [x[1] for x in xs]
    imp_supply = [int(x[2]) for x in xs]
    imps = Impressions(imp_types, imp_supply, shrink=False)
    return None, imps

def ipinyou_advs_imps(day):
    with open(f"pickles/ipinyou-{day}.pickle", 'rb') as handle:
        (advs, imp_list) = pickle.load(handle)
    imps = Impressions(imp_list)
    return advs, imps



def create_advertisers(n_adv, min_budget=50, equal_budget=False):
    max_budget = min_budget if equal_budget else 2 * min_budget
    return np.random.randint(min_budget, max_budget+1, n_adv)

def create_advertisers_(_, imps, **kwargs):
    return create_advertisers(imps.n_adv, **kwargs)

def impressions_random_order(imps):
    imps = imps.copy()
    np.random.shuffle(imps.ix)
    return imps

def supply_ascending_order(imps):
    order = np.argsort(imps.imp_supply)
    return Impressions(imps.imp_types[order], imps.imp_supply[order])

def supply_descending_order(imps):
    order = np.argsort(imps.imp_supply)[::-1]
    return Impressions(imps.imp_types[order], imps.imp_supply[order])

def create_advertisers_worst_case(n_adv, n_imp):
    return (n_imp // n_adv) * np.ones(n_adv, dtype=int)

def create_advertisers_rnd_quota(_, imps):
    advs = np.zeros(imps.n_adv)
    for imp in imps:
        if np.sum(imp) > 0: advs += imp / np.sum(imp)
    return advs.astype(int) + (advs - advs.astype(int) > np.random.rand(len(advs))).astype(int)

def create_advertisers_leastdeg_quota(_, imps):
    advs = np.zeros(imps.n_adv)
    adv_degs = np.zeros(imps.n_adv)
    for imp in imps:
        adv_degs[np.where(imp > 0)] += 1
    return advs.astype(int)

def create_advertisers_maxmin_quota(_, imps):
    advs = np.zeros(imps.n_adv)
    for imp in imps:
        (ixs,) = np.where(imp > 0)
        ix = ixs[np.argmin(advs[ixs])]
        advs[ix] += 1
    return advs.astype(int)

def create_advertisers_maxmin_quota2(_, imps):
    _, n_adv = imps.shape
    advs = np.zeros(n_adv)
    i = 0
    while i < len(imps):
        first_imp = imps[i]
        supply = 0
        while imps[i + supply] == first_imp:
            supply += 1
        ixs = np.where(first_imp > 0)
        ix = ixs[np.argmin(advs[ixs])]
        advs[ix] += supply
        i += supply
    return advs.astype(int)

def create_impressions_synthetic(n_adv, n_imp):
    return np.random.exponential(1, (n_imp, n_adv))

def create_impressions_synthetic2(n_adv, n_imp, n_types=10):
    types = np.random.exponential(1, (n_types, n_adv))
    return types[np.random.choice(n_types, n_imp)]

def create_impressions_synthetic3(n_adv, n_imp, n_types=10, spread=0.1):
    # types = np.random.exponential(1, (n_types, n_adv))
    types = np.vstack((np.random.exponential(1, (n_types // 2, n_adv)), np.random.rand(n_types // 2, n_adv)))
    tss = []
    n_type_ads = n_imp // n_adv
    for i, type in enumerate(types):
        ts = np.mod(np.random.normal(i / n_types, scale=spread * np.random.rand(), size=n_type_ads), 1)
        tss.append(ts)
    times = np.hstack(tss)
    ixs = np.argsort(times) // n_type_ads
    return Impressions(types[ixs])

def create_impressions_worst_case(n_advs, n_imp):
    imps = np.zeros((n_imp, n_advs))
    ix = [round(f * n_imp) for f in np.linspace(0, 1, n_advs+1)]
    for i, (l, r) in enumerate(zip(ix, ix[1:])):
        imps[l:r, i:] = 1
    return Impressions(np.vstack(imps))



def opt_assignment(advs, imps):
    if imps.is_adwords: return opt_assignment_adwords(advs, imps)

    n_advs = len(advs)
    f = 1000
    x = n_advs + imps.n_types
    start_nodes = np.hstack((
        np.repeat([range(n_advs)], imps.n_types+1, axis=0).flatten(),
        np.repeat(x, imps.n_types)))
    end_nodes = np.hstack((
        n_advs + np.repeat(range(imps.n_types), n_advs),
        np.repeat(x, n_advs),
        n_advs + np.arange(imps.n_types)))
    capacities = 1e10 * np.ones_like(start_nodes)
    unit_costs = np.hstack((-f * imps.imp_types.flatten(),
        np.zeros(n_advs + imps.n_types)))
    supplies = np.hstack((advs, -imps.imp_supply, imps.len - np.sum(advs))).astype(int)

    smcf = min_cost_flow.SimpleMinCostFlow()
    smcf.add_arcs_with_capacity_and_unit_cost(start_nodes, end_nodes, capacities, unit_costs)
    for i, supply in enumerate(supplies):
        smcf.set_node_supply(i, supply)
    assert(smcf.solve() == smcf.OPTIMAL)
    val = -smcf.optimal_cost() / f
    asgn = -np.ones(imps.len, dtype=int)
    for i in range(imps.n_types):
        alla = -np.ones(imps.imp_supply[i], dtype=int)
        j = 0
        for a in range(n_advs):
            flow = smcf.flow(i * n_advs + a)
            if flow > 0:
                alla[j:j+flow] = a
                j += flow
        np.random.shuffle(alla)
        asgn[imps.ix == i] = alla
    return asgn, val

def opt_value(advs, imps, epsilon=1):
    (non_zero_rows,) = np.where(imps.imp_supply > 0)
    (non_zero_cols,) = np.where(np.any(imps.imp_types > 0, axis=0) & (advs > 0))
    imp_supply_ = imps.imp_supply[non_zero_rows]
    imps_ = imps.imp_types[non_zero_rows][:, non_zero_cols]
    advs_ = advs[non_zero_cols]
    # print(f"reduced to {len(non_zero_rows), len(non_zero_cols)}")

    n_advs = len(advs_)
    betas = cp.Variable(n_advs, nonneg=True)
    zs = cp.Variable(len(imps_), nonneg=True)
    objective = cp.Minimize(cp.sum(cp.multiply(advs_, epsilon*betas)) +
                            cp.sum(cp.multiply(imp_supply_, zs)))
    dual_constraint = [zs >= imps_[:, a] - betas[a] for a in range(n_advs)]
    prob = cp.Problem(objective, dual_constraint)
    prob.solve(verbose=False, solver="GUROBI", maximumSeconds=10)
    all_betas = np.zeros(len(advs))
    all_betas[non_zero_cols] = betas.value
    return objective.value, all_betas

def pred_value(advs, imps, pred):
    wss = [[] for _ in advs]
    for imp, a in zip(imps, pred):
        if a >= 0:
            wss[a].append(imp[a])
    total_weight = sum(np.sum(np.sort(ws)[::-1][:budget]) for ws, budget in zip(wss, advs))
    return total_weight
    """
    total_weight = 0
    for a, budget in enumerate(advs):
        ws = imps.imp_types[imps.ix[pred == a]]
        w = np.sum(np.sort(ws)[::-1][:budget])
        total_weight += w
    return total_weight
    """



@named("random corruption")
def corrupt_prediction_rnd(pred, corr_prob=0.1):
    n_pred = len(pred)
    do_corrupt = np.random.rand(n_pred) < corr_prob
    rnd_pred = np.random.randint(0, np.max(pred)+1, n_pred)
    return (1 - do_corrupt) * pred + do_corrupt * rnd_pred

@named("biased corruption")
def corrupt_prediction_biased(pred, corr_prob=0.1):
    n_pred = len(pred)
    do_corrupt = np.random.rand(n_pred) < corr_prob
    n_advs = np.max(pred) + 1
    permutation = np.random.choice(range(-1, n_advs), n_advs+1, replace=False)
    cor_pred = permutation[pred] # -1
    # rnd_pred = np.random.randint(-1,  // 3, n_pred)
    return (1 - do_corrupt) * pred + do_corrupt * cor_pred

def corrupt_prediction_delete(pred, corr_prob=0.1):
    n_pred = len(pred)
    do_corrupt = np.random.rand(n_pred) < corr_prob
    return (1 - do_corrupt) * pred - do_corrupt





def no_pred(advs, imps, _, on_advs=None, on_imps=None):
    return -np.ones(len(imps), dtype=int)

@named("DualBase")
def dual_base(advs, imps, opt_pred=None, epsilon=0.01, shuffle_betas=False, verbose=True, on_advs=None, on_imps=None):
    if on_advs is None: on_advs = advs
    if on_imps is None: on_imps = imps

    n_samples = int(epsilon * len(imps))
    assert(n_samples > 0)
    if verbose: print(f"DualBase {epsilon:.2f}", end="", flush=True)
    sample_imps = imps[:n_samples]
    _, betas = opt_value(advs, sample_imps, epsilon=epsilon)
    # if shuffle_betas: betas = np.zeros_like(betas) # np.random.shuffle(betas)
    # import pdb; pdb.set_trace()

    rem_budget = (on_advs > 0).astype(float)
    pred = -np.ones(len(on_imps), dtype=int)
    for i, imp in enumerate(on_imps):
        if verbose and i % (len(on_imps) // 20) == 0: print(".", end="", flush=True)
        discounted_gain = imp - betas
        ismax = discounted_gain > np.max(discounted_gain) - 1e-5
        a = np.random.choice(*np.where(ismax)) # np.argmax(ismax * rem_budget)
        if discounted_gain[a] >= 0 and rem_budget[a] > 0:
            pred[i] = a
            rem_budget[a] -= 1 / on_advs[a]
            # print(">", a, betas[a])
    if verbose: print(" done")
    return pred


"""
def exp_avg_pred_dual_base(advs, imps, pred=None, alpha=2, epsilon=0.1, verbose=False):
    n_samples = int(epsilon * len(imps))
    assert(n_samples > 0)
    sample_imps = imps[:n_samples]
    _, betas_dual_base = opt_value(advs, sample_imps, epsilon=epsilon)
    betas_dual_base_ = np.hstack((betas_dual_base, 0))

    mask = advs > 0
    B = np.min(advs[mask])
    f = B * ((1 + 1/B)**alpha - 1)
    n_adv = len(advs)
    betas_ = np.zeros(n_adv + 1)
    weights = [np.zeros(budget) for budget in advs]
    imp_ = np.zeros(n_adv + 1)

    x = 0
    if verbose: print(f"ExpAvg(alpha={alpha}) B={B} ", end="", flush=True)
    for i, imp in enumerate(imps):
        if verbose and i % (len(imps) // 20) == 0: print(".", end="", flush=True)
        imp_[:n_adv] = mask * imp
        discounted_gain = imp_ - betas_

        discounted_gain_dual_base = imp_ - betas_dual_base_
        max_dual_base = np.max(discounted_gain_dual_base)
        if max_dual_base > 0:
            a_pred = discounted_gain_dual_base > max_dual_base - 1e-5
            x += 1
            # import pdb; pdb.set_trace()
        else:
            a_pred = np.zeros_like(imp_, dtype=bool)
        tmp = np.sum(((f - 1) * a_pred + 1)) - len(advs) - 1
        # if tmp > 0: print(tmp)
        a = np.argmax(((f - 1) * a_pred + 1) * discounted_gain)

        if 0 <= a < n_adv and mask[a]:
            weights[a][0] = imp_[a]
            weights[a].sort()
            budget = advs[a]
            ix = 1 + np.arange(budget)
            expsum = np.sum(weights[a] * np.power(1 + 1/budget, alpha * (budget - ix)))
            betas_[a] = ((1 + 1/budget)**alpha - 1) / ((1 + 1/budget)**(budget * alpha) - 1) * expsum

    total_weight = sum(np.sum(ws) for ws in weights)
    if verbose: print(f" total value={total_weight} {x}/{len(imps)}")
    return total_weight
"""


def exp_avg(advs, imps, pred, alpha=None, verbose=False): # ignores pred and alpha
    mask = advs > 0
    n_adv = len(advs)
    betas_ = np.zeros(n_adv + 1)
    weights = [np.zeros(budget) for budget in advs]
    imp_ = np.zeros(n_adv + 1)

    if verbose: print(f"ExpAvg(Feldman) ", end="", flush=True)
    for i, imp in enumerate(imps):
        if verbose and i % (len(imps) // 20) == 0: print(".", end="", flush=True)
        imp_[:n_adv] = imp
        discounted_gain = imp_ - betas_
        a = np.argmax(discounted_gain)

        if 0 <= a < n_adv and mask[a]:
            weights[a][0] = imp_[a]
            weights[a].sort()
            budget = advs[a]
            ix = 1 + np.arange(budget)
            expsum = np.sum(weights[a] * np.power(1 + 1/budget, budget - ix))
            betas_[a] = (1 / (budget * ((1 + 1/budget)**budget - 1))) * expsum

    total_weight = sum(np.sum(ws) for ws in weights)
    if verbose: print(f" total value={total_weight}")
    return total_weight

def trivial(advs, imps, pred, alpha=2, p=0.5, verbose=False):
    p = 1 - 1 / alpha
    if np.random.random() < p:
        return pred_value(advs, imps, pred)
    else:
        return exp_avg(advs, imps, pred)

def exp_avg_pred(advs, imps, pred, alpha=2, verbose=False):
    mask = advs > 0
    B = np.min(advs[mask])
    f = B * ((1 + 1/B)**alpha - 1) if alpha > 1 else 1.0
    n_adv = len(advs)
    betas_ = np.zeros(n_adv + 1)
    weights = [np.zeros(budget) for budget in advs]
    imp_ = np.zeros(n_adv + 1)

    if verbose: print(f"ExpAvg(alpha={alpha}) B={B} ", end="", flush=True)
    for i, (imp, a_pred) in enumerate(zip(imps, pred)):
        if verbose and i % (len(imps) // 20) == 0: print(".", end="", flush=True)
        imp_[:n_adv] = mask * imp
        discounted_gain = imp_ - betas_
        a_algo = np.argmax(discounted_gain)
        a = a_pred if f * discounted_gain[a_pred] > discounted_gain[a_algo] else a_algo # or >=

        if 0 <= a < n_adv and mask[a]:
            weights[a][0] = imp_[a]
            weights[a].sort()
            budget = advs[a]
            ix = 1 + np.arange(budget)
            expsum = np.sum(weights[a] * np.power(1 + 1/budget, alpha * (budget - ix)))
            betas_[a] = ((1 + 1/budget)**alpha - 1) / ((1 + 1/budget)**(budget * alpha) - 1) * expsum

    total_weight = sum(np.sum(ws) for ws in weights)
    if verbose: print(f" total value={total_weight}")
    return total_weight


def exp_avg_wpo(advs, imps, verbose=False):
    n_adv = len(advs)
    betas_ = np.zeros(n_adv + 1)
    weights = [np.zeros(budget) for budget in advs]
    imp_ = np.zeros(n_adv + 1)

    if verbose: print(f"ExpAvg(Feldman) ", end="", flush=True)
    for i, imp in enumerate(imps):
        if verbose and i % (len(imps) // 20) == 0: print(".", end="", flush=True)
        imp_[:n_adv] = imp
        discounted_gain = imp_ - betas_
        a = np.argmax(discounted_gain)

        if 0 <= a < n_adv:
            weights[a][0] = imp_[a]
            weights[a].sort()
            budget = advs[a]
            ix = 1 + np.arange(budget)
            ds = weights[a] - np.hstack((0, weights[a]))[:-1]
            expsum = np.sum(ds * (budget - ix) / budget)
            betas_[a] = (expsum - weights[a][-1]) / (np.exp(1) - 1)

    total_weight = sum(np.sum(ws) for ws in weights)
    if verbose: print(f" total value={total_weight}")
    return total_weight



"""
for _ in range(10):
    n_adv = 12
    n_imp = 2000
    advs = create_advertisers(n_adv)
    imps = create_impressions_synthetic3(n_adv, n_imp, spread=0.01, n_types=5)
    opt_val, _ = opt_value(advs, imps)
    alg_val = exp_avg_wpo(advs, imps)
    alg2_val = exp_avg(advs, imps)

    print(alg_val / opt_val, alg2_val / opt_val, 1-1/np.exp(1))
"""



# use lower bound from WINE paper (and some other paper)
# redo plots that i find in the lyx


"""
advs, imps = ipinyou_advs_imps(20130606)
advs //= 10
print(advs)
# imps = imps[1000:2000]

print("loading imps...")
# advs, imps = yahoo_advs_imps(20)
# advs = create_advertisers(imps.n_adv, min_budget=10, equal_budget=False)
# imps = impressions_random_order(imps)

print("computing dual-base pred...")
pred = dual_base(advs, imps, epsilon=0.1)

print("done")
print(pred_value(advs, imps, pred))

# print("computing opt val...")
# val, _ = opt_value(advs, imps)
# print(val)

print("running exp avg...")
import time
start = time.time()
exp_avg_val = exp_avg(advs, imps, verbose=True)
print(exp_avg_val, time.time() - start)

print("computing opt asgn...")
opt_pred, opt_val = opt_assignment(advs, imps)
print(opt_val)

import pdb; pdb.set_trace()
"""




"""
def opt_assignment_adwords(advs, imps):
    n_advs = len(advs)
    f = 1000
    x = n_advs + imps.n_types
    start_nodes = np.hstack((
        np.repeat([range(n_advs)], imps.n_types+1, axis=0).flatten(),
        np.repeat(x, imps.n_types)))
    end_nodes = np.hstack((
        n_advs + np.repeat(range(imps.n_types), n_advs),
        np.repeat(x, n_advs),
        n_advs + np.arange(imps.n_types)))
    capacities = f * imps.imp_types
    unit_costs = np.hstack((-f * imps.imp_types.flatten(),
        np.zeros(n_advs + imps.n_types)))
    supplies = np.hstack((advs, -imps.imp_supply, imps.len - np.sum(advs))).astype(int)

    print(start_nodes)
    print(end_nodes)
    print(capacities)
    print(unit_costs)

    smcf = min_cost_flow.SimpleMinCostFlow()
    smcf.add_arcs_with_capacity_and_unit_cost(start_nodes, end_nodes, capacities, unit_costs)
    for i, supply in enumerate(supplies):
        smcf.set_node_supply(i, supply)
    assert(smcf.solve() == smcf.OPTIMAL)
    val = -smcf.optimal_cost() / f
    asgn = -np.ones(imps.len, dtype=int)
    for i in range(imps.n_types):
        alla = -np.ones(imps.imp_supply[i], dtype=int)
        j = 0
        for a in range(n_advs):
            flow = smcf.flow(i * n_advs + a)
            if flow > 0:
                alla[j:j+flow] = a
                j += flow
        np.random.shuffle(alla)
        asgn[imps.ix == i] = alla
    return asgn, val



raw_imps = np.array([[1, 2], [3, 4]])
imps = Impressions(raw_imps)
advs = np.array([5, 6])
opt_assignment_adwords(advs, imps)
"""




"""
def opt_assignment_adwords(advs, imps, imps_sizes):
    A = cp.Variable(imps.shape, boolean=True)
    objective = cp.Maximize(cp.sum(cp.multiply(A, imps)))
    budget_constraint = cp.sum(cp.multiply(A, imps_sizes), axis=0) <= advs
    imp_constraint = cp.sum(A, axis=1) <= 1
    prob = cp.Problem(objective, [budget_constraint, imp_constraint])
    prob.solve(verbose=False, solver="GUROBI", maximumSeconds=10) # CBC
    return A.value, objective.value
"""

def opt_assignment_adwords(advs, imps):
    (non_zero_rows,) = np.where(imps.imp_supply > 0)
    (non_zero_cols,) = np.where(np.any(imps.imp_types > 0, axis=0) & (advs > 0))
    imp_supply_ = imps.imp_supply[non_zero_rows]
    imps_ = imps.imp_types[non_zero_rows][:, non_zero_cols]
    advs_ = advs[non_zero_cols]

    X = cp.Variable((len(imps_), len(advs_)), boolean=True)
    objective = cp.Maximize(cp.sum(cp.multiply(X, imps_)))
    budget_constraint = cp.sum(cp.multiply(X, imps_), axis=0) <= advs_
    imp_constraint = cp.sum(X, axis=1) <= 1
    prob = cp.Problem(objective, [budget_constraint, imp_constraint])
    prob.solve(verbose=False, solver="GUROBI", maximumSeconds=10)

    asgn = -np.ones(imps.len, dtype=int)
    for i, imp_asgn in zip(non_zero_rows, X.value):
        a, = np.where(imp_asgn > 0)
        if len(a) > 0: asgn[i] = non_zero_cols[a[0]]
    return asgn, objective.value

def opt_value_adwords_relaxed(advs, imps, epsilon=1):
    (non_zero_rows,) = np.where(imps.imp_supply > 0)
    (non_zero_cols,) = np.where(np.any(imps.imp_types > 0, axis=0) & (advs > 0))
    imp_supply_ = imps.imp_supply[non_zero_rows]
    imps_ = imps.imp_types[non_zero_rows][:, non_zero_cols]
    advs_ = advs[non_zero_cols]

    n_advs = len(advs_)
    betas = cp.Variable(n_advs, nonneg=True)
    zs = cp.Variable(len(imps_), nonneg=True)
    objective = cp.Minimize(cp.sum(cp.multiply(advs_, epsilon*betas)) +
                            cp.sum(cp.multiply(imp_supply_, zs)))
    dual_constraint = [zs >= imps_[:, a] * (1 - betas[a]) for a in range(n_advs)]
    prob = cp.Problem(objective, dual_constraint)
    prob.solve(verbose=False, solver="GUROBI", maximumSeconds=10)
    all_betas = np.zeros(len(advs))
    all_betas[non_zero_cols] = betas.value
    return objective.value, all_betas


raw_imps = np.array([[1, 2], [3, 4]])
imps = Impressions(raw_imps)
advs = np.array([5, 5])
print(opt_assignment_adwords(advs, imps))
print(opt_value_adwords_relaxed(advs, imps))







def opt_assignment_gap(advs, imps, imps_sizes):
    A = cp.Variable(imps.shape, boolean=True)
    objective = cp.Maximize(cp.sum(cp.multiply(A, imps)))
    budget_constraint = cp.sum(cp.multiply(A, imps_sizes), axis=0) <= advs
    imp_constraint = cp.sum(A, axis=1) <= 1
    prob = cp.Problem(objective, [budget_constraint, imp_constraint])
    prob.solve(verbose=False, solver="GUROBI", maximumSeconds=10) # CBC
    return A.value, objective.value

def mahdian(advs, imps, pred, alpha=2, verbose=False):
    phi = 1 - np.exp(alpha * (np.zeros(len(advs)+1) - 1))
    n_adv = len(advs)
    rem_budget = np.copy(advs)
    total_weight = 0

    for imp, a_pred in zip(imps, pred):
        imp_ = np.hstack((imp, 0))
        value = phi * imp_
        a_algo = np.argmax(value)
        a = a_pred if alpha * value[a_pred] >= value[a_algo] else a_algo

        if 0 <= a < n_adv and rem_budget[a] > 0:
            rem_budget[a] -= imp[a]
            total_weight += imp[a]
            phi[a] = 1 - np.exp(alpha * (- rem_budget[a] / advs[a]))

    return total_weight

def gap_pred(advs, imps, pred, alpha=2, verbose=False):
    imps_sizes = imps
    mask = advs > 0
    B = np.min(advs[mask])
    f = alpha # B * ((1 + 1/B)**alpha - 1)
    n_adv = len(advs)
    betas_ = np.zeros(n_adv + 1)
    ratios = [[(0, budget)] for budget in advs]

    for i, (imp, sizes, a_pred) in enumerate(zip(imps, imps_sizes, pred)):
        imp_ = np.hstack((mask * imp, 0))
        sizes_ = np.hstack((mask * imp, 0))
        discounted_gain = imp_ - sizes_ * betas_
        a_algo = np.argmax(discounted_gain)
        a = a_pred if f * discounted_gain[a_pred] > discounted_gain[a_algo] else a_algo

        # print(a, discounted_gain[a])
        if 0 <= a < n_adv and mask[a]:
            s = sizes_[a]
            ratios[a].append((imp_[a], s))
            ratios[a].sort(key=lambda r: r[0] / r[1])
            for i, (w, s2) in enumerate(ratios[a]):
                d = min(s, s2)
                s -= d
                s2new = s2 - d
                ratios[a][i] = (w * s2new / s2, s2new)
                if s2new > 1e-10: break
            ratios[a] = ratios[a][i:]
            budget = advs[a]
            rs = np.array(ratios[a])
            fs = np.hstack((np.cumsum(rs[1:,1][::-1])[::-1], 0))
            # expsum = np.sum(rs[:,0] * np.power(1 + 1/budget, fs / budget))
            expsum = np.sum(rs[:,0] * np.exp(alpha * fs / budget))
            # betas_[a] = (alpha / (budget * ((1 + 1/budget)**budget - 1))) * expsum
            betas_[a] = (alpha / (1 + budget * (np.exp(alpha) - 1))) * (1 + expsum)

    # import pdb; pdb.set_trace()
    total_weight = sum(sum(w for (w, _) in rs) for rs in ratios)
    if verbose: print(f" total value={total_weight}")
    return total_weight

def pred_value_gap(advs, imps, imps_sizes, pred):
    wss = [[] for _ in advs]
    for imp, imp_sizes, a in zip(imps, imps_sizes, pred):
        if a >= 0:
            wss[a].append((imp[a], imp_sizes[a]))
    total_weight = 0
    for budget, ws in zip(advs, wss):
        ws.sort(key=lambda w: w[0] / w[1], reverse=True)
        total_size = 0
        for (w, s) in ws:
            total_weight += w
            total_size += s
            if total_size >= budget: break
    return total_weight

