import numpy as np
from scipy.optimize import bisect, fsolve
from sklearn.linear_model import Ridge
import sys


class network:
    '''
    Calculate the Green's functions of a network with given network parameters
    up to first order for stepwise constant stimuli and use it to optimize the
    soft margin.
    '''

    def __init__(self, network_params, stop_points, eps=1e-15):

        self.lamb=network_params['lamb']
        self.Right=network_params['Right']
        self.Left=network_params['Left']
        self.tau=network_params['tau']
        self.poly_coeffs=network_params['poly_coeffs']
        self.stop_points=np.array(stop_points)
        self.N=len(network_params['lamb'])
        self.T=len(stop_points)-1
        self.readout_time=stop_points[-1]
        self.eps=eps

        self.Greens1=None
        self.Greens2=None

        self.dist_lin_summand = None
        self.dist_non_summand = None
        self.Sigma_lin_summand = None
        self.Sigma_non_summand = None

        self.stimuli = None
        self.labels = None


#%%

    ################
    #Determine dynamical quantities
    ################

# Basic dynamics in response to stimuli. Usually used in combination with train data




    def determine_sample_dynamics(self, stimuli, labels, large_N=True):
        '''
        To be executed before any optimization.
        Determine product of Green's functions with stimuli. Used for an
        efficient calculation of distance M and covariance Sigma.
        stimuli: (samples, timesteps)
        labels: (samples)
        large_N: Set True if number of neurons N is large and you are in
        need of working memory. This will then be more memory efficient,
        but make some exra steps during the optimization.
        Note: At the current state, the correct Sigma can be computed
        only when large_N=True!
        '''

        self.stimuli = stimuli
        self.labels = labels

        self.dist_lin_summand = self.linear_distance_summand(stimuli, labels)

        self.dist_non_summand = self.nonlinear_distance_summand(stimuli, labels)

        self.Sigma_lin_summand = self.linear_covariance_summand(stimuli, labels)

        self.Sigma_non_summand = self.nonlinear_covariance_summand(stimuli, labels, large_N)




















    def linear_distance_summand(self, stimuli, labels):
        '''
        For automated use by determine_sample_dynamics or manual use to
        calculate estimated response to stimuli.
        Calculate sum_nu int G^(1)(t)*zeta_nu*x^nu(t) dt.
        Output: (output neuron, input_neuron)
        Useage:
            - With training stimuli+labels, calculate distance M_0 as
            M_0 = dist_lin_summand * u
            - With any stimuli, labels = np.ones(len(stimuli)), calculate
            linear contribution y_0 to response y as
            y_0 = dist_lin_summand * u
        If you need only y and do not wish to inspect its contributions,
        use determine_responses instead.
        '''

        lamb_exp = np.ones_like(self.lamb) - self.poly_coeffs[1] * self.lamb
        time_exp_matrix = np.exp( - np.einsum('n, a -> na', self.readout_time - self.stop_points, lamb_exp, optimize='greedy') * 1. / self.tau )
        exp_diff = time_exp_matrix[1:] - time_exp_matrix[:-1]
        dist_lin_summand = np.real(np.einsum('ia, ap, a, na, tn, t -> ip', self.Right, self.Left, 1./lamb_exp, exp_diff, stimuli, labels, optimize='greedy')) * 1. / len(labels)

        return dist_lin_summand




    def nonlinear_distance_summand(self, stimuli, labels):
        '''
        For automated use by determine_sample_dynamics or manual use to
        calculate estimated response to stimuli.
        Calculate sum_nu int G^(2)(t,s)*zeta_nu*x^nu(t)*x^nu(s) ds dt.
        Output: (output neuron, input_neuron_1, input_neuron_2)
        Useage:
            - With training stimuli+labels, calculate distance M_1 as
            M_1 = dist_lin_summand * u * u
            - With any stimuli, labels = np.ones(len(stimuli)), calculate
            O(alpha) contribution y_1 to response y as
            y_1 = dist_lin_summand * u * u
        If you need only y and do not wish to inspect its contributions,
        use determine_responses instead.
        '''

        if self.poly_coeffs[2] == 0:

            dist_non_summand = None

        else:

            labels = labels * 1. / len(labels)

            lamb_exp = np.ones_like(self.lamb) - self.poly_coeffs[1] * self.lamb
            timediff = np.diff(self.stop_points)
            readout_exp = np.exp(- lamb_exp / self.tau * self.readout_time)#(alpha)

            time_exp_matrix = np.exp( np.einsum('n, a -> na', self.stop_points, lamb_exp, optimize='greedy') * 1. / self.tau )
            exp_diff_matrix = time_exp_matrix[1:] - time_exp_matrix[:-1]#(n,alpha)

            time_exp_tensor = np.exp( np.einsum('n, ab -> nab', self.stop_points, np.subtract.outer(lamb_exp, lamb_exp), optimize='greedy')* 1./self.tau )
            exp_diff_tensor = time_exp_tensor[1:] - time_exp_tensor[:-1]#(n,alpha,beta)


            simple = self.tau * np.einsum('a, a, na, tn -> ta', 1./lamb_exp, readout_exp, exp_diff_matrix, stimuli, optimize='greedy')
            name0 = np.einsum('ta, tb, t -> ab', simple, simple, labels, optimize='greedy')

            #eq. 33
            lower_diag = np.einsum('ncb, mb, nm, t, tn, tm -> cb',
                                   exp_diff_tensor,
                                   exp_diff_matrix,
                                   np.tril(np.ones((self.T, self.T)),  -1),
                                   labels,
                                   stimuli,
                                   stimuli,
                                   optimize='greedy') \
                       - np.einsum('ncb, nb, nm, t, tn, tm -> cb',
                                    exp_diff_tensor,
                                    time_exp_matrix[:-1],
                                    np.eye(self.T),
                                    labels,
                                    stimuli,
                                    stimuli,
                                    optimize='greedy')

            name1 = self.tau**2 * np.einsum('b, cb, c, cb -> bc',
                                            1./lamb_exp,
                                            1./np.subtract.outer(lamb_exp, lamb_exp),
                                            readout_exp,
                                            lower_diag,
                                            optimize='greedy') \
                    + self.tau**2 * np.einsum('b, c, nc, c, nm, t, tn, tm -> bc',
                                              1./lamb_exp,
                                              1./lamb_exp,
                                              exp_diff_matrix,
                                              readout_exp,
                                              np.eye(self.T),
                                              labels,
                                              stimuli,
                                              stimuli,
                                              optimize='greedy')

            #where beta == gamma overwrite diagonal
            diag_idx = np.diag_indices(self.N)

            #eq. 35
            name1[diag_idx[0], diag_idx[1]] = self.tau * np.einsum('c, c, n, mc, nm, t, tn, tm -> c',
                                                                         1./lamb_exp,
                                                                         readout_exp,
                                                                         timediff,
                                                                         exp_diff_matrix,
                                                                         np.tril(np.ones((self.T, self.T)),  -1),
                                                                         labels,
                                                                         stimuli,
                                                                         stimuli,
                                                                         optimize='greedy') \
                                                    - self.tau * np.einsum('c, c, n, nc, nm, t, tn, tm -> c',
                                                                         1./lamb_exp,
                                                                         readout_exp,
                                                                         timediff,
                                                                         time_exp_matrix[:-1],
                                                                         np.eye(self.T),
                                                                         labels,
                                                                         stimuli,
                                                                         stimuli,
                                                                         optimize='greedy') \
                                                    + self.tau**2 * np.einsum('c, c, nc, nm, t, tn, tm -> c',
                                                                              1./lamb_exp**2,
                                                                              readout_exp,
                                                                              exp_diff_matrix,
                                                                              np.eye(self.T),
                                                                              labels,
                                                                              stimuli,
                                                                              stimuli,
                                                                              optimize='greedy')

            lamb_fraction = self.lamb[:, None, None] / (lamb_exp[:, None, None] - lamb_exp[None, :, None] - lamb_exp[None, None, :]) #(cab)

            subsum = np.einsum('cj, ja, jb, cab-> abc',
                               self.Left,
                               self.Right,
                               self.Right,
                               lamb_fraction,
                               optimize='greedy')

            max_integral = np.einsum('ic, ap, bq, abc, bc -> ipq',
                                     self.Right,
                                     self.Left,
                                     self.Left,
                                     subsum,
                                     name1,
                                     optimize='greedy')

            dist_non_summand = np.real(self.poly_coeffs[2] / (self.tau**2) * (
                                      np.einsum('ic, ap, bq, abc, ab -> ipq',
                                                self.Right,
                                                self.Left,
                                                self.Left,
                                                subsum,
                                                name0,
                                                optimize='greedy')
                                      - max_integral \
                                      - np.einsum('iqp -> ipq', max_integral, optimize='greedy')
                                      ))

        return dist_non_summand






    def linear_covariance_summand(self, stimuli, labels):
        '''
        For automated use by determine_sample_dynamics only.
        Calculate int G^(1)(t)*G^(1)(s)*(<x(t)x(s)>-<zeta*x(t)><zeta*x(s)> ds dt.
        Output: (output neuron_1, output_neuron_2, input_neuron_1, input_neuron_2)
        Useage:
            - With training stimuli+labels, calculate covariance Sigma_0 as
            Sigma_0 = Sigma_lin_summand * u * u
        '''

        lamb_exp = np.ones_like(self.lamb) - self.poly_coeffs[1] * self.lamb
        time_exp_matrix = np.exp( - np.einsum('n, a -> na', self.readout_time - self.stop_points, lamb_exp, optimize='greedy') * 1. / self.tau )
        exp_diff = time_exp_matrix[1:] - time_exp_matrix[:-1]

        Sigma_lin_summand = np.einsum('ia, ap, a, na, tn, jb, bq, b, mb, tm -> ijpq',
                                      self.Right,
                                      self.Left,
                                      1./lamb_exp,
                                      exp_diff,
                                      stimuli,
                                      self.Right,
                                      self.Left,
                                      1./lamb_exp,
                                      exp_diff,
                                      stimuli,
                                      optimize='greedy') * 1. / len(labels) \
                          - np.einsum('ip, jq -> ijpq', self.dist_lin_summand, self.dist_lin_summand)

        return Sigma_lin_summand





    def nonlinear_covariance_summand(self, stimuli, labels, large_N=True):
        '''
        For automated use by determine_sample_dynamics only.
        Calculate int (G^(1)(t)*G^(2)(s1, s2) + G^(2)(s1, s2)*G^(1)(t)) \
        * (<x(t)x(s1)x(s2)>-<zeta*x(t)><zeta*x(s1)x(s2)> ds1 ds2 dt.
        Output: (output neuron_1, output_neuron_2, input_neuron_1, input_neuron_2, input_neuron_3)
                or list of constituents of this quantity, if large_N = True
        Useage:
            - With training stimuli+labels, calculate covariance first contribution to Sigma_1 as
            Sigma_1 = Sigma_non_summand * u * u * u
            But don`t forget to use correction term. See nonlinear_covariance_correction
        '''
        if self.poly_coeffs[2] == 0:

            Sigma_non_summand = None
            return Sigma_non_summand

        else:

            lamb_exp = np.ones_like(self.lamb) - self.poly_coeffs[1] * self.lamb
            timediff = np.diff(self.stop_points)
            readout_exp = np.exp(- lamb_exp / self.tau * self.readout_time)#(alpha)

            time_exp_matrix = np.exp( np.einsum('n, a -> na', self.stop_points, lamb_exp, optimize='greedy') * 1. / self.tau )
            exp_diff_matrix = time_exp_matrix[1:] - time_exp_matrix[:-1]#(n,alpha)

            time_exp_tensor = np.exp( np.einsum('n, ab -> nab', self.stop_points, np.subtract.outer(lamb_exp, lamb_exp), optimize='greedy') * 1./self.tau )
            exp_diff_tensor = time_exp_tensor[1:] - time_exp_tensor[:-1]#(n,alpha,beta)


            simple = self.tau * np.einsum('a, a, na, tn -> ta', 1./lamb_exp, readout_exp, exp_diff_matrix, stimuli, optimize='greedy')
            name0 = np.einsum('ta, tb -> tab', simple, simple, optimize='greedy')

            #eq. 33
            lower_diag = np.einsum('ncb, mb, nm, tn, tm -> tcb',
                                   exp_diff_tensor,
                                   exp_diff_matrix,
                                   np.tril(np.ones((self.T, self.T)),  -1),
                                   stimuli,
                                   stimuli,
                                   optimize='greedy') \
                       - np.einsum('ncb, nb, nm, tn, tm -> tcb',
                                    exp_diff_tensor,
                                    time_exp_matrix[:-1],
                                    np.eye(self.T),
                                    stimuli,
                                    stimuli,
                                    optimize='greedy')

            name1 = self.tau**2 * np.einsum('b, cb, c, tcb -> tbc',
                                            1./lamb_exp,
                                            1./np.subtract.outer(lamb_exp, lamb_exp),
                                            readout_exp,
                                            lower_diag,
                                            optimize='greedy') \
                    + self.tau**2 * np.einsum('b, c, nc, c, nm, tn, tm -> tbc',
                                              1./lamb_exp,
                                              1./lamb_exp,
                                              exp_diff_matrix,
                                              readout_exp,
                                              np.eye(self.T),
                                              stimuli,
                                              stimuli,
                                              optimize='greedy')

            #where beta == gamma overwrite diagonal
            diag_idx = np.diag_indices(self.N)

            #eq. 35
            name1[:, diag_idx[0], diag_idx[1]] = self.tau * np.einsum('c, c, n, mc, nm, tn, tm -> tc',
                                                                         1./lamb_exp,
                                                                         readout_exp,
                                                                         timediff,
                                                                         exp_diff_matrix,
                                                                         np.tril(np.ones((self.T, self.T)),  -1),
                                                                         stimuli,
                                                                         stimuli,
                                                                         optimize='greedy') \
                                                    - self.tau * np.einsum('c, c, n, nc, nm, tn, tm -> tc',
                                                                         1./lamb_exp,
                                                                         readout_exp,
                                                                         timediff,
                                                                         time_exp_matrix[:-1],
                                                                         np.eye(self.T),
                                                                         stimuli,
                                                                         stimuli,
                                                                         optimize='greedy') \
                                                    + self.tau**2 * np.einsum('c, c, nc, nm, tn, tm -> tc',
                                                                              1./lamb_exp**2,
                                                                              readout_exp,
                                                                              exp_diff_matrix,
                                                                              np.eye(self.T),
                                                                              stimuli,
                                                                              stimuli,
                                                                              optimize='greedy')

            lamb_fraction = self.lamb[:, None, None] / (lamb_exp[:, None, None] - lamb_exp[None, :, None] - lamb_exp[None, None, :]) #(cab)

            subsum = np.einsum('cj, ja, jb, cab-> abc',
                               self.Left,
                               self.Right,
                               self.Right,
                               lamb_fraction,
                               optimize='greedy')

            # overwrite to calculate the linear summand parts
            lamb_exp = np.ones_like(self.lamb) - self.poly_coeffs[1] * self.lamb
            time_exp_matrix = np.exp( - np.einsum('n, a -> na', self.readout_time - self.stop_points, lamb_exp, optimize='greedy') * 1. / self.tau )
            exp_diff = time_exp_matrix[1:] - time_exp_matrix[:-1]

            dist_lin_summand = np.einsum('ia, ap, a, na, tn -> tip',
                                          self.Right,
                                          self.Left,
                                          1./lamb_exp,
                                          exp_diff,
                                          stimuli,
                                          optimize='greedy')

            if large_N:

                return [
                        dist_lin_summand,
                        subsum,
                        name0,
                        name1
                        ]

            else:
                max_integral = np.einsum('ic, ap, bq, abc, tbc -> tipq',
                                         self.Right,
                                         self.Left,
                                         self.Left,
                                         subsum,
                                         name1,
                                         optimize='greedy')

                dist_non_summand = self.poly_coeffs[2] / (self.tau**2) * (
                                          np.einsum('ic, ap, bq, abc, tab -> tipq',
                                                    self.Right,
                                                    self.Left,
                                                    self.Left,
                                                    subsum,
                                                    name0,
                                                    optimize='greedy')
                                          - max_integral \
                                          - np.einsum('tiqp -> tipq', max_integral, optimize='greedy')
                                          )




                Sigma_non_summand = np.einsum('tip, tjqr -> ijpqr', dist_lin_summand, dist_non_summand, optimize='greedy') \
                                    + np.einsum('tjp, tiqr -> ijpqr', dist_lin_summand, dist_non_summand, optimize='greedy')

                Sigma_non_summand *= 1. / len(labels)
                Sigma_non_summand -= np.einsum('ip, jqr -> ijpqr', self.dist_lin_summand, self.dist_non_summand, optimize='greedy') \
                                    + np.einsum('jp, iqr -> ijpqr', self.dist_lin_summand, self.dist_non_summand, optimize='greedy')


                return Sigma_non_summand







