import numpy as np
import pickle
import network_class
import matplotlib.pyplot as plt
from scipy.stats import unitary_group
import multiprocessing as mp
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
from copy import deepcopy
import sys
# For evaluation only, a nest installation is not required. For generation of data,
# a nest version with changes by the authors is required.
try:
    import nest
except:
    nest = None    
import create_network, apply_stimulus, get_properties

'''
This file can be used to generate figues 2 and 3.
'''

#%%

def setup_network(seed):
    '''
    Sets up a "initialize_network" file which contains information on the exact network realization.
    '''


    #This seed is the main indicator used to identify a realization. The parameters below match the sample networks in data/network_realizations.
    np.random.seed(seed)

    N = 100 # number of neurons
    alpha=0.05  # strength of nonlinearity
    t_stop=10. # readout time
    tau=.25 # time constant for e.o.m.
    g=0.9 # connection strength. Keep smaller than 1, otherwise bad approximation of network dynamics due to chaotic behavior. Chaotic behavior is also possible in realizations of finite networks with g<1.
    input_changes = 100 # number of steps in input signals

    n_samples = 1000 # number of samples per class when generating artificial data
    opt_steps_non = 30 # steps for optimization of the non-linear system
    opt_steps_lin = 30 # steps for optimization of the linear system
    seps = 5 # number of linear stimulus separabilities to be evaluated in fig. 3

    eta = 10. # determines weight of covariance term for the soft margin

    # To avoid chaotic behavior, network realizations with eigenvalues > 1 are not permitted
    lamb = np.array([2.])
    while np.max(np.real(lamb))>1.:
        W = np.random.normal(0., g*1./np.sqrt(N), (N,N))
        lamb, R = np.linalg.eig(W)

    L=np.linalg.inv(R)
    network_params_non={'lamb':lamb, 'Right':R, 'Left':L, 'tau':tau, 'poly_coeffs':[0., 1., alpha]}
    network_params_lin={'lamb':lamb, 'Right':R, 'Left':L, 'tau':tau, 'poly_coeffs':[0., 1., 0.]}
    # Network realization is determined by left- and righthanded eigenvectors of the connectivity including eigenvalues, the network time constant and the strength of the nonlinearity

    stop_points=np.linspace(0., t_stop, input_changes+1)

    ###calculate separabilities
    mu = np.random.rand(input_changes)-0.5
    mu = mu / np.linalg.norm(mu)

    #generate a suitable chi
    lamb_chi = (np.random.rand(input_changes)-0.5)*1.5
    rotation_matrix = unitary_group.rvs(input_changes)
    chi = np.matmul(np.matmul(rotation_matrix, np.diag(lamb_chi)), rotation_matrix.conj().T)
    chi = np.real(0.5*(chi+chi.T))

    # psi is identity matrix by default.

    # check that eigenvalues of both psi+-chi are positive
    eig_plus, _ = np.linalg.eig(np.eye(input_changes)+chi)
    eig_minus, _ = np.linalg.eig(np.eye(input_changes)-chi)
    if np.min([eig_plus, eig_minus])<0:
        sys.exit('negative EV of assumed covariance matrix.')


    net_non = network_class.network(network_params_non, stop_points, eps=1e-15)
    net_lin = network_class.network(network_params_lin, stop_points, eps=1e-15)

    # ||mu|| to be used to generate samples for fig. 3b
    lengths = np.linspace(0., 1., seps)
    network_dictionary = {
            'N': N,
            'alpha': alpha,
            'eta': eta,

            'n_samples': n_samples,
            'opt_steps_non': opt_steps_non,
            'opt_steps_lin': opt_steps_lin,

            'seps': seps,
            'lengths': lengths,

            'mu': mu,
            'chi': chi,
            'psi': None, # used for identity matrix as default

            'net_non': net_non,
            'net_lin': net_lin
            }

    print('network parameters:', network_dictionary)

    with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'wb') as handle:
        pickle.dump(network_dictionary, handle, protocol=4)




#%%



def sample_stimuli(seed, n_samples, length, mu, chi, psi=None):
    '''
    Generate samples of stimuli given the first two cumulants.
    length is used to be clear about how the length of mu is changed while chi, psi are held constant in fig. 3b
    '''

    np.random.seed(seed**2)
    # just a different seed than used in the beginning of the function where this is used, but connected to the seed used there.
    # simple workaround to also try the samples of other seed-networks in given realization
    mu_plus = length * mu
    mu_minus = - length * mu
    if psi is None:
        cov_plus = np.eye(len(mu)) + chi
        cov_minus = np.eye(len(mu)) - chi
    else:
        cov_plus = psi + chi
        cov_minus = psi - chi
    samples_plus = np.random.multivariate_normal(mu_plus, cov_plus, n_samples)
    samples_minus = np.random.multivariate_normal(mu_minus, cov_minus, n_samples)

    samples = np.append(samples_plus, samples_minus, axis=0)
    labels = np.append(np.ones(n_samples), -np.ones(n_samples))

    return samples, labels



#%%




