import numpy as np
import random
import math
from algorithms.AutoTuning import *

class LinUCB:
    def __init__(self, class_context, T):
        self.data = class_context
        self.T = T
        self.d = self.data.d
        
    def linucb_theoretical_explore(self, lamda=1, delta=0.1, explore = -1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        B = np.identity(d) * lamda
        B_inv = np.identity(d) / lamda
        theta_hat = np.zeros(d)
        
        explore_flag = explore
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            
            # when explore = -1, which is impossible, use theoretical value
            # otherwise, it means I have specify a fixed value of explore in the code
            # specifying a fixed value for explore is only for grid serach
            if explore_flag == -1:
                explore = self.data.sigma*math.sqrt( d*math.log((t*self.data.max_norm**2/lamda+1)/delta) ) + math.sqrt(lamda)
            else:
                explore = explore_flag
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_hat) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)
            tmp = B_inv.dot(feature[pull])
            B += np.outer(feature[pull], feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    def linucb_tl(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        B = np.identity(d) * lamda
        B_inv = np.identity(d) / lamda
        theta_hat = np.zeros(d)
        
        # initialization for exp3 layer in TL
        Kexp = len(explore_rates)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial explore para
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_hat) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)
            
            # update linucb
            tmp = B_inv.dot(feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            
            # update explore rates by auto_tuning (tl)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore = explore_rates[index]
        return regret
    
    def linucb_op(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        B = np.identity(d) * lamda
        B_inv = np.identity(d) / lamda
        theta_hat = np.zeros(d)
        
        # initialization for op
        Kexp = len(explore_rates)
        s = np.ones(Kexp)
        f = np.ones(Kexp)
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_hat) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)
            
            # update linucb
            tmp = B_inv.dot(feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            
            # update explore rates by op
            s, f, index = op_tuning(s, f, observe_r, index)
            explore = explore_rates[index]
        return regret
    
    def linucb_syndicated(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        theta_hat = np.zeros(d)
        
        # initialization for exp3 layers
        # exp3 layer for exploration para
        Kexp = len(explore_rates)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial exploration para
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        # exp3 layer for lambda
        Klam = len(lamdas)
        loglamw = np.zeros(Klam)
        plam = np.ones(Klam) / Klam
        gamma_lam = min(1, math.sqrt( Klam*math.log(Klam) / ( (np.exp(1)-1) * T ) ) )
        # random initial lambda
        index_lam = np.random.choice(Klam)
        lamda = lamdas[index_lam]
        
        xxt = np.zeros((d,d))
        B_inv = np.identity(d) / lamda
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_hat) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)

            # update hyper-paras by auto_tuning (syndicated)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore = explore_rates[index]
            loglamw, plam, index_lam = auto_tuning(loglamw, plam, observe_r, index_lam, gamma_lam)
            lamda = lamdas[index_lam]
            
            # update linucb
            xxt += np.outer(feature[pull], feature[pull])
            B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    def linucb_tl_combined(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        theta_hat = np.zeros(d)
        
        # initialization for exp3 algo
        explore_lamda = np.array(np.meshgrid(explore_rates, lamdas)).T.reshape(-1,2) # combination set
        Kexp = len(explore_lamda)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial hyper-para
        index = np.random.choice(Kexp)
        explore, lamda = explore_lamda[index]
        
        xxt = np.zeros((d,d))
        B_inv = np.identity(d) / lamda
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_hat) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)

            # update hyper-paras by auto_tuning (tl-combined)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore, lamda = explore_lamda[index]
            
            # update linucb
            xxt += np.outer(feature[pull], feature[pull])
            B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    
    def linucb_corral(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        
        
        K = len(self.data.fv[0])
        eta0 = 1/math.sqrt(K*T*math.log(K))
        
        M = len(explore_rates)
        p = np.ones(M) / M
        pbar = np.ones(M) / M
        gamma = 1/T
        beta = np.exp(1/math.log(T))
        rho = [2*M] * M
        eta = [eta0] * M
        
        xr = [np.zeros(d) for _ in range(M)]
        B = [np.identity(d) * lamda for _ in range(M)]
        B_inv = [np.identity(d) / lamda for _ in range(M)]
        theta_hat = [np.zeros(d) for _ in range(M)]

        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            pull = []
            for base in range(M):
                ucb_idx = [0]*K
                explore = explore_rates[base]
                for arm in range(K):
                    ucb_idx[arm] = feature[arm].dot(theta_hat[base]) + explore * math.sqrt( feature[arm].T.dot(B_inv[base]).dot(feature[arm]) )
                pull += [np.argmax(ucb_idx)]
                
            chosen_base = np.random.choice(M, p=pbar)
            observe_r = self.data.random_sample(t,pull[chosen_base])
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull[chosen_base]]
            
            # update linucb
            for base in range(M):
                if base == chosen_base:
                    rew = observe_r
                else: 
                    rew = 0
                tmp = B_inv[base].dot(feature[pull[base]])
                B_inv[base] -= np.outer(tmp, tmp)/ (1+feature[pull[base]].dot(tmp))
                xr[base] += feature[pull[base]] * rew
                theta_hat[base] = B_inv[base].dot(xr[base])
            
            passl = np.zeros(M)
            passl[chosen_base] = -observe_r
            p = log_barrier(p, passl, eta)
            pbar = (1-gamma) * p + gamma/M
            for base in range(M):
                if 1/pbar[base] >= rho[base]:
                    rho[base] = 2/pbar[base]
                    eta[base] *= beta     
        return regret    
    
    def linucb_corral_combined(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        
        
        K = len(self.data.fv[0])
        eta0 = 1/math.sqrt(K*T*math.log(K))
        
        explore_lamda = np.array(np.meshgrid(explore_rates, lamdas)).T.reshape(-1,2) # combination set
        M = len(explore_lamda)
        p = np.ones(M) / M
        pbar = np.ones(M) / M
        gamma = 1/T
        beta = np.exp(1/math.log(T))
        rho = [2*M] * M
        eta = [eta0] * M
        
        xr = [np.zeros(d) for _ in range(M)]
        B = []
        B_inv = []
        for base in range(M):
            _, lamda = explore_lamda[base]
            B += [np.identity(d) * lamda]
            B_inv += [np.identity(d)/lamda]
        theta_hat = [np.zeros(d) for _ in range(M)]

        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            pull = []
            for base in range(M):
                ucb_idx = [0]*K
                explore, lamda = explore_lamda[base]
                for arm in range(K):
                    ucb_idx[arm] = feature[arm].dot(theta_hat[base]) + explore * math.sqrt( feature[arm].T.dot(B_inv[base]).dot(feature[arm]) )
                pull += [np.argmax(ucb_idx)]
            
            chosen_base = np.random.choice(M, p=pbar)
            observe_r = self.data.random_sample(t,pull[chosen_base])
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull[chosen_base]]
            
            # update linucb
            for base in range(M):
                if base == chosen_base:
                    rew = observe_r
                else: 
                    rew = 0
                tmp = B_inv[base].dot(feature[pull[base]])
                B_inv[base] -= np.outer(tmp, tmp)/ (1+feature[pull[base]].dot(tmp))
                xr[base] += feature[pull[base]] * rew
                theta_hat[base] = B_inv[base].dot(xr[base])
            
            passl = np.zeros(M)
            passl[chosen_base] = -observe_r
            p = log_barrier(p, passl, eta)
            pbar = (1-gamma) * p + gamma/M
            for base in range(M):
                if 1/pbar[base] >= rho[base]:
                    rho[base] = 2/pbar[base]
                    eta[base] *= beta     
        return regret