#%%

    ################
    #Network states for fixed input vector
    ################

# Evaluate distance and covariance contributions for given input vector u


    def linear_distance(self, input_vector):
        '''
        To be used automatedly by determine_readout or manually to determine the
        linear contribution M_0 to the distance vector M.
        Use determine_sample_dynamics first in either case.
        '''

        return np.real(np.einsum('ip, p -> i', self.dist_lin_summand, input_vector, optimize='greedy') )






    def nonlinear_distance(self, input_vector):
        '''
        To be used automatedly by determine_readout or manually to determine the
        O(alpha) contribution M_1 to the distance vector M.
        Use determine_sample_dynamics first in either case.
        '''

        if self.dist_non_summand is None:

            return np.zeros(self.N)

        else:

            return np.real(np.einsum('ipq, p, q -> i', self.dist_non_summand, input_vector, input_vector, optimize='greedy') )






    def linear_covariance(self, input_vector):
        '''
        To be used automatedly by determine_readout or manually to determine the
        linear contribution Sigma_0 to the covariance Sigma.
        Use determine_sample_dynamics first in either case.
        '''

        return np.real(np.einsum('ijpq, p, q -> ij', self.Sigma_lin_summand, input_vector, input_vector, optimize='greedy') )






    def nonlinear_covariance(self, input_vector):
        '''
        To be used automatedly by determine_readout or manually to determine the
        O(alpha) contribution to the covariance Sigma. To get the correct Sigma
        for dynamics in O(alpha), add nonlinear_covariance_correction.
        Use determine_sample_dynamics first in either case.
        '''

        if self.Sigma_non_summand is None:

            return np.zeros((self.N, self.N))

        elif np.ndim(self.Sigma_non_summand) == 5:

            return np.real(np.einsum('ijpqr, p, q, r -> ij', self.Sigma_non_summand, input_vector, input_vector, input_vector, optimize='greedy') )

        elif len(self.Sigma_non_summand) == 4:

            dist_lin_summand = self.Sigma_non_summand[0]
            subsum = self.Sigma_non_summand[1]
            name0 = self.Sigma_non_summand[2]
            name1 = self.Sigma_non_summand[3]
            N_samples = np.shape(name0)[0]

            max_integral = np.einsum('ic, ap, bq, abc, tbc, p, q -> ti',
                                     self.Right,
                                     self.Left,
                                     self.Left,
                                     subsum,
                                     name1,
                                     input_vector,
                                     input_vector,
                                     optimize='greedy')

            dist_non_summand = self.poly_coeffs[2] / (self.tau**2) * (
                                      np.einsum('ic, ap, bq, abc, tab, p, q -> ti',
                                                self.Right,
                                                self.Left,
                                                self.Left,
                                                subsum,
                                                name0,
                                                input_vector,
                                                input_vector,
                                                optimize='greedy')
                                      - 2 * max_integral
                                      )

            Sigma_non = np.einsum('tip, tj, p -> ij', dist_lin_summand, dist_non_summand, input_vector, optimize='greedy') \
                                + np.einsum('tjp, ti, p -> ij', dist_lin_summand, dist_non_summand, input_vector, optimize='greedy')

            Sigma_non *= 1. / N_samples
            Sigma_non -= np.einsum('ip, jqr, p, q, r -> ij', self.dist_lin_summand, self.dist_non_summand, input_vector, input_vector, input_vector, optimize='greedy') \
                        + np.einsum('jp, iqr, p, q, r -> ij', self.dist_lin_summand, self.dist_non_summand, input_vector, input_vector, input_vector, optimize='greedy')

            return np.real(Sigma_non)

        else:

            sys.exit('Error in computation of Sigma_non_summand')





    def nonlinear_covariance_correction(self, input_vector):
        '''
        To be used automatedly by determine_readout or manually to determine the
        O(alpha**2) contribution to the covariance Sigma.
        Use determine_sample_dynamics first in either case.
        '''

        if self.Sigma_non_summand is None:

            return np.zeros((self.N, self.N))

        elif np.ndim(self.Sigma_non_summand) == 5:

            print('Computation of nonlinear_covariance_correction valid only for large_N=True.')
            return np.zeros((self.N, self.N))

        elif len(self.Sigma_non_summand) == 4:
            subsum = self.Sigma_non_summand[1]
            name0 = self.Sigma_non_summand[2]
            name1 = self.Sigma_non_summand[3]
            N_samples = np.shape(name0)[0]

            max_integral = np.einsum('ic, ap, bq, abc, tbc, p, q -> ti',
                                     self.Right,
                                     self.Left,
                                     self.Left,
                                     subsum,
                                     name1,
                                     input_vector,
                                     input_vector,
                                     optimize='greedy')

            dist_non_summand = self.poly_coeffs[2] / (self.tau**2) * (
                                      np.einsum('ic, ap, bq, abc, tab, p, q -> ti',
                                                self.Right,
                                                self.Left,
                                                self.Left,
                                                subsum,
                                                name0,
                                                input_vector,
                                                input_vector,
                                                optimize='greedy')
                                      - 2 * max_integral
                                      )
            Sigma_corr = np.einsum('ti, tj -> ij', dist_non_summand, dist_non_summand, optimize='greedy')
            Sigma_corr *= 1./ N_samples
            Sigma_corr -= np.einsum('isp, jqr, s, p, q, r -> ij', self.dist_non_summand, self.dist_non_summand, input_vector, input_vector, input_vector, input_vector, optimize='greedy')

            return np.real(Sigma_corr)

        else:

            sys.exit('Error in computation of Sigma corrections')












    def determine_responses(self, stimuli, input_vector, center=False):
        '''
        Calculate the responses y^u,nu to stimuli.
        stimuli: (samples, timesteps)
        Can be used manually. Use determine_sample_dynamics first.
        '''
        responses = np.empty((len(stimuli), self.N))

        for idx, stimulus in enumerate(stimuli):

            sample = np.array([stimulus])

            responses[idx] = np.einsum('ip, p -> i', self.linear_distance_summand(sample, np.ones(1)), input_vector)

            if self.dist_non_summand is not None:

                responses[idx] += np.einsum('ipq, p, q -> i', self.nonlinear_distance_summand(sample, np.ones(1)), input_vector, input_vector)
        if center:
            responses = responses - np.mean(responses, axis=0)[None, :]

        return responses





