import time,re

## sys.version_info(major=3, minor=6, micro=8, releaselevel='final', serial=0)
#if sys.version_info.major >=3 and  sys.version_info.minor <=6 :
from numpy import remainder as remdr
#else:
#   from math import remainder as remdr

from util_py.arg_parsing import ArgumentRegistry

import logging

_LOGGER = logging.getLogger('optim.iterdata')

def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            setattr(func, k, kwargs[k])
        return func
    return decorate

@static_vars(counter=-1)
def valueCounter():
    valueCounter.counter += 1
    return valueCounter.counter #print "Counter is %d" % foo.counter


class ValueType():
    
    def __init__(self, ndx, nm, shnm, fmt):
        self.index = ndx
        self.description = nm
        self.short_name = shnm
        self.format = fmt
        self.value=None

    def __str__(self):
        return "Ndx({:d}) value({}) type ({}) Desc({}) handle ({}) fmt ({}) ".format(
                self.index, self.value, type(self.value), self.description,
                self.short_name, self.format)
        
class IntValueType(ValueType):

    def __init__(self, ndx, nm, shnm, defl=0, fmt='6d'):
        super(IntValueType,self).__init__(ndx,nm,shnm, fmt)
        self.value = int(defl)
        self.init_value = self.value 

    @staticmethod
    def get_new(nm,shtnm,defl=0,fmt='6d'):
        return IntValueType(valueCounter(), nm, shtnm, defl, fmt)

class FloatValueType(ValueType):

    def __init__(self, ndx, nm, shnm, defl=0., fmt='12.9f'):
        super(FloatValueType,self).__init__(ndx,nm,shnm,fmt)
        self.value = float(defl)
        self.init_value = self.value

    @staticmethod
    def get_new(nm,shtnm,defl=0., fmt='12.9f'):
        return FloatValueType(valueCounter(), nm, shtnm, defl, fmt)


class AbsAlgoData(object):
    r'''
    This class captures some basic info of algorithm data to print out every iteration. 
    This can be subclassed to allow for more info gathering.
    '''    
    argname_skipwriting='output_cadence'
    argdefval_skipwriting=5
    
#    argname_calctest='shd_calc_test'
#    argdefval_calctest=False

    @staticmethod
    def get_arg_registry():
        '''
        Return an arguments registry with all the required inputs
        '''
        arg_reg = ArgumentRegistry('algodata')

        arg_reg.register_int_arg(AbsAlgoData.argname_skipwriting,
                             'the number of iterations to skip before writing out current iterate to output csv file',
                             AbsAlgoData.argdefval_skipwriting)

        return arg_reg

    
    def __init__(self, arg_dict=None, ocad=5):

        if not ( (arg_dict is None) ) :
            arg_reg = self.get_arg_registry()
            self.output_cadence = arg_dict[arg_reg.get_arg_fullname(self.argname_skipwriting)]
        else:
            self.output_cadence = ocad    

        self.n_itr     = IntValueType.get_new('Iter (count)', 'n_itr')
        self.n_samp    = IntValueType.get_new('Samples (number)', 'n_samp')
        self.c_samp    = IntValueType.get_new('CumSamples (number)', 'c_samp')
        self.n_time    = FloatValueType.get_new('Wallclock Time (secs)', 'n_time')
        self.c_time    = FloatValueType.get_new('Wallclock Cum Time (secs)', 'c_time')
        self.objval    = FloatValueType.get_new('Objfcn Value (number)','objval' )
        self.stepsize  = FloatValueType.get_new('Step length (number)','stepsize' )
        self.dirngrad  = FloatValueType.get_new('dirn^t grad (number)','dirngrad' )

        
        self.data_values ={}
        for vl in (self.n_itr, self.n_samp, self.c_samp, self.n_time, self.c_time , \
                   self.objval, self.stepsize, self.dirngrad):
            self.data_values[vl.index] = vl

        # user can switch off initialization if needed.
        self.shd_initialize=True

        self.output_to_csvfile = False

    def attach_csvwriter(self, csvwrtr):
        # create/open output file
        if csvwrtr is not None:
            self.output_to_csvfile=True
            self.csv_writer = csvwrtr

    def detach_csvwriter(self):
        self.output_to_csvfile=False
        self.csv_writer = None

    def initiate_output(self, optim_object):

        self.optimizer = optim_object
        self.objective = self.optimizer.objective
        
        self.st_time = time.time()
        
        # user can switch off initialization if needed from outside
        if self.shd_initialize:
            for v in self.data_values:
                vl=self.data_values[v]
                vl.value=vl.init_value
            
            # output to file
            if self.output_to_csvfile: 
                ln=[ self.data_values[ke].short_name for ke in self.data_values  ]
                self.csv_writer.writerow(ln)
        
            # log out
            if _LOGGER.isEnabledFor(logging.INFO):
                strn=''
                for ke in self.data_values:
                    vl = self.data_values[ke]
                    # we need only the first numbre in the format string
                    nu=re.search(r'^\d*',vl.format).group()
                    strf = ' {{:{}s}} ,'.format(nu)
                    strn += strf.format(vl.short_name)
                _LOGGER.info(strn)
            
    def _write_output(self):
        #csv out
        if self.output_to_csvfile:    
            # writing this to file
            ln=[ self.data_values[ke].value for ke in self.data_values  ]
            self.csv_writer.writerow(ln) 
            # print("Writtern row")

        # logging out
        if _LOGGER.isEnabledFor(logging.INFO):
            strn=''
            for ke in self.data_values:
                vl = self.data_values[ke]
                strf = ' {{:{}}} ,'.format(vl.format)
                try:
                    strn += strf.format(vl.value)
                except ValueError as e:
                    print("Value {} doesnt like formatting".format(vl))
                    raise e
            _LOGGER.info(strn)

        
    def _cadence_rem(self):
        return abs(remdr(self.n_itr.value-1, self.output_cadence))       
        
    def output_current(self, nsm, obv, stz, dirn):    
        r'''
        Overload this to provide more output! Use the obj fcn object to do more 
        sleuthing if needed. This is useful to do gathering for the extra data
        values that were added in the initialization of this object.
        '''
        # update these here
        self.n_itr.value += 1
        self.n_samp.value = nsm
        self.c_samp.value += self.n_samp.value # estimator.index_sampler.cum_samples_accessed

        tm = time.time()
        self.n_time.value = tm - self.st_time
        self.c_time.value += self.n_time.value
        self.st_time = tm

        self.objval.value = obv
        self.stepsize.value = stz
        self.dirngrad.value = self.objective.get_gradient_dot(dirn)        
        
        remm = self._cadence_rem() 
        if (remm <= 0.0):  # we report only at this cadence
            self._write_output()
    
    
    def terminate_output(self):

        remm = self._cadence_rem()        
        if (remm > 0.0):  # we skipped last iter because of report cadence
            self._write_output()

        # for added measure
        self.optimizer = None
        self.objective = None
