import numpy as np
from CANN_1D import CANN_v2
from scipy import special
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy import ndimage
import os

num_PC = 4
pca = PCA(n_components=num_PC,copy=True)

# set font type to make it editable in illustrator in eps format
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
# set figure style
plt.style.use('seaborn-white')

# set no margin for the figure
plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0

# set ticks as visible
plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True



# set figure path
figure_path = './figures/NeurIPS_Figures/0'
# determine whether figure_path exists
if not os.path.exists(figure_path):
    os.makedirs(figure_path)

# set figure size
fig_width = 8
fig_height = 8
fig_size = [fig_width,fig_height]


# define a function to compute the firing rate
def Au(k,a=0.5,J0=1,rho=512/(2*np.pi)):
    num = np.sqrt(np.pi) * a * rho * J0 + np.sqrt(np.pi * a**2 * rho**2 *J0**2 - 4 * k * rho * np.sqrt(2*np.pi) * a)
    den = 2 * k * rho * np.sqrt(2*np.pi) * a
    return num / den

# network parameters
N = 512
tau = 1
trans = True

a = 0.5
cor = np.arange(-np.pi, np.pi, 2*np.pi/N)
J0 = 1
rho = N / (2*np.pi)

kc = np.pi*a*rho*J0**2/(4*np.sqrt(2*np.pi))
print('kc:',kc)
k = 0.04*kc

exp = np.exp(-(cor)**2/(4*a**2))
Jexp = np.exp(-(cor)**2/(2*a**2))
J = J0 * Jexp
J = np.roll(J,shift=int(N/2))

T = 2000
dt = 0.1
frames = int(T/dt)

net = CANN_v2(N,k,J,tau,trans)

alpha_0 = 0.04
factors_1 = np.arange(0.2,3.5,0.2)
factors_2 = -factors_1
factors = np.concatenate((factors_2,factors_1))
Alphas = alpha_0 * factors
shift = int(N/5)

In = exp
I0 = np.roll(In,shift=shift)
start = cor[np.argmax(I0)]

center = cor[I0.argmax()]

thres = 0.004


trials = Alphas.shape[0]

set_time = 100

fano = 0.001

sample_times = 10

# define a common colormap for all plots
cmap = plt.cm.Spectral
colors = cmap(np.linspace(0,1,trials))
colors = np.concatenate([colors[int(trials/2)::-1],colors[int(trials/2)::1]])
# define the dot size
dot_size = 200

plot_inter = 3
# plot_index = np.arange(int(trials/2),trials,plot_inter)
# select the reference curve
ref_idx = np.where(np.abs(factors-1)<1e-4)[0][0]
print('ref_idx',ref_idx)
plot_index_pos = np.arange(ref_idx-plot_inter,trials,plot_inter)

neg_ref_idx = np.where(np.abs(factors+1)<1e-4)[0][0]
print('neg_ref_idx',neg_ref_idx)
plot_index_neg = np.arange(neg_ref_idx-plot_inter,int(trials/2),plot_inter)

plot_idx = np.concatenate([plot_index_neg,plot_index_pos])


