import numpy as np
import logging, os

_LOGGER=logging.getLogger('wdro.rob_profile')

from func_f import FApprox
#from func_f_c_onetheta import FcApproxSingleTheta
from func_f_c_Jdim import FcApproxJDim

from util_py.csv_readwrite import get_excelcsv_writer
from util_py.file_ops import ensure_dir_exists

from optim_py.abs_objective import NumpyObjective
from optim_py.sampling import  FixedSamplingRate
from optim_py.stoch_gradient_descent import StochasticGradientDescent
from optim_py.stop_criterion import StopMinObjRelChange,StopMaxIterations,StopSet
from optim_py.algodata import AbsAlgoData
from optim_py.learning_rate import PolyDenRate
from optim_py.momentum import NoMomentum


class RobProfObjective(NumpyObjective):

    
    def __init__(self, f, empr_set, p_star_sam, n_proc=1, nstarts=60):

        self.func = f

        super(RobProfObjective, self).__init__(f.n_weights)

        self.func_c_eval = FcApproxJDim(self.func, n_proc, nstarts)

        # self.P_n = p_n
        self.empr_set = empr_set
        self.p_n_size = self.empr_set.shape[0] # self.P_n.shape[0]
        
        self.p_n_recip = -1.0/self.p_n_size

        self.p_star_sampler = p_star_sam
        
        self.n_inner_iters=20
        
        self.f_c_x_vals = np.zeros_like(self.empr_set) #(self.p_n_size, self.func.dim))
        
        _LOGGER.info("Num Weights {} , PMF dim {}".format(self.func.n_weights,
                     self.func.dim))
        
    def set_samples(self, sampl, ssiz):
        _LOGGER.debug("Size of samp {} vs shape {}".format(ssiz, sampl.shape))
        self.curr_samples = sampl
        self.n_curr_samples = ssiz
    
    def initialize(self):
        #self.iterate[:] = np.random.random(self.iterate.shape)
        #self.func.copy_weights_to(self.iterate)
        # need this so that the plot below of the itertes is more representative!
        self.func.randomize_weights(0.01)
        self.func.copy_weights_to(self.iterate)
        #self.func.set_weights_from(self.iterate)
        self.f_c_x_vals[:] = self.empr_set

    def evaluate_fn_and_derivatives(self, skipgrad=False):

        # firstset weights in wavelets from iterate
        self.func.set_weights_from(self.iterate)

        
        # this is the average of curr_f over the current sample drawn from p_star
        retval = np.sum(self.func.get_value(self.curr_samples)) / self.n_curr_samples 
        
        #for f_c in self.func_c_set:    
        fcvals, optdeltas= self.func_c_eval.optimize_delta(self.empr_set, self.curr_samples) # optimize_delta(self.n_inner_iters)

        self.f_c_x_vals[:] = self.empr_set + optdeltas
        
        for n in range(self.p_n_size):
            retval += self.p_n_recip * fcvals[n]

        if not skipgrad:
            self.gradient.fill(0.0)
            
            # this is the average of curr_f over the current sample drawn from p_star        
            mult =1./self.n_curr_samples
            self.func.fill_gradient_wrt_weights(self.curr_samples,   self.gradient, mult)
            
            # retrieve the x_n + delta_n vlaues from f_cs            
            #for c in range(self.p_n_size):
    
            self.func.fill_gradient_wrt_weights(self.f_c_x_vals, self.gradient, self.p_n_recip)
                
        return retval


