clear; clc; close all;

formatting = "res_job_%d_epoch_400_alpha_%s_beta_%1.1f_m_1500_n_500_%s.mat";

jobrange = [1, 2, 4, 8, 16];
alpharange = ["0.5"];
betarange = [0.0, 0.1, 0.3, 0.6];
algrange = ["SGD", "SPL"];

nbeta = length(betarange);
njobs = length(jobrange);
computefval = false;
computetime = true;
computenorm = false;
nepoch = 40;

lgd = [];
markers = ["o", "+", "*", "x", "s", "d", "^", "p", "h", "<"];
idx = 1;

if computefval
    sgdfval = [];
    splfval = [];
    for i = 1:njobs
        alpha = "0.5";
        jobs = jobrange(i);
        fname = sprintf(formatting, jobs, "0.5", 3.0, "SGD");
        data = load(fname);
        fval = data.fval(1:nepoch)';
        if sum(isnan(fval)) ~= 0
            fval = zeros(nepoch, 1) + fval(1);
        end % End if
        lgd = [lgd, "DSEGD " + jobs + " workers"]; %#ok
        sgdfval = [sgdfval, fval]; %#ok
        semilogy(fval, "LineWidth", 3, "LineStyle", ":", "Marker", markers(idx));
        idx = idx + 1;
        hold on;
        fname = sprintf(formatting, 16, alpha, 3.0, "SPL");
        data = load(fname);
        fval = data.fval(1:nepoch)';
        lgd = [lgd, "DSEPL " + jobs + " workers"]; %#ok
        if sum(isnan(fval)) ~= 0
            fval = zeros(nepoch, 1) + fval(1);
        end % End if
        splfval = [splfval, fval]; %#ok
        semilogy(fval, "LineWidth", 3, "LineStyle", "-", "Marker", markers(idx));
        idx = idx + 1;
        hold on;
    end % End for
end % End if


if computenorm
    alpha = "0.5";
    sgdnorm = [];
    splnorm = [];
    
    for i = 1:nbeta
        beta = betarange(i);
        fname = sprintf(formatting, 16, alpha, 10 * betarange(i), "SGD");
        data = load(fname);
        dnrm = data.norm(1:nepoch)';
        if sum(isnan(dnrm)) ~= 0
            dnrm = zeros(nepoch, 1) + dnrm(1);
        end % End if
        lgd = [lgd, "DSEGD \beta = " + beta]; %#ok
        % legend = [legend, "SGD, $beta$ = " + be
        sgdnorm = [sgdnorm, dnrm]; %#ok
        semilogy(dnrm, "LineWidth", 3, "LineStyle", ":", "Marker", markers(idx));
        idx = idx + 1;
        hold on;
        fname = sprintf(formatting, 16, alpha, 10 * betarange(i), "SPL");
        data = load(fname);
        dnrm = data.norm(1:nepoch)';
        lgd = [lgd, "DSEPL \beta = " + beta]; %#ok
        if sum(isnan(dnrm)) ~= 0
            dnrm = zeros(nepoch, 1) + dnrm(1);
        end % End if
        splnorm = [splnorm, dnrm]; %#ok
        semilogy(dnrm, "LineWidth", 3, "LineStyle", "-", "Marker", markers(idx));
        idx = idx + 1;
        hold on;
    end % End for
end % End if

legend(lgd);
set(gca, "FontSize", 20, "FontWeight", "bold");
ylim([1e-07, 1.0]);
tightfig;

if computetime
    sgdtime = zeros(length(jobrange), 1);
    spltime = zeros(length(jobrange), 1);
    for i = 1:length(jobrange)
        fname = sprintf(formatting, jobrange(i), "0.5", "1.0", "SGD");
        data = load(fname)
        sgdtime(i) = data.time(end);
        fname = sprintf(formatting, jobrange(i), "0.5", "1.0", "SPL");
        data = load(fname);
        spltime(i) = data.time(end);
    end % End for
end % End if