import numpy as np


class MFMAB:
    def __init__(self, K, M, MuMatrix, ZetaVector, LambdaVector):
        self.K = K
        self.M = M

        self.MuMatrix = MuMatrix
        assert len(self.MuMatrix[0, :]) == self.M
        assert len(self.MuMatrix[:, 0]) == self.K

        self.ZetaVector = ZetaVector
        self.LambdaVector = LambdaVector
        assert len(self.ZetaVector) == self.M
        assert len(self.LambdaVector) == self.M

    def pull(self, k, m):
        return np.random.binomial(n=1, p=self.MuMatrix[k, m])

    def cost(self, m):
        return self.LambdaVector[m]


class LUCB:
    def __init__(self, mfmab,  delta, TildeMu1, TildeMu2):
        # inputs
        self.MFMAB = mfmab
        self.K = mfmab.K
        self.M = mfmab.M
        self.delta = delta
        self.TildeMu1 = TildeMu1
        self.TildeMu2 = TildeMu2
        self.L = 4 * self.K * self.M

        # internal quantities
        self.t = 1
        self.EstMuMatrix = np.zeros([self.K, self.M])
        self.HatDelta = np.zeros([self.K, self.M])
        self.NumMatrix = np.zeros([self.K, self.M])
        self.l = 0
        self.u = 1
        self.HatMVector = -1 * np.ones([self.K], dtype=int)

        # cost complexity
        self.Cost = 0

    def frame(self, type):
        while max([self.CB('LCB', self.l, m) for m in range(self.M)]) <= min([self.CB('UCB', self.u, m) for m in range(self.M)]):
            rank = np.argsort([min([self.CB('UCB', k, m) for m in range(self.M)])
                              for k in range(self.K)])
            self.l = rank[-1]
            self.u = rank[-2]
            self.explore(type, self.l, self.u)

        return self.l, self.Cost

    def explore(self, type, ell, u):
        if type == 'A':
            m = np.argmax([self.fUCB(ell, m) for m in range(self.M)])
            obs = self.MFMAB.pull(ell, m)
            self.update(ell, m, obs)

            m = np.argmax([self.fUCB(u, m) for m in range(self.M)])
            obs = self.MFMAB.pull(u, m)
            self.update(u, m, obs)

        elif type == 'B':

            if self.HatMVector[ell] != -1:
                obs = self.MFMAB.pull(ell, self.HatMVector[ell])
                self.update(ell, self.HatMVector[ell], obs)
            else:
                for m in range(self.M):
                    obs = self.MFMAB.pull(ell, m)
                    self.update(ell, m, obs)
                if np.max([self.HatDelta[ell, m]/np.sqrt(self.MFMAB.LambdaVector[m]) for m in range(self.M)]) > 3 * np.sqrt(np.log(self.L / self.delta) / (self.MFMAB.LambdaVector[0] * self.NumMatrix[ell, 0])):
                    self.HatMVector[ell] = np.argmax(
                        [self.HatDelta[ell, m]/np.sqrt(self.MFMAB.LambdaVector[m]) for m in range(self.M)])

            if self.HatMVector[u] != -1:
                obs = self.MFMAB.pull(u, self.HatMVector[u])
                self.update(u, self.HatMVector[u], obs)
            else:
                for m in range(self.M):
                    obs = self.MFMAB.pull(u, m)
                    self.update(u, m, obs)
                if np.max([self.HatDelta[u, m]/np.sqrt(self.MFMAB.LambdaVector[m]) for m in range(self.M)]) > 3 * np.sqrt(np.log(self.L / self.delta) / (self.MFMAB.LambdaVector[0] * self.NumMatrix[u, 0])):
                    self.HatMVector[u] = np.argmax(
                        [self.HatDelta[u, m]/np.sqrt(self.MFMAB.LambdaVector[m]) for m in range(self.M)])

        else:
            assert False

    def update(self, k, m, obs):
        self.t = self.t + 1
        self.Cost += self.MFMAB.LambdaVector[m]
        self.EstMuMatrix[k, m] = (
            (self.EstMuMatrix[k, m] * self.NumMatrix[k, m]) + obs) / (self.NumMatrix[k, m] + 1)
        if k == self.l:
            self.HatDelta[k, m] = (
                self.EstMuMatrix[k, m] - self.MFMAB.ZetaVector[m]) - self.TildeMu2
        else:
            self.HatDelta[k, m] = self.TildeMu1 - \
                (self.EstMuMatrix[k, m] + self.MFMAB.ZetaVector[m])
        self.NumMatrix[k, m] += 1

    def CB(self, type, k, m):
        if type == 'UCB':
            return self.EstMuMatrix[k, m] + self.MFMAB.ZetaVector[m] + self.beta(self.NumMatrix[k, m])
        elif type == 'LCB':
            return self.EstMuMatrix[k, m] - self.MFMAB.ZetaVector[m] - self.beta(self.NumMatrix[k, m])
        else:
            assert False

    def beta(self, n):
        return np.sqrt(np.log(self.L * self.t**4 / self.delta)/n)

    def fUCB(self, k, m):
        return (self.HatDelta[k, m]/np.sqrt(self.MFMAB.LambdaVector[m])) + np.sqrt(2*np.log(np.sum(self.NumMatrix[k, :])) / (self.MFMAB.LambdaVector[m] * self.NumMatrix[k, m]))


K = 5
M = 3
delta = 0.1
TildeMu1 = 0.95
TildeMu2 = 0.75
MuMatrix = np.zeros([K, M])
LambdaVector = [1, 2, 3]
ZetaVector = [0.3, 0.15, 0]


MuMatrix[0, :] = [0.7, 0.8, 0.9]
MuMatrix[1, :] = [0.75, 0.775, 0.8]
MuMatrix[2, :] = [0.5, 0.6, 0.7]
MuMatrix[3, :] = [0.5, 0.55, 0.6]
MuMatrix[4, :] = [0.3, 0.45, 0.5]

mfmab = MFMAB(K, M, MuMatrix, ZetaVector, LambdaVector)


nsim = 100
deltaVector = [0.05, 0.1, 0.15, 0.2, 0.25]
CostComplex = np.zeros([len(deltaVector), nsim, 2])

for sim in range(nsim):
    for n in range(len(deltaVector)):
        algo1 = LUCB(mfmab, deltaVector[n], TildeMu1, TildeMu2)
        _, CostComplex[n, sim, 0] = algo1.frame(type='A')
        # print("algorithm A:", algo1.frame(type='A'), "with delta: ", deltaVector[n])
        algo2 = LUCB(mfmab, deltaVector[n], TildeMu1, TildeMu2)
        _, CostComplex[n, sim, 1] = algo2.frame(type='B')
        # print("algorithm B:", algo2.frame(type='B'), "with delta: ", deltaVector[n])

print(CostComplex)

np.save('CostComplex100', CostComplex)