class RobustProfile(object):
    '''
    The robust profile function is 
        R_n = sup_{f \in \cF} E_{P^*} f(X) - E_{P_n} f^c(X)
    
    This follows from the duality result that has been established.
    In our case, we model the set \cF using a basis of wavelet functions.
    
    
    '''
    
    def __init__(self, ths, eset, glvl, samplr, m_siz, n_proc):
        
        self.thetas = ths
        self.empr_set = eset
        self.grain_level = glvl
        
        self.p_star_sampler = samplr
        self.p_star_sample_size = m_siz
        
        _LOGGER.info("P_star sample size M = {}".format(self.p_star_sample_size))
        # setup the object that will contain the function approximation 
        # using wavelets. They are initialized using the expectations from
        # the empriical set
        rngset = samplr.get_range_of_lincomb(self.thetas)
        _LOGGER.info("Fn range as reported by p_star: {}".format(rngset))
        eset_thx = eset @ ths.transpose()
        eset_rng_mx = np.max(eset_thx,axis=0)
        _LOGGER.info("Max as reported by emp_set: {}".format(eset_rng_mx))
        eset_mx_btr,=np.where(eset_rng_mx>rngset[:,1])
        if (eset_mx_btr.size > 0):
            rngset [eset_mx_btr,1] = eset_rng_mx[eset_mx_btr]
            _LOGGER.info("Max corrected to:\n", rngset[:,1])
        else:
            _LOGGER.info("No correction needed for max from emp_set.")
        eset_rng_mn = np.min(eset_thx,axis=0)
        _LOGGER.info("Min as reported by emp_set: {}".format( eset_rng_mn))
        eset_mn_btr,=np.where(eset_rng_mn<rngset[:,0])
        if (eset_mn_btr.size > 0):
            rngset [eset_mn_btr,0] = eset_rng_mn[eset_mn_btr]
            _LOGGER.info("Min corrected to: {}".format(rngset[:,0]))
        else:
            _LOGGER.info("No correction needed for min from emp_set.")

        self.func = FApprox(self.thetas, rngset, self.grain_level)

        self.n_thetas = self.func.n_thetas
        self.empr_set_size = eset.shape[0]


        self.rob_objective = RobProfObjective(self.func, self.empr_set, self.p_star_sampler, n_proc)
        self.is_maxima= True
        
        # sundires needed for sgd
        self.algodata=AbsAlgoData(None,1)

        # our stop criterion will be the min of max-iters and rel-change, with the 
        # expectation that the max-iter param will be changed dynamically. This will
        # likely be triggered most.        
        self.stopmxitr=StopMaxIterations()
        self.stopmnchg=StopMinObjRelChange(None,None,60,0.7,40) 
        self.stopcrit=StopSet(None,None,None,[self.stopmxitr,self.stopmnchg])
        self.stopcrit.is_min(not self.is_maxima)

        self.learning_rate = PolyDenRate(None,None, .1, 100.0) #, PolyNumRate()
        self.sampling_rate = FixedSamplingRate(None, None, self.p_star_sample_size)
        self.momentum = NoMomentum(self.rob_objective, self.is_maxima)


    def __plot_function_sets(self, outbasedir, filepre, itrn):
        if itrn==0:
            plot_empr_sets = [self.empr_set, self.rob_objective.f_c_x_vals]
        else:
            plot_empr_sets = [self.rob_objective.curr_samples, self.empr_set, self.rob_objective.f_c_x_vals]
        if _LOGGER.isEnabledFor(logging.DEBUG):
            for jk in range(len(plot_empr_sets)):
                _LOGGER.info("set[{}] shape {}".format(jk, plot_empr_sets[jk].shape))
                
        pngnm= os.path.join(outbasedir, filepre,'iter[{:04d}]'.format(itrn+1))
        self.rob_objective.func_c_eval.do_detailed_plots(
                pngnm,self.empr_set, self.rob_objective.f_c_x_vals)
        # self.func.plot_fcn(pngnm, plot_empr_sets)
        _LOGGER.info("At itr {:5d}, wrote iterate to {}".format(itrn+1, pngnm))

    
    def calculate_one_sample(self, nitrs, outbasedir, maxstep=0.1, filepre=None, plot_stride=20):

        self.stopmxitr.max_iters = nitrs
        
        optimizer= StochasticGradientDescent(self.rob_objective,self.stopcrit,
                                                  self.algodata,self.p_star_sampler,
                                                  self.sampling_rate,
                                                  self.learning_rate, self.momentum,
                                                  None, True)    

        self.learning_rate.initial = maxstep

        # open a csv output channel
        if filepre is not None:
            if plot_stride >=0:
                ensure_dir_exists(os.path.join(outbasedir, filepre,''))

            csvwr,filwr = get_excelcsv_writer(os.path.join(outbasedir, filepre))
            self.algodata.attach_csvwriter(csvwr)

        optimizer.initialize()
        

        # run optim algo
        while not optimizer.should_itrs_stop():
            itrn=self.algodata.n_itr.value

            # shd we output png
            if plot_stride>=0 and (abs(np.remainder(itrn, plot_stride)) <= 0.):
                npynm= os.path.join(outbasedir, filepre,'iter[{:04d}]'.format(itrn+1))
                np.save(npynm,self.rob_objective.iterate)
                self.__plot_function_sets(outbasedir, filepre, itrn)

            # take the step
            objv=optimizer.step()

            _LOGGER.info("At itr {:5d}, objv {:20.15f}".format(itrn+1, objv))

        # shd we output png
        if plot_stride >=0 :
            self.__plot_function_sets(outbasedir, filepre, itrn+1)
            
        optimizer.terminate()
        
        if filepre is not None:
            filwr.close()
            
        retval = objv # self.rob_objective.evaluate_fn() 
        _LOGGER.info("Calculated sample: {:20.15f}".format(retval))
        # return current obj val
        return retval