#%%

    ################
    #Network states for fixed readout vector
    ################

# valuate distance and covariance contributions for given readout vector v


    def distance_vector(self, readout_vector):
        '''
        To be used only automatedly by determine_input to determine the
        linear contribution m_0 to the distance vector m.
        Use determine_sample_dynamics first.
        '''

        return np.real(np.einsum('i, ip -> p', readout_vector, self.dist_lin_summand) )






    def distance_matrix(self, readout_vector):
        '''
        To be used only automatedly by determine_input to determine the
        O(alpha) contribution m_1 to the distance vector m.
        Use determine_sample_dynamics first.
        '''

        if self.dist_non_summand is not None:

            dist_matrix = np.einsum('i, ipq -> pq', readout_vector, self.dist_non_summand)

        else:

            dist_matrix = np.zeros((self.N, self.N))

        return np.real(dist_matrix)






    def covariance_matrix(self, eta, readout_vector):
        '''
        To be used only automatedly by determine_input to determine the
        linear contribution sigma_0 to the covariance sigma.
        Use determine_sample_dynamics first.
        '''

        return np.real(0.5 * eta * np.einsum('i, ijpq, j -> pq', readout_vector, self.Sigma_lin_summand, readout_vector))






    def covariance_tensor(self, eta, readout_vector):
        '''
        To be used only automatedly by determine_input to determine the
        O(alpha) contribution to the covariance sigma. To get the correct
        sigma for dynamics in O(alpha), also use covariance_tensor_correction.
        Use determine_sample_dynamics first.
        '''

        if self.Sigma_non_summand is None:

            return None

        elif np.ndim(self.Sigma_non_summand) == 5:

            Sigma_tensor = 0.5 * eta * np.einsum('ijpqr, i, j -> pqr', self.Sigma_non_summand, readout_vector, readout_vector, optimize='greedy')

        elif len(self.Sigma_non_summand) == 4:

            dist_lin_summand = self.Sigma_non_summand[0]
            subsum = self.Sigma_non_summand[1]
            name0 = self.Sigma_non_summand[2]
            name1 = self.Sigma_non_summand[3]
            N_samples = np.shape(name0)[0]

            max_integral = np.einsum('ic, ap, bq, abc, tbc, i -> tpq',
                                     self.Right,
                                     self.Left,
                                     self.Left,
                                     subsum,
                                     name1,
                                     readout_vector,
                                     optimize='greedy')

            dist_non_summand = self.poly_coeffs[2] / (self.tau**2) * (
                                      np.einsum('ic, ap, bq, abc, tab, i -> tpq',
                                                self.Right,
                                                self.Left,
                                                self.Left,
                                                subsum,
                                                name0,
                                                readout_vector,
                                                optimize='greedy')
                                      - max_integral \
                                      - np.einsum('tqp -> tpq', max_integral, optimize='greedy')
                                      )

            Sigma_tensor = 2 * np.einsum('tip, tqr, i -> pqr', dist_lin_summand, dist_non_summand, readout_vector, optimize='greedy')

            Sigma_tensor *= 1. / N_samples

            Sigma_tensor -= np.einsum('ip, jqr, i, j -> pqr', self.dist_lin_summand, self.dist_non_summand, readout_vector, readout_vector, optimize='greedy') \
                            + np.einsum('jp, iqr, i, j -> pqr', self.dist_lin_summand, self.dist_non_summand, readout_vector, readout_vector, optimize='greedy')

            Sigma_tensor *= 0.5 * eta

        return np.real(Sigma_tensor)




    def covariance_tensor_correction(self, eta, readout_vector):
        '''
        To be used only automatedly by determine_input to determine the
        O(alpha**2) contribution to the covariance sigma.
        Use determine_sample_dynamics first.
        '''

        if self.Sigma_non_summand is None:

            return None

        elif np.ndim(self.Sigma_non_summand) == 5:

            print('Coputation of nonlinear_covariance_tensor_correction valid only if large_N=True.')
            return None

        elif len(self.Sigma_non_summand) == 4:

            subsum = self.Sigma_non_summand[1]
            name0 = self.Sigma_non_summand[2]
            name1 = self.Sigma_non_summand[3]
            N_samples = np.shape(name0)[0]

            max_integral = np.einsum('ic, ap, bq, abc, tbc, i -> tpq',
                                     self.Right,
                                     self.Left,
                                     self.Left,
                                     subsum,
                                     name1,
                                     readout_vector,
                                     optimize='greedy')

            dist_non_summand = self.poly_coeffs[2] / (self.tau**2) * (
                                      np.einsum('ic, ap, bq, abc, tab, i -> tpq',
                                                self.Right,
                                                self.Left,
                                                self.Left,
                                                subsum,
                                                name0,
                                                readout_vector,
                                                optimize='greedy')
                                      - max_integral \
                                      - np.einsum('tqp -> tpq', max_integral, optimize='greedy')
                                      )

            Sigma_corr = np.einsum('tsp, tqr -> spqr', dist_non_summand, dist_non_summand, optimize='greedy')

            Sigma_corr *= 1. / N_samples

            Sigma_corr -= np.einsum('isp, jqr, i, j -> spqr', self.dist_non_summand, self.dist_non_summand, readout_vector, readout_vector, optimize='greedy')

            Sigma_corr *= 0.5 * eta

        return np.real(Sigma_corr)










