import logging

_LOGGER = logging.getLogger(name='optim.momentum')

from util_py.arg_parsing  import ArgumentRegistry

    
class AbsMomentum(object):
    
    @staticmethod
    def get_name():
        raise NotImplementedError("this is an abstract class defining the interface.")
        
    def __init__(self, objective, isMx=False):
        self.objective= objective
        self.set_minimization (not isMx)
        self.grad_dirn = self.objective.get_blank_of_itersize()

    def set_minimization(self, ismn):
        self.is_minimization = ismn
        self.dirn_mult=1.0
        if self.is_minimization: self.dirn_mult=-1.0
            
    def take_step(self, stepsize):
        raise NotImplementedError("this is an abstract class defining the interface.")
        
    def initialize(self):
        pass

    
class NoMomentum(AbsMomentum):
    
    _name='none'
    def get_name(): 
        return NoMomentum._name

    def __init__(self, objective, ismx=False):        
        super(NoMomentum,self).__init__(objective,ismx)
            
    def fill_args_registry(arg_reg):        
        pass
    
    def take_step(self, stepsize):
        self.objective.copy_gradient_to(self.grad_dirn, self.dirn_mult)
        
        self.objective.add_step_along_dirn(stepsize, self.grad_dirn)

        return self.grad_dirn, stepsize
    
    
class PlainMomentum(AbsMomentum):
    
    _name='plain'
    def get_name(): 
        return PlainMomentum._name

    _argname_velo, _argdefval_velo = '{}_velocity'.format(_name), 0.90
    _argname_damp, _argdefval_damp = '{}_dampening'.format(_name), 0.50

    def fill_args_registry(arg_reg):
        momformla ='(mom = velo * mom + (1-damp) * grad )'
        
        arg_reg.register_float_arg(PlainMomentum._argname_velo,
                                   'velocity "velo" in '+momformla,
                                   PlainMomentum._argdefval_velo)        #@fixme!
        arg_reg.register_float_arg(PlainMomentum._argname_damp,
                                   'dampening factor "damp" in '+momformla,
                                   PlainMomentum._argdefval_damp )        #@fixme!

    def __init__(self, objective, ismx=False, arg_reg=None, arg_dict=None,vel=0.9, damp=0.5):        
        super(PlainMomentum,self).__init__(objective,ismx)
        
        self.momentum = self.objective.get_blank_of_itersize()
        
        if (arg_dict is None ) or (arg_reg is None):
            self.velocity = vel
            self.dampfactor = damp
        else:
            self.velocity   = arg_dict[arg_reg.get_arg_fullname(PlainMomentum._argname_velo)]
            self.dampfactor = arg_dict[arg_reg.get_arg_fullname(PlainMomentum._argname_damp)]
            
        self.initialize()
            
    def initialize(self):
        self.itcnt = 0
        self.momentum = self.momentum * 0.0
        
    def take_step(self, stepsize):

        self.objective.copy_gradient_to(self.grad_dirn, self.dirn_mult)
        
        # the factor applied to the gradient is - (1- damp) * steplength
        damp = stepsize
        # the initial iteration should not dampen the gradient
        if self.itcnt>1: damp *= (1.0 - self.dampfactor)
            
        self.momentum = self.momentum * self.velocity + damp * self.grad_dirn

        self.objective.add_step_along_dirn(1.0, self.momentum)        
        self.itcnt += 1

        return self.momentum, 1.0
    
    
                #                if weight_decay != 0:
            #                    d_p.add_(weight_decay, p.data)
            #                if momentum != 0:
            #                    param_state = self.state[p]
            #                    if 'momentum_buffer' not in param_state:
            #                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            #                        buf.mul_(momentum).add_(d_p)
            #                    else:
            #                        buf = param_state['momentum_buffer']
            #                        buf.mul_(momentum).add_(1 - dampening, d_p)
            #                    if nesterov:
            #                        d_p = d_p.add(momentum, buf)
            #                    else:
            #                        d_p = buf


    # @Fixme !!! this don't work with stock rosenbrock!!
class NesterovMomentum(PlainMomentum):
    
    _name='nesterov'
    def get_name(): 
        return NesterovMomentum._name

    _argname_velo, _argdefval_velo = '{}_velocity'.format(_name), 0.90
    _argname_damp, _argdefval_damp = '{}_dampening'.format(_name), 0.50

    def fill_args_registry(arg_reg):
        momformla ='(mom = velo * mom + (1-damp) * grad )'
        
        arg_reg.register_float_arg(NesterovMomentum._argname_velo,
                                   'velocity "velo" in '+momformla,
                                   NesterovMomentum._argdefval_velo)        #@fixme!
        arg_reg.register_float_arg(NesterovMomentum._argname_damp,
                                   'dampening factor "damp" in '+momformla,
                                   NesterovMomentum._argdefval_damp )        #@fixme!

        
    def __init__(self, objective, ismx=False, arg_reg=None, arg_dict=None,vel=0.9, damp=0.5):        
        super(NesterovMomentum,self).__init__(objective,ismx, arg_reg,arg_dict, vel, damp)
        
            
    def take_step(self, stepsize):

        # we take a reverse step along last time's momentum first
        self.objective.add_step_along_dirn( (-1*self.velocity) , self.momentum )
        
        self.objective.copy_gradient_to(self.grad_dirn, self.dirn_mult)
        
        # the factor applied to the gradient is - (1- damp) * steplength
        damp = stepsize
        # the initial iteration should not dampen the gradient
        if self.itcnt>1: damp *= (1.0 - self.dampfactor)
            
        self.momentum = self.momentum * self.velocity + damp * self.grad_dirn
        
        # now we take a larger step along new momentum dirn, nearly canceling previous
        # backstep.
        self.objective.add_step_along_dirn(1.0+self.velocity , self.momentum)
        
        self.itcnt+=1

        return self.momentum, 1.0+self.velocity
    
    

MomClassList = [NoMomentum, PlainMomentum, NesterovMomentum]
_arg_reg_base = 'momentum'
_argname_type='type'
    
def get_momentum_arg_registry(extra_classes=None) :

    argdefval_type = [c.get_name() for c in MomClassList]

    if extra_classes is not None:
        for c in extra_classes:
            argdefval_type.append(c.get_name())
        
    arg_reg = ArgumentRegistry(_arg_reg_base)

    arg_reg.register_str_arg(_argname_type,
                             'which momentum type to use',
                             argdefval_type[0],
                             argdefval_type)

    for cl in MomClassList:          
        cl.fill_args_registry(arg_reg)

    if extra_classes is not None:
        for cl in extra_classes:
            cl.fill_args_registry(arg_reg)
            
    return arg_reg    


def instantiate_momentum(objective, ismx, arg_dict, addlcls=None):

    arg_registry = get_momentum_arg_registry(addlcls)
    
    # read in the args
    momnm = arg_dict[arg_registry.get_arg_fullname(_argname_type)]

    for c in MomClassList:
        if momnm == c.get_name():
            return c(objective, ismx, arg_registry, arg_dict)

    if addlcls is not None:
        for c in addlcls:
            if momnm == c.get_name():
                return c(objective, ismx, arg_registry, arg_dict)


    raise ValueError('have not implemented momenturm of  type \'{}\' yet.'.format(momnm))
        
