import numpy as np

from parameters import get_param

def groupify(users, C):
    if len(users) <= C:
        raise Exception("not enough users to groupify", users)
    
    groups = []
    ind = 0
    while ind + 2*(C+1) <= len(users):
        groups.append(users[ind:ind+(C+1)]) 
        ind += C+1
    
    groups.append(users[ind:])

    assert(all(len(group) >= C+1) for group in groups)
    return groups

# Feedback-Eliciting Sub-Algorithm
def FESA(assignment, instance):
    N = instance.N
    K = instance.K
    C = instance.C
    
    users_per_action = [[] for _ in range(K)]
    
    for i in range(N):
        users_per_action[assignment[i]].append(i)
    
    big_groups = []
    big_actions = []

    small_groups = []
    small_actions = []
    for k in range(K):
        if len(users_per_action[k]) > C:
            new_groups = groupify(users_per_action[k], C)
            big_groups += new_groups
            big_actions += [k] * len(new_groups)
        elif len(users_per_action[k]) > 0:
            small_groups.append(users_per_action[k])
            small_actions.append(k)

    base_groups = big_groups + small_groups
    base_actions = big_actions + small_actions

    def modify(i):
        rem_groups = []
        rem_actions = []

        mod_groups = []
        mod_actions = []

        for g, a in zip(big_groups, big_actions):
            if i >= len(g):
                mod_groups.append(g)
                mod_actions.append(a)
            else:
                mod_groups.append(g[:i]+g[i+1:])
                mod_actions.append(a)

                rem_groups.append([g[i]])
                rem_actions.append(a)

        mod_groups = mod_groups + small_groups + rem_groups
        mod_actions = mod_actions + small_actions + rem_actions

        return (mod_groups, mod_actions)

    schedule = [(base_groups, base_actions)] + [modify(i) for i in range(2*C+1)]

    feedbacks = []
    for groups, actions in schedule:
        assert len(groups) == len(actions)
        cur_feedback = instance.do_round(groups, actions)
        assert len(cur_feedback) == len(groups)
        assert len(groups) >= len(big_groups)
        feedbacks.append(cur_feedback)
    
    estimates = [None for _ in range(N)]
    for s, group in enumerate(big_groups):
        for i, user in enumerate(group):
            estimates[user] = feedbacks[0][s] - feedbacks[i+1][s]
    
    return estimates

DEFAULT_B = 5
class BaSE:
    def __init__(self, T, K, gamma, num_batches=DEFAULT_B):
        self.K = K
        self.active_arms = set(range(K))

        self.gamma = gamma
        self.T = T

        self.cur_batch = 0
        self.num_batches = num_batches

    def update(self, mus, ns):
        tau = max(ns)
        best_arm = np.argmax(mus)

        threshold = get_param('base_threshold')*np.sqrt(self.gamma/tau)

        to_remove = []
        for arm in self.active_arms:
            if mus[best_arm] - mus[arm] >= threshold:
                to_remove.append(arm)

        # don't remove everything
        if len(to_remove) < len(self.active_arms):
            for arm in to_remove:
                self.active_arms.remove(arm)
        
        self.cur_batch += 1
        if self.cur_batch == self.num_batches:
            self.active_arms = set([best_arm])

    def get_batch_size(self):
        if self.cur_batch < self.num_batches:
            return 0.5 * pow(self.T, (1 - (2**(-self.cur_batch))))
        else:
            return self.T

    def get_next_batch(self):
        return self.get_batch_size(), list(self.active_arms)


