import math
import torch


class ESGD(torch.optim.Optimizer):
    r"""
        Epoch-GDA for WCSC Min-Max Problems: for primal variables.
        Args:
            params (iterable): iterable of parameters to optimize
            loss_fn (callable): loss function used for optimization (default: ``None``)
            lr (float): learning rate (default: ``0.1``)
            mode (str): optimization mode, 'sgd' or 'adam' (default: ``'sgd'``)
            weight_decay (float, optional): weight decay (L2 penalty) (default: ``1e-5``)
            epoch_decay (float, optional): epoch decay (epoch-wise l2 penalty) (default: ``0.0``)
            momentum (float, optional): momentum factor for 'sgd' mode (default: ``0.9``)
            betas (Tuple[float, float], optional): coefficients used for computing
                running averages of gradient and its square for 'adam' mode (default: ``(0.9, 0.999)``)
            eps (float, optional): term added to the denominator to improve
                numerical stability for 'adam' mode (default: ``1e-8``)
            amsgrad (bool, optional): whether to use the AMSGrad variant of 'adam' mode
                from the paper `On the Convergence of Adam and Beyond` (default: ``False``)
            verbose (bool, optional): whether to print optimization progress (default: ``True``)
            device (torch.device, optional): the device used for optimization, e.g., 'cpu' or 'cuda' (default: ``None``)

        Example:
            >>> optimizer = libauc.optimizers.SOPA(model.parameters(), loss_fn=loss_fn, lr=0.1, momentum=0.9)
            >>> optimizer.zero_grad()
            >>> loss_fn(model(input), target).backward()
            >>> optimizer.step()
        Reference:
    """

    def __init__(self,
                 params,
                 eta=1.0,
                 lr=1e-3,
                 clip_value=1.0,
                 weight_decay=0,
                 epoch_decay=0,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 amsgrad=False,
                 momentum=0,
                 nesterov=False,
                 dampening=0,
                 verbose=False,
                 device=None,
                 gamma=1,
                 rand_init_wbuf=True,
                 epoch_steps = 100,
                 **kwargs):

        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        # if not 0.0 <= lr:
        #     raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not 0.0 <= epoch_decay:
            raise ValueError("Invalid epoch_decay value: {}".format(epoch_decay))
        # if not isinstance(mode, str):
        #     raise ValueError("Invalid mode type: {}".format(mode))

        self.params = list(params)  # support optimizing partial parameters of models

        self.lr = lr
        self.model_ref = self.__init_model_ref__(self.params) if epoch_decay > 0 else None
        self.model_acc = self.__init_model_acc__(self.params) if epoch_decay > 0 else None
        self.T = 0  # for epoch_decay
        self.steps = 0  # total optimization steps
        self.verbose = verbose  # print updates for lr/regularizer
        self.epoch_decay = epoch_decay
        self.gamma = gamma
        self.rand_init_wbuf = rand_init_wbuf
        self.epoch_steps = epoch_steps # epoch_wise updates

        # assert self.mode in ['adam', 'sgd'], "Keyword is not found in [`adam`, `sgd`]!"

        defaults = dict(lr=lr, betas=betas, eps=eps, momentum=momentum, nesterov=nesterov, dampening=dampening,
                        epoch_decay=epoch_decay, weight_decay=weight_decay, amsgrad=amsgrad,
                        clip_value=clip_value, model_ref=self.model_ref, model_acc=self.model_acc)
        super(ESGD, self).__init__(self.params, defaults)

    def __setstate__(self, state):
        r"""
        # Set default options for sgd mode and adam mode
        """
        super(ESGD, self).__setstate__(state)
        for group in self.param_groups:
            if self.mode == 'sgd':
                group.setdefault('nesterov', False)
            elif self.mode == 'adam':
                group.setdefault('amsgrad', False)
            else:
                NotImplementedError

    def __init_model_ref__(self, params):
        model_ref = []
        if not isinstance(params, list):
            params = list(params)
        for var in params:
            if var is not None:
                model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device))
        return model_ref

    def __init_model_acc__(self, params):
        model_acc = []
        if not isinstance(params, list):
            params = list(params)
        for var in params:
            if var is not None:
                model_acc.append(
                    torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(
                        self.device))
        return model_acc

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            # self.lr_0 = group['lr_0']
            # self.lr_1 = group['lr_1']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            model_ref = group['model_ref']
            model_acc = group['model_acc']
            clip_value = group['clip_value']
            weight_decay = group['weight_decay']
            epoch_decay = group['epoch_decay']


            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                param_state = self.state[p]
                if 'w_buffer' not in param_state:
                    if self.rand_init_wbuf:
                        buf = param_state['w_buffer'] = torch.empty(p.data.shape).normal_(mean=0, std=0.01).to(
                                self.device).detach() # last epoch weight
                        param_state['mv_w_buffer'] = torch.empty(p.data.shape).normal_(mean=0, std=0.01).to(
                            self.device).detach() # epoch cumulative weight
                    else:
                        buf = param_state['w_buffer'] = p.data.to(self.device).detach()
                        param_state['mv_w_buffer'] = p.data.to(self.device).detach()

                else:
                    buf = param_state['w_buffer']

                d_p = torch.clamp(p.grad.data, -clip_value, clip_value) # line 6: grad1 + moving average of 1/\gamma [\hat{\w}_t - \w_t]
                # updating primal variable x
                i_lr_grad = 1/(1/group['lr'] + 1/self.gamma) # 0.09 self.lr
                i_p_lr = 1/(1+self.gamma/group['lr']) # self.gamma
                # print('i_p_lr', i_p_lr, 'i_lr_grad', i_lr_grad)
                p.mul_(1-1/self.gamma).add_(d_p, alpha= -group['lr']) # update \hat{w} @ line 6 : \hat{\w}_t --> \hat{\w}_{t+1}
                p.add_(buf, alpha=1/self.gamma) # bar_x term

                param_state['mv_w_buffer'] += p.data.detach()
                # epoch wise updates bar{x}
                if self.steps % self.epoch_steps == 0:
                    param_state['w_buffer'] = param_state['mv_w_buffer']/self.epoch_steps
                    param_state['mv_w_buffer'] = torch.zeros_like(p.data).to(
                            self.device).detach()
                    p.mul_(0).add_(param_state['w_buffer'], alpha=1)
                    # p.data * 0  + param_state['w_buffer']

        self.steps += 1
        self.T += 1
        return loss

    def update_lr(self, decay_factor=None):
        if decay_factor != None:
            self.param_groups[0]['lr'] = self.param_groups[0]['lr'] / decay_factor  # for learning rate
            print('Reducing lr to %.5f @ T=%s!' % (self.param_groups[0]['lr'], self.steps))

