import logging, sys
from util_py.arg_parsing import ArgumentRegistry
from util_py.configure_logging import getname_logging_iter_stride

_LOGGER = logging.getLogger(name='optim.line_search')
 
import numpy as np
import matplotlib
matplotlib.rcParams['text.usetex'] = True
#matplotlib.rcParams['text.latex.unicode'] = True
import matplotlib.pyplot as plt


class AbsLineSearch(object):
    
    @staticmethod
    def get_name():
        raise NotImplementedError("this is an abstract class defining the interface.")

    def initialize(self):
        pass
    
    def __init__(self, obj, ismax=False): 
        self.objective = obj
        self.is_min = not ismax
        self.objmult=1.0
        if not self.is_min: self.objmult=-1.0
        
    def find_steplength(self, dirn):
        raise NotImplementedError("this is an abstract class defining the interface.")


class ArmijoBacktrackingLineSearch(AbsLineSearch):
    r'''
     This (simple) Armijo backtracking rule only checks for this condition:
      
          phi_k <= phi_0 + beta_1 * alpha_k * phi_0^'
     
     where  phi_0 = f(w_t) ,   phi_k = f( w_t + alpha_k * dirn ),
        phi_0^' is deriv of phi at 0, i.e. = (grad * (dirn))  
        and dirn could be grad or newton_dir = Hessian^{-1} grad
      
     For MAXIMIZATION, we check the inequality using (-1.0) * phi_k / phi_0,
     assuming the user is using the CORRECT direction for problem improvement.
     
      This rule corresponds to checking that we get a function value reduction that is
     at least a beta_1-factor of the reduction predicted by the first-order (i.e. linear) 
     Taylor expansion of the loss function f around the current iterate w_t 
     along the direction grad_t.
    
         Note that this procedure will always stop with finite number of iterations because
     as per our assumptions, the loss function is smooth enough for the Taylor linear approx
     to be accurate very close to w_t, so when \alpha_t gets sufficiently close to 0, we
     will exit this rule. The alpha_t is reduced as 
          alpha_t = alpha_0 * (beta_2)^k
     where k represents the number of times we have checked the rule and both beta_1 and 
     beta_2  are factors in the interval (0 , 1). Typically beta1 is 0.01 and beta2 is set to .5, 
     but if you want to exit quickly, then maybe \beta_2 = .25. This will lead to smaller steps.

    This class uses 'opt_lsmtd_armijo' as the basename of all input args. 
    Four args are defined: 'beta1', 'alpha0', 'alpha_min' and 'beta2'
    with meanings as in the formulae above. Recall that input args must be entered
    as 
          --<basename>_<argname>=<value>   , e.g. '--opt_lsmtd_armijo_beta1=1e-4'
    '''

    _name = 'armijo'
    
    def get_name():
        return ArmijoBacktrackingLineSearch._name


    argname_beta1, argdefval_beta1 = '{}_{}'.format(_name,  'beta1') , 1e-1
    argname_beta2, argdefval_beta2 = '{}_{}'.format(_name,  'beta2') , 5e-1
    argname_alpha0, argdefval_alpha0 = '{}_{}'.format(_name,  'alpha0') , 1e-0
    argname_minalpha, argdefval_minalpha = '{}_{}'.format(_name,  'minalpha') , 1e-10


    @staticmethod
    def fill_args_registry(arg_reg):
        r'''
        '''

        fmla = 'f_k \\le f_0 + beta1 * alpha0 * beta2^k * f\'_k '
        arg_reg.register_float_arg(
                ArmijoBacktrackingLineSearch.argname_beta1, 
                'value of beta1 in armijo backtracking rule {} '.format(fmla), 
                ArmijoBacktrackingLineSearch.argdefval_beta1 )
        arg_reg.register_float_arg(
                ArmijoBacktrackingLineSearch.argname_beta2, 
                'value of beta2 in armijo backtracking rule {} '.format(fmla), 
                ArmijoBacktrackingLineSearch.argdefval_beta2 )
        arg_reg.register_float_arg(
                ArmijoBacktrackingLineSearch.argname_alpha0, 
                'value of alpha0 in armijo backtracking rule {} '.format(fmla), 
                ArmijoBacktrackingLineSearch.argdefval_alpha0)
        arg_reg.register_float_arg(
                ArmijoBacktrackingLineSearch.argname_minalpha, 
                'value of min-alpha in armijo backtracking rule {} '.format(fmla), 
                ArmijoBacktrackingLineSearch.argdefval_minalpha)

    
    def __init__(self, obj, ismax=False, args_dict=None, bt1=1e-1,bt2=5e-1,a0=1.,amn=1e-10, lgstr=1): 
        
        super(ArmijoBacktrackingLineSearch, self).__init__(obj,ismax)
        
        #obj, drn, lr=1.0, rho = 1e-2, w = 0.5, alpha_min = 1e-5):
        
        # we store the current iterate at the beginning of every linesearch here 
        self._init_param = self.objective.get_blank_of_itersize()
        
        if args_dict is not None:
            arg_reg = self.get_args_registry()
            self.beta1 = args_dict[arg_reg.get_arg_fullname('beta1')]
            self.beta2 = args_dict[arg_reg.get_arg_fullname('beta2')]
            self.alpha0 = args_dict[arg_reg.get_arg_fullname('alpha0')]
            self.alpha_mn = args_dict[arg_reg.get_arg_fullname('alpha_min')]

            # log out the args description    
            self.log_stride = args_dict[getname_logging_iter_stride()]
            arg_reg.log_args_description(args_dict,_LOGGER)
        else:
            self.beta1 = bt1
            self.beta2 = bt2
            self.alpha0=a0
            self.alpha_mn=amn
            self.log_stride=lgstr


        self.call_count = 0
        
        # an array for storing t and f(t) value to plot once we find things too small
        av,n_plt_vals=self.alpha0, 0
        while av > self.alpha_mn:
            n_plt_vals+=1
            av *= self.beta2
        

        self.we_plot=False
        
        if self.we_plot:
            self.plot_values=np.zeros((n_plt_vals,3))

        
    def find_steplength(self, dirn):
        r'''
        Finds the next iterate along the direction 'dirn'. The direction is ALWAYS the 
        one that we wish to follow to make progress on the optimziation, be in MINIMIZE or
        MAXIMIZE. So, user needs to multiply with +1/-1 as needed.
        '''
        curr_obj_val = self.objective.evaluate_fn_and_derivatives(True)
        phi_0 = curr_obj_val * self.objmult

        # keep copy for resets
        self.objective.copy_iterate_to(self._init_param)            
        
        graddirn= self.objective.get_gradient_dot(dirn)
        phi_0_prime = graddirn * self.objmult
        # for minimization, this should be NEGATIVE and for maximization positive.

        if (phi_0_prime>0):
            mdir, badr='minimization','POSITIVE'
            if not self.is_min:
                mdir, badr='maximization','NEGATIVE'
                
            wmsg= "LS : {} is given an mvmt direction that has {}".format(mdir, badr)
            wmsg +=" dot product {} with gradient.".format(phi_0_prime*self.objmult)
            _LOGGER.warn(wmsg)
            #raise ValueError(wmsg)
    
        alpha_k = self.alpha0
        
        ls_func_evals = 1

        if(_LOGGER.isEnabledFor(logging.INFO)) and ((self.call_count % self.log_stride) == 0):                
            _LOGGER.info("LS --    status  : {:15s} <= {:15s}  = {:12s} + {:12s} * {:9s} * {:12s} ".
                    format("phi_k", "RHS", "p0", "beta1", "stpsz", "ph0_prm"))

        avcnt=0
        while True:
            # next candidate backtrack point
            self.objective.add_step_along_dirn(alpha_k, dirn)
    
            # ignore the second output
            phi_k = self.objective.evaluate_fn_and_derivatives(True) * self.objmult
            
            # we restore the original iterate, else we will be double stepping!
            self.objective.set_iterate_from(self._init_param)

            # print("phio {} phik {}".format(phi_0, phi_k))
            ls_func_evals += 1
            
            # the armijo rule itself
            rule_rhs = phi_0 + self.beta1 * alpha_k * phi_0_prime

            if self.we_plot:
                try:
                    self.plot_values[avcnt] = [alpha_k, phi_k, rule_rhs]
                    avcnt+=1
                except IndexError as e:
                    print('avcnt is {} in len {} alpha_k {} alpha_mn {}, error: e'.format(avcnt, 
                          len(self.plot_values),alpha_k, self.alpha_mn, e),file=sys.stderr)
                    raise e
        
            
            if phi_k <= rule_rhs:
                if(_LOGGER.isEnabledFor(logging.INFO)) and ((self.call_count % self.log_stride) == 0):                
                    _LOGGER.info("LS --    success : {:15.12f} <= {:15.12f} = {:12.9f} + {:12.9f} * {:9.2e} * {:12.9f} ".
                                 format(phi_k , rule_rhs, phi_0, self.beta1, alpha_k, phi_0_prime) )
                break
            else:
                if(_LOGGER.isEnabledFor(logging.DEBUG)) and ((self.call_count % self.log_stride) == 0):                
                    _LOGGER.debug("LS -- keeplookin : {:15.12f} >  {:15.12f} = {:12.9f} + {:12.9f} * {:9.2e} * {:12.9f} ".
                                  format(phi_k , rule_rhs,  phi_0, self.beta1, alpha_k, phi_0_prime))
                alpha_k *= self.beta2

                if (alpha_k <= self.alpha_mn):
                    if _LOGGER.isEnabledFor(logging.WARN) :
                        _LOGGER.warn("LS -- toosmall  : {:15.12f} > {:15.12f} = {:12.9f} + {:12.9f} * {:9.2e} * {:12.9f} ".
                                     format(phi_k , rule_rhs, phi_0, self.beta1, alpha_k, phi_0_prime))
                        
                    if self.we_plot: # DEBUGMODE:    
                        # plot 
                        fig1, ax1 = plt.subplots(nrows=1,ncols=1)
                        ax1.plot(self.plot_values[:,0], self.plot_values[:,1],color='black')
                        ax1.plot(self.plot_values[:,0], self.plot_values[:,2],color='blue')
                        ax1.set_xscale('log')
                        _LOGGER.warn("Stopping now.")                    
                        sys.exit()
                    break

        self.call_count += 1
        return (alpha_k, ls_func_evals)