def optimize_soft_margins(seed, length):
    '''
    Optimize soft margin for artificial stimuli. Basis for fig. 3b. Results saved in data/responses_soft_margins.
    '''
    np.random.seed(seed)

    #functions analogous to those in network_class. Allows parallel optimization of initial conditions using python's multiprocessing for initial conditions
    def test_iniconds(list_, q):

        network = list_[0]
        initial_steps = list_[1]
        initial_guesses = list_[2]
        mu = list_[3]
        eta = list_[4]
        input_vector = list_[5]
        solver = list_[6]

        result = []

        soft_margins, input_vectors, readout_vectors \
                    = network.alternating_optimization(initial_steps, mu, eta, input_vector, solver, initial_guesses)

        result.append(soft_margins)
        result.append(input_vectors)
        result.append(readout_vectors)

        q.put(result)


    def find_good_optimization(network, opt_steps, mu, eta=15., initial_cond=10, initial_steps=1, solver='eigenvalue'):

        if network.poly_coeffs[2] == 0:
            initial_guesses = opt_steps
        else:
            initial_guesses = 1

        initial_input_vectors = np.random.rand(initial_cond, network.N) - 0.5
        initial_input_vectors /= np.linalg.norm(initial_input_vectors, axis=1)[:, np.newaxis]

        processes = [None] * initial_cond
        res = [None] * initial_cond

        argument = [[network, initial_steps, initial_guesses, mu, eta, input_vector, solver] for input_vector in initial_input_vectors]

        for i in range(initial_cond):
            q = mp.Queue()
            processes[i] = mp.Process(target=test_iniconds, args=(argument[i], q,))
            processes[i].start()
            res[i] = q.get()
        for i in range(initial_cond):
            processes[i].join()

        soft_margins_initial = np.array([res[idx][0] for idx in range(initial_cond)])
        input_vectors_initial = np.array([res[idx][1] for idx in range(initial_cond)])
        readout_vectors_initial = np.array([res[idx][2] for idx in range(initial_cond)])

        ic_idx = np.argmax(soft_margins_initial[:, -1])

        soft_margins, input_vectors, readout_vectors \
                        = network.alternating_optimization(opt_steps-initial_steps+1, mu, eta=eta, initial_input_vector=input_vectors_initial[ic_idx, -1], solver=solver, initial_guesses=0)

        soft_margins = np.append(soft_margins_initial[ic_idx], soft_margins[1:])
        input_vectors = np.vstack([input_vectors_initial[ic_idx], input_vectors[1:]])
        readout_vectors = np.vstack([readout_vectors_initial[ic_idx], readout_vectors[1:]])

        return soft_margins, input_vectors, readout_vectors



    #dictionary to store soft_margins and input/readout vectors for the given strength of mu
    results = {}

    #network information
    with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'rb') as handle:
        network_dictionary = pickle.loads(handle.read())

    eta = network_dictionary['eta']

    n_samples = network_dictionary['n_samples']
    opt_steps_non = network_dictionary['opt_steps_non']
    opt_steps_lin = network_dictionary['opt_steps_lin']

    mu = network_dictionary['mu']
    chi = network_dictionary['chi']
    psi = network_dictionary['psi']


    net_non = network_dictionary['net_non']
    net_lin = network_dictionary['net_lin']


    samples, labels = sample_stimuli(seed, n_samples, length, mu, chi, psi)

    # estimate of mu based on the drawn samples
    mu_s = np.einsum('tn, t -> n', samples, labels) * 1. / len(labels)

    # Before any optimization, the Green's function's product with the given stimuli has to be determined.
    # Always use Large_N=True, as the O(alpha**2) correction term is still missing otherwise. Also, less
    # memory intensive, although slower.
    net_non.determine_sample_dynamics(samples, labels, large_N=True)
    net_lin.determine_sample_dynamics(samples, labels, large_N=True)



    #number of initial conditions for non-linear system optimization. On our machine: 48
    n_ini = mp.cpu_count()
    steps_ini = 4 # Number of steops, after which initial conditions are compared.
    # If you feel generous, try using 10. Then, optimization is usually approximately finished.


    if length == 0:
        # Reminder of the used parameters
        print(network_dictionary)

    #analysis of connectivity and stimulus
    if length == 1.:
        fig = plt.figure(1)
        spec = GridSpec(ncols=2, nrows=2)
        ax1 = fig.add_subplot(spec[0, 0])
        ax2 = fig.add_subplot(spec[0, 1])
        ax3 = fig.add_subplot(spec[1, 0])

        ax1.bar(np.linspace(0., 1., len(mu)), mu, width=1./len(mu))
        ax1.set_xlabel('position on time axis')
        ax1.set_ylabel('mu')

        eig, vec = np.linalg.eig(chi)
        for eigi in eig:
            ax2.axvline(eigi, alpha=0.6)
        ax2.set_xlim([np.min(eig)-0.05, np.max(eig)+0.05])
        ax2.set_xlabel('eigenvalue of chi')

        for eig in net_non.lamb:
            ax3.scatter(np.real(eig), np.imag(eig), c='darkorange')
        ax3.axvline(1.)
        ax3.set_xlabel('real part of connectivity EV')
        ax3.set_ylabel('imag part of connectivity EV')

        fig.savefig('data/responses_soft_margins/soft_margin_analysis_'+str(seed)+'.png')





    #optimize the linear network. This function is easy to optimize and does not need many initial conditions.
    soft_margins_lin, input_vectors_lin, readout_vectors_lin = net_lin.alternating_optimization(opt_steps_lin, mu_s, eta=eta, initial_guesses=opt_steps_lin)
    results['soft_margins_lin'] = soft_margins_lin
    results['input_vectors_lin'] = input_vectors_lin
    results['readout_vectors_lin'] = readout_vectors_lin

    #optimize the nonlinear network in parallel concerning initial conditions
    soft_margins_non, input_vectors_non, readout_vectors_non = find_good_optimization(net_non, opt_steps_non, mu_s, eta=eta, initial_cond=n_ini, initial_steps=steps_ini, solver='eigenvalue')
    results['soft_margins_non'] = soft_margins_non
    results['input_vectors_non'] = input_vectors_non
    results['readout_vectors_non'] = readout_vectors_non

    with open('data/responses_soft_margins/optimization_'+str(seed)+'_'+str(length)+'.txt', 'wb') as handle:
        pickle.dump(results, handle)


#%%





def random_u_soft_margins(seed, length):
    '''
    Soft margins for random u, but optimized v according to eq. (6). Generate the clouds in fig. 3b.
    '''

    np.random.seed(seed)

    with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'rb') as handle:
        network_dictionary = pickle.loads(handle.read())

    N = network_dictionary['N']
    eta = network_dictionary['eta']

    mu = network_dictionary['mu']
    chi = network_dictionary['chi']
    psi = network_dictionary['psi']

    n_samples = network_dictionary['n_samples']

    #generate the same sample stimuli as in optimize_soft_margin
    samples, labels = sample_stimuli(seed, n_samples, length, mu, chi, psi)

    # estimated mu based on the samples
    mu_s = np.einsum('ti, t -> i', samples, labels)/len(labels)

    net_non = network_dictionary['net_non']
    net_lin = network_dictionary['net_lin']

    # Before any optimization, the Green's function's product with the given stimuli has to be determined.
    # Always use Large_N=True, as the O(alpha**2) correction term is still missing otherwise. Also, less
    # memory intensive, although slower.
    net_lin.determine_sample_dynamics(samples, labels, large_N=True)
    net_non.determine_sample_dynamics(samples, labels, large_N=True)


    with open('data/responses_soft_margins/optimization_'+str(seed)+'_'+str(length)+'.txt', 'rb') as handle:
        results = pickle.loads(handle.read())

    # number of random input vectors
    n_rand = 500

    random_u = np.random.rand(n_rand, N) - 0.5
    random_u = np.einsum('ti, t -> ti', random_u, 1./np.linalg.norm(random_u, axis=1))

    # optimize readout for random u
    soft_margins_random = []
    random_v = []
    for idx, input_vector in enumerate(random_u):
        v_lin, S_lin = net_lin.determine_readout(mu_s, eta, input_vector)
        v_non, S_non = net_non.determine_readout(mu_s, eta, input_vector)
        soft_margins_random.append([S_lin, S_non])
        random_v.append([v_lin, v_non])

    soft_margins_random = np.array(soft_margins_random)
    random_v = np.array(random_v)


    results['soft_margins_random'] = soft_margins_random #(n_rand, 2)
    results['input_vectors_random'] = random_u #(n_rand, N)
    results['readout_vectors_random'] = random_v #(n_rand, 2, N)


    with open('data/responses_soft_margins/optimization_'+str(seed)+'_'+str(length)+'.txt', 'wb') as handle:
        pickle.dump(results, handle)




#%%