if True:
    # positions = np.zeros([trials,sample_times,frames])
    # amplitudes = np.zeros([trials,sample_times,frames])
    FR_profile = np.zeros([trials,sample_times,frames,N])
    Interval = np.zeros([trials,sample_times])
    final_center = np.zeros([trials,sample_times])

    # positions[:,:,-2:] = start #for the first frame
    
    for trial in range(trials):
        alpha = Alphas[trial]
        print('alpha:', alpha)
        if alpha > 0:
            I_init = I0
            target = cor[In.argmax()]
        else:
            I_init = np.roll(In,shift=2)
            target = cor[I0.argmax()]
        for samp in range(sample_times):
            net.reset()

            distance = 10
            final_t = frames
            for t in range(frames):
                
                if t < set_time:
                    I_ext = 0.02 * I_init

                else:
                    I_ext = alpha * In

                I_ext = I_ext + np.sqrt(tau*net.U*fano)*np.random.randn(N)

                net.update(I_ext,dt)

                smooth_r = ndimage.gaussian_filter1d(net.r,sigma=10)
                smooth_center = cor[smooth_r.argmax()]
                center = np.angle(np.exp(1j*(cor-smooth_center))) @ net.r / net.r.sum() + smooth_center
                # print('center:', center)
                distance = np.abs(center-target)
                # print('distance:', distance)
                if distance < thres*2*np.pi and t < final_t:
                    final_t = t
                
                # positions[trial][samp][t] = center
                # amplitudes[trial][samp][t] = net.U.max()
                # velocities[trial][samp][t] = (center-positions[trial][samp][t-2])/(2*dt)
                FR_profile[trial][samp][t] = net.r

            Interval[trial][samp] = (final_t-set_time)*dt
            final_center[trial][samp] = center

    # save the firing rate data corresponds to the plot_idx
    FR_plot = FR_profile[plot_idx]
    FR_plot = FR_plot.mean(axis=1)
    np.save(figure_path+'/FR_plot.npy',FR_plot)

    negative_frames = frames

    import matplotlib.pyplot as plt

    factor_plot = np.linspace(factors_1.min(),factors_1.max(),100)
    Alpha_plot = alpha_0 * factor_plot
    Ei = (special.expi((thres*2*np.pi)**2/(8*a**2)) - special.expi(start**2/(8*a**2)))    

    Ts = -tau * Au(k=k) / (2*Alpha_plot) * Ei
    
    
    Theory_T = -tau * Au(k=k) / (2*Alphas) * Ei 

    plt.figure(figsize=fig_size)
    plt.tick_params(axis='both',which='both',direction='out',length=4,width=1)
    plt.scatter(factors,Interval.mean(axis=1),marker='o',facecolors='white',label='Simulation',edgecolors='k',s=dot_size)
    plt.plot(factor_plot,Ts,'g',label='Theory')
    plt.plot(-factor_plot,Ts,'g')
    
    for i in range(plot_index_pos.shape[0]):
        plt.scatter(factors[plot_index_pos[i]],Interval[plot_index_pos[i]].mean(),color=colors[plot_index_pos[i]],edgecolors='k',label='$\\alpha$ = %.2f'%factors[plot_index_pos[i]],s=dot_size)
    for i in range(plot_index_neg.shape[0]):
        plt.scatter(factors[plot_index_neg[i]],Interval[plot_index_neg[i]].mean(),color=colors[plot_index_neg[i]],edgecolors='k',label='$\\alpha$ = %.2f'%factors[plot_index_neg[i]],s=dot_size)
    plt.legend()
    plt.xlabel('Relative input $\\alpha$')
    plt.ylabel('Time interval (ms)')
    #save fig in pdf form under figure_path
    plt.savefig(figure_path+'/Interval.pdf',format='pdf',dpi=1000)
    
    # plot the relationship between real ratio and alpha
    ref_interval = Interval[ref_idx].mean()
    real_ratio = ref_interval/Interval.mean(axis=1)
    # take the opposite of the first half of real_ratio
    real_ratio = np.append(-real_ratio[:int(trials/2)],real_ratio[int(trials/2):])
    plt.figure(figsize=fig_size)
    plt.tick_params(axis='both',which='both',direction='out',length=4,width=1)
    plt.scatter(factors,real_ratio,marker='o',facecolors='white',edgecolors='k',s=dot_size)
    plt.xlabel('Relative input $\\alpha$')
    plt.ylabel('Ratio of time interval $\\alpha$\'')
    # plot a reference line
    plt.plot([factors_2[-1],factors_1[-1]],[factors_2[-1],factors_1[-1]],'k--')
    # mark the special points the same as in the last figure
    for i in range(plot_index_pos.shape[0]):
        plt.scatter(factors[plot_index_pos[i]],real_ratio[plot_index_pos[i]],color=colors[plot_index_pos[i]],edgecolors='k',s=dot_size,label='$\\alpha$ = %.2f'%factors[plot_index_pos[i]])
    # plt.scatter(negative_alpha/alpha_0,-ref_interval/negative_interval.mean(),color=colors[negative_idx],edgecolors='k',s=dot_size,label='$\\alpha$ = -1.00')
    for i in range(plot_index_neg.shape[0]):
        plt.scatter(factors[plot_index_neg[i]],real_ratio[plot_index_neg[i]],color=colors[plot_index_neg[i]],edgecolors='k',s=dot_size,label='$\\alpha$ = %.2f'%factors[plot_index_neg[i]])
    plt.legend()
    plt.savefig(figure_path+'/Ratio.pdf',format='pdf',dpi=1000)
    # plt.show()

