import numpy as np
from math import sqrt,ceil,pi

import logging

_LOGGER= logging.getLogger('wdro.wavelet')

import matplotlib
matplotlib.rcParams['text.usetex'] = True
#matplotlib.rcParams['text.latex.unicode'] = True
import matplotlib.pyplot as plt
from util_py.plotting import colors_ring

class Marr: 
    '''
    AKA (inverted?) mexican hat, standard form defined as:
        (2/\sqrt(3)*\pi^{1/4}) *(1-t^2)*exp{-t^2/2} 

    derivative: 
        (2/\sqrt(3)*\pi^{1/4}) *-(3t-t^3)*exp{-t^2/2} 
    '''

    const=(2./(sqrt(3.)*pow(pi,1./4)))
    valid_range=4.5
    
    def get_range(self):
        ''' standard has range [-3,3]
        '''
        return [-Marr.valid_range, Marr.valid_range]
    
        
    def get_value_at(self,x):
        px=np.power(x,2.)
        return  self.const * (1-px)*np.exp(-px/2.) 

    
    def get_deriv_at(self,x):
        px=np.power(x,2.)
        return  self.const * x* (px-3.)*np.exp(-px/2.) 
    
    def get_second_deriv_at(self, x):
        px=np.power(x,2.)
        #print(px.shape)
        #poly=[-px*px + 6*px - 3]
        
        ret= self.const * np.exp(-px/2.) * (-px*px + 6*px - 3)
        #print(ret.shape)
        return ret
        
class wavelet:
    '''
    wavelet basis of a function f  = \int_s \int_u  w(s,u) (1/\sqrt(s)) \phi((t-u)/s)) duds
    where the weights are derived from the integral \int_t f(t)  (1/\sqrt(s)) \phi((t-u)/s)) dt.
    Note that if f(t) is a density then w(u,s) = \E [(1/\sqrt(s)) \phi((t-u)/s)) ]

    The varaible u is called the translation and s the scaling. 

    We will use a discrete version of this in this class. User will specify a range [-M,M] 
    to cover and a max "granularity of discretization" 2^{-L} for s. Suppose the mother wavelet 
    functions has a range [-m,m]. Then, we cover the set of pairs:
        { (s,u) : s=2^{l}, l=0,...,L , u = k* (\lceil M/(2m) \rceil +1 ), k = -s,-(s-1),...,0,1,...,s }
    '''
 
    def __init__(self, emprlims, mthr, J):