def evaluate_soft_margins(seed, ax3, input_seps):
    '''
    This generates fig. 3b. It is put in a layout together with fig. 3a in the function fig3.
    First, use optimize_soft_margins. Optionally, use random_u_soft_margins afterwards. Then, use fig3.
    '''

    with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'rb') as handle:
        network_dictionary = pickle.loads(handle.read())

    #network information
    N = network_dictionary['N']
    alpha = network_dictionary['alpha']
    eta = network_dictionary['eta']

    opt_steps_non = network_dictionary['opt_steps_non']
    opt_steps_lin = network_dictionary['opt_steps_lin']

    lengths = network_dictionary['lengths']
    seps = network_dictionary['seps']

    # Colorbar to illustrate how well the stimuli are linearly separable with respect to each other
    input_separabilities = input_seps


    #for each input separability, optimize input and readout
    soft_margins_non = np.empty((seps, opt_steps_non))
    input_vectors_non = np.empty((seps, opt_steps_non, N))
    readout_vectors_non = np.empty((seps, opt_steps_non, N))

    soft_margins_lin = np.empty((seps, opt_steps_lin))
    input_vectors_lin = np.empty((seps, opt_steps_lin, N))
    readout_vectors_lin = np.empty((seps, opt_steps_lin, N))

    soft_margins_random = []




    #load saved optimization steps
    for idx, length in enumerate(lengths):

        with open('data/responses_soft_margins/optimization_'+str(seed)+'_'+str(length)+'.txt', 'rb') as handle:
            results = pickle.loads(handle.read())

        soft_margins_non[idx] = results['soft_margins_non']
        input_vectors_non[idx] = results['input_vectors_non']
        readout_vectors_non[idx] = results['readout_vectors_non']

        soft_margins_lin[idx] = results['soft_margins_lin']
        input_vectors_lin[idx] = results['input_vectors_lin']
        readout_vectors_lin[idx] = results['readout_vectors_lin']

        if 'soft_margins_random' in results:
            soft_margins_random.append(results['soft_margins_random'])

    soft_margins_random = np.array(soft_margins_random)




    # fig. 3b: if this function is used by the user (with ax3=None), the plot is saved as is. If it is called by the function fig3, it is not.
    fig3b = plt.figure(3)

    # angle bisector
    ax3.plot(np.append(0., soft_margins_lin[:, -1]), np.append(0., soft_margins_lin[:, -1]), alpha=.3, c='k')
    # optimal soft margin in each system
    sc = ax3.scatter(soft_margins_lin[:, -1], soft_margins_non[:,-1], c=input_separabilities, cmap=plt.cm.get_cmap('viridis'))

    # random u soft margins
    if len(soft_margins_random) > 0:
        ax3.scatter(soft_margins_random[:, :, 0], soft_margins_random[:, :, 1], c=np.outer(input_separabilities, np.ones_like(soft_margins_random[0, :, 0])), cmap=plt.cm.get_cmap('viridis'), alpha=0.15, marker='*')
    ax3.set_xlabel(r'$\kappa_{\eta}$, linear system')
    ax3.set_ylabel(r'$\kappa_{\eta}$, nonlinear system')

    # projection of optimal soft margins onto the axes
    xlim = ax3.get_xlim()
    ylim = ax3.get_ylim()
    for i in range(seps):
        ax3.plot([xlim[0], soft_margins_lin[i, -1]], [soft_margins_non[i, -1], soft_margins_non[i, -1]], c=plt.cm.get_cmap('viridis')(input_separabilities[i]/np.max(input_separabilities)))
        ax3.plot([soft_margins_lin[i, -1], soft_margins_lin[i, -1]], [ylim[0], soft_margins_non[i, -1]], c=plt.cm.get_cmap('viridis')(input_separabilities[i]/np.max(input_separabilities)))
    ax3.set_xlim(xlim)
    ax3.set_ylim(ylim)



    return sc






#%%