Interval_mean = Interval.mean(axis=1)

# single neuron tuning curve
Curves = np.zeros([plot_index_pos.shape[0],frames])
neuro_index = int(3*N/5)
# take single neuron profile from FR_profile every 15 trials and average them over sample_times
for i in range(plot_index_pos.shape[0]):
    Curves[i] = FR_profile[plot_index_pos[i],:,:,neuro_index].mean(axis=0)

# single neuron tuning curve for negative alpha
Curves_neg = np.zeros([plot_index_neg.shape[0],frames])
for i in range(plot_index_neg.shape[0]):
    Curves_neg[i] = FR_profile[plot_index_neg[i],:,:,neuro_index].mean(axis=0)

# set the end time as the max interval
# plot the Curves from set_time to end_time
end_time = set_time + int((Interval_mean).max()/dt) 

# align curves according to the peaks by multiplying a factor on the time axis
# find the peak of each curve by gaussian filter
peak_time = np.zeros(Curves.shape[0],dtype=int)
smoothed_Curves = ndimage.gaussian_filter1d(Curves,sigma=20,axis=1)
for i in range(Curves.shape[0]):
    peak_time[i] = smoothed_Curves[i].argmax() - set_time
print(peak_time)
# align the curves
align_peak = peak_time.max()
ref_peak = peak_time[1]

end_time_plot = align_peak*3 + 100

plt.figure(figsize=[8,4])
plt.tick_params(axis='both',which='both',direction='out',length=4,width=1)
# plt.plot(np.arange(set_time,end_time)*dt,Curves[:,set_time:end_time].T,color=colors)
for i in range(Curves.shape[0]):
    plt.plot(np.arange(set_time,end_time_plot)*dt,Curves[i,set_time:end_time_plot],color=colors[plot_index_pos[i]])

plt.xlabel('t (ms)')
plt.ylabel('firing rate')
plt.savefig(figure_path+'/neuron_{}.pdf'.format(neuro_index),format='pdf',dpi=1000)



plt.figure(figsize=[8,4])
plt.tick_params(axis='both',which='both',direction='out',length=4,width=1)
for i in range(Curves.shape[0]):
    end_time_i = int((end_time_plot-set_time) / align_peak * peak_time[i]) + set_time
    plot_time = np.arange(set_time,end_time_i)
    plot_time_i = (plot_time - set_time) / peak_time[i] * align_peak + set_time
    print('plottime',plot_time_i[0],plot_time_i[-1])
    plt.plot(plot_time_i*dt,Curves[i,set_time:end_time_i],color=colors[plot_index_pos[i]],label='$\\alpha\'$ = %.2f'%(ref_peak/peak_time[i]))
    #mark the peak points with triangle markers and make them on top of the lines
    plt.scatter(plot_time_i[peak_time[i]]*dt,Curves[i,plot_time[peak_time[i]]],marker='^',color=colors[plot_index_pos[i]],edgecolors='k',s=dot_size,zorder=10)

plt.legend()
plt.xlabel('aligned t')
plt.ylabel('aligned firing rate')
plt.savefig(figure_path+'/neuron_{}_aligned.pdf'.format(neuro_index),format='pdf',dpi=1000)

#  plot the single neuron tuning curves for negative alphas
peak_time_neg = np.zeros(Curves_neg.shape[0],dtype=int)
smoothed_Curves_neg = ndimage.gaussian_filter1d(Curves_neg,sigma=20,axis=1)
for i in range(Curves_neg.shape[0]):
    peak_time_neg[i] = smoothed_Curves_neg[i].argmax() - set_time
print(peak_time_neg)
# align the curves
align_peak_neg = peak_time_neg.max()
ref_peak_neg = peak_time_neg[1]

