# pack CANN as a function, with alpha as input, and return U of CANN at every time step

import sys
sys.path.append('..')
import numpy as np
from CANN_1D import CANN_v2
from scipy import ndimage

def CANN(alpha,fano=0.005):
    reverse = alpha < 0
    N = 512
    tau = 1
    trans = True

    a = 0.5
    cor = np.arange(-np.pi, np.pi, 2*np.pi/N)
    J0 = 1
    rho = N / (2*np.pi)

    kc = np.pi*a*rho*J0**2/(4*np.sqrt(2*np.pi))
    k = 0.04*kc

    exp = np.exp(-(cor)**2/(4*a**2))
    Jexp = np.exp(-(cor)**2/(2*a**2))
    J = J0 * Jexp
    J = np.roll(J,shift=int(N/2))

    T = 1000
    dt = 0.1

    net = CANN_v2(N,k,J,tau,trans)

    In = exp
    shift = int(N/5)
    I0 = np.roll(In,shift=shift)
    Iinit = np.roll(exp,shift=2)
    if reverse:
        target = cor[I0.argmax()]
    else:
        target = cor[In.argmax()]
    
    set_time = 100

    threshold = 0.004
    dis = 10

    r_record = []
    dis_record = []

    # return r_record
    while dis > threshold*2*np.pi:
        if len(r_record) < set_time:
            if reverse:
                I_ext = 0.02 * Iinit
            else:
                I_ext = 0.02 * I0
        else:
            I_ext = alpha * In

        I_ext = I_ext + np.sqrt(tau*net.U*fano)*np.random.randn(N)
        net.update(I_ext,dt)
        smooth_r = ndimage.gaussian_filter1d(net.r,sigma=10)
        smooth_center = cor[smooth_r.argmax()]
        center = np.angle(np.exp(1j*(cor-smooth_center))) @ net.r / net.r.sum() + smooth_center
        dis = np.abs(center-target)
        r_record.append(net.r)
        dis_record.append(dis)
    
    r_record = np.array(r_record)[set_time:]
    dis_record = np.array(dis_record)[set_time:]
    return r_record, dis_record