def fig2(seed, *args):
    '''
    Generate data and plot figure 2.
    Possible args:
        'a': evaluate data for 2a
        'b': evaluate data for fig 2b and 2c
        'd': evaluate data for fig 2d
        'A': plot figure 2a
        'B': plot figure 2b
        'C': plot figure 2c
        'D': plot figure 2d
        'oneline': plot all figures next to each other
        'variate_u': show variation of soft margin in figure b for random u
        'variate_mu': show variation of soft margin in figure b over different stimuli
    '''

    np.random.seed(seed)

    #General parameters for the linear network
    with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'rb') as handle:
        network_dictionary = pickle.loads(handle.read())

    N = network_dictionary['N']
    eta = network_dictionary['eta']

    opt_steps_lin = network_dictionary['opt_steps_lin']

    mu = network_dictionary['mu']
    chi = network_dictionary['chi']
    psi = network_dictionary['psi']

    n_samples = network_dictionary['n_samples']

    samples, labels = sample_stimuli(seed, n_samples, 1., mu, chi, psi)
    # estimated mu from samples - here used as example sample for fig. 2a
    mu_s = np.einsum('ti, t -> i', samples, labels) / len(labels)


    global net_lin
    net_lin = network_dictionary['net_lin']

    # Before any optimization, the Green's function's product with the given stimuli has to be determined.
    # Always use Large_N=True, as the O(alpha**2) correction term is still missing otherwise. Also, less
    # memory intensive, although slower.
    net_lin.determine_sample_dynamics(samples, labels, large_N=True)

    # figure parameters
    show_last_steps = 5 # number of steps of the signal showed in fig. 2a
    detail_steps = 50 # number of time steps for resolution of fig. 2a
    skip_steps = 10 # steps of the signals to be skipped to generate fig. 2b
    # (e.g., if x(t) has 100 time steps and skip_steps=10, then the soft margin is evaluated on 10 time points)
    n_stimuli = 250 # number of stimuli and random input vectors to obtain data for fig. 2b
    n_angles = 15 # number of angles in range [0., pi] to go over in fig. 2d






    #setup the single input for 2a
    #one random input vector for comparison
    u_rand = np.random.rand(N) - 0.5
    u_rand /= np.linalg.norm(u_rand)

    #time point arrays for stimulus and readout
    stop_points = net_lin.stop_points
    detailed_times = np.linspace(stop_points[-(show_last_steps+1)], stop_points[-1], detail_steps)
    # arrays to plot the stimulus
    input_times = np.outer(stop_points, np.ones(2)).flatten()[-(show_last_steps*2+1):-1]
    input_heights = np.outer(mu_s, np.ones(2)).flatten()[-(show_last_steps*2):]


    dt = 0.001
    dt_multi = dt
    timesteps = int(stop_points[-1]/dt)
    # set up connectivity matrix from eigenvectors and eigenvalues for simulation
    W = np.real(np.einsum('a, ia, aj -> ij', net_lin.lamb, net_lin.Right, net_lin.Left))



    #setup the stimuli- and random u sets for 2b
    set_mu = np.random.rand(n_stimuli, net_lin.T) - 0.5
    set_mu = np.einsum('ti, t -> ti', set_mu, 1./np.linalg.norm(set_mu, axis=1))

    set_u = np.random.rand(n_stimuli, N) - 0.5
    set_u = np.einsum('ti, t -> ti', set_u, 1./np.linalg.norm(set_u, axis=1))

    # time index for fig. 2c and 2d
    early_idx = np.argmin(np.abs(stop_points[skip_steps::skip_steps] - 0.1*net_lin.readout_time))
    late_idx = np.argmin(np.abs(stop_points[skip_steps::skip_steps] - 0.9*net_lin.readout_time))


    #setup the angles for 2d
    angles = np.pi/2 * np.linspace(0., 1., n_angles)

    # set up fig. 2
    if 'oneline' in args:
        fig2 = plt.figure(2, figsize=(30,5))
        plt.clf()
        gs = GridSpec(1,8, hspace=0.05,wspace=0.4, bottom=0.2,
                       left=0.05, right=0.99, width_ratios=[1.1, 0.3, 1, 0.05, 0.6, 0.6, 0.05, 1])
        plt.rcParams.update({'font.size': 24})

        #fig 2a
        ax2a = fig2.add_subplot(gs[0,0])
        ax2aa = ax2a.twinx()
        ax2a.set_title('(a)',loc='left')
        ax2a.set_xlabel(r'$T$')
        ax2a.set_ylabel(r'$\kappa_{\eta}$')
        ax2aa.set_ylabel(r'$\mu$')

        #fig 2b
        ax2b = fig2.add_subplot(gs[0,2])
        ax2b.set_title('(b)',loc='left')
        ax2b.set_xlabel(r'$T$')
        ax2b.set_ylabel(r'$\kappa_{\eta}$')

        #fig 2c
        ax2c = fig2.add_subplot(gs[0,4])
        ax2c.set_title('(c)',loc='left')
        ax2c.set_xlabel(r'$\tau$')
        ax2c.set_ylabel(r'average $\vert \omega_{\alpha}\vert$')
        ax2cc = fig2.add_subplot(gs[0,5])
        ax2cc.set_xlabel(r'$\tau$')
        ax2c.text(0.5, 0.8, r'$T=$'+str(round(stop_points[skip_steps::skip_steps][early_idx], 1)), transform=ax2c.transAxes)
        ax2cc.text(0.5, 0.8, r'$T=$'+str(round(stop_points[skip_steps::skip_steps][late_idx], 1)), transform=ax2cc.transAxes)
        ax2c.set_xlim([-1.02, 1.02])
        ax2cc.set_xlim([-1.02, 1.02])

        #fig 2d
        ax2d = fig2.add_subplot(gs[0,7])
        ax2d.set_title('(d)',loc='left')
        ax2d.set_xlabel('angle [rad.]')
        ax2d.set_ylabel(r'$\kappa_{\eta}$')
        ax2d.set_xlim([- np.pi * 0.02, np.pi/2 * 1.02])
        ax2d.text(0.7, 0.8, r'$T=$'+str(round(stop_points[skip_steps::skip_steps][late_idx], 1)), transform=ax2d.transAxes)

    else:

        fig2 = plt.figure(2)
        plt.clf()
        plt.subplots_adjust(left=0.15,right=0.95,top=0.9,bottom=0.2,wspace=0.6,hspace=0.7)
        gs = GridSpec(2,4)

        #fig 2a
        ax2a = fig2.add_subplot(gs[0,:2])
        ax2aa = ax2a.twinx()
        ax2a.set_title('(a)',loc='left')
        ax2a.set_xlabel(r'$T$')
        ax2a.set_ylabel(r'$\kappa_{\eta}$')
        ax2aa.set_ylabel(r'$\mu$')

        #fig 2b
        ax2b = fig2.add_subplot(gs[0,2:])
        ax2b.set_title('(b)',loc='left')
        ax2b.set_xlabel(r'$T$')
        ax2b.set_ylabel(r'$\kappa_{\eta}$')

        #fig 2c
        ax2c = fig2.add_subplot(gs[1,0])
        ax2c.set_title('(c)',loc='left')
        ax2c.set_xlabel(r'$\tau$')
        ax2c.set_ylabel(r'average $\vert \omega_{\alpha}\vert$')
        ax2cc = fig2.add_subplot(gs[1,1])
        ax2cc.set_xlabel(r'$\tau$')
        ax2c.text(0.5, 0.8, r'$T=$'+str(round(stop_points[skip_steps::skip_steps][early_idx], 1)), transform=ax2c.transAxes)
        ax2cc.text(0.5, 0.8, r'$T=$'+str(round(stop_points[skip_steps::skip_steps][late_idx], 1)), transform=ax2cc.transAxes)
        ax2c.set_xlim([-1.02, 1.02])
        ax2cc.set_xlim([-1.02, 1.02])

        #fig 2d
        ax2d = fig2.add_subplot(gs[1,2:])
        ax2d.set_title('(d)',loc='left')
        ax2d.set_xlabel('angle')
        ax2d.set_ylabel(r'$\kappa_{\eta}$')
        ax2d.set_xlim([- np.pi * 0.02, np.pi/2 * 1.02])
        ax2d.text(0.7, 0.8, r'$T=$'+str(round(stop_points[skip_steps::skip_steps][late_idx], 1)), transform=ax2d.transAxes)





    # functions to generate the data of fig. 2 a-d
    def data2a():
        # optimize input projection for times prior to simulation end. Includes Soft margin for random u.
        p = mp.Pool(int(mp.cpu_count()*1./4))
        arguments = [[readout_time, opt_steps_lin, stop_points, samples, labels, u_rand, eta] for readout_time in detailed_times]
        res = p.map(soft_margin_over_time_a, arguments)
        p.close()
        p.join()

        soft_margin_rand_theo = np.array([res[i][0] for i in range(len(res))])
        soft_margin_opti_theo = np.array([res[i][1] for i in range(len(res))])



        # Use NEST Simulator to obtain verification of response to random u.
        # As the simulation for optimized input vector would have been to be repeated
        # for every single time point on the time axis, which each has its own optimal
        # projection vector, the responses to the optimized input projections are not
        # explicitly checked here.

        # Only print when errors occur
        nest.set_verbosity('M_ERROR')
        #dictionaries for multimeter and neuron setup
        neuron_dict={'poly_coeffs':net_lin.poly_coeffs, 'linear_summation': False, 'mu': 0., 'sigma': 0., 'tau': net_lin.tau}
        multi_dict={'withtime': True, 'record_from': ['rate'], 'interval': dt_multi}

        # setup multimeter, neurons and step rate generator (self-made by the authors) with the routine described in create_network
        multi, n, rate_gen = create_network.create_network(N=N, N_recorded=N, neuron_dict=neuron_dict, multi_dict=multi_dict, neuron_type='polynomial_rate_ipn', input_type='same', W=W, input_vector=u_rand, dt=dt, setKernel=True)

        # apply stimuli to neurons (n_samples per class) using routine described in apply_stimulus and sort multimeter readout
        # into times_sim (simulation readout times with resolution dt_multi) and rates #(time, trial, neuron)
        readout=np.empty((2*n_samples, N)) # y(T)
        multi_list=np.zeros((2*n_samples, 3, int(N*stop_points[-1]/dt_multi))) # multimeter measurements, unsorted
        for trial in range(2*n_samples):
            readout[trial], multi_list[trial] = apply_stimulus.apply_stimulus(rate_amplitude=np.append(samples[trial], 0.), dt=dt, rate_generator=rate_gen,  neurons=n, T=stop_points[-1], multimeter=multi, rate_timestamps=stop_points)
        rates, times_sim = get_properties.get_rate(multi_list, N, 2*n_samples) # sort multimeter outcome





        # evaluate soft margin only in time window showed in fig. 2a
        plot_idx = np.argmin(np.abs(times_sim - detailed_times[0]))
        # M based on simulated y(t)
        dist_sim = np.mean(np.einsum('tsi, s -> tsi', rates, labels), axis=1)[plot_idx:]
        # Sigma based on simulates y(t)
        Sigma_sim = np.array([np.cov(np.einsum('si, s -> si', rates[idx], labels).T, bias=True) for idx in range(plot_idx, timesteps)])

        # Readout vectors based on simulated y(t):
        v_sim = np.empty((timesteps - plot_idx, N))
        for i in range(timesteps - plot_idx):
            v_sim[i] = get_properties.find_simulation_readout(Sigma_sim[i], dist_sim[i], eta, N, net_lin, rates[i], labels)


        #calculate soft_margin based on simulation
        soft_margin_sim = np.einsum('ti, ti -> t', v_sim, dist_sim) - 0.5 * eta * np.einsum('ti, tij, tj -> t', v_sim, Sigma_sim, v_sim)

        # dictionary containing results for fig. 2a
        a_dictionary = {
                'soft_margin_rand_theo': soft_margin_rand_theo,
                'soft_margin_opti_theo': soft_margin_opti_theo,
                'times_sim': times_sim[plot_idx:],
                'soft_margin_sim': soft_margin_sim
                }

        with open('data/linear_networks/2a_'+str(seed)+'.txt', 'wb') as handle:
            pickle.dump(a_dictionary, handle)





    def data2b():
        # calculate optimal soft margin and sof margin for random input vectors for many stimuli.
        # if 'variate_u'is not argument of the function call, one random u per stimulus is used,
        # to average over many more stimuli. For the fig. 2b in the paper, 'variate_u' was used.

        soft_margins_opt = np.zeros((n_stimuli, int(net_lin.T/skip_steps)))
        soft_margins_rand = np.zeros((n_stimuli, int(net_lin.T/skip_steps)))
        if 'variate_u' in args:
            soft_margins_rand = np.zeros((n_stimuli, len(set_u), int(net_lin.T/skip_steps)))

        # optimal input vectors at T_early and T_late defined by early_idx and late_idx. For fig. 2c.
        input_vectors_early = np.zeros((n_stimuli, N))
        input_vectors_late = np.zeros((n_stimuli, N))

        p = mp.Pool(int(mp.cpu_count()*1./4))
        if 'variate_u' in args:
            arguments = [[i, idx, opt_steps_lin, set_u, set_mu, stop_points, eta, n_samples, 'variate_u'] for i in range(net_lin.T)[skip_steps-1::skip_steps] for idx in range(n_stimuli)]
        else:
            arguments = [[i, idx, opt_steps_lin, set_u, set_mu, stop_points, eta, n_samples] for i in range(net_lin.T)[skip_steps-1::skip_steps] for idx in range(n_stimuli)]
        res = p.map(soft_margin_over_time_b, arguments)
        p.close()
        p.join()

        res = np.reshape(res, (int(net_lin.T/skip_steps), n_stimuli, -1))



        if 'variate_u' in args:
            for idx in range(n_stimuli):
                for time_idx in range(int(net_lin.T/skip_steps)):
                    soft_margins_rand[idx, :, time_idx] = res[time_idx][idx][0]
                    soft_margins_opt[idx, time_idx] = res[time_idx][idx][1]

                input_vectors_early[idx] = res[early_idx, idx, 2]
                input_vectors_late[idx] = res[late_idx, idx, 2]

        else:

            for idx in range(n_stimuli):
                soft_margins_rand[idx] = res[:, idx, 0]
                soft_margins_opt[idx] = res[:, idx, 1]

                input_vectors_early[idx] = res[early_idx, idx, 2]
                input_vectors_late[idx] = res[late_idx, idx, 2]

        # dictionary containing results for fig. 2b and 2c
        b_dictionary = {
                'soft_margins_opt': soft_margins_opt,
                'soft_margins_rand': soft_margins_rand,
                'input_vectors_early': input_vectors_early,
                'input_vectors_late': input_vectors_late
                }

        with open('data/linear_networks/2b_'+str(seed)+'.txt', 'wb') as handle:
            pickle.dump(b_dictionary, handle)








    def data2d():
        # soft margins for input vectors varied in angle between optimized solution and random vectors
        p = mp.Pool(int(mp.cpu_count()*1./4))
        arguments = [[idx, opt_steps_lin, set_u, angles, samples, labels, stop_points, late_idx, eta, seed] for idx in range(n_stimuli)]
        res = p.map(soft_margin_over_time_d, arguments)
        p.close()
        p.join()

        angle_decay = np.array(res)

        # dictionary containing results for fig. 2d
        d_dictionary = {
                'angle_decay': angle_decay
                }

        with open('data/linear_networks/2d_'+str(seed)+'.txt', 'wb') as handle:
            pickle.dump(d_dictionary, handle)









    # generate data for the following subfigures
    # fig. 2a
    if 'a' in args:
        data2a()

    # fig. 2b, c
    if 'b' in args:
        data2b()

    # fig. 2d
    if 'd' in args:
        data2d()









    #load data and plot the figs
    if 'A' in args:
        #load and plot data for fig 2a
        with open('data/linear_networks/2a_'+str(seed)+'.txt', 'rb') as handle:
            a_dictionary = pickle.loads(handle.read())

        soft_margin_rand_theo = a_dictionary['soft_margin_rand_theo']
        soft_margin_opti_theo = a_dictionary['soft_margin_opti_theo']
        times_sim = a_dictionary['times_sim']
        soft_margin_sim = a_dictionary['soft_margin_sim']

        ln1 = ax2a.plot(detailed_times, soft_margin_opti_theo, c='crimson', label=r'opt. $u$')
        ln2 = ax2a.plot(detailed_times, soft_margin_rand_theo, c='deepskyblue', label=r'rand. $u$')
        ln3 = ax2a.plot(times_sim, soft_margin_sim, linestyle='dashed', c=(19/255, 19/255, 19/255), label=r'rand. $u$. sim.')
        ax2aa.plot(input_times, input_heights, c='darkorange')
        lns = ln1 + ln2 + ln3
        ax2aa.legend(lns, [l.get_label() for l in lns], loc='upper left')


    if 'B' in args or 'C' in args:
        #load and plot data for fig 2b and 2c
        with open('data/linear_networks/2b_'+str(seed)+'.txt', 'rb') as handle:
            b_dictionary = pickle.loads(handle.read())


    if 'B' in args:
        soft_margins_opt = b_dictionary['soft_margins_opt']
        soft_margins_rand = b_dictionary['soft_margins_rand']

        # if neither 'variate_u' nor 'variate_mu' in args: show mean over stimuli
        # at t=0, the soft margin is analytically 0 for any input and readout vector
        random_mean = np.append(0., np.mean(soft_margins_rand, axis=0))
        optimized_mean = np.append(0., np.mean(soft_margins_opt, axis=0))

        if 'variate_u' in args:
            # show standard deviation over random u for fixed mu. All results averaged over mu.
            random_mean = np.append(0., np.mean(np.mean(soft_margins_rand, axis=1), axis=0))
            random_std = np.append(0., np.mean(np.std(soft_margins_rand, axis=1), axis=0))

            #deviation for random u, averaged over the different stimuli
            ax2b.fill_between(stop_points[::skip_steps], random_mean - random_std, random_mean + random_std, color='deepskyblue', alpha=0.3)

        if 'variate_mu' in args:
            # show standard deviation of random and optimized input vectors over mu.
            random_std = np.append(0., np.std(soft_margins_rand, axis=0))
            optimized_std = np.append(0., np.std(soft_margins_opt, axis=0))

            #deviation for both optimized and random from all the different sample stimuli
            ax2b.fill_between(stop_points[::skip_steps], random_mean - random_std, random_mean + random_std, color='deepskyblue', alpha=0.3)
            ax2b.fill_between(stop_points[::skip_steps], optimized_mean - optimized_std, optimized_mean + optimized_std, color='crimson', alpha=0.3)

        ax2b.plot(stop_points[::skip_steps], optimized_mean, c='crimson', label=r'optimized $u$')
        ax2b.plot(stop_points[::skip_steps], random_mean, c='deepskyblue', label=r'random $u$')

        ax2b.legend(loc=4)


    if 'C' in args:
        input_vectors_early = b_dictionary['input_vectors_early']
        input_vectors_late = b_dictionary['input_vectors_late']

        # These are the time constants of the linear system. For N -> infty, it
        # will become continuous, but for finite N, we instead plot over Re(lambda) directly
        # time_constants = net_lin.tau / (1.-np.real(net_lin.lamb))


        modes = np.real(net_lin.lamb)
        if 'oneline' in args:
            ax2c.set_xlabel(r'Re$(\lambda_{\alpha})$', x=1.3)
            ax2cc.set_xlabel('')
        else:
            ax2c.set_xlabel(r'Re$(\lambda_{\alpha})$')
            ax2cc.set_xlabel(r'Re$(\lambda_{\alpha})$')

        # omega_alpha
        mode_weight_early = np.einsum('ai, ti -> ta', net_lin.Left, input_vectors_early)
        mode_weight_late = np.einsum('ai, ti -> ta', net_lin.Left, input_vectors_late)

        n_bins = 15
        bin_width = (np.max(modes) - np.min(modes)) / n_bins
        mode_bins = np.linspace(np.min(modes), np.max(modes) - bin_width, n_bins)

        #calc early bin weights on average over stimuli and modes in bin, absolute values for complex weights
        abs_mode_weights_early = np.zeros((n_bins, 2))
        abs_mode_weights_late = np.zeros((n_bins, 2))

        stack_data = np.vstack([np.vstack([modes, np.mean(np.abs(mode_weight_early), axis=0)]), \
                                np.mean(np.abs(mode_weight_late), axis=0)])

        # manually counting contributions to modes in bin
        for mode, mode_weight_e, mode_weight_l in stack_data.T:
            idx = len(mode_bins[mode_bins<=mode]) - 1 # bin to which mode belongs
            abs_mode_weights_early[idx, 0] += mode_weight_e # sum over all |omega_alpha| in that bin
            abs_mode_weights_early[idx, 1] += 1 # number of contributions to that bin
            abs_mode_weights_late[idx, 0] += mode_weight_l
            abs_mode_weights_late[idx, 1] += 1

        # average over all |omega_alpha| within bin
        abs_mode_weights_early = np.einsum('s, s -> s', abs_mode_weights_early[:, 0], 1./abs_mode_weights_early[:, 1])
        abs_mode_weights_early = np.nan_to_num(abs_mode_weights_early)
        abs_mode_weights_late = np.einsum('s, s -> s', abs_mode_weights_late[:, 0], 1./abs_mode_weights_late[:, 1])
        abs_mode_weights_late = np.nan_to_num(abs_mode_weights_late)

        ax2c.bar(mode_bins, abs_mode_weights_early, width=bin_width, align='edge', color='cornflowerblue')
        ax2cc.bar(mode_bins, abs_mode_weights_late, width=bin_width, align='edge', color='cornflowerblue')





    if 'D' in args:
        #load and plot data for fig 2d
        with open('data/linear_networks/2d_'+str(seed)+'.txt', 'rb') as handle:
            d_dictionary = pickle.loads(handle.read())

        angle_decay = d_dictionary['angle_decay']
        # one line for each stimulus. Printed as random permutation to have a color mix in the figure (random color for each line)
        for idx in np.random.permutation(n_stimuli):
            ax2d.plot(angles, angle_decay[idx], c=plt.cm.get_cmap('winter')(idx/n_stimuli), linewidth=0.2)
        ax2d.set_ylim([np.min(angle_decay) * 0.95, np.max(angle_decay) * 1.05])



    fig2.savefig('data/linear_networks/fig2_'+str(seed)+'.pdf')






