import math
import numpy as np
from my_logistic_regression import MyLogisticRegression
import time



class CompareAlgs:
  """Class to run multiple iterative algorithms and compare the results.
  """
  def __init__(self,lr,dataset,optimal_w,iters=10,w0=None,reg=1e-9,pb=None):
    """Initialize the problem.
    
        lr = an instance of MyLogisticRegression 
        dataset = dataset in the format of (features,label)
        optimal_w = optimal minimizer of logistic loss on dataset without privacy
        iters = number of iterations
        w0 = initialization
        reg = regularizer of logistic loss
        pb = {"total": Total privacy budget, "grad_frac": Fraction of privacy budget for gradient vs search direction, "num_iteration": num_iter}

    """
    X,y = dataset
    self.w_opt = optimal_w
    n, d = np.shape(X)
    print("dataset is created: (number of samples, dimension)=" + str(n) + "," + str(d))

    if w0 is None:
      w0_un = np.random.multivariate_normal(np.zeros(d), np.eye(d))
      w0 = w0_un/np.linalg.norm(w0_un)
    self.w0=w0  # initial iterate
    self.iters = iters
    self.pb = pb
    self.lr = lr
    self.clock_time = []
    self.params=[]  # List of lists of iterates
    self.names=[]  # List of names
    self.cutoff = 100*np.linalg.norm(self.w_opt)+100*np.linalg.norm(self.w0)+100 # how do you set this value? is it problem-specific?

  def add_algo(self,update_rule,name,t_stop=None):
    """Run an iterative update method & add to plot.

    update_rule is a function that takes 4 arguments:
      current iterate
      LogisticRegression problem
      index of current iterate
      total number of iterations
      pb = privacy budget and other info
    """
    w = self.w0
    params = [w]
    start_t = time.time()
    wall_clock = [0]
    for i in range(self.iters):
      if (t_stop is not None) and (time.time()-start_t >= t_stop):
        print("time is up!")
        print(str(i))
        break
      w = update_rule(w,self.lr,i,self.iters,self.pb)
      if np.linalg.norm(w)>self.cutoff:
        w=self.w0  # Stop things exploding
        print("Stop Things Exploding!")
      params.append(w)
      wall_clock.append(time.time()-start_t)
      
    self.clock_time.append(wall_clock)
    self.params.append(params)
    self.names.append(name)
    print()


  def wall_clock_alg(self):
    clock_time_dict = {}
    for time_alg,name in zip(self.clock_time,self.names):
        clock_time_dict[name]=[time_alg]
    return clock_time_dict


  def loss_vals(self):
    """output the loss per iteration for different methods
    """
    baseline = self.lr.loss_wor(self.w_opt)
    loss_dict = {}
    for params,name in zip(self.params,self.names):
        losses = [self.lr.loss_wor(w)-baseline for w in params]
        loss_dict[name]=[losses]
    return loss_dict
  
  def accuracy_vals(self):
    """output the accuracy per iteration for different methods
    """
    acc_dict = {}
    for params,name in zip(self.params,self.names):
        acc_vec = [self.lr.accuracy(w) for w in params]
        acc_dict[name]=[acc_vec]
    return acc_dict

  def accuracy_np(self):
    """output the accuracy of the optimal model without privacy
    """
    return self.lr.accuracy(self.w_opt)

  def gradnorm_vals(self):
     """output the gradient norm per iteration for different methods
     """
     gradnorm_dict = {}
     for params,name in zip(self.params,self.names):
         grad_norms = [np.linalg.norm(self.lr.grad_wor(w)) for w in params]
         gradnorm_dict[name]=[grad_norms]
     return gradnorm_dict



def private_newton(w,lr,i,iters,pb):
    """ implementation of private newton method from [ABL21] with non-private backtracking Linesearch

        w = current iterate
        lr = an instance of MyLogisticRegression
        i = the index of current iterate
        iters = total number of iterations
        pb =  privacy budget and other info

        return the next iterate

    """
    total = pb["total"]
    grad_frac = pb["grad_frac"]
    Hess = lr.hess(w)
    rho_grad =  grad_frac * total / iters  # divide total privacy budget up.
    rho_H =  (1-grad_frac) * total / iters
    grad_scale = (1/lr.n)*math.sqrt(0.5/rho_grad)
    grad_noise = np.random.normal(scale=grad_scale,size=lr.d)
    H_scale = (0.25/lr.n)*math.sqrt(0.5/rho_H)
    H_noise = np.random.normal(scale=H_scale,size=(lr.d,lr.d))
    H_noise = (H_noise + H_noise.T)/2
    Hess_noisy = eigenclip(Hess + H_noise)
    grad =  lr.grad(w)
    grad_noisy = grad+grad_noise
    dir = np.linalg.solve(Hess_noisy,grad_noisy)
    dir_size = np.linalg.norm(np.linalg.solve(Hess,grad)) # non-private version
    stepsize = min(np.log(1+dir_size)*(1/dir_size),1)
    return w - stepsize * dir


