% gradient descent ucl model
function [] = compute_theory_permuted_mnist(i0,j0,seed)

Nall = [50,100:100:1000];
N_1 = Nall(j0);
N = N_1;
M = 50;

load('mnist.mat')
n = 10;


N_0 = 784;

P = 300;
N_s = 500;
sigma = 0.2;
beta = 1e4;
rng(2,'twister');
trainX = double(trainX);
testX = double(testX);

thall = [-2:0.2:2];
th = thall(i0);
trainX = (trainX'-mean(trainX'))./std(trainX');
testX = (testX'-mean(testX'))./std(testX');

indtrain = find(mod(trainY,2)==0);
indtest = find(mod(testY,2)==0);
indtrain1 = find(mod(trainY,2)==1);
indtest1 = find(mod(testY,2)==1);
index = [indtrain(1:P/2),indtrain1(1:P/2)];
index_t = [indtest(1:N_s/2),indtest1(1:N_s/2)];

x_00 = trainX(:,index);
x_t0 = testX(:,index_t);

y = (mod(trainY(index),2)==0);
y_t = (mod(testY(index_t),2)==0);

th = thall(i0);
for k = 1:n
    ind(k,:) = randperm(784);
    x_0(:,((k-1)*P+1):(k*P)) = double(x_00(ind(k,:),:));
    x_t(:,((k-1)*N_s+1):(k*N_s)) = double(x_t0(ind(k,:),:));

end

rng(seed,'twister');
v = normrnd(0,1,M,N_0);
g = (v*x_0/sqrt(N_0)>th);
gp = (v*x_t/sqrt(N_0)>th);
% change g gradually...

% iteration = 1000;
% beta0 = 0.1;
% [g,xc] = softkmeans(M, x_0g, beta0, iteration, seed);
% g = g(1:M,:);
% 
% for i0 = 1:M
%    d(i0,:) = sum((x_tg - xc(:, i0)).^2)';
%    expd(i0,:) = exp(-beta0*d(i0,:));
% end
% gp = expd./sum(expd);
% gp = gp(1:M,:);
%     
    



y = kron(ones(1,n),double(y));
y_t = kron(ones(1,n),double(y_t));

y = 2*y-1;
y_t = 2*y_t-1;


[~,H] = numeric_saddlepoint_nonlinear(N_0,P*n,M,g,y,x_0,N,sigma,beta);

K_1 = sigma^2/N_0*(x_0'*x_0);
K_1p = sigma^2/N_0*(x_t'*x_0);
k_1p = sigma^2/N_0*(x_t'*x_t);

K2 = (g'*H*g).*K_1 + 1/beta*eye(P*n,P*n);
K2p = (gp'*H*g).*K_1p;
k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
fp =  (K2p*pinv(K2)*y')';
fvarp = diag(k2p - K2p*pinv(K2)*K2p');
k = 1/N*y*pinv(K2)*y';

ge = mean((fp-y_t).^2) + mean(fvarp);
fvarp(fp==0) = 1;
gb = mean((y_t+1)/2+(-y_t).*erfc(-fp./sqrt(fvarp'*2))/2);

coeff = K2p*pinv(K2);

% relu gp

    t = x_0'*x_0./(sqrt(sum(x_0.^2))'*sqrt(sum(x_0.^2)));
    theta = real(acos(t));
    k0 = 1/2/pi*sigma^2/N_0*sqrt(sum(x_0.^2))'*sqrt(sum(x_0.^2)).*(sin(theta)+(pi-theta).*cos(theta));
    
    % K for the test sample
    t_k = x_t'*x_t./(sqrt(sum(x_t.^2))'*sqrt(sum(x_t.^2)));
    theta_k = real(acos(t_k));
    k = 1/2/pi*sigma^2/N_0*sqrt(sum(x_t.^2))'*sqrt(sum(x_t.^2)).*(sin(theta_k)+(pi-theta_k).*cos(theta_k));
    
    % K for training/testing sample
    t_kp = x_t'*x_0./(sqrt(sum(x_t.^2))'*sqrt(sum(x_0.^2)));
    theta_kp = real(acos(t_kp));
    k_p = 1/2/pi*sigma^2/N_0*sqrt(sum(x_t.^2))'*sqrt(sum(x_0.^2)).*(sin(theta_kp)+(pi-theta_kp).*cos(theta_kp));
    
    % b =  N_0/N*sigma^(-4)*y*pinv(k0)*y';
    % h = (-(1-alpha)+sqrt((1-alpha)^2+4*b))/2/b;
    % finite T version
    h0 = normrnd(0,1,1,1);
    fun = @(x) x-sigma^(-2)+1/N*y*pinv(k0/x+1/beta*eye(P*n,P*n))*pinv(k0/x+1/beta*eye(P*n,P*n))*k0*y'-1/N*trace(pinv(k0/x+1/beta*eye(P*n,P*n))*k0);
    h =  abs(fsolve(fun,h0));
    
    
    % w_sample = cell(sample_size,1);]
    fprelu =  k_p/h*pinv(k0/h + 1/beta*eye(P*n,P*n))*y';
    fvarprelu = diag(k/h - k_p/h*pinv(k0/h+1/beta*eye(P*n,P*n))*k_p'/h);
    gerelu = mean((fprelu'-y_t).^2)+mean(fvarprelu);
    
    coeffrelu = k_p/h*pinv(k0/h + 1/beta*eye(P*n,P*n));
    
    for i = 1:(2*n)
        for j = 1:(2*n)
            coeff_c(i,j) = mean(mean(abs(coeff(((i-1)*(N_s/2)+1):(i*N_s/2),((j-1)*(P/2)+1):(j*P/2)))));
            coeffrelu_c(i,j) = mean(mean(abs(coeffrelu(((i-1)*(N_s/2)+1):(i*N_s/2),((j-1)*(P/2)+1):(j*P/2)))));
        end
       
    end
        for i = 1:n
        for j = 1:n
          lambda(i,j) = mean(mean(coeff_c(1+2*(i-1):2+2*(i-1),1+2*(j-1):2+2*(j-1))));
          lambda_relu(i,j) = mean(mean(coeffrelu_c(1+2*(i-1):2+2*(i-1),1+2*(j-1):2+2*(j-1))));
        end
        end
        k_coeff = sum(diag(lambda))/(sum(lambda(:))-sum(diag(lambda)));
        k_relu = sum(diag(lambda_relu))/(sum(lambda_relu(:))-sum(diag(lambda_relu)));
       
   
save(['../compute_global_coeff/pretrained2_coeff_N',num2str(j0),'_th',num2str(i0),'_seed',num2str(seed),'.mat'],'coeff_c','coeffrelu_c','ge','gb','fp','fvarp','y_t','gerelu','fprelu','fvarprelu','k_coeff','k_relu')

end