"""

This file implements the f^c calculation using the J-dimensional formulation.
J is the number of thetas in the problem.

The derived form of the optimizaiton problem is :
    max_v  \sum_j f_j(v_j + v_{0j})  - v^t (UU^t)^{-1}v

We will keep the origin constant by instead working with the formulation
    max_v  \sum_j f_j(v_j )  - (v - v_{0})^t (UU^t)^{-1} (v-v_0)

We solve this using a deterministic gradient following / newton numerical method.
        v(k+1) = v(k) + step * gradient(v(k))   and
        v(k+1) = v(k) + step * hessian_inv_times_gradient(v(k))  
where gradient and hessinv_x_gradient is computed from the wavelet functions 
and the step is calculated by the Armijo backtracking line search methods.
The initial iterations are by the first method and then the newton method takes over.

We provide the gradient and more importantly the hessian inverse without invoking the 
inversion function in numpy.

"""

import numpy as np
from enum import Enum

import logging
_LOGGER = logging.getLogger(name='wdro.func_c_approx')

from util_py.sorting import is_point_in_interval

from optim_py.det_gradient_descent import DeterministicGradientDescent
from optim_py.abs_objective import NumpyObjective
from optim_py.stop_criterion import StopMinObjRelChange,StopMaxIterations,StopSet
from optim_py.line_search import ArmijoBacktrackingLineSearch
from optim_py.algodata import AbsAlgoData
from optim_py.dirn_modifier import NoModification, NewtonDirection


class InitSampleSyle(Enum):
    Uniform=0
    NegNorm=1
    CandidatePts=2
    