end_time_plot_neg = int(align_peak_neg*1.5 + 100)

plt.figure(figsize=[8,4])
plt.tick_params(axis='both',which='both',direction='out',length=4,width=1)
for i in range(Curves_neg.shape[0]):
    plt.plot(np.arange(set_time,end_time_plot_neg)*dt,Curves_neg[i,set_time:end_time_plot_neg],color=colors[plot_index_neg[i]])

plt.xlabel('t (ms)')
plt.ylabel('firing rate')
plt.savefig(figure_path+'/neuron_{}_neg.pdf'.format(neuro_index),format='pdf',dpi=1000)



plt.figure(figsize=[8,4])
plt.tick_params(axis='both',which='both',direction='out',length=4,width=1)
for i in range(Curves_neg.shape[0]):
    end_time_i = int((end_time_plot_neg-set_time) / align_peak_neg * peak_time_neg[i]) + set_time
    plot_time = np.arange(set_time,end_time_i)
    plot_time_i = (plot_time - set_time) / peak_time_neg[i] * align_peak_neg + set_time
    print('plottime',plot_time_i[0],plot_time_i[-1])
    plt.plot(plot_time_i*dt,Curves_neg[i,set_time:end_time_i],color=colors[plot_index_neg[i]],label='$\\alpha\'$ = %.2f'%(ref_peak_neg/peak_time_neg[i]))
    #mark the peak points with triangle markers and make them on top of the lines
    plt.scatter(plot_time_i[peak_time_neg[i]]*dt,Curves_neg[i,plot_time[peak_time_neg[i]]],marker='^',color=colors[plot_index_neg[i]],edgecolors='k',s=dot_size,zorder=10)
    
plt.legend()
plt.xlabel('aligned t')
plt.ylabel('aligned firing rate')
plt.savefig(figure_path+'/neuron_{}_aligned_neg.pdf'.format(neuro_index),format='pdf',dpi=1000)


# PCs
# cut off the timesteps outside the interval
FR_profile_PCA_fit = FR_profile[ref_idx,:,set_time:end_time,:]
PCs_across_trials = np.zeros([frames,trials,num_PC])
# fit PCA to explain the variations across samples
pca.fit(FR_profile_PCA_fit.mean(axis=1))
# pca.fit(Acts_across_trials.mean(axis=1))
for t in range(frames):
    PCs = pca.transform(FR_profile[:,:,t,:].mean(axis=1))
    PCs_across_trials[t,:,:] = PCs


# 3D plot of PCs across trials and
plot_idx = np.concatenate([plot_index_neg,plot_index_pos])
PC_curves = PCs_across_trials[:,plot_idx,:]
num_curves = PC_curves.shape[1]


ax = plt.figure(figsize=fig_size).add_subplot(projection='3d')
print(num_curves)
for i in range(num_curves):
    end = int(Interval_mean[plot_idx[i]]/dt)
    pcc = PC_curves[set_time:set_time+end,i,:]
    ax.plot(pcc[:,0],pcc[:,1],pcc[:,2],color=colors[plot_idx[i]])
    ax.scatter(pcc[0,0],pcc[0,1],pcc[0,2],color='g',label='start' if i==0 else None,s=dot_size,edgecolors='k')
    probe = 200
    ax.scatter(pcc[-1,0],pcc[-1,1],pcc[-1,2],color='r',label='end' if i==0 else None,s=dot_size,edgecolors='k')
    ax.scatter(pcc[probe,0],pcc[probe,1],pcc[probe,2],marker='D',color=colors[plot_idx[i]],label='%d ms'%(probe*dt) if i==0 else None,s=dot_size,edgecolors='k')
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
# no grid
ax.grid(False)
# add legend for the negative alpha curve
ax.legend()
# set view angle
ax.view_init(-10, 40)
# save fig in pdf form
plt.savefig(figure_path+'/PCs_123.pdf',format='pdf',dpi=1000)


# save the pc curves in a .npy file
np.save(figure_path+'/PC_curves.npy',PC_curves)
# save the intervals
np.save(figure_path+'/Interval_mean.npy',Interval_mean)

plt.show()