#%%

    ################
    #Optimization routine
    ################

# Optimization of u given v and of v given u as well as two wrapping functions to iterate it


#%%
    # Determine readout vector given u and eta
    def determine_readout(self, mu, eta, input_vector, solver='eigenvalue'):
        '''
        For given input vector, find the readout vector that maximizes O_eta = v^T M - eta/2 v^T Sigma v + lagrange * ( v^T v - 1 )
        To be used automatedly by alternating_optimization or manually to find the optimal readout vector for a fixed eta.
        Use determine_sample_dynamics first in either case.

        mu has to be passed just to check the (physically implausible) case of mu=0.
        solver can, at the current state, be 'eigenvalue' or 'ridge', for ridge regression.
        '''

        def readout_vector_by_lagrange(lagrange, Sigma_lin, Sigma_non, dist_lin, dist_non):
            eig, vec = np.linalg.eigh( eta * (Sigma_lin + Sigma_non) - 2 * lagrange * np.eye(self.N) )
            eiginv = 1./eig
            readout_vector = np.einsum('a, ia, ja, j -> i', eiginv, vec, vec, dist_lin + dist_non)

            return readout_vector

        def bisect_lagrange(lagrange, Sigma_lin, Sigma_non, dist_lin, dist_non):
            norm_eval = np.linalg.norm(readout_vector_by_lagrange(lagrange, Sigma_lin, Sigma_non, dist_lin, dist_non))**2 - 1.
            return norm_eval

        #determine dynamical properties M, Sigma at readout time
        dist_lin = self.linear_distance(input_vector)

        dist_non = self.nonlinear_distance(input_vector)

        Sigma_lin = self.linear_covariance(input_vector)

        Sigma_non = self.nonlinear_covariance(input_vector)

        Sigma_non += self.nonlinear_covariance_correction(input_vector)


        if self.poly_coeffs[2]==0 and np.linalg.norm(mu)==0.:
            eig, vec = np.linalg.eigh(Sigma_lin)
            readout_vector = vec.T[np.argmin(eig)]
            lagrange = 0.

        else:
            eig, vec = np.linalg.eigh(eta * (Sigma_lin + Sigma_non))
            lag_max = 0.5 * np.min(eig)
            if bisect_lagrange(lag_max, Sigma_lin, Sigma_non, dist_lin, dist_non) > 0.:
                lag_range = np.array([lag_max - 1., lag_max])
            else:
                lag_range = np.append(-1., np.linspace(lag_max - self.eps, lag_max + self.eps, 1001))

            inis = np.array([bisect_lagrange(lag, Sigma_lin, Sigma_non, dist_lin, dist_non) for lag in lag_range])

            if np.any(inis > 0):

                lagrange = bisect(bisect_lagrange, lag_range[np.argmin(inis)], lag_range[np.argmax(inis)], args=(Sigma_lin, Sigma_non, dist_lin, dist_non), xtol=1e-16)

                if solver == 'eigenvalue':

                    readout_vector = readout_vector_by_lagrange(lagrange, Sigma_lin, Sigma_non, dist_lin, dist_non)

                elif solver == 'ridge':

                    alpha_ridge = - 2 * len(self.labels) * lagrange / eta

                    if alpha_ridge < 0:
                        print('unrealistic lambda. Use abs.')
                        alpha_ridge = np.abs(alpha_ridge)

                    clf = Ridge(alpha=alpha_ridge, fit_intercept=False)

                    responses = self.determine_responses(self.stimuli, input_vector, center=True)

                    clf.fit(responses, self.labels)

                    readout_vector = clf.coef_
                    readout_vector = np.copy(readout_vector) / np.linalg.norm(readout_vector)


            else:
                print('Determining Lagrange parameter failed in determine_readout. Use Ridge Regression with arbitrary fixed alpha.')
                #This value of alpha is of the order of magnitude encountered in tested examples.
                alpha_ridge = 2 * len(self.labels) * 1e-12 / eta

                clf = Ridge(alpha=alpha_ridge, fit_intercept=False)
                responses = self.determine_responses(self.stimuli, input_vector, center=True)
                clf.fit(responses, self.labels)
                readout_vector = clf.coef_
                readout_vector = np.copy(readout_vector) / np.linalg.norm(readout_vector)



        readout_vector /= np.linalg.norm(readout_vector)
        soft_margin = self.readout_function_Lagrange(np.append(readout_vector, lagrange), Sigma_lin, Sigma_non, dist_lin, dist_non, eta)

        return readout_vector, soft_margin



    def readout_function_Lagrange(self, X, Sigma_lin, Sigma_non, dist_lin, dist_non, eta):
        '''
        calculate the soft margin kappa_eta in the lagrange scheme. To be used
        automatedly by determine_readout or manually to determine the soft
        margin given the cumulants of zeta*y.
        '''

        readout_vector = X[:-1]
        lagrange = X[-1]

        Lagrangian = np.einsum('i, i -> ', dist_lin + dist_non, readout_vector) \
                    - 0.5 * eta * np.einsum('i, ij, j ->', readout_vector, Sigma_lin + Sigma_non, readout_vector) \
                    + lagrange * (np.einsum('i, i ->', readout_vector, readout_vector) - 1.)

        return Lagrangian



    def readout_gradient_lagrange(self, X, Sigma_lin, Sigma_non, dist_lin, dist_non, eta):
        '''
        calculate gradient of the soft margin with respect to the readout vector.
        To be used automatedly in determine_readout for solvers to come.
        '''

        readout_vector = X[:-1]
        lagrange = X[-1]

        #calculate the gradient with respect to v
        gradient_v = dist_lin + dist_non \
                     - eta * np.einsum('ij, j -> i', Sigma_lin + Sigma_non, readout_vector) \
                     + 2 * lagrange * readout_vector

        gradient_lagrange = np.array(np.linalg.norm(readout_vector)**2 - 1)

        gradient = np.append(gradient_v, gradient_lagrange)

        return gradient



    def readout_Hessian_Lagrange(self, X, Sigma_lin, Sigma_non, dist_lin, dist_non, eta):
        '''
        determine Hessian of soft margin with respect to the readout vector.
        To be used automatedly in determine_readout for solvers to come.
        '''

        readout_vector = X[:-1]
        lagrange = X[-1]

        Hessian = - eta * (Sigma_lin + Sigma_non) \
                  + 2 * lagrange * np.eye(self.N)

        gradient_Jacobi = np.block([
                [Hessian, 2*readout_vector.reshape(self.N, 1)],
                [2*readout_vector.reshape(1, self.N), 0.]
                ])

        return gradient_Jacobi