class FcApproxObjectiveJDim(NumpyObjective):

    def __init__(self, f_obj, initsm=InitSampleSyle.CandidatePts):
        '''
        '''

        super(FcApproxObjectiveJDim, self).__init__(f_obj.n_thetas)        
        self.func = f_obj
        self.Jdimen = self.func.n_thetas
        self.range_Jdimen = range(self.Jdimen)
        
        # these are convenience vars to compute fn and gradient value
        UUtr = self.func.thetas @ self.func.thetas.transpose()
        self.UUtr_inverse = np.linalg.inv(UUtr)
        #self.UUtrinv_L = np.linalg.cholesky(self.UUtr_inverse)

        #if (np.max(np.abs(self.UUtrinv_L @ self.UUtrinv_L.transpose() - self.UUtr_inverse))) > 1e-8:
        #    _LOGGER.error("Chol facs don't match up to matrix.")
        #    raise ValueError("Chol facs don't match up to matrix.")
            
        # these are needed for hessian_inv_x_vec computation
        self.second_deriv_values=np.zeros((self.Jdimen,))
        
        # initialization helpers
        self.shd_initialize=True
        self.init_sample_style=initsm
        self.deduce_wavelet_limits()        
        self.__starting_candidates = None
        self.__n_starting_cs=0            

    def set_x_anchor(self,xanch):
        self.x_anchor= xanch 
        self.v_anchor=self.func.thetas @ self.x_anchor

        self.fj_far_ends = np.zeros(self.iterate.shape)
        for v in self.range_Jdimen:
            if not is_point_in_interval(self.v_anchor[v], self.fj_lims[v]):
                _LOGGER.warn("Anchor pt {} for theta dim {} is outside wavelet range {}".format(
                        self.v_anchor, v, self.fj_lims[v]))
            self.fj_far_ends[v] = max(abs(self.fj_lims[v,1]-self.v_anchor[v]), 
                       abs(self.v_anchor[v] - self.fj_lims[v,0]))


    def deduce_wavelet_limits(self):            
        self.fj_lims = self.func.get_reasonable_wavelet_limits()
        self.fj_range = self.fj_lims[:,1] - self.fj_lims[:,0]

            
    def set_starting_points_set(self, stpts_x):
        if stpts_x is not None:
            self.__starting_candidates = stpts_x @ self.func.thetas.transpose()
            self.__n_starting_cs=self.__starting_candidates.shape[0]            
        
    def __initialize_uniform(self):
        unirand=np.random.random(self.iterate.shape)
        for v in self.range_Jdimen:
            self.iterate[v] = self.fj_lims[v,0] + self.fj_range[v]*unirand[v]

    def __initialize_negnorm(self):
        n_samps=0
        for v in self.range_Jdimen:
            while True:
                unipr = np.random.random((2,))
                cdd = self.fj_lims[v,0] + self.fj_range[v]*unipr[0]
                tlim = 1- pow(cdd - self.v_anchor[v], 2.) / pow(cdd - self.fj_far_ends[v], 2.)
                n_samps +=1
                if unipr[1] <= tlim:
                    self.iterate[v] = cdd
                    break 
        _LOGGER.debug("Needed {} smaples to initialize under plan '{}'".format(
                n_samps, self.init_sample_style))

    def __initialize_candidate_pts(self):
        stscle = np.min(self.fj_range) / 100.
        unint=np.random.randint(0,self.__n_starting_cs)
        gsn = np.random.normal(0.0,stscle,(self.Jdimen,))
        self.iterate[:] = self.__starting_candidates[unint] + gsn

        
    def initialize(self):
        if self.shd_initialize:

            if self.init_sample_style == InitSampleSyle.Uniform:
                self.__initialize_uniform()
            elif self.init_sample_style == InitSampleSyle.NegNorm:
                self.__initialize_negnorm()
            elif self.init_sample_style == InitSampleSyle.CandidatePts:
                if self.__n_starting_cs > 0:
                    self.__initialize_candidate_pts()
                else:
                    self.__initialize_negnorm()
            else:
                raise ValueError("Have not implemented sampling style '{}' yet".format(
                        self.init_sample_style))

                

    def evaluate_fn_and_derivatives(self, skipgrad=False):

        #new_v = self.x_anchor + self.iterate
        retval = self.func.fill_value_thetax(self.iterate, 0.0)
        v_minus_v0 = self.iterate - self.v_anchor
        #print("v {} minus v0 {} : {}".format(self.iterate, self.v_anchor, v_minus_v0))
        retval -= v_minus_v0 @ self.UUtr_inverse @ v_minus_v0 

        if not skipgrad:
            #self.gradient.fill(0.0)
            for v in self.range_Jdimen:
                # gradient of first term
                self.gradient[v] = self.func.wavelets[v].get_deriv(self.iterate[v])

                
            # gradient of second term
            self.gradient -= 2. * (self.UUtr_inverse @ v_minus_v0)

            for v in self.range_Jdimen:
                secder = self.func.wavelets[v].get_second_deriv(self.iterate[v])
                self.second_deriv_values[v] = secder
                        
        return retval

       
    def get_hessian_inverse_vector_prod(self, vec, prod):
        
        hessian = np.diag(self.second_deriv_values) - 2.*self.UUtr_inverse
        prod[:] = np.linalg.inv(hessian) @ vec

    def get_optimal_delta_in_x(self):
        # delta = Ut(UUt)−1 (v-v0) = Ut(UUt)−1 * v - x[0]
        return self.func.thetas.transpose() @ self.UUtr_inverse @ self.iterate - self.x_anchor