#%%









def fig3(seed, *args):
    '''
    Possible args:
        'a': generate data for figure a
        'A': plot figure a
        'B': plot figure b
    To generate data for fig. 3a, use optimize_soft_margins and random_u_soft_margins first.
    '''
    np.random.seed(seed)


    fig3 = plt.figure(3, figsize=(20,6))
    plt.clf()
    plt.subplots_adjust(left=0.15,right=0.95,top=0.9,bottom=0.2,wspace=0.6,hspace=0.7)

    gs = GridSpec(1,2, hspace=0.05,wspace=0.45, bottom=0.2,
                  left=0.08, right=0.97, width_ratios=[1, 1])
    plt.rcParams.update({'font.size': 24})

    def reduce_ticks(ax, every_nth=2, axtype='y'):
        if axtype == 'x':
            for n, label in enumerate(ax.xaxis.get_ticklabels()):
                if n % every_nth != 0:
                    label.set_visible(False)
        if axtype == 'y':
            for n, label in enumerate(ax.yaxis.get_ticklabels()):
                if n % every_nth != 0:
                    label.set_visible(False)


    #fig 3a
    ax3a = fig3.add_subplot(gs[0,0])

    ax3aa = ax3a.twinx()
    ax3aa.set_ylabel(r'$x(t)$')
    ax3aa.yaxis.set_major_locator(plt.MultipleLocator(0.5))

    ax3a.set_title('(a)',loc='left')
    ax3a.set_xlabel(r'time $t$')
    ax3a.set_ylabel(r'activity $y$')
    reduce_ticks(ax3a)

    ax3ab = inset_axes(ax3a, width="20%", height="20%", loc='lower right', borderpad=1.5)
    ax3ab.set_ylabel(r'time $t$')
    ax3ab.set_xlabel(r'$\max($Re$(\tilde{\lambda}_{\alpha}))$')
    ax3ab.tick_params(labelbottom=False, labelleft=False)


    ax3ac = inset_axes(ax3a, width="35%", height="35%", loc=2)
    ax3ac.tick_params(labelbottom=False, labelleft=False)


    #fig3b
    ax3b = fig3.add_subplot(gs[0,1])
    ax3b.set_title('(b)',loc='left')





    # load network
    with open('data/network_realizations/initialize_network_'+str(seed)+'.txt', 'rb') as handle:
        network_dictionary = pickle.loads(handle.read())

    net_non = network_dictionary['net_non']
    poly_coeffs_sim = np.copy(net_non.poly_coeffs)
    poly_coeffs_sim[2] *= 2 # This is alpha, the strength of the non-linearity.
    # the prefactor of NEST's polynomial_rate_ipn neuron's quadratic non-linearity
    # is 1/2, so we need to multiply it with to to use the same equation of motion.
    # also, in NEST the stimuli x(t) pass through the non-linearity. We will take
    # measures against that further below.

    net_lin = network_dictionary['net_lin']

    N = network_dictionary['N']




    def data3a():
        # generate responses of neurons to a sample stimulus

        # time resolution of theory and simulation
        dt = 0.01
        dt_multi = 0.01
        times = np.arange(0., net_lin.stop_points[-1]+dt, dt)
        # set up connectivity from eigenvectors and eigenvalues
        W = np.real(np.einsum('a, ia, aj -> ij', net_lin.lamb, net_lin.Right, net_lin.Left))

        # set up a sample stimulus that is easy to look at
        stop_points = times[::int(len(times)/4)]
        stop_points[-1] = times[-1]
        stimulus = np.random.rand(len(stop_points)-1) - 0.5
        stimulus /= np.linalg.norm(stimulus)

        # modify network for new stimulus time axis
        net_lin.stop_points = stop_points
        net_lin.T=len(net_lin.stop_points)-1
        net_lin.readout_time = stop_points[-1]

        net_non.stop_points = stop_points
        net_non.T=len(net_non.stop_points)-1
        net_non.readout_time = stop_points[-1]

        net_lin.determine_sample_dynamics(np.array([stimulus]), np.ones(1), large_N=True)
        net_non.determine_sample_dynamics(np.array([stimulus]), np.ones(1), large_N=True)

        # stimulus for easy plotting
        input_times = np.outer(stop_points, np.ones(2)).flatten()[1:-1]
        input_heights = np.outer(stimulus, np.ones(2)).flatten()

        # random input vector
        input_vector = np.random.rand(N) - 0.5
        input_vector /= np.linalg.norm(input_vector)


        lin_distances = np.empty((len(times)-1, N))
        non_distances = np.empty((len(times)-1, N))


        # the network has to be modified for intermediate times
        for time_idx, time in enumerate(times[1:]):

            net_lin_temp = deepcopy(net_lin)
            net_non_temp = deepcopy(net_non)

            net_lin_temp.stop_points = np.append(net_lin.stop_points[net_lin.stop_points < time], time)
            net_lin_temp.T=len(net_lin_temp.stop_points)-1
            net_lin_temp.readout_time = time

            net_non_temp.stop_points = np.append(net_non.stop_points[net_non.stop_points < time], time)
            net_non_temp.T=len(net_non_temp.stop_points)-1
            net_non_temp.readout_time = time

            sam = np.array([stimulus[:net_non_temp.T]])
            lin_distances[time_idx] = net_lin_temp.determine_responses(sam, input_vector)
            non_distances[time_idx] = net_non_temp.determine_responses(sam, input_vector)

        lin_distances = np.append(np.zeros((1,N)), lin_distances, axis=0)
        non_distances = np.append(np.zeros((1,N)), non_distances, axis=0)







        # Only print when errors occur
        nest.set_verbosity('M_ERROR')
        # dictionaries for multimeter and neuron setup
        neuron_dict={'poly_coeffs':poly_coeffs_sim, 'linear_summation': False, 'mu': 0., 'sigma': 0., 'tau': net_non.tau}
        multi_dict={'withtime': True, 'record_from': ['rate'], 'interval': dt_multi}

        # setup multimeter, neurons and step rate generator (self-made by the authors) with the routine described in create_network
        multi, n, rate_gen = create_network.create_network(N=N, N_recorded=N, neuron_dict=neuron_dict, multi_dict=multi_dict, neuron_type='polynomial_rate_ipn', input_type='same', W=W, input_vector=input_vector, dt=dt, setKernel=True)

        readout=np.empty((1, N))
        multi_list=np.zeros((1, 3, int(N*stop_points[-1]/dt_multi)))
        # as mentioned above, NEST internally, the output of the step rate generator passes through the non-linearity.
        # This, we do not want, so we have to revert this. This can fail if the stimuli are too large, so a rescaling can be necessary.
        if poly_coeffs_sim[2] != 0:
            stimulus_sim = - poly_coeffs_sim[1]/poly_coeffs_sim[2] + np.sqrt((poly_coeffs_sim[1]/poly_coeffs_sim[2])**2 \
                                            + 2./poly_coeffs_sim[2] * stimulus)
        else:
            stimulus_sim = stimulus
        # apply stimuli to neurons (n_samples per class) using routine described in apply_stimulus and sort multimeter readout
        # into times_sim (simulation readout times with resolution dt_multi) and rates #(time, trial, neuron) with here only 1 trial
        readout[0], multi_list[0] = apply_stimulus.apply_stimulus(rate_amplitude=np.append(stimulus_sim, 0.), dt=dt, rate_generator=rate_gen,  neurons=n, T=stop_points[-1], multimeter=multi, rate_timestamps=stop_points)
        rates, times_sim = get_properties.get_rate(multi_list, N, 1)




        a_dictionary = {
                'sim_data': [times_sim, rates[:, 0]],
                'lin_data': [times, lin_distances],
                'non_data': [times, non_distances],
                'stim_data': [input_times, input_heights]
                }

        with open('data/responses_soft_margins/3a_'+str(seed)+'.txt', 'wb') as handle:
            pickle.dump(a_dictionary, handle)







    def legend_without_duplicate_labels(ax):
        handles, labels = ax.get_legend_handles_labels()
        unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
        ax.legend(*zip(*unique), fancybox=True, framealpha=0.5, loc=3)

    np.random.seed(seed)

    # generate data for fig. 3a
    if "a" in args:
        data3a()

    np.random.seed(seed) # to ensure same chosen samples independent of using "a" or not

    # plot fig. 3a
    if "A" in args:

        with open('data/responses_soft_margins/3a_'+str(seed)+'.txt', 'rb') as handle:
            a_dictionary = pickle.loads(handle.read())

        sim_data = a_dictionary['sim_data'] # simulation
        lin_data = a_dictionary['lin_data'] # linear system
        non_data = a_dictionary['non_data'] # O(alpha)
        stim_data= a_dictionary['stim_data']# stimulus

        # sample neurons
        show_neurons = 4
        neuron_sample = np.random.choice(len(lin_data[1][0]), show_neurons, replace=False)

        ax3aa.plot(stim_data[0], stim_data[1], c='darkorange')
        ax3a.plot(sim_data[0], np.array(sim_data[1])[:, neuron_sample], c='deepskyblue', label='sim.')
        ax3a.plot(lin_data[0], np.array(lin_data[1])[:, neuron_sample], c=(19/255, 19/255, 19/255), label='linear', linestyle='dashed')
        ax3a.plot(non_data[0], np.array(non_data[1])[:, neuron_sample], c='crimson', label=r'$\mathcal{O}(\alpha)$')

        legend_without_duplicate_labels(ax3a)


        # detailed view in inset
        window_start = 6.
        detail_indices_sim = range(np.argmin(np.abs(sim_data[0]-window_start)), np.argmin(np.abs(sim_data[0]-window_start-0.5)))
        detail_indices_lin = range(np.argmin(np.abs(lin_data[0]-window_start)), np.argmin(np.abs(lin_data[0]-window_start-0.5)))
        detail_indices_non = range(np.argmin(np.abs(non_data[0]-window_start)), np.argmin(np.abs(non_data[0]-window_start-0.5)))
        # the neuron that has the highest rate in the given time slot used for inset
        single = np.argmax(np.bincount(np.argmax(np.array(sim_data[1])[detail_indices_sim][:, neuron_sample], axis=1)))

        ax3ac.plot(sim_data[0][detail_indices_sim], np.array(sim_data[1])[detail_indices_sim, neuron_sample[single]], c='deepskyblue', label='sim.')
        ax3ac.plot(lin_data[0][detail_indices_lin], np.array(lin_data[1])[detail_indices_lin, neuron_sample[single]], c=(19/255, 19/255, 19/255), label='linear', linestyle='dashed')
        ax3ac.plot(non_data[0][detail_indices_non], np.array(non_data[1])[detail_indices_non, neuron_sample[single]], c='crimson', label=r'\mathcal{O}(\alpha)')

        mark_inset(ax3a, ax3ac, loc1=1, loc2=4, fc="none", ec='0.5')



        #effective connectivity
        if np.any(np.isnan(sim_data[1])): # this happened only in some pathetic cases near the edge of chaos
            mask = np.ones_like(sim_data[1], dtype=bool)
            mask[np.isnan(sim_data[1])] = 0
            time_no_nan = np.min(np.argmin(mask, axis=0))
            rates_no_nan = sim_data[1][:time_no_nan]
            times_no_nan = sim_data[0][:time_no_nan]
        else:
            rates_no_nan = sim_data[1]
            times_no_nan = sim_data[0]

        # the effective connectivity of the linear system, of course, is identical to the actual connectivity
        W_effective_lin = np.einsum('a, ia, aj -> ij', net_lin.lamb, net_lin.Right, net_lin.Left)[None, :, :] + net_lin.poly_coeffs[2] * np.einsum('a, ia, aj, tj -> tij', net_lin.lamb, net_lin.Right, net_lin.Left, lin_data[1])
        lamb_effective_lin = np.linalg.eig(W_effective_lin)[0]
        W_effective_non = np.einsum('a, ia, aj -> ij', net_non.lamb, net_non.Right, net_non.Left)[None, :, :] + 2*net_non.poly_coeffs[2] * np.einsum('a, ia, aj, tj -> tij', net_non.lamb, net_non.Right, net_non.Left, non_data[1])
        lamb_effective_non = np.linalg.eig(W_effective_non)[0]
        W_effective_sim = np.einsum('a, ia, aj -> ij', net_non.lamb, net_non.Right, net_non.Left)[None, :, :] + poly_coeffs_sim[2] * np.einsum('a, ia, aj, tj -> tij', net_non.lamb, net_non.Right, net_non.Left, rates_no_nan)
        lamb_effective_sim = np.linalg.eig(W_effective_sim)[0]

        ax3ab.plot(np.max(np.real(lamb_effective_lin), axis=1), lin_data[0], c=(19/255, 19/255, 19/255), linestyle='dashed')
        ax3ab.plot(np.max(np.real(lamb_effective_non), axis=1), non_data[0], c='crimson')
        ax3ab.plot(np.max(np.real(lamb_effective_sim), axis=1), times_no_nan, c='deepskyblue')




    if "B" in args:
        n_samples = network_dictionary['n_samples']
        mu = network_dictionary['mu']
        chi = network_dictionary['chi']
        psi = network_dictionary['psi']

        input_separabilities = []
        for length in  network_dictionary['lengths']:
            # the same samples used in the optimization
            samples, labels = sample_stimuli(seed, n_samples, length, mu, chi, psi)
            # estimate of mu based on the drawn samples
            mu_s = np.einsum('tn, t -> n', samples, labels) * 1. / len(labels)

            # covariances of the classes plus and minus
            cov_plus = np.cov(samples[labels>0].T, bias=True)
            cov_minus = np.cov(samples[labels<0].T, bias=True)

            eig_plus, _ = np.linalg.eig(cov_plus)
            eig_minus, _ = np.linalg.eig(cov_minus)

            # estimate of psi based on the samples
            psi_s = 0.5 * (cov_plus + cov_minus)
            # signal-to-noise-like rough estimator of linear separability of data to compare with ECG dataset
            lin_sep = np.linalg.norm(mu_s)**2 / np.sqrt(np.einsum('n, nm, m ->', mu_s, psi_s, mu_s))
            print('linear separability measure:', lin_sep)

            input_separabilities.append(np.linalg.norm(mu_s))

        input_separabilities = np.array(input_separabilities)

        sc = evaluate_soft_margins(seed, ax3b, input_separabilities)
        cbar = plt.colorbar(sc)
        cbar.set_label(r'$\vert\!\vert\mu\vert\!\vert$', x=0.9, rotation=90)

        ax3b.xaxis.set_major_locator(plt.MultipleLocator(0.002))
        ax3b.yaxis.set_major_locator(plt.MultipleLocator(0.002))






    fig3.savefig('data/responses_soft_margins/fig3_'+str(seed)+'.pdf')