#        if rng < 0: 
#            raise ValueError("Range {} of the function can't be negative! Specify M where f\in[-M,M].".format(rng))

        if J<0:
            raise ValueError("Granularity of descretization {} can't be negative. Specify J such that 2^{-J} grains allowed.".format(J))

        self.J=J
        self._mother = mthr
                
        self._fn_lims=[np.min(emprlims), np.max(emprlims)]
        self._fn_range=self._fn_lims[1] - self._fn_lims[0]
        self._fn_center= np.sum(self._fn_lims)*0.5
        
        _LOGGER.debug("Fn lims: {} , so ctr {} and rng {}".format(
                self._fn_lims, self._fn_center, self._fn_range))

        self._mthr_lims=self._mother.get_range()
        self._mthr_range= self._mthr_lims[1]-self._mthr_lims[0]

        '''
        We cover as follows: 
            we start adding wavelets centered at the center of the fn range.
            So, on each side, we need at least 
               K = ceil ( fn_range*0.5 / (wvlt_range * 0.5) )
            half-wavelets.
            Given that the first half-wavelet belongs to the one centered at fn_cntr,
            we subtract one half wavelet, which means we need (K-1)/2 more wavelets in either
            dirn to cover the range. Note that K could be even, so we really need even(K-1,k) / 2
            Then, total wavelets are 1 + 2 * even(K-1/k) /2 .
        '''
        self.n_weights=0
        self.n_weights_per_level=[]
        self._wvapx_tru_rng=[float('inf'),-float('inf')]
        
        for j in range(self.J+1): # remember, range counts till max-1
            
            trng,ds,ds=self.get_wavelet_params(j)
            K = ceil(self._fn_range/ trng)
            n_t_wvlts = 1 + (K-1) + (K-1) % 2
            self.n_weights_per_level.append(n_t_wvlts)
            self.n_weights += n_t_wvlts

            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("for J {:d} , this rng {:6.4f} vs fnrng {:6.4f}, K is {} so adding {} more".format(
                        j,trng, self._fn_range, K, n_t_wvlts))

            n_per = int((n_t_wvlts -1) / 2 )
            for k in range(-n_per, n_per+1):
                ctr=self._fn_center + trng*k
                _LOGGER.debug("cntr: {}".format(ctr))
                if ctr - trng*0.5 < self._wvapx_tru_rng[0] :
                    self._wvapx_tru_rng[0] = ctr - trng*0.5
                if ctr + trng*0.5 > self._wvapx_tru_rng[1] :
                    self._wvapx_tru_rng[1] = ctr + trng*0.5
                            
        _LOGGER.info("total n weights: {}".format(self.n_weights))
        self.weights = np.zeros((self.n_weights, ))

    def __wtndx_to_lvl(self, ndx):
        sm=0
        for n in range(self.J+1):
            if (ndx <sm + self.n_weights_per_level[n]): 
                return n,sm
            sm += self.n_weights_per_level[n]

    def guess_reasonable_funclimits(self):
        retval = np.array([float('inf'), - float('inf')])
        
        wgamax = np.argsort(-np.abs(self.weights))
        wmax=abs(self.weights[wgamax[0]])
        
        for i in range(self.n_weights):
            # keep going till we drop 1e-4 in value from top?
            tndx = wgamax[i]
            t_wt=abs(self.weights[tndx])
            if t_wt < 1e-4 * wmax: break
            
            lvl,lvlst = self.__wtndx_to_lvl(tndx)
            
            wvrng,m,s = self.get_wavelet_params(lvl)     
            n_per = int((self.n_weights_per_level[lvl] -1) / 2 )
            k = tndx-lvlst-n_per
            
            ctr=self._fn_center + wvrng* k
            if retval[0] > ctr - wvrng*0.5: retval[0] = ctr-wvrng*0.5
            if retval[1] < ctr + wvrng*0.5: retval[1] = ctr+wvrng*0.5
            
        
        return retval

    def get_wavelet_params(self, lvl):
        mult=2**lvl
        scl=sqrt(mult)
        wvrng=self._mthr_range/mult
        return wvrng, mult, scl        

        
        
    def __iterate_over_levels(self, fn):
        r'''
        fn interface should be fn( )
        '''
        #ret=np.zeros_like(x)
        ndx=-1
        
        for j in range(self.J+1):
            wvrng, mult, scl =self.get_wavelet_params(j)

            n_per = int((self.n_weights_per_level[j] -1) / 2 )

            for k in range(-n_per, n_per+1):
                ndx+=1

                ctr=self._fn_center + wvrng*k
                fn( ctr , wvrng*0.5, ndx, scl, mult, j)

        
    def __get_summation_over(self, x, fn):
        '''
        implement the derivative of the approximation f using the mother function's
        derivative as well as the weight parameters
        '''        
        #we treat the unitary x case specially to avoid awkward indexng issues below
        if type(x) == float or type(x) == np.float64:
            retval=np.zeros((1,)) 
            # NEED the np array wrapper to preserve value within
            # this local function!!
            def _summ_over_iter_1x(ctr, rng, ndx, fscale, xmult, j):
                if ( (x>=ctr-rng) and (x<=ctr+rng) ):
                    # we are REVERSING the wavelet function in order to get the correct fn 
                    # or gradient value here.
                    retval[0] += self.weights[ndx] * fscale * fn(x-ctr, xmult )
            
            self.__iterate_over_levels(_summ_over_iter_1x)
            return retval[0]             
            
        # so here the x object is a np.ndarray
        retval = np.zeros_like(x)
        
        def _summ_over_iter(ctr, rng, ndx, fscale, xmult, j):
            nddd= np.where( (x>=ctr-rng) & (x<=ctr+rng) )
            if (np.size(nddd) > 0):
                # we are REVERSING the wavelet function in order to get the correct fn 
                # or gradient value here.
                retval[nddd] += self.weights[ndx] * fscale * fn( x[nddd]-ctr, xmult )
            
        self.__iterate_over_levels(_summ_over_iter)            
        return retval

    def get_value(self, x):
        '''
        implement the approximation f using the mother function's value as well as the weight parameters
        '''        
        return self.__get_summation_over(x, lambda dlx, mlt: self._mother.get_value_at(dlx*mlt))

    def get_deriv(self,x):
        '''
        implement the derivative of the approximation f using the mother function's
        derivative as well as the weight parameters
        '''        
        return self.__get_summation_over(x, lambda dlx, mlt: mlt*self._mother.get_deriv_at(dlx*mlt))
    
    def get_second_deriv(self, x):
        '''
        implement the second derivative of the approximation f using the mother function's
        second derivative as well as the weight parameters
        '''        
        return self.__get_summation_over(x, lambda dlx, mlt: mlt*mlt*self._mother.get_second_deriv_at(dlx*mlt))
    

    def fill_basis(self, x, retval, offset, multext):
        r'''
        fill 'retval' from offset index with the value of the basis functions, and
        also multiply with multext.
        '''
        
        def _extract_basis_over_iter(ctr, rng, ndx, fscale, xmult, j):
            nddd= np.where( (x>=ctr-rng) & (x<=ctr+rng) )
            if (np.size(nddd) > 0):
                fvals_sum = np.sum(self._mother.get_value_at( xmult*(x[nddd]-ctr) ))
                retval[ndx+offset] = fscale * fvals_sum *multext
            
        self.__iterate_over_levels(_extract_basis_over_iter) 
        


    def init_weights_from_emprdistn(self, emprset):

        N = np.size(emprset)
        self.weights.fill(0.0)

        def _set_wts_over_iter(ctr, rng, ndx, fscale, xmult, j):
            nddd= np.where( (emprset>=ctr-rng) & (emprset<=ctr+rng) )
            if (np.size(nddd) > 0):
                fvals_sum = np.sum(self._mother.get_value_at( xmult*(emprset[nddd]-ctr) ))
                self.weights[ndx] = fscale * fvals_sum / N
            
        self.__iterate_over_levels(_set_wts_over_iter) 
        
            
    def _plot_value(self, ax, func, x=None,  color='k', lw=0.5):
        if x is None:
            x_lims = self._wvapx_tru_rng
            x=np.arange(x_lims[0],x_lims[1],.01)

        fv=func(x)
    
        ax.plot(x, fv, color=color, linewidth=lw)
        return np.max(fv), np.min(fv)
        
    def _plot_place_funcvalue_dots(self, ax, ptset, func, color='k', lenr=1., relwidth=0.005):

        x_lims = self._wvapx_tru_rng# self._fn_lims
        width = relwidth*(x_lims[1]-x_lims[0])
        
        if abs(lenr) > 1e-9:
            ax.bar(ptset, np.ones_like(ptset)*lenr, width=width,color=color)
        
        # func vals
        fvals = func(ptset)        
        ax.plot(ptset, fvals, color=color, linestyle='', marker='o',
                            markersize=5, fillstyle='full')
        

    def plot_wavelet(self, theta_times_x_es=None, need_deriv=True):
        r'''
        Note that the emprirical sets must be one dimenional! That is, they 
        should be transformed using the resp. thetas before getting to this 
        function.
        '''
        #sprd=self._fn_range*1.05*0.5
        #x_lims=(self._fn_center - sprd, self._fn_center + sprd)
        x_lims = self.guess_reasonable_funclimits() #_fn_lims #wvapx_tru_rng# self._fn_lims