class SolverForDelta():
    
    _Gather_Plot_Data=False

    def __init__(self, pid, fcn):

        self.pid= pid
        # creating it once ensures that the iterate info persists over re-runs of 'calculate'
        self.objective= FcApproxObjectiveJDim(fcn)

        # our stop criterion will be the min of max-iters and rel-change, with the 
        # expectation that the max-iter param will be changed dynamically. This will
        # likely be triggered most.        
        is_Maximization=True
        # use_newton=True
        
        self.stopmxitr=StopMaxIterations()
        self.stopmnchg=StopMinObjRelChange(None,None, 8, 0.25,5, 1e-3) 
        self.stopcrit=StopSet(None,None,None,[self.stopmxitr,self.stopmnchg])
        self.stopcrit.is_min(not is_Maximization)
        
        # the object containing output formatter and linesearch
        self.algodata= AbsAlgoData(None, 1)

        self.grad_dirnmod = NoModification(self.objective, is_Maximization)
        self.newtn_dirnmod = NewtonDirection(self.objective, is_Maximization)
        
        self.newton_maxstp, self.grad_maxstp   = 1., 0.25
        self.max_grad_itrs, self.max_nwtn_itrs = 5, 32

        self.linesearch = ArmijoBacktrackingLineSearch(self.objective, is_Maximization, None, 
                                                       5e-5, 5e-1, self.grad_maxstp)
        
        self.optimizer = DeterministicGradientDescent(self.objective,self.stopcrit,
                                         self.algodata, self.linesearch, 
                                         self.grad_dirnmod,  None, is_Maximization)
        
    def initialize_for_multirun(self, stset=None):
        self.objective.deduce_wavelet_limits()
        self.objective.set_starting_points_set(stset)


    def optimize(self, runid, anchn, x_anch, startingpt=None):
        
        self.objective.set_x_anchor(x_anch) 
        
        # set up for gradient descent first
        self.optimizer.set_dirn_modifier(self.grad_dirnmod)
        self.stopmxitr.max_iters = self.max_grad_itrs
        self.linesearch.alpha0 = self.grad_maxstp
        
        if startingpt is not None:
            self.objective.iterate[:] = startingpt
            self.objective.shd_initialize = False
            
        if SolverForDelta._Gather_Plot_Data:
            self.iterpath=[]
            self.funcpath=[]
        
        #initialize all counters etc. Note that this also 
        # randomizes the initial starting point!
        self.optimizer.initialize()

        _LOGGER.debug("Starting run # {} for anch # {} at iter {}".format(
                     runid, anchn, self.objective.iterate))
        
        # run optim algo
        opval=-float('inf')
        while not self.optimizer.should_itrs_stop(): 
            if SolverForDelta._Gather_Plot_Data: 
                self.iterpath.append(self.objective.iterate.copy())

            opval=self.optimizer.step()

            if SolverForDelta._Gather_Plot_Data: 
                self.funcpath.append(opval)

        self.optimizer.terminate()

        #if self.algodata.n_itr.value < self.max_grad_itrs:
        #if True:
        # we may have stopped early, so due to no-change-in-obj val
        # now we switch to the newton dirn
        self.objective.shd_initialize = False
        self.optimizer.set_dirn_modifier(self.newtn_dirnmod)
        self.stopmxitr.max_iters = self.max_grad_itrs + self.max_nwtn_itrs
        self.linesearch.alpha0 = self.newton_maxstp

        self.algodata.shd_initialize = False
        # we will let the stoprelchg initialize
        
        self.optimizer.initialize()
        
        while not self.optimizer.should_itrs_stop(): 
            if SolverForDelta._Gather_Plot_Data: 
                self.iterpath.append(self.objective.iterate.copy())

            opval=self.optimizer.step()
            
            if SolverForDelta._Gather_Plot_Data: 
                self.funcpath.append(opval)
        
        self.optimizer.terminate()
        
        self.objective.shd_initialize = True
        self.algodata.shd_initialize = True

        return opval
    
    def get_optimal_delta(self):
        return self.objective.get_optimal_delta_in_x()
    
    def process_worker(self, inpQ, outQ):
        np.random.seed() # randomize seed
        for smpn, anchn, xanch in iter(inpQ.get, 'STOP'):
            
            result = self.optimize(smpn, anchn, xanch)
            resdel =  self.get_optimal_delta()
            detrest = None
            if SolverForDelta._Gather_Plot_Data:
                nitr=len(self.funcpath)
                dim = self.iterpath[0].size
                detrest=np.zeros((nitr, dim+1))
                for n in range(nitr):
                    detrest[n,0]=self.funcpath[n]
                    detrest[n,1:] = self.iterpath[n]

            outQ.put((result, resdel, smpn, anchn, self.pid, detrest ))
                