#%%
# These functions are very similar to functions in network_class, but pulled up
# here to allow using python's multiprocessing

def soft_margin_over_time_a(list_):
    # Optimization of input and readout vector for fig. 2a at intermediate time points
    readout_time = list_[0]
    opt_steps_lin = list_[1]
    stop_points = list_[2]
    stimuli = list_[3]
    labels = list_[4]
    input_vector = list_[5]
    eta = list_[6]

    mu = np.einsum('ti, t -> i', stimuli, labels) / len(labels)

    results = []

    net_copy = deepcopy(net_lin)

    # The optimization routine was written for a single readout time. Modified network dynamics is necessary for clipped time axis to generate fig. 2a
    if readout_time > stop_points[0]:
        temp_stop_points = np.append(stop_points[stop_points < readout_time], readout_time)
        section_idx = np.where(temp_stop_points == readout_time)[0][0]

        temp_mu = mu[:section_idx]
        temp_stimuli = stimuli[:, :section_idx]

        net_copy.stop_points = temp_stop_points
        net_copy.T = len(temp_stop_points) - 1
        net_copy.readout_time = temp_stop_points[-1]
        net_copy.determine_sample_dynamics(temp_stimuli, labels, large_N=True)

        # soft margin in response to random u (only 1 optimization step):
        soft_margin_r, u_r, v_r = net_copy.alternating_optimization(1, temp_mu, eta=eta, initial_input_vector=input_vector)
        results.append(soft_margin_r[0])

        # soft margin, optimized:
        soft_margin_o, u_o, v_o = net_copy.find_good_optimization(opt_steps_lin, temp_mu, eta=eta, initial_cond=10, initial_steps=3)
        results.append(soft_margin_o[-1])

    else:
        # In case T=0, we have an analytical result of zero for all variables. However, this case is not used.
        results = [np.zeros(1), np.zeros(1), np.zeros(net_copy.N), np.zeros((net_copy.N,net_copy.N))]

    return results




