% Inner test script for Synthetic Phase Retrieval Momentum&Minibatch (3.1.3)
A = data.A;
b = data.b;
bestloss = data.bestloss;

tol = bestloss * 1.5;

[m, n] = size(A);

nSgdEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nSgdmEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nSgdbEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nSgdbmEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nProxLinEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nProxLinmEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nProxLinbEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;
nProxLinbmEpochtoOpt = ones(nTest, length(steprange)) * maxiter * m;

for k = 1:length(steprange)
    
    stepsize = steprange(k);
    
    parfor i = 1:nTest
        
        init_x = randn(n, 1);
        
        % Nothing
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % Test SGD
        [sgdsol, sgdinfo] = proxsgd(A, b, sqrt(maxiter * m), 0, init_x, ...
            maxiter, tol, true, 1, 0, stepsize, 0, show_info);
        
        if sgdinfo.status == "Optimal"
            nSgdEpochtoOpt(i, k) = sgdinfo.niter;
        end % End if
        
        % Test Proximal linear
        [proxlinsol, proxlininfo] = proxlin(A, b, sqrt(maxiter * m), 0, init_x, ...
            maxiter, tol, true, 0, stepsize, 0, show_info);
        
        if proxlininfo.status == "Optimal"
            nProxLinEpochtoOpt(i, k) = proxlininfo.niter;
        end % End if 
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
        % Momentum
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % Test SGD
        [sgdsolm, sgdinfom] = proxsgd(A, b, sqrt(maxiter * m), beta, init_x, ...
            maxiter, tol, true, 1, 0, stepsize, 0, show_info);
        
        if sgdinfom.status == "Optimal"
            nSgdmEpochtoOpt(i, k) = sgdinfom.niter;
        end % End if
        
        % Test Proximal linear
        [proxlinsolm, proxlininfom] = proxlin(A, b, sqrt(maxiter * m), beta, init_x, ...
            maxiter, tol, true, 0, stepsize, 0, show_info);
        
        if proxlininfom.status == "Optimal"
            nProxLinmEpochtoOpt(i, k) = proxlininfom.niter;
        end % End if 
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
        % Minibatch
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % Test SGD
        [sgdsolb, sgdinfob] = proxsgd(A, b, sqrt(maxiter * m / batchsize), 0, init_x, ...
            maxiter, tol, true, batchsize, 0, stepsize, 0, show_info);
        
        if sgdinfob.status == "Optimal"
            nSgdbEpochtoOpt(i, k) = sgdinfob.niter;
        end % End if
        
        % Test Proximal linear
        [proxlinsolb, proxlininfob] = proxlinbatch(A, b, sqrt(maxiter * m / batchsize), 0, init_x, ...
            maxiter, tol, true, batchsize, 0, stepsize, show_info);
        
        if proxlininfob.status == "Optimal"
            nProxLinbEpochtoOpt(i, k) = proxlininfob.niter;
        end % End if 
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
        % Momentum and Minibatch
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % Test SGD
        [sgdsolbm, sgdinfobm] = proxsgd(A, b, sqrt(maxiter * m / batchsize), beta, init_x, ...
            maxiter, tol, true, batchsize, 0, stepsize, 0, show_info);
        
        if sgdinfobm.status == "Optimal"
            nSgdbmEpochtoOpt(i, k) = sgdinfobm.niter;
        end % End if
        
        % Test Proximal linear
        [proxlinsolbm, proxlininfobm] = proxlinbatch(A, b, sqrt(maxiter * m / batchsize), beta, init_x, ...
            maxiter, tol, true, batchsize, 0, stepsize, show_info);
        
        if proxlininfobm.status == "Optimal"
            nProxLinbmEpochtoOpt(i, k) = proxlininfobm.niter;
        end % End if 
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
    end % End for
    
    fprintf("- Stepsize " + k + " done. \n"); 
    
end % End for

% Plot the summary graph
num_iter = maxiter * m;
xcord = steprange;

semilogx(xcord, sum(nSgdEpochtoOpt / (nTest * 300), 1), "-+", "MarkerSize", 18,  "LineWidth", 2);
hold on;

semilogx(xcord, (sum(nProxLinEpochtoOpt / (nTest * 300), 1)), "-o", "MarkerSize", 18,  "LineWidth", 2);
hold on;

semilogx(xcord, (sum(nSgdmEpochtoOpt / (nTest * 300), 1)), "-s", "MarkerSize", 18,  "LineWidth", 2 , "LineStyle", "--");
hold on;

semilogx(xcord, (sum(nProxLinmEpochtoOpt / (nTest * 300), 1)), "-*", "MarkerSize", 18,  "LineWidth", 2, "LineStyle", "--");
hold on;

set(gca, "FontSize", 20, "FontWeight", "bold")
xlim([min(steprange), max(steprange)]);

legend(["SGD", "SPL", "SEGD", "SEPL"], "FontSize", 20);

savefig("kappa_" + kappa + "_pfail_" + pfail + "_batch_" + 1 + "_momentum_" + beta + "_epoch_env.fig");

hold off;
close all;

semilogx(xcord, sum(nSgdbEpochtoOpt / (nTest * 300), 1), "-+", "MarkerSize", 18,  "LineWidth", 2);
hold on;

semilogx(xcord, (sum(nProxLinbEpochtoOpt / (nTest * 300), 1)), "-o", "MarkerSize", 18,  "LineWidth", 2);
hold on;

semilogx(xcord, (sum(nSgdbmEpochtoOpt / (nTest * 300), 1)), "-s", "MarkerSize", 18,  "LineWidth", 2 , "LineStyle", "--");
hold on;

semilogx(xcord, (sum(nProxLinbmEpochtoOpt / (nTest * 300), 1)), "-*", "MarkerSize", 18, "LineWidth", 2, "LineStyle", "--");
hold on;

set(gca, "FontSize", 20, "FontWeight", "bold")
xlim([min(steprange), max(steprange)]);

legend(["SGD", "SPL", "SEGD", "SEPL"], "FontSize", 20);

savefig("kappa_" + kappa + "_pfail_" + pfail + "_batch_" + batchsize + "_momentum_" + beta + "_epoch.fig");

hold off;
close all;
save("kappa_" + kappa + "_pfail_" + pfail + "_batch_" + batchsize + "momentum_" + beta + "_epoch_env.mat");

