from sklearn.mixture import GaussianMixture as GMM
import numpy as np
from gym.spaces import Box
from collections import deque

def proportional_choice(v, eps=0.):
    if np.sum(v) == 0 or np.random.rand() < eps:
        return np.random.randint(np.size(v))
    else:
        probas = np.array(v) / np.sum(v)
        return np.where(np.random.multinomial(1, probas) == 1)[0][0]

# Implementation of IGMM (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3893575/) + minor improvements
class Disagreement():
    def __init__(self, mins, maxs, seed=None, params=dict()):
        self.seed = seed
        if not seed:
            self.seed = np.random.randint(42,424242)
        np.random.seed(self.seed)

        self.presample_fun = params['presample_fun']
        if params['disagreement_str'] == 'id':
            def disagreement_fun(mu, std):
                if np.allclose(np.sum(std), 0):
                    p = None
                else:
                    p = std / np.sum(std)
                idx = np.random.choice(len(mu), p=p)
                return idx
        elif params['disagreement_str'].startswith('exp'):
            _, lmbda, temp = params['disagreement_str'].split('_')
            lmbda, temp = float(lmbda), float(temp)
            if temp != float('inf'):
                def disagreement_fun(mu, std):
                    logits = (lmbda*mu + std) / temp
                    logits -= np.min(logits)
                    p = np.exp(logits)
                    p /= np.sum(p)
                    idx = np.random.choice(len(mu), p=p)
                    return idx
            else:
                def disagreement_fun(mu, std):
                    return np.argmax(lmbda*mu + std)
        else:
            raise NotImplementedError

        self.disagreement_fun = disagreement_fun

        # Ratio of randomly sampled tasks VS tasks sampling using GMM
        # self.random_task_ratio = params['random_task_ratio']
        # assert self.random_task_ratio == 0, f'Disagreement has random_task_ratio {self.random_task_ratio}'
        self.reuse_past_goals_ratio = params['reuse_past_goals_ratio']
        self.presample_size = params['presample_size']
        self.random_task_ratio = params['random_task_ratio']

        self.tasks = deque(maxlen=self.presample_size)

        # boring book-keeping
        self.bk = {}

    def update(self, task, reward):
        # this is normalized undiscounted episodic reward
        self.tasks.append(task)

    def sample_task(self, compute_vals_fun, init_o):
        if np.random.random() < self.random_task_ratio:
            _, presampled_goals = self.presample_fun(o=init_o, g_list=None)
            return presampled_goals[0]
        if len(self.tasks) >= self.presample_size and np.random.random() < self.reuse_past_goals_ratio:
            g_list = np.asarray(self.tasks)[np.random.choice(len(self.tasks), size=self.presample_size, replace=True)]
        else:
            g_list = None

        q_inputs, presampled_goals = self.presample_fun(o=init_o, g_list=g_list)
        vals = compute_vals_fun(q_inputs)
        mu = np.mean(vals, axis=0)
        std = np.std(vals, axis=0)

        # print(np.mean(std / mu))

        goal_idx = self.disagreement_fun(mu, std)
        return presampled_goals[goal_idx]

    def dump(self, dump_dict):
        dump_dict.update(self.bk)
        return dump_dict
