import os
import torch
from torch.utils.cpp_extension import load

parent_dir = os.path.dirname(os.path.abspath(__file__))
sources=['cuda/adam_upd.cpp', 'cuda/adam_upd_kernel.cu']
adam_upd_cuda = load(
        name='adam_upd_cuda',
        sources=[os.path.join(parent_dir, path) for path in sources],
        verbose=True)


''' Extend Adam optimizer
1. support per-voxel learning rate
2. masked update (ignore zero grad) which speeduping training
'''
class MaskedAdam(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps)
        self.per_lr = None
        self.f_per_lr = None
        super(MaskedAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(MaskedAdam, self).__setstate__(state)

    def set_pervoxel_lr(self, count, f_count=None):
        assert self.param_groups[0]['params'][0].shape == count.shape
        self.per_lr = count.float() / count.max()
        if f_count is not None:
            self.f_per_lr = f_count.float() / f_count.max()
        else:
            self.f_per_lr = None


    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            skip_zero_grad = group['skip_zero_grad']
            for param in group['params']:
                if param.grad is not None:
                    state = self.state[param]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

                    state['step'] += 1

                    if self.per_lr is not None and param.shape == self.per_lr.shape:
                        adam_upd_cuda.adam_upd_with_perlr(
                                param, param.grad, state['exp_avg'], state['exp_avg_sq'], self.per_lr,
                                state['step'], beta1, beta2, lr, eps)
                    elif self.f_per_lr is not None and param.shape == self.f_per_lr.shape:
                        adam_upd_cuda.adam_upd_with_perlr(
                                param, param.grad, state['exp_avg'], state['exp_avg_sq'], self.f_per_lr,
                                state['step'], beta1, beta2, lr, eps)
                    elif skip_zero_grad:
                        adam_upd_cuda.masked_adam_upd(
                                param, param.grad, state['exp_avg'], state['exp_avg_sq'],
                                state['step'], beta1, beta2, lr, eps)
                    else:
                        adam_upd_cuda.adam_upd(
                                param, param.grad, state['exp_avg'], state['exp_avg_sq'],
                                state['step'], beta1, beta2, lr, eps)