def soft_margin_over_time_b(list_):
    # Optimization of input and readout vector for fig. 2b at intermediate time point,
    # one of the stimuli and all of the random u for that stimulus, if more than 1 is used.

    i = list_[0] # time idx
    idx = list_[1] # stimulus idx
    opt_steps_lin = list_[2]
    set_u = list_[3]
    set_mu = list_[4]
    stop_points = list_[5]
    eta = list_[6]
    n_samples = list_[7]

    u_rand = set_u[idx]

    net_copy = deepcopy(net_lin)

    temp_stop_points = stop_points[:i+2]

    temp_mu = set_mu[idx, :i+1]
    temp_psi = np.eye(i+1)

    # to achieve general results, use different realizations of the stimuli for every readout time and ground truth of stimuli
    temp_samples, temp_labels = sample_stimuli(i*len(set_mu)+idx, n_samples, 1., temp_mu, chi=np.zeros_like(temp_psi), psi=temp_psi)

    # for earlier readout times, these parameters have to be adapted
    net_copy.stop_points = temp_stop_points
    net_copy.T = len(temp_stop_points) - 1
    net_copy.readout_time = temp_stop_points[-1]
    net_copy.determine_sample_dynamics(temp_samples, temp_labels, large_N=True)

    # soft margin for random input vectors
    if 'variate_u' in list_:
        soft_margin_r = []
        for u_rand in set_u:
            softm_r, u_r, v_r = net_copy.alternating_optimization(1, temp_mu, eta=eta, initial_input_vector=u_rand)
            soft_margin_r.append(softm_r[0])
        soft_margin_r = np.array(soft_margin_r)
    else:
        soft_margin_r, u_r, v_r = net_copy.alternating_optimization(1, temp_mu, eta=eta, initial_input_vector=u_rand)
        soft_margin_r = soft_margin_r[0]

    # optimized soft margins
    soft_margin_o, u_o, v_o = net_copy.find_good_optimization(opt_steps_lin, temp_mu, eta=eta, initial_cond=10, initial_steps=3)

    return [soft_margin_r, soft_margin_o[-1], u_o[-1]]




