import logging

_LOGGER = logging.getLogger(name='optim.learning_rate')

from util_py.arg_parsing  import ArgumentRegistry

from math import ceil

class AbsLearningRate(object):
    
    @staticmethod
    def get_name():
        raise NotImplementedError("this is an abstract class defining the interface.")
        
    def get_stepsize(self, algodata):
        raise NotImplementedError("this is an abstract class defining the interface.")
        
    def initialize(self):
        pass


class FixedStep(AbsLearningRate):
    
    _name = 'fixed'
    def get_name():
        return FixedStep._name

    _argname_fixed, _argdefval_fixed= _name+'_size', .1
    
    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg( FixedStep._argname_fixed,
                                   'cosntant step size',FixedStep._argdefval_fixed)
        
    def __init__(self,arg_reg=None, arg_dict=None, siz=.1):

        if (arg_dict is None ) or (arg_reg is None):
            self.size = siz
        else:
            self.size = arg_dict[arg_reg.get_arg_fullname(FixedStep._argname_fixed)]
        
    def get_stepsize(self,algodata):
        return self.size
    
    
class PolyNumRate(AbsLearningRate):

    _name= 'polynumerator'
    def get_name():
        return PolyNumRate._name
    
    _formula = 'min( maxval, (\lfloor n_iter/stride \rfloor )^(-1.0 * expo) )'

    _argname_max , _argdefval_max          = _name+'_max' , 0.1       
    _argname_stride , _argdefval_stride    = _name+'_stride' , 1.0   
    _argname_exponent, _argdefval_exponent = _name+'_exponent' , 1.0   

    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg(  PolyNumRate._argname_max, 
                                   '"maxval" in \n\t{}'.format(PolyNumRate._formula), 
                                   PolyNumRate._argdefval_max )                
        arg_reg.register_float_arg(  PolyNumRate._argname_stride, 
                                   '"stride" in \n\t{}'.format(PolyNumRate._formula), 
                                   PolyNumRate._argdefval_stride)        
        arg_reg.register_float_arg(  PolyNumRate._argname_exponent, 
                                   '"expo" in \n\t{}'.format(PolyNumRate._formula), 
                                   PolyNumRate._argdefval_exponent)        
        

    def __init__(self,arg_reg=None, arg_dict=None, mx=0.1, strd=1.0, expo=1.0):
        if (arg_dict is None ) or (arg_reg is None):
            self.max_val = mx
            self.stride = strd
            self.expo = expo
        else:
            self.expo    = arg_dict[arg_reg.get_arg_fullname(PolyNumRate._argname_exponent)]
            self.stride  = arg_dict[arg_reg.get_arg_fullname(PolyNumRate._argname_stride)]
            self.max_val = arg_dict[arg_reg.get_arg_fullname(PolyNumRate._argname_max)]
      

    def get_stepsize(self, algodata):
        ceil_itr=ceil( (algodata.n_itr.value+1) / self.stride)
        retval = min( self.max_val , (ceil_itr)**(-1.0*self.expo) )                 

        if _LOGGER.isEnabledFor(logging.DEBUG):
            _LOGGER.debug("itr {:5d} ret {:6.4e} = (ceil ({:5d}/{:6.4e}))**(-1*{:5.1f})".format(
                    algodata.n_itr.value+1, retval, algodata.n_itr.value+1, self.stride, self.expo ))

        return retval

class PolyDenRate(AbsLearningRate):

    _name= 'polydenominator'
    def get_name():
        return PolyDenRate._name
    
    _formula = 'initial * BigNum/(BigNum + n_itr) '

    _argname_bignum , _argdefval_bignum    = _name+'_bignum' , 100.0   
    _argname_initial, _argdefval_initial   = _name+'_initial' , 0.1   

    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg(  PolyDenRate._argname_bignum, 
                                   '"BigNum" in \n\t{}'.format(PolyDenRate._formula), 
                                   PolyDenRate._argdefval_bignum)        
        arg_reg.register_float_arg(  PolyDenRate._argname_initial, 
                                   '"initial" in \n\t{}'.format(PolyDenRate._formula), 
                                   PolyDenRate._argdefval_initial)        
        

    def __init__(self,arg_reg=None, arg_dict=None, init=.1, bign=100.0):
        if (arg_dict is None ) or (arg_reg is None):
            self.bignum = bign
            self.initial = init
        else:
            self.bignum  = arg_dict[arg_reg.get_arg_fullname(PolyDenRate._argname_bignum)]
            self.initial = arg_dict[arg_reg.get_arg_fullname(PolyDenRate._argname_initial)]
      

    def get_stepsize(self, algodata):
        retval=  self.initial * (self.bignum / (self.bignum + algodata.n_itr.value + 1 )) 
        
        if _LOGGER.isEnabledFor(logging.DEBUG):
            _LOGGER.debug("itr {:5d} ret {:6.4e} = {:6.4e} * ( {:5.3e}/({:5.3e} + {:5d}))".format(
                algodata.n_itr.value+1, retval, self.initial, self.bignum,self.bignum, 
                algodata.n_itr.value+1))

        return retval
        



_arg_reg_base = 'learningrate'
_argname_type='type'
LRClassList = [PolyNumRate, PolyDenRate, FixedStep]

    
def get_learning_rate_arg_registry(extra_classes=None) :

    argdefval_type = [c.get_name() for c in LRClassList]

    if extra_classes is not None:
        for c in extra_classes:
            argdefval_type.append(c.get_name())
        
    arg_reg = ArgumentRegistry(_arg_reg_base)

    arg_reg.register_str_arg(_argname_type,
                             'which learning rate schedule to use',
                             argdefval_type[0],
                             argdefval_type)

    for cl in LRClassList:          
        cl.fill_args_registry(arg_reg)
    
    if extra_classes is not None:
        for cl in extra_classes:
            cl.fill_args_registry(arg_reg)
    
    return arg_reg


def instantiate_learning_rate(arg_dict, addlcls=None):

    arg_registry = get_learning_rate_arg_registry(addlcls)
    # read in the args
    lrnm = arg_dict[arg_registry.get_arg_fullname(_argname_type)]

    for c in LRClassList:
        if lrnm == c.get_name():
            return c(arg_registry, arg_dict)

    if addlcls is not None:
        for c in addlcls:
            if lrnm == c.get_name():
                return c(arg_registry, arg_dict)
            
    raise ValueError('have not implemented learning rate \'{}\' yet.'.format(lrnm))    