def eigenclip(A, min_eval=1e-5):
    """ operation of the eigenclip 

        A = symmetric matrix
        min_eval = minimum eigenvalue for clipping

        return the modified matrix
    """
    eval,evec = np.linalg.eigh(A)
    eval = np.maximum(eval,min_eval*np.ones(eval.shape))
    Hclipped = np.dot(evec * eval, evec.T)
    return Hclipped


def gd_priv(w,lr,i,iters,pb):
    """Implementation of DP-GD.

        w = current point
        lr = logistic regression
        i = iteration number
        pb = auxillary information

        output is the next iterate
    """
    inv_lr_gd = 0.25  # learning rate based on the smoothness
    sens = 1/(lr.n*(inv_lr_gd))  # sensitivity
    rho = pb["total"] / iters  # divide total privacy budget up
    noise = np.random.normal(scale=sens/np.sqrt(2*rho),size=lr.d)
    return w - lr.grad(w)/(inv_lr_gd) + noise


def sgd_priv(w,lr,i,iters,pb):
    """Implementation of DP-SGD.

        w = current point
        lr = logistic regression
        i = iteration number
        pb = auxillary information

        output is the next iterate
    """
    batch_size = pb['batch_size']
    sigma_privacy = pb['noise_multiplier']
    lr_sgd = 4  # learning rate based on the smoothness
    sample_rate = batch_size/lr.n  # sampling probability
    sample_vec = np.random.binomial(n=1, p=sample_rate, size=lr.n) # bernouli vector showing which samples are chosen
    batch_idx = np.where(sample_vec == 1)[0] # index of batch
    batch_size_t = len(batch_idx) # realization of the batch size
    noise = np.random.normal(scale = sigma_privacy, size=lr.d)
    grad_minibatch = lr.grad_wor(w,batch_idx) # average gradient over batch_idx
    return w - lr_sgd * (batch_size_t / batch_size * grad_minibatch + noise / batch_size)


def gd_priv_optls(w,lr,i,iters,pb):
    """Implementation of DP-GD with back-tracking line search
        !!! this method is not private. We only use it as a baseline.

        w = current point
        lr = logistic regression
        i = iteration number
        pb = auxillary information

        output is the next iterate
    """
    rho_grad = pb["total"] / iters  # divide total privacy budget up
    grad_scale = (1/lr.n)*math.sqrt(0.5/rho_grad)
    grad_noise = np.random.normal(scale=grad_scale,size=lr.d)
    dir = lr.grad(w)+grad_noise
    stepsize_opt = backtracking_ls(lr,dir, w)
    return w - stepsize_opt * dir


def backtracking_ls(lr,dir, w_0, alpha=0.4, beta=0.95):
    """Implementation of backtracking line search

        lr = logistic regression
        dir = the "noisy" gradient direction
        w_0 = current point
        alpha and beta tradeoff the precision and complexity of the linesearch

        output is an (close to) optimal stepsize
    """
    t = 100
    while lr.loss(w_0 - t*dir) >= lr.loss(w_0) - t*alpha* np.dot(dir, lr.grad(w_0)):
        t = beta * t
        if t <1e-6:
            break
    return t


def backtracking_ls_private(lr,dir,w_0):   #private backtracking LS based on 
    """Implementation of backtracking line search based on [ABL21]

        lr = logistic regression
        dir = the "noisy" gradient direction
        w_0 = current point

        output is an (close to) optimal stepsize

        Note: In [ABL21] paper the authors propose noisy linesearch. 
        However, it is not working for our examples. For this reason we consider, 
        non-private LS with a "limited" number of backtracking LS.
    """
    alpha=0.3
    beta=0.1
    t = 1
    counter = 0
    num_steps = 5 
    while counter <= num_steps:
        if lr.loss(w_0 - t*dir) - lr.loss(w_0) <= - t*alpha* np.dot(dir, lr.grad(w_0)):
            break
        t = beta * t
        counter = counter + 1
    return t


