import plotting
import numpy as np
import numba as nb
from snn_cvx import run_snn, update_weights
import holoviews as hv
hv.extension('matplotlib')
# Setting up dimensions
M = 7
K = 3
N = 300
# Iniatilize network parameters
random_state = np.random.RandomState(seed=4)
G_weights_init = random_state.rand(N, M)
G_weights_init /= np.linalg.norm(G_weights_init, axis=1)[:, None]
scale = 0.1 # scaling of D compared to G
D_weights_init = G_weights_init.T*scale
F_weights_init = random_state.randn(N, K)
F_weights_init /= np.linalg.norm(F_weights_init, axis=1)[:, None]
omega_init = -G_weights_init @ D_weights_init
thresholds_init = np.ones(N)
leak = 2
sigma_V = 0.1
mu = 0.1
T_l = 5 # trial time
dt_l = 1e-04 # simulation time-step during learning
t_span_l = np.arange(0, T_l, dt_l)
num_bins_l = t_span_l.size
buffer_bins_l = int(0.5/dt_l) # time for input onset
buffer_learning_bins = buffer_bins_l + buffer_bins_l # time before learning starts
alpha_init = 1e-01 # Initial learning rate
leak_thresh_init = 1e-03 # Initial threshold drift
decay = 1e-03 # Decay of learning rate across epochs
num_datapoints = 8
num_iter = 750
# get all possible pixel combinations
x_train = np.array([[i, j, k] for i in range(2) for j in range(2) for k in range(2)]).astype(float)
# generate a target for each pixel combination (not that the target for no active pixels is a constant across the output dimensions)
y_target = np.identity(M+1)[:, 1:]*3
y_target+=3
# initialize parameters
D_weights = D_weights_init.copy()
G_weights = G_weights_init.copy()
F_weights = F_weights_init.copy()
omega = omega_init.copy()
thresholds = thresholds_init.copy()
# initialize storage arrays
thresholds_array_fit = np.zeros((N, num_iter))
F_weights_array_fit = np.zeros((N, K, num_iter))
x_sample = np.zeros((K, num_bins_l))
# data index list
data_index_list = np.arange(num_datapoints)
for _iter in range(num_iter):
if _iter % 99==0:
print('iterations '+str(_iter+1)+' of '+str(num_iter), end='\r', flush=True)
# shuffle inputs
np.random.shuffle(data_index_list)
# decay learning rate
alpha = alpha_init * np.exp(-decay * (_iter + 1))
# run network for each input
for data_index in data_index_list:
x_sample[:, buffer_bins_l:] = x_train[data_index, :][:, None]
y_sample = y_target[data_index, :]
thresholds, F_weights = update_weights(
x_sample,
y_sample,
F_weights,
G_weights,
omega,
thresholds,
buffer_learning_bins,
dt_l,
leak,
leak_thresh_init*np.exp(-decay * (_iter + 1)),
alpha,
alpha,
mu,
sigma_V
)
# store updated parameters
thresholds_array_fit[:, _iter] = thresholds
F_weights_array_fit[:, :, _iter] = F_weights
F_weights_fit = F_weights_array_fit[:, :, -1]
thresholds_fit = thresholds_array_fit[:, -1]
np.random.seed(2)
# Stimulus settings
Tend = 26
dt = 1e-4
times = np.arange(0, Tend, dt)
nT = len(times)
tstart = 1/dt # stimulus onset
stimlen = int(1/dt) # length of each stimulus
gap = int(0.5/dt) # gap between each stimulus
# perturbation parameters
sigma_stim = 0.1
sigma_OU = 0.05
leak_OU = 10
# build stimulus
x = np.zeros((K, nT))
count = 0
for stim in np.arange(tstart, nT, stimlen+gap, dtype=int):
if count%(M+1) > 0:
# fill in noisy version of the stimulus, if any pixel is on
x[:, stim:stim+stimlen] = (x_train[count%(M+1), :]+sigma_stim*np.random.randn(K))[:, None]
count += 1
# smoothen stimulus slightly to avoid discontinuities
x_kernel = np.linspace(-1, 1, 1000)
kernel_sigma = .1
smoothen_kernel = np.exp(-x_kernel**2/kernel_sigma**2)
smoothen_kernel/=smoothen_kernel.sum()
for i_dim in range(K):
x[i_dim, :] = np.convolve(x[i_dim, :], smoothen_kernel, 'same')
# generate OU noise
noise = np.random.randn(K, nT) # white noise
noiseOU = np.zeros((K, nT))
for t in range(1, nT):
noiseOU[:, t] = noiseOU[:, t-1] + dt*(-noiseOU[:, t-1]*leak_OU) + sigma_OU*noise[:, t-1]*np.sqrt(2*dt*leak_OU)
silence_T = int(nT/2) # at what time-step to silence neuron
silence_prop = 0.4 # which proportion to silence
delay = 20 # number of timesteps of synaptic delay in recurrent connections
firing_rates, spikes, V_membrane, I_E, I_I = run_snn(x, noiseOU, F_weights_fit, omega, thresholds_fit, dt, leak, mu, sigma_V, silence_T, silence_prop, delay)
y_sim = D_weights@firing_rates
%%opts Image [aspect=10 yticks=0] (cmap='gray')
%%opts Overlay [show_legend=False show_title=False]
%%opts Curve [aspect=10] (linewidth=1) {+axiswise} Overlay [show_legend=False show_title=False] Layout [sublabel_format=None]
%%opts Scatter [aspect=10] Scatter.spikes [aspect=6] (s=1 color='k')
%%opts Scatter.spikes2 [aspect=6] (s=5)
%%output dpi=300 fig='png'
# waiting time before plotting
tstart = 2.2
tend = Tend-tstart
# choose example neurons to plot, and their colors
exneurons = [20, 78, 100]
color_exneurons = ['#67001f', '#fa9fb5', '#ce1256', '#df65b0']
# plot the input stimulus
fig_x = hv.Image(x, bounds=(-tstart, 0, Tend-tstart, 3), kdims=['Time (s)', 'Input x'])
# plot the read-outs
colors = plotting.colors
fig_readouts = hv.Overlay()
for i in range(0, M):
img_i = hv.Curve(zip(times-tstart, y_sim[i, :]), kdims='Time (s)', vdims='Readouts').opts(color=colors[(i)%len(colors)])
fig_readouts *= img_i
# plot the spikes of active neurons
fig_spikes = plotting.spike_plot(times-tstart, spikes[spikes.sum(axis=1)>0, :], 0, 1)
for i, n in enumerate(exneurons):
fig_spikes *= plotting.plot_spikes_single(times-tstart, spikes[n, :], color_exneurons[i], base_offset=n)
# plot the currents (relative to the thresholds)
fig_I = hv.Overlay()
thresholds_pos, thresholds_neg = thresholds.copy(), thresholds.copy()
thresholds_pos[thresholds<0]=0
thresholds_neg[thresholds>0]=0
for i, n in enumerate(exneurons):
fig_inh = hv.Curve(zip(times-tstart, I_I[n, :]-thresholds_neg[n]), kdims='Time (s)', vdims='Currents').opts(color=color_exneurons[i])
fig_ex = hv.Curve(zip(times-tstart, I_E[n, :]-thresholds_pos[n]), kdims='Time (s)', vdims='Currents').opts(color=color_exneurons[i])
fig_I *= fig_inh*fig_ex*hv.HLine(0).opts(linestyle='--', color='k', linewidth=1)
# plot the voltages (relative to the thresholds)
fig_V = hv.Overlay()
for i, n in enumerate(exneurons):
fig_V *= hv.Curve(zip(times-tstart, V_membrane[n, :] + spikes[n, :] - thresholds[n]), kdims='Time (s)', vdims='Voltages').opts(color=color_exneurons[i], alpha=1)*hv.HLine(0).opts(linestyle='--', color='k', linewidth=1)
# Choose time-points to slice and zoom-in on (see next cell)
t1 = 15
t2 = 16.5
# combine all
fig_tot = (fig_x[0:tend] +
fig_readouts[0:tend] +
fig_spikes[0:tend]*hv.Curve([0]) +
fig_V[0:tend]*hv.VSpan(t1, t2).opts(color='gray') +
fig_I[0:tend]*hv.VSpan(t1, t2).opts(color='gray')
).cols(1)
xdim = hv.Dimension('Time (s)', range=(0, tend))
fig_tot.redim(x=xdim)
Zoom in on specific time-points (between t1 and t2 in previous block) (note, in Holoviews you can slice the x-axis according the the actual x-values)
%%opts Curve [aspect=3 xticks=3 yticks=3] (alpha=1 linewidth=1) {+axiswise} Overlay [show_legend=False show_title=False] Layout [sublabel_format=None]
%%output dpi=300 fig='svg'
(fig_V[t1:t2].opts(xaxis='bare')+fig_I[t1:t2]).cols(1)