import numpy as np
import logging

_LOGGER= logging.getLogger('wdro.pstar')

from optim_py.sampling import AbsSampler


class GaussMix(AbsSampler):
    
    def __init__(self, mixp, means, stdevs):
        
        assert(len(mixp) == means.shape[0])
        assert(len(mixp) == stdevs.shape[0])
        assert(stdevs.shape[1] == means.shape[1])

        self.dimen = means.shape[1]
        # FINALLY, call the init in the mother class
        #super(GaussMix,self).__init__()

        self.mixture=mixp
        self.n_mxs = len(self.mixture)

        self.means= means
        self.stdevs=stdevs
        self.covars = [np.diag(self.stdevs[h]) for h in range(self.n_mxs)]
        
        
    def next_sample(self, n_samp):

        # sample the mixture distn first
        ndxsmp=np.random.choice(self.n_mxs,size=n_samp, replace=True, p=self.mixture)            

        # now sample the individual values using the mixture ndx
        ret= np.zeros((n_samp,self.dimen))
        for mx in range(self.n_mxs):
            ndx = np.where(ndxsmp==mx)
            nsz =  np.size(ndx)
            if (nsz>0):
                # do this only if we have non-zero need to sample from this mixture element
                tmpret = np.random.multivariate_normal(self.means[mx],self.covars[mx],(nsz))            
                # print("shp {}, ndx {}".format(tmpret.shape, ndx))
                ret[ndx,:]=tmpret

        return ret


    def get_range_of_lincomb(self, thetas):
        
        n_thetas=thetas.shape[0]
        retval = np.zeros((n_thetas,2))
        mean_theta = thetas @ (self.means.transpose())

        for n in range(n_thetas):
            mean_ths = mean_theta[n]
            th=thetas[n]
            amin,amax=np.argmin(mean_ths), np.argmax(mean_ths)
            stdmin = th.dot( th * self.stdevs[amin])
            stdmax = th.dot( th * self.stdevs[amax])
            retval[n][0] = mean_ths[amin] - 3*stdmin
            retval[n][1] = mean_ths[amax] + 3*stdmax
        
            if _LOGGER.isEnabledFor(logging.INFO):
                _LOGGER.info("means along {} theta: {}".format(n,mean_ths))
                _LOGGER.info("arg min/max means {}/{}".format(amin,amax))
                _LOGGER.info("stdev of min/max {}, {}".format(stdmin,stdmax))
            
        return retval
    