def soft_margin_over_time_d(list_):
    idx = list_[0]
    opt_steps_lin = list_[1]
    set_u = list_[2]
    angles = list_[3]
    samples = list_[4]
    labels = list_[5]
    stop_points = list_[6]
    late_idx = list_[7]
    eta = list_[8]
    seed = list_[9]

    np.random.seed(seed)

    results = []
    u_rand = set_u[idx]

    # modify the network and stimulus for earlier readout times
    net_copy = deepcopy(net_lin)

    temp_stop_points = stop_points[:late_idx+2]


    temp_samples = samples[:, :late_idx + 1]
    temp_mu = np.einsum('tn, t -> n', temp_samples, labels) / len(labels)

    net_copy.stop_points = temp_stop_points
    net_copy.T = len(temp_stop_points) - 1
    net_copy.readout_time = temp_stop_points[-1]
    net_copy.determine_sample_dynamics(temp_samples, labels, large_N=True)

    # optimized solution
    soft_margin_o, u_o, v_o = net_copy.find_good_optimization(opt_steps_lin, temp_mu, eta=eta, initial_cond=10, initial_steps=3)

    # orthogonal complement of random u to optimal input vector, then varying angle
    for angle in angles:

        orth = u_rand - np.dot(u_rand, u_o[-1])/np.dot(u_o[-1], u_o[-1]) * u_o[-1]
        orth /= np.linalg.norm(orth)
        u_temp = np.cos(angle)*u_o[-1] + np.sin(angle)*orth
        u_temp /= np.linalg.norm(u_temp)

        soft_margin_r, u_r, v_r = net_copy.alternating_optimization(1, temp_mu, eta=eta, initial_input_vector=u_temp)

        results.append(soft_margin_r[0])

    return results