_all_ls_classes= [ ArmijoBacktrackingLineSearch ]
_arg_reg_base = 'linesearch'
_argname_type = 'type'

def get_linesearch_args_registry(extra_classes=None):
    

    argdefval_type = [c.get_name() for c in _all_ls_classes]

    if extra_classes is not None:
        for c in extra_classes:
            argdefval_type.append(c.get_name())
        
    arg_reg = ArgumentRegistry(_arg_reg_base)

    arg_reg.register_str_arg(_argname_type,
                             'which line search type to use',
                             argdefval_type[0],
                             argdefval_type)

    for cl in _all_ls_classes:          
        cl.fill_args_registry(arg_reg)

    if extra_classes is not None:
        for cl in extra_classes:
            cl.fill_args_registry(arg_reg)
            
    return arg_reg    

def instantiate_linesearch(arg_dict, addlcls=None):

    arg_registry = get_linesearch_args_registry(addlcls)
    
    # read in the args
    lsnm = arg_dict[arg_registry.get_arg_fullname(_argname_type)]

    for c in _all_ls_classes:
        if lsnm == c.get_name():
            return c(arg_registry, arg_dict)

    if addlcls is not None:
        for c in addlcls:
            if lsnm == c.get_name():
                return c(arg_registry, arg_dict)


    raise ValueError('have not implemented linesearch of  type \'{}\' yet.'.format(lsnm))
        