from multiprocessing import Process, Queue

class FcApproxJDim(object):

    @staticmethod
    def get_args_registries():
        r'''
        nuthin' here yets
        '''
        ans = []
        
        for o in DeterministicGradientDescent.get_args_registries():
            ans.append(o)

        return ans



    def __init__(self, fcn, nprocs=1, nstarts=60):
        '''
        Calculate the f^c function at sample indexed as ndx
        
        Recall that : f^c(x) = sup_{\Delta} f(x+\Delta) - |\Delta|^2
        
        We solve this using a simple deterministic gradient following numerical method.
        \Delta_k = \Delta_{k-1} - step * gradient
        where gradient is computed from the wavelet functions and the step is calculated 
        by the Armijo backtracking line search methods.
        '''
        
        self.fcn = fcn
        
        # creating it once ensures that the iterate info persists over re-runs of 'calculate'
        self.solvers=[]
        self.n_procs = nprocs
        for n in range(self.n_procs):
            self.solvers.append(SolverForDelta(n, fcn))

        self.n_random_starts=nstarts
        

    def optimize_delta(self, x_anchs, stset=None): 
        
        # initialize solvers for this run
        for n in range(self.n_procs):
            self.solvers[n].initialize_for_multirun(stset)
            
        n_anchs = x_anchs.shape[0]
        ret_fmax = np.zeros((n_anchs,)) 
        ret_opt_delta = np.zeros((n_anchs,self.fcn.dim))
        ret_fmax.fill(-float('inf'))
        
        # Create queues
        task_queue = Queue()
        done_queue = Queue()

        # Submit tasks
        for i in range(self.n_random_starts):        
            for xn in range(n_anchs):
                task_queue.put( (i, xn, x_anchs[xn]) )
                
        # Start worker processes
        for i in range(self.n_procs):
            Process(target=self.solvers[i].process_worker, args=(task_queue, done_queue)).start()
    
        # Get and print results
        #fvals=np.zeros((ptsperdim, ptsperdim))
        #gradients=np.zeros_like(deltas)
        # print('Unordered results (shd be {} in number):'.format(len(TASKSCALC)))
        
        if SolverForDelta._Gather_Plot_Data:
            self.detailed_results_HACK = {}  
        
        for i in range(n_anchs * self.n_random_starts):
            
            fnval, iterv, rndx, anchndx, pid, detailed_results = done_queue.get() 
            
            if SolverForDelta._Gather_Plot_Data:
                if anchndx not in self.detailed_results_HACK:
                    self.detailed_results_HACK[anchndx] = []
                self.detailed_results_HACK[anchndx].append(detailed_results)
                            
            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("Proc # {} reported run # {} for anchor #{} with value {} at {}".format(
                             pid, rndx,anchndx, fnval,iterv))

            if fnval > ret_fmax[anchndx]:
                if _LOGGER.isEnabledFor(logging.INFO):
                    _LOGGER.info("(Proc# {:3d}, run# {:3d}) for anch #{:3d}: updated best from {:12.9f} to {:12.9f}".format(
                             pid, rndx, anchndx, ret_fmax[anchndx], fnval))
                ret_fmax[anchndx]=fnval
                ret_opt_delta[anchndx, :] = iterv
    
        # Tell child processes to stop
        for i in range(self.n_procs):
            task_queue.put('STOP')
        
        return ret_fmax, ret_opt_delta
    
    def do_detailed_plots(self, filepre, anchor_set, opt_pts_set, sep_plots=False, derivplot=False):
        ''' screw this, will figure it out later'''
        pass
        
