# For evaluation only, a nest installation is not required. For generation of data,
# a nest version with changes by the authors is required.
try:
    import nest
except:
    nest = None    
import numpy as np
from scipy.optimize import bisect
from sklearn.linear_model import Ridge

'''
A few useful functions when working with nest
'''

def get_connectivity(neurons):
    # Get connectivity matrix W between neurons
    W = np.zeros((len(neurons), len(neurons)))
    for t, target in enumerate(neurons):
        temp = nest.GetStatus(
            nest.GetConnections(target=[target]), keys=['source', 'weight'])
        for j in temp:
            if j[0]-neurons[0] < len(neurons):
                W[t, j[0] - neurons[0]] += j[1]
    return W

def get_rate(multi, N, trials):
    #multi is reaturned by apply_stimulus.apply_stimulus
    #Sort data of multi[trial, (s-t-r), int(N*T_readout/dt_multi)]
    #for multi_timestep: response vector of length trials for N neurons each
    idx=np.min(multi[0][0])
    times=multi[0][1][np.where(multi[0][0]==idx)]
    rate=np.zeros((len(times), trials, N))
    for trial in range(trials):
        for sender in range(N):
            sender_pos=np.where(multi[trial][0]==sender+idx)
            rate[:, trial, sender]=multi[trial][2][sender_pos]
    return rate, times


def find_simulation_readout(Sigma_sim, dist_sim, eta, N, net, responses, labels, solver='eigenvalue'):
    # This function is the determine_readout function of the network_class, but skipping the determination of Sigma and M
    def readout_vector_by_lagrange(lagrange, Sigma_sim, dist_sim):
        eig, vec = np.linalg.eigh( eta * Sigma_sim - 2 * lagrange * np.eye(N) )
        eiginv = 1./eig
        readout_vector = np.einsum('a, ia, ja, j -> i', eiginv, vec, vec, dist_sim)

        return readout_vector

    def bisect_lagrange(lagrange, Sigma_sim, dist_sim):
        norm_eval = np.linalg.norm(readout_vector_by_lagrange(lagrange, Sigma_sim, dist_sim))**2 - 1.
        return norm_eval


    eig, vec = np.linalg.eigh(eta * Sigma_sim)
    lag_max = 0.5 * np.min(eig)
    if bisect_lagrange(lag_max, Sigma_sim, dist_sim) > 0.:
        lag_range = np.array([lag_max - 1., lag_max])
    else:
        lag_range = np.append(-1., np.linspace(lag_max - net.eps, lag_max + net.eps, 101))

    inis = np.array([bisect_lagrange(lag, Sigma_sim, dist_sim) for lag in lag_range])

    if np.any(inis > 0):

        lagrange = bisect(bisect_lagrange, lag_range[np.argmin(inis)], lag_range[np.argmax(inis)], args=(Sigma_sim, dist_sim), xtol=1e-16)

        if solver == 'eigenvalue':

            readout_vector = readout_vector_by_lagrange(lagrange, Sigma_sim, dist_sim)

        elif solver == 'ridge':

            alpha_ridge = - 2 * len(labels) * lagrange / eta

            if alpha_ridge < 0:
                print('unrealistic lambda. Use abs.')
                alpha_ridge = np.abs(alpha_ridge)

            clf = Ridge(alpha=alpha_ridge, fit_intercept=False)

            clf.fit(responses, labels)

            readout_vector = clf.coef_
            readout_vector = np.copy(readout_vector) / np.linalg.norm(readout_vector)
        else:
            print('Determining Lagrange parameter failed in determine_readout. Use Ridge Regression with arbitrary fixed alpha.')
            #This value of alpha is of the order of magnitude encountered in tested examples.
            alpha_ridge = 2 * len(labels) * 1e-12 / eta

            clf = Ridge(alpha=alpha_ridge, fit_intercept=False)
            clf.fit(responses, labels)
            readout_vector = clf.coef_
            readout_vector = np.copy(readout_vector) / np.linalg.norm(readout_vector)

    v_sim = readout_vector / np.linalg.norm(readout_vector)

    return v_sim