class IndependentUCB:
    def __init__(self, instance):
        self.instance = instance
        
        self.mu_hats = [[0. for _ in range(self.instance.K)]
                            for _ in range(self.instance.N)]
        self.n_hats = [[1  for _ in range(self.instance.K)]
                           for _ in range(self.instance.N)]
        self.ucbs = [[get_param('ucb_ci_size')*np.sqrt(2 * np.log(self.instance.T))
                      for k in range(self.instance.K)]
                      for i in range(self.instance.N)]

    def do_round(self):
        # for each user, get the arm with the highest empirical mean
        groups = []
        actions = []

        for i in range(self.instance.N):
            groups.append([i])
            actions.append(np.argmax(self.ucbs[i]))
        
        feedback = self.instance.do_round(groups, actions, anonymize_feedback=False)
        T = self.instance.T

        for i in range(self.instance.N):
            action = actions[i]
            reward = feedback[i]
            
            cur_mu = self.mu_hats[i][action]
            cur_n = self.n_hats[i][action]
            self.mu_hats[i][action] = (cur_mu * cur_n + reward) / (cur_n + 1)
            self.n_hats[i][action] += 1
            self.ucbs[i][action] = self.mu_hats[i][action] + np.sqrt(2 * np.log(T)/self.n_hats[i][action])

class ExploreThenCommit:
    def __init__(self, instance):
        self.instance = instance
        
        N, K, C, T = self.instance.N, self.instance.K, self.instance.C, self.instance.T

        self.mu_hats = [[0. for _ in range(K)]
                            for _ in range(N)]
        self.n_hats = [[1  for _ in range(K)]
                            for _ in range(N)]

        self.T_exp = min(get_param('etc_explore_length')*((C**2)*K*(T**2))**(1/3), 0.5*T)
        self.R = int(self.T_exp/(2*C + 2))

        self.cur_round = 0

    def do_round(self):
        # for each user, get the arm with the highest empirical mean

        if self.cur_round < self.R:
            action = self.cur_round % self.instance.K
            assignment = [action for _ in range(self.instance.N)]
            estimates = FESA(assignment, self.instance)

            for i, reward in enumerate(estimates):
                if reward is None:
                    # This should not happen for this algorithm
                    continue
                cur_mu = self.mu_hats[i][action]
                cur_n = self.n_hats[i][action]
                self.mu_hats[i][action] = (cur_mu * cur_n + reward) / (cur_n + 1)
                self.n_hats[i][action] += 1
            
            self.cur_round += 1
        else:
            
            groups = [[i] for i in range(self.instance.N)]
            actions = [np.argmax(self.mu_hats[i]) for i in range(self.instance.N)]

            self.instance.do_round(groups, actions)

class MatchAlgorithm:
    def __init__(self, instance, decomposer, alpha_bound):
        self.instance = instance
        self.decomposer = decomposer


        N, K, C, T = self.instance.N, self.instance.K, self.instance.C, self.instance.T

        Tp = int(T/(alpha_bound * (2*C + 2)))

        gamma = np.log(N * K * Tp)
        self.bases = [BaSE(Tp, K, gamma) for _ in range(N)]

        self.cur_round_of_batch = 0
        self.num_rounds_in_batch = 0
        self.cur_assignments = []

        self.mu_hats = [[0. for _ in range(K)]
                         for _ in range(N)]
        self.n_hats = [[1  for _ in range(K)]
                        for _ in range(N)]

    def do_round(self):
        if self.cur_round_of_batch >= self.num_rounds_in_batch:
            # get next batch
            demand_sets = []
            max_demand = 0

            for i, base in enumerate(self.bases):
                base.update(self.mu_hats[i], self.n_hats[i])
                total_demand, demand_set = base.get_next_batch()

                demand_sets.append(demand_set)
                max_demand = max(max_demand, total_demand)
            
            self.cur_assignments = self.decomposer(self.instance.K,
                                                   self.instance.C,
                                                   demand_sets, 
                                                   int(max_demand))
            self.num_rounds_in_batch = len(self.cur_assignments)
            self.cur_round_of_batch = 0

        
        assignment = self.cur_assignments[self.cur_round_of_batch]
        estimates = FESA(assignment, self.instance)
        
        for i, reward in enumerate(estimates):
            if reward is None:
                continue

            action = assignment[i]
            cur_mu = self.mu_hats[i][action]
            cur_n = self.n_hats[i][action]
            self.mu_hats[i][action] = (cur_mu * cur_n + reward) / (cur_n + 1)
            self.n_hats[i][action] += 1

        self.cur_round_of_batch += 1

        



