import logging

_LOGGER = logging.getLogger("util.bisection")


def is_point_in_interval(pt, ntvl_lims):
    return pt < ntvl_lims[1] and pt >= ntvl_lims[0]
        

def do_bisection_search(range_vals, tgt, func, is_incr=True, ftol=1e-9, xtol=1e-9):
    """
    Given a function f(x), ASSUMED to be an INCREASING function of x,
    deduces the x at which f(x) = tgt.
    'range_vals' is a two dimensional array with [min, max], where max> min + xtol
    'is_incr' is True for increasing by default, make it False for decreasing
    Throws error if f(min) > tgt or f(max) < tgt when increasing, reverse when
    decreasing.
    """
    sense =1.0  # increasing
    if not is_incr: sense = -1.0
    min_x, max_x = range_vals

    if (max_x < min_x + xtol):
        errm="diff betn maxx {:5.2e} and minx {:5.2e} = {:5.2e} < xtol {:5.2e}".format(
                max_x, min_x, max_x-min_x, xtol)
        _LOGGER.error(errm)
        raise ValueError(errm)
        
    if func(min_x) * sense > tgt * sense:
        errm = "{}*increasing fn has f(min_x) = {:5.2e} > tgt {:5.2e}".format(
                sense, func(min_x), tgt)
        _LOGGER.error(errm)
        raise ValueError (errm)

    if func(max_x) * sense < tgt * sense:
        errm = " fn sense {:6.2f} in domain [{:5.2e},{:5.2e}] range [{:.2e},{:5.2e}]".format(
                sense, min_x, max_x, func(min_x), func(max_x))
        errm += "\n{}*increasing fn has at max_x f({:5.2e}) = {:5.2e} < tgt {:5.2e}".format(
                sense, max_x, sense * func(max_x), sense * tgt)
        _LOGGER.error(errm)
        raise ValueError (errm)

    while (max_x - min_x > xtol):
        x= (max_x + min_x) * 0.5
        fx = func(x)
        
        if ( fx * sense > sense * tgt + ftol ):
            max_x = x
        elif (fx * sense < sense * tgt - ftol ):
            min_x = x
        else:
            break
        
    return (max_x + min_x)*0.5

#from math import ceil
import numpy as np

def find_zeroes(range_vals, delta, func1, func2, ftol=1e-9, xtol=1e-9):
    
    # first divide the range into the smaller delta ranges
    #n_ndd = 2*int(ceil((range_vals[1]-range_vals[0]/ (delta*( .5 - 3./12. ))))) 
    #unifs= np.random.uniform(size=(n_ndd,))
    flam =lambda x : func1(x) - func2(x)
    # for each interval
    retvals=[]
    stt , end = range_vals[0]  , range_vals[0]
    #ndx=0
    while end < range_vals[1]: 

        # setup the interval
        stt = end 
        unif = np.random.uniform() # unifs[ndx]
        end = stt + unif * delta # range_vals[0] + nv * delta, min(range_vals[1], range_vals[0] + (nv+1)*delta)]
        flo,fhi = flam(stt), flam(end)

        # search only if the difference straddles zero         
        if flo * fhi <0.:
            ret= do_bisection_search([stt,end], 0., flam, (flo<fhi), ftol, xtol)             
            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("Entered  [{:8.5f},{:8.5f}] and found {} to have value {} vs tgt {}".format(stt,end,ret,func1(ret), func2(ret)))
            retvals.append(ret)
        else:
            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("Skipping [{:8.5},{:8.5f}] because fdiff were {}, {} to have value on same side of tgt 0".format(
                        stt,end, flo,fhi))
         
        ##ndx+=1

    #_LOGGER.info("Used {} of {} uniforms.".format(ndx, n_ndd))
    return retvals
        
    

def sort_all_by_index(to_sort, ndx=0):
    try:
        srtord=np.argsort(to_sort[ndx])
    except KeyError as e:
        _LOGGER.error("Can't seem to find key {} in {}\nEror: {}".format(ndx,to_sort, e))
        return
    
    for cl in to_sort:
        vals = np.array(to_sort[cl])
        to_sort[cl] = vals[srtord]



def get_running_average(to_avg, navg=2):
    if type(to_avg) == np.ndarray:
        leng=to_avg.shape[0]
    else:
        leng=len(to_avg)
        
    retval=np.zeros((leng-navg+1,))
    for n in range(navg-1,leng):
        retval[n-(navg-1)]= np.average(to_avg[n-(navg-1):n])
        
    return retval
    
    