def newton(dataset,w_init, bias=True):
    """Implementation of the newton method with linesearch without privacy constraints

        dataset = dataset
        w_init = initialization point
        
        output is the model parameter
    """
    X,y = dataset
    if bias == True:
        X = np.hstack((np.ones(shape=(np.shape(X)[0],1)), X))
    lr = MyLogisticRegression(X,y,reg=1e-9)
    n, d = np.shape(X)
    w = w_init
    for _ in range(8):
        H = lr.hess(w)
        dir = np.linalg.solve(H,lr.grad(w))
        step_size = backtracking_ls(lr,dir, w)
        w = w - step_size * dir
    if lr.loss_wor(w)<lr.loss_wor(w_init):
        w_out = w
    else:
        w_out = w_init
    return w_out



class DoubleNoiseMech:
    """Our Method: Double Noise Mechanism
    """
    def __init__(self,lr,type_reg='add',hyper_tuning=False,curvature_info='hessian'):
        """ Initializer of the double noise mechanism

            lr = an instance of MyLogisticRegression
            type_reg = minimum eigenvalue modification type, it can be either 'add' or 'clip'
            hyper_tuning = True if we want to tune the minimum eigenvalue for modification
            curvature_info = type of the second-order information, it can be either 'hessian' or 'ub'
        """
        self.type_reg = type_reg
        self.hyper_tuning = hyper_tuning
        self.curvature_info = curvature_info
        if self.curvature_info == 'hessian':
            self.H = lr.hess_wor
        elif self.curvature_info == 'ub':
            self.H = lr.upperbound_wor

    def find_opt_reg_wop(self,w,lr,noisy_grad,rho_hess):
        """Implementation of fine tuning lambda without privacy.

            w = current point
            lr = logistic regression
            noisy_grad = noisy gradient
            rho_hess = privacy budget for hessian

            output is the optimal minimum eigenvalue
        """
        increase_factor = 1.25 # at each step we increase the clipping by increase_factor
        if self.type_reg == 'add':
            lambda_cur = 5e-6  # starting parameter
        elif self.type_reg == 'clip':
            lambda_cur = 0.25/lr.n + 1e-5 # starting parameter, the denominator has to be greater than zero
        num_noise_sample = 5 # we want to estimate the expected value over the second noise
        grad_norm = np.linalg.norm(noisy_grad)
        H = self.H(w)
        best_loss = 1e6 # a large dummy number
        while lambda_cur <= 0.25:
            H = self.hess_mod(w,lambda_cur)
            if self.type_reg == 'add': # Sensitivity is different for add vs clip
                sens2 = grad_norm * 0.25/(lr.n*lambda_cur**2 + 0.25*lambda_cur)
            elif self.type_reg == 'clip':
                sens2 = grad_norm * 0.25/(lr.n*lambda_cur**2 - 0.25*lambda_cur)
            loss_ = 0
            for _ in range(num_noise_sample):
                noise2 = np.random.normal(scale=sens2 * math.sqrt(0.5/rho_hess), size=lr.d)
                loss_ = loss_ + lr.loss_wor(w - np.linalg.solve(H,noisy_grad) + noise2)
            if loss_ < best_loss:
                best_loss = loss_
                lambda_star = lambda_cur
            lambda_cur = lambda_cur * increase_factor
        return lambda_star

    def update_rule(self,w,lr,i,iters,pb):
        """Implementation of the double noise mechanism update rule

            w = current iterate
            lr = an instance of MyLogisticRegression
            i = the index of current iterate
            iters = total number of iterations
            pb =  privacy budget and other info

            return the next iterate
        """
        total = pb["total"]
        grad_frac = pb["grad_frac"]
        frac_trace = pb["trace_frac"]
        trace_coeff = pb["trace_coeff"]
        rho1 = grad_frac * total / iters  # divide total privacy budget for gradient
        rho2 = (1-grad_frac) * total / iters  # divide total privacy budget for direction
        sc1 = (1/lr.n) * math.sqrt(0.5/rho1)
        noise1 = np.random.normal(scale=sc1,size=lr.d)
        noisy_grad = lr.grad(w)+noise1
        grad_norm = np.linalg.norm(noisy_grad)
        m = 0.25 # smoothness parameter
        frac_trace = frac_trace #fraction of privacy budget for estimating the trace.  # 0.2
        H = self.H(w)
        if self.hyper_tuning == True:
            min_eval = self.find_opt_reg_wop(w,lr,noisy_grad,rho2)
        elif self.hyper_tuning == False:
            noisy_trace = trace_coeff * max(np.trace(H) + np.random.normal(scale = (0.25/lr.n) * np.sqrt(0.5/(frac_trace*rho2))),0)
            min_eval = max((noisy_trace/((lr.n)**2 * (1-frac_trace)*rho2))**(1/3), 1/(lr.n)) 

        if self.type_reg == 'add': # Sensitivity is different for add vs clip
            sens2 = grad_norm * m/(lr.n * min_eval**2 + m * min_eval)
            noise2 = np.random.normal(scale = sens2 * np.sqrt(0.5/((1-frac_trace)*rho2)),size=lr.d)
            return w - np.linalg.solve(H + min_eval * np.eye(lr.d) ,noisy_grad) + noise2
        elif self.type_reg == 'clip':
            sens2 = grad_norm * m/(lr.n * min_eval**2 - m * min_eval)
            noise2 = np.random.normal(scale = sens2 * np.sqrt(0.5/((1-frac_trace)*rho2)),size=lr.d)
            eval,evec = np.linalg.eigh(H)
            eval_trunc = eval[eval>=min_eval]
            num_eig = len(eval_trunc)
            if num_eig == 0:
                H_modified_inv =  1/min_eval * np.eye(lr.d)
            else:
                evec_trun = evec[:,-num_eig:]
                H_modified_inv = np.dot(evec_trun * (1/eval_trunc - 1/min_eval * np.ones(num_eig)), evec_trun.T) + 1/min_eval * np.eye(lr.d)
            return w - (H_modified_inv @ noisy_grad) + noise2

    def update_rule_stochastic(self,w,lr,i,iters,pb):
        """Implementation of the stochastic double noise mechanism update rule

            w = current iterate
            lr = an instance of MyLogisticRegression
            i = the index of current iterate
            iters = total number of iterations
            pb =  other info

            return the next iterate
        """
        std_grad = pb['noise_multiplier_grad']
        std_hess = pb['noise_multiplier_hess']
        p1 = pb['batchsize_grad']/lr.n
        p2 = pb['batchsize_hess']/lr.n
        min_eval = pb['min_eval']
        #### minibatching gradient
        sample_vec = np.random.binomial(n=1, p=p1, size=lr.n) # bernouli vector showing which samples are chosen for computing gradient
        batch_idx_grad = np.where(sample_vec == 1)[0] # index of batch for gradient
        batch_size_grad_t = len(batch_idx_grad) # realization of the batch size 
        grad_minibatch = lr.grad_wor(w,batch_idx_grad) 
        noise_g = np.random.normal(scale = std_grad, size=lr.d)
        grad_noisy = (batch_size_grad_t / (lr.n * p1)) * grad_minibatch + noise_g / (lr.n * p1)
        grad_norm = np.linalg.norm(grad_noisy)
        #### minibatching hessian
        sample_vec = np.random.binomial(n=1, p=p2, size=lr.n) # bernouli vector showing which samples are chosen for computing hessian
        batch_idx_hess = np.where(sample_vec == 1)[0] # index of batch for gradient
        batch_size_hess_t = len(batch_idx_hess) # realization of the batch size 
        H = (batch_size_hess_t)/(lr.n * p2) * self.H(w,batch_idx_hess)
        m = 0.25 # smoothness parameter
        if self.type_reg == 'add': # Sensitivity is different for add vs clip
            sens2 = grad_norm * m/((lr.n*p2) * min_eval**2 + m * min_eval)
            noise2 = np.random.normal(scale = sens2 * std_hess, size = lr.d)
            return w - np.linalg.solve(H + min_eval * np.eye(len(H)) ,grad_noisy) + noise2
        elif self.type_reg == 'clip':
            min_eval_c = max(min_eval , 1/((lr.n * p2)))
            sens2 = grad_norm * m/((lr.n*p2) * min_eval_c**2 - m * min_eval_c)
            noise2 = np.random.normal(scale = sens2 * std_hess , size=lr.d)
            eval,evec = np.linalg.eigh(H)
            eval_trunc = eval[eval>=min_eval_c]
            num_eig = len(eval_trunc)
            if num_eig == 0:
                H_modified_inv =  1/min_eval_c * np.eye(lr.d)
            else:
                evec_trun = evec[:,-num_eig:]
                H_modified_inv = np.dot(evec_trun * (1/eval_trunc - 1/min_eval_c * np.ones(num_eig)), evec_trun.T) + 1/min_eval_c * np.eye(lr.d)
            return w - (H_modified_inv @ grad_noisy) + noise2