#        sprd = x_lims[1] - x_lims[0]
        _LOGGER.debug("X_lim for plotting is now: {}".format(x_lims))
        x=np.arange(x_lims[0],x_lims[1],.01)
                
        if need_deriv:
            fig1, (ax0,ax1,ax2,ax3) = plt.subplots(nrows=4,ncols=1)
            fig1.set_size_inches(10.,12.75)
            for ax in (ax0,ax1,ax2,ax3):
                ax.set_xlim(x_lims)
                
        else:
            fig1, (ax0,ax1) = plt.subplots(nrows=2,ncols=1)
            fig1.set_size_inches(10.,6.35)            
            for ax in (ax0,ax1):
                ax.set_xlim(x_lims)

        # the first plot provides the weights
        # change default range from [0,1] so that circles will work
        ax0.set_ylim((-1, self.J+1))        
        ax0.set_ylabel("Relative Weights")
        wtabslog = np.abs(self.weights)
        radius_mx, plt_rad_mx = np.max(wtabslog), 0.45
        mltfac=plt_rad_mx/radius_mx
        wtabslog *= mltfac
        
        #print("wtablog: {} , radmx {}".format(wtabslog, radius_mx))
        def _place_circles_over(ctr, rng, ndx, fscale, xmult, j):

            rad = wtabslog[ndx]
            #print('ctr {}, radius {}'.format(ctr, rad))
            if rad > 1e-4: 
                col='k'
                if self.weights[ndx] < 0.: col='r'
                circ = plt.Circle((ctr, j), rad, color=col)
                ax0.add_artist(circ)
            
        self.__iterate_over_levels(_place_circles_over) 

            
        # the second provides the function value
        fmx, fmn = self._plot_value(ax1, self.get_value, x)
        ax1.set_ylabel("Fn value")

        # next we add all the emprirical sets provided
        if theta_times_x_es is not None and len(theta_times_x_es) > 0:
            n_sets=len(theta_times_x_es)
            lrdn=0.8
            for n in range(n_sets):
                color = colors_ring[0]
                colors_ring.rotate(1)

                self._plot_place_funcvalue_dots(ax1, theta_times_x_es[n], self.get_value, 
                                                color, fmx * pow(lrdn,n))
                

            # for repeatability
            colors_ring.rotate(-n_sets)
            
        # finally derivatives
        if need_deriv:
            dmx = self._plot_value(ax2, self.get_deriv, x)
            ax2.set_ylabel("deriv value")
    
            smx = self._plot_value(ax3, self.get_second_deriv, x)
            ax3.set_ylabel("second deriv value")
            return (fmx, dmx, smx), fig1, (ax0,ax1,ax2,ax3)
        else:
            return fmx, fig1, (ax0,ax1)