#%%

    # Determine input vector given v and eta, as well as a guess of the solution
    def determine_input(self, mu, readout_vector, eta, guess_input_vector, guess = False):
        '''
        For given readout vector, find the input vector that maximizes O_eta = v^T M - eta/2 v^T Sigma v + lagrange * ( v^T v - 1 )
        To be used automatedly by alternating_optimization or manually to find the optimal input vector for a fixed eta.
        Use determine_sample_dynamics first in either case.

        mu has to be passed just to check the (physically implausible) case of mu=0.
        guess_input_vector is a guess of the solution. Good guess is required when all orders are used to find the optimum (guess=False).
        guess: Calculate a guess of the solution rather than the solution. Ignores all contributions to Sigma of O(alpha) and higher.
        If the determination of the lagrange parameter fails, a gradient descent is performed.
        '''

        def input_vector_by_lagrange(lagrange, Sigma_matrix, dist_matrix, dist_vector):
            eig, vec = np.linalg.eigh(2 * (Sigma_matrix - dist_matrix - lagrange * np.eye(self.N)) )
            eiginv = 1./eig
            input_vector = np.einsum('a, pa, qa, q -> p', eiginv, vec, vec, dist_vector)

            return input_vector

        def bisect_lagrange(lagrange, Sigma_matrix, dist_matrix, dist_vector):
            return np.array([np.linalg.norm(input_vector_by_lagrange(lagrange, Sigma_matrix, dist_matrix, dist_vector))**2 - 1.])


        # calculate minimal dynamical quantities (Sigma only linear)
        dist_vector = self.distance_vector(readout_vector)

        dist_matrix = self.distance_matrix(readout_vector)

        Sigma_matrix = self.covariance_matrix(eta, readout_vector)


        if np.linalg.norm(mu) == 0:
            eig, vec = np.linalg.eigh(Sigma_matrix - dist_matrix)
            input_vector = vec.T[np.argmin(eig)]
            lagrange = 0.
            Sigma_tensor = None
            Sigma_tensor_corr = None
            solve_full_system = False

        elif guess:
            Sigma_tensor = None
            Sigma_tensor_corr = None
            solve_full_system = False

            eig, vec = np.linalg.eigh(Sigma_matrix - dist_matrix)
            lag_max = np.min(eig)
            if bisect_lagrange(lag_max, Sigma_matrix, dist_matrix, dist_vector) > 0.:
                lag_range = np.array([-1., lag_max])
            else:
                lag_range = np.append(-1., np.linspace(lag_max - self.eps, lag_max + self.eps, 101))
            inis = np.array([bisect_lagrange(lag, Sigma_matrix, dist_matrix, dist_vector) for lag in lag_range])

            if np.any(inis > 0):
                lagrange = bisect(bisect_lagrange, lag_range[np.argmin(inis)], lag_range[np.argmax(inis)], args=(Sigma_matrix, dist_matrix, dist_vector), xtol=1e-16)
                input_vector = input_vector_by_lagrange(lagrange, Sigma_matrix, dist_matrix, dist_vector)
            else:
                print('Determining Lagrange parameter failed in determine_input. Use gradient descent.')

                input_vector = np.random.rand(self.N) - 0.5
                input_vector /= np.linalg.norm(input_vector)
                lagrange = lag_max
                soft_margin = self.input_function_Lagrange(np.append(input_vector, lagrange), dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr)

                stepwidth = 0.5
                tol = 1e-12
                diff = 1.
                counter = 0
                maxcount = 1e6

                while diff > tol and counter < maxcount:
                    gradient = self.input_gradient_lagrange(np.append(input_vector, lagrange), dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor)
                    input_vector = stepwidth*gradient[:-1] + input_vector
                    lagrange = stepwidth*gradient[-1] + lagrange

                    prev_sm = soft_margin
                    soft_margin = self.input_function_Lagrange(np.append(input_vector/np.linalg.norm(input_vector), lagrange), dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr)

                    diff = soft_margin - prev_sm
                    counter += 1
                    input_vector /= np.linalg.norm(input_vector)*(0.9+0.2*np.random.rand())

        else:
            Sigma_tensor = self.covariance_tensor(eta, readout_vector)

            Sigma_tensor_corr = self.covariance_tensor_correction(eta, readout_vector)

            solve_full_system = True

        if solve_full_system:
            X = fsolve(self.input_gradient_lagrange, np.append(guess_input_vector, -1.), args=(dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr), fprime=self.input_Hessian_Lagrange)

            input_vector = X[:-1]
            lagrange = X[-1]

        input_vector /= np.linalg.norm(input_vector)

        soft_margin = self.input_function_Lagrange(np.append(input_vector, lagrange), dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr)

        return soft_margin, input_vector




    def input_function_Lagrange(self, X, dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr=None):
        '''
        calculate the soft margin kappa_eta in the lagrange scheme. To be used
        automatedly by determine_input.
        '''

        input_vector = X[:-1]
        lagrange = X[-1]

        if Sigma_tensor is None:
            Lagrangian = np.einsum('p, p -> ', dist_vector, input_vector) \
                    + np.einsum('p, pq, q ->', input_vector, dist_matrix - Sigma_matrix, input_vector) \
                    + lagrange * (np.einsum('p, p ->', input_vector, input_vector) - 1.)
        elif Sigma_tensor_corr is None:
            Lagrangian = np.einsum('p, p -> ', dist_vector, input_vector) \
                    + np.einsum('p, pq, q ->', input_vector, dist_matrix - Sigma_matrix, input_vector) \
                    + np.einsum('pqr, p, q, r ->', - Sigma_tensor, input_vector, input_vector, input_vector) \
                    + lagrange * (np.einsum('p, p ->', input_vector, input_vector) - 1.)
        else:
            Lagrangian = np.einsum('p, p -> ', dist_vector, input_vector) \
                    + np.einsum('p, pq, q ->', input_vector, dist_matrix - Sigma_matrix, input_vector) \
                    + np.einsum('pqr, p, q, r ->', - Sigma_tensor, input_vector, input_vector, input_vector) \
                    + np.einsum('spqr, s, p, q, r ->', - Sigma_tensor_corr, input_vector, input_vector, input_vector, input_vector) \
                    + lagrange * (np.einsum('p, p ->', input_vector, input_vector) - 1.)

        return Lagrangian


    def input_gradient_lagrange(self, X, dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr=None):
        '''
        calculate gradient of the soft margin with respect to the input vector.
        To be used automatedly by determine_input.
        '''

        input_vector = X[:-1]
        lagrange = X[-1]

        if Sigma_tensor is None:
            Sigma_tensor_contribution = np.zeros(self.N)
            Sigma_tensor_correction_contribution = np.zeros(self.N)

        else:
            Sigma_tensor_contribution =  np.einsum('pqr, q, r -> p', Sigma_tensor, input_vector, input_vector, optimize='greedy') \
                                         + np.einsum('rpq, q, r -> p', Sigma_tensor, input_vector, input_vector, optimize='greedy') \
                                         + np.einsum('qrp, q, r -> p', Sigma_tensor, input_vector, input_vector, optimize='greedy')
            if Sigma_tensor_corr is None:
                Sigma_tensor_correction_contribution = np.zeros(self.N)
            else:
                Sigma_tensor_correction_contribution =   np.einsum('spqr, p, q, r -> s', Sigma_tensor_corr, input_vector, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rspq, p, q, r -> s', Sigma_tensor_corr, input_vector, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('qrsp, p, q, r -> s', Sigma_tensor_corr, input_vector, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('pqrs, p, q, r -> s', Sigma_tensor_corr, input_vector, input_vector, input_vector, optimize='greedy')

        gradient_u = dist_vector + 2 * np.einsum('pq, q -> p', dist_matrix, input_vector) \
                     - 2 * np.einsum('pq, q -> p', Sigma_matrix, input_vector) \
                     - Sigma_tensor_contribution \
                     - Sigma_tensor_correction_contribution \
                     + 2 * lagrange * input_vector

        gradient_lagrange = np.array(np.linalg.norm(input_vector)**2 - 1)

        gradient = np.append(gradient_u, gradient_lagrange)

        return gradient



    def input_Hessian_Lagrange(self, X, dist_vector, dist_matrix, Sigma_matrix, Sigma_tensor, Sigma_tensor_corr=None):
        '''
        determine Hessian of soft margin with respect to the input vector.
        To be used automatedly in determine_input.
        '''

        input_vector = X[:-1]
        lagrange = X[-1]

        if Sigma_tensor is None:
            Sigma_tensor_contribution = np.zeros((self.N, self.N))
        else:
            Sigma_tensor_contribution =  np.einsum('pqr, r -> pq', Sigma_tensor, input_vector, optimize='greedy') \
                                         + np.einsum('prq, r -> pq', Sigma_tensor, input_vector, optimize='greedy') \
                                         + np.einsum('qpr, r -> pq', Sigma_tensor, input_vector, optimize='greedy') \
                                         + np.einsum('qrp, r -> pq', Sigma_tensor, input_vector, optimize='greedy') \
                                         + np.einsum('rpq, r -> pq', Sigma_tensor, input_vector, optimize='greedy') \
                                         + np.einsum('rqp, r -> pq', Sigma_tensor, input_vector, optimize='greedy')
            #0.5 from 0.5 in soft_margin
            if Sigma_tensor_corr is None:
                Sigma_tensor_correction_contribution = np.zeros((self.N, self.N))
            else:
                Sigma_tensor_correction_contribution =   np.einsum('pqrs, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('qprs, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('psqr, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('qspr, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('prsq, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('qrsp, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rpqs, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rqps, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rpsq, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rqsp, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rspq, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy') \
                                                         + np.einsum('rsqp, r, s -> pq', Sigma_tensor_corr, input_vector, input_vector, optimize='greedy')


        Hessian = 2 * dist_matrix - 2 * Sigma_matrix \
                  - Sigma_tensor_contribution \
                  - Sigma_tensor_correction_contribution \
                  + 2 * lagrange * np.eye(self.N)

        gradient_Jacobi = np.block([
                [Hessian, 2*input_vector.reshape(self.N, 1)],
                [2*input_vector.reshape(1, self.N), 0.]
                ])

        return gradient_Jacobi





#%%

    # Alternating optimization of input and readout vector.
    def alternating_optimization(self, opt_steps, mu, eta=10., initial_input_vector=None, solver='eigenvalue', initial_guesses=1):

        '''
        Alternating gradient descent of the soft margin with respect to input vector and readout vector for given stimulus statistics.
        initial_guesses can be used to obtain a better guess of the solution for the input projection in early steps, neglecting
        contributions to Sigma of O(alpha) and higher. In linear systems, this function is often sufficient. For non-linear systems
        or if unhappy with results, use find_good_optimization.
        To be used manually.
        '''

        if initial_input_vector is None:
            input_vector = np.random.rand(self.N)-0.5
            input_vector /= np.linalg.norm(input_vector)
        else:
            input_vector = initial_input_vector
        readout_vector, soft_margin = self.determine_readout(mu, eta, input_vector, solver=solver)

        soft_margins = np.empty(opt_steps)
        input_vectors = np.empty((opt_steps, self.N))
        readout_vectors = np.empty((opt_steps, self.N))

        for step in range(opt_steps):
            soft_margins[step] = soft_margin
            input_vectors[step] = input_vector
            readout_vectors[step] = readout_vector

            if step < initial_guesses or np.linalg.norm(mu) == 0:
                guess = True
                soft_margin, input_vector = self.determine_input(mu, readout_vector, eta, input_vector, guess=guess)

            elif step < opt_steps - 1:
                guess = False
                soft_margin, input_vector = self.determine_input(mu, readout_vector, eta, input_vector, guess=guess)

                if soft_margin < soft_margins[step]:
                    soft_margin, guessed_input_vector = self.determine_input(mu, readout_vector, eta, input_vectors[step], guess=True)
                    soft_margin, input_vector = self.determine_input(mu, readout_vector, eta, guessed_input_vector, guess=False)

            readout_vector, soft_margin = self.determine_readout(mu, eta, input_vector, solver=solver)

            if step>10 and np.all(np.diff(soft_margins[step-10:step+1]) / soft_margins[step-9:step+1] < 2e-3):
                soft_margins[step:] = soft_margin
                input_vectors[step:] = np.einsum('s, i -> si', np.ones(opt_steps-step), input_vector)
                readout_vectors[step:] = np.einsum('s, i -> si', np.ones(opt_steps-step), readout_vector)
                break

        return soft_margins, input_vectors, readout_vectors








    # Alternating optimization with a numer of initial conditions
    def find_good_optimization(self, opt_steps, mu, eta=15., initial_cond=10, initial_steps=1, solver='eigenvalue'):
        '''
        Choose a set of initial_cond initial conditions. Depending on the soft_margin after
        initial_steps initial steps, the best is chosen to be optimized for opt_steps in
        total. Note: initial_steps=1 only calculates the best readout vector and corresponding
        soft_margin. No optimization step is performed.
        To be used manually.
        '''

        if self.poly_coeffs[2] == 0:
            initial_guesses = opt_steps
        else:
            initial_guesses = 1

        initial_input_vectors = np.random.rand(initial_cond, self.N) - 0.5
        initial_input_vectors /= np.linalg.norm(initial_input_vectors, axis=1)[:, np.newaxis]

        soft_margins_initial = np.empty((initial_cond, initial_steps))
        input_vectors_initial = np.empty((initial_cond, initial_steps, self.N))
        readout_vectors_initial = np.empty((initial_cond, initial_steps, self.N))

        for ic_idx, input_vector in enumerate(initial_input_vectors):
            soft_margins_initial[ic_idx], input_vectors_initial[ic_idx], readout_vectors_initial[ic_idx] \
                        = self.alternating_optimization(initial_steps, mu, eta=eta, initial_input_vector=input_vector, initial_guesses=initial_guesses, solver=solver)

        ic_idx = np.argmax(soft_margins_initial[:, -1])
        if self.poly_coeffs[2] != 0:
            initial_guesses = 0

        soft_margins, input_vectors, readout_vectors \
                        = self.alternating_optimization(opt_steps-initial_steps+1, mu, eta=eta, initial_input_vector=input_vectors_initial[ic_idx, -1], initial_guesses=initial_guesses, solver=solver)

        soft_margins = np.append(soft_margins_initial[ic_idx], soft_margins[1:])
        input_vectors = np.vstack([input_vectors_initial[ic_idx], input_vectors[1:]])
        readout_vectors = np.vstack([readout_vectors_initial[ic_idx], readout_vectors[1:]])

        return soft_margins, input_vectors, readout_vectors


