function[loss_vec_f,loss_vec_g,rec_vec,time_vec, sample_vec] = Alg_projection_sto(A_train,X_train,A_test,Dstar,...
    D0,X_last,param)
% The update rule is given by: 
% 
% x_{k+1} = \Pi_{\mathcal{Z}}(x_k - \gamma_k(\nabla g(x_k)+\eta_k \nabla f(x_k)))
% 
% We choose \gamma_k = \gamma_0/sqrt{k+1} and \eta_k = \eta_0/(k+1)^0.25

% get dimension
[p,~] = size(X_train);
[~,n_test] = size(A_test);
[~,n_train] = size(A_train);

% Define loss and gradients
loss_f = @(D,X) norm(A_test-D*X,'fro')^2/2;
loss_g = @(D) norm(A_train-D*X_train,'fro')^2/2;

% grad_f_D = @(D,X) (D*X-A_test)*X';
% grad_f_X = @(D,X) D'*(D*X-A_test);
% grad_g_D = @(D) (D*X_train-A_train)*X_train'; 

eta_0 = 1;
gamma_0 = 1e-4;

delta=param.delta;
thres=param.thres;

% Initialization
D = D0;
X = X_last;
D_bar = D0;
X_bar = X_last;
%% algorithm
maxiter = param.maxiter;
% maxtime = param.maxtime;

loss_vec_f = zeros(maxiter+1,1);
loss_vec_g = zeros(maxiter+1,1);
time_vec = zeros(maxiter+1,1);
loss_vec_f(1) = loss_f(D,X);
loss_vec_g(1) = loss_g(D);
time_vec(1) = 0;

sample_vec = zeros(maxiter+1,1);
sample_vec(1) = 0;
rec_vec = zeros(maxiter+1,1);
rec_vec(1) = recovery(D,Dstar,thres);

k=0;
S = 0;
tic;
while k< maxiter
    k = k+1;
    eta_k = (eta_0)*(k+1)^0.25;
    gamma_k = gamma_0/(k+1)^0.75;
    % Sampling
    uppidx1 = randsample(n_test,8);
    uppidx2 = randsample(n_test,8);
    lowidx1 = randsample(n_train,8);
    lowidx2 = randsample(n_train,8);

    %loss_test_i = @(D,X) norm(A_test(:,uppidx)-D*X(:,uppidx),'fro')^2/2;
    gD_test_1= @(D,X) n_test/8*(D*X(:,uppidx1)-A_test(:,uppidx1))*X(:,uppidx1)';
    gX_test_1 = @(D,X) n_test/8*D'*(D*X(:,uppidx1)-A_test(:,uppidx1));
    gD_test_2= @(D,X) n_test/8*(D*X(:,uppidx2)-A_test(:,uppidx2))*X(:,uppidx2)';
    gX_test_2 = @(D,X) n_test/8*D'*(D*X(:,uppidx2)-A_test(:,uppidx2));
    
    %loss_lower_i = @(D) n_train*norm(A_train(:,lowidx)-D*X_last(:,lowidx),'fro')^2/2;
    gD_lower_1 = @(D) n_train/8*(D*X_last(:,lowidx1)-A_train(:,lowidx1))*X_last(:,lowidx1)';
    %loss_lower_i = @(D) n_train*norm(A_train(:,lowidx)-D*X_last(:,lowidx),'fro')^2/2;
    gD_lower_2 = @(D) n_train/8*(D*X_last(:,lowidx2)-A_train(:,lowidx2))*X_last(:,lowidx2)';

    gX_1 = zeros(50,n_test);
    gX_1(:,uppidx1) = gX_test_1(D,X);

%     gD = gD_test_i(D,X);
%     gX = gX_test(D,X);
%     gD_lo = gD_lower_i(D);
%     loss_lo = loss_lower_i(D);
    % Descent step
    D_tmp = D;
    X_tmp = X;

    D = D - gamma_k*(eta_k*gD_lower_1(D)+(gD_test_1(D,X)));
    X = X - gamma_k*(gX_1);
    
    % Projection step
    for col_i=1:p
        if norm(D(:,col_i))>1
            D(:,col_i) = D(:,col_i)./norm(D(:,col_i));
        end
    end
    for col_n = 1:n_test
        if norm(X(:,col_n),1)>delta
            x = X(:,col_n);
            X(:,col_n) = ProjectOntoL1Ball(x,delta);
        end
    end
    
    % extra gradient descent
    gX_2 = zeros(50,n_test);
    gX_2(:,uppidx2) = gX_test_2(D,X);
    D = D_tmp - gamma_k*(eta_k*gD_lower_2(D)+(gD_test_2(D,X)));
    X = X_tmp - gamma_k*(gX_2);
    
    % Projection step
    for col_i=1:p
        if norm(D(:,col_i))>1
            D(:,col_i) = D(:,col_i)./norm(D(:,col_i));
        end
    end
    for col_n = 1:n_test
        if norm(X(:,col_n),1)>delta
            x = X(:,col_n);
            X(:,col_n) = ProjectOntoL1Ball(x,delta);
        end
    end
    
    % averaging iters
    S_next = S+(gamma_k*eta_k)^0.5;
    X_bar = (S*X_bar+(gamma_k*eta_k)^0.5*X)/S_next;
    D_bar = (S*D_bar+(gamma_k*eta_k)^0.5*D)/S_next;
    S = S_next;

    time_vec(k+1) = toc;
    loss_vec_f(k+1) = loss_f(D_bar,X_bar); 
    loss_vec_g(k+1) = loss_g(D_bar);
    rec_vec(k+1) = recovery(D_bar,Dstar,thres);
    sample_vec(k+1) = k*32;
end

time_vec = time_vec(1:k+1);
loss_vec_f = loss_vec_f(1:k+1);
loss_vec_g = loss_vec_g(1:k+1);
rec_vec = rec_vec(1:k+1);
end

function rec = recovery(D,Dstar,thres)
D = D./vecnorm(D);
[~,num_dict] = size(Dstar);
corr_mat = D'*Dstar;
num = sum(max(abs(corr_mat))>thres);
rec = num/num_dict;
end
