from dist_sketching_dpp_functions import *

# inputs
figure_no = int(sys.argv[1])
dataset_name = sys.argv[2]
m = int(sys.argv[3])
sketch_type = sys.argv[4]

print("Figure name is: figure_{}_data_{}_m_{}_sketchtype_{}.pdf".format(figure_no, dataset_name, m, sketch_type))

# load dataset
A, y = dataset_loader(dataset_name)
n, d = A.shape

print("dataset loaded")


# compute the results and plot
if figure_no == 1:
    # averaging "curves" (multiple trials)
    ATA = np.matmul(A.T, A)
    ATy = np.matmul(A.T, y)
    AAT = np.matmul(A, A.T)

    #m = 20
    num_workers = 2500 #2500

    lm1 = 10
    effective_dim = compute_effective_dimension(ATA, lm1)
    lm2 = lm1 * (1 - effective_dim / m)
    print("A.shape = {}, lm1 = {}, lm2 = {}, effective_dim = {}".format(A.shape, lm1, lm2, effective_dim))


    num_trials = 100

    errors_gaus = np.zeros((num_trials, num_workers))
    errors_gaus_biascorr = np.zeros((num_trials, num_workers))
    errors_rade = np.zeros((num_trials, num_workers))
    errors_rade_biascorr = np.zeros((num_trials, num_workers))
    errors_surr_p1 = np.zeros((num_trials, num_workers))
    errors_surr_p1_biascorr = np.zeros((num_trials, num_workers))
    errors_unif = np.zeros((num_trials, num_workers))
    errors_unif_biascorr = np.zeros((num_trials, num_workers))


    for trial in range(num_trials):
        print("\n ########### [{}/{}] ###########".format(trial, num_trials), end=", ")

        np.random.seed(trial)
        errors_gaus[trial,:] = get_errors(A, y, m, lm1, lm1, num_workers, "gaus", effective_dim, ATA, ATy, AAT)[0]
        np.random.seed(trial)
        errors_gaus_biascorr[trial,:] = get_errors(A, y, m, lm1, lm2, num_workers, "gaus", effective_dim, ATA, ATy, AAT)[0]

        np.random.seed(trial)
        errors_rade[trial,:] = get_errors(A, y, m, lm1, lm1, num_workers, "rademacher", effective_dim, ATA, ATy, AAT)[0]
        np.random.seed(trial)
        errors_rade_biascorr[trial,:] = get_errors(A, y, m, lm1, lm2, num_workers, "rademacher", effective_dim, ATA, ATy, AAT)[0]

        np.random.seed(trial)
        errors_unif[trial,:] = get_errors(A, y, m, lm1, lm1, num_workers, "unif", effective_dim, ATA, ATy, AAT)[0]
        np.random.seed(trial)
        errors_unif_biascorr[trial,:] = get_errors(A, y, m, lm1, lm2, num_workers, "unif", effective_dim, ATA, ATy, AAT)[0]

        np.random.seed(trial)
        errors_surr_p1[trial,:] = get_errors(A, y, m, lm1, lm1, num_workers, "surrogate_p1", effective_dim, ATA, ATy, AAT)[0]
        np.random.seed(trial)
        errors_surr_p1_biascorr[trial,:] = get_errors(A, y, m, lm1, lm2, num_workers, "surrogate_p1", effective_dim, ATA, ATy, AAT)[0]
    
    
    # plot results
    plt.rcParams.update({'font.size': 16}); plt.grid();
    plt.xlabel('number of averaged outputs'); plt.ylabel('$||x^*-\hat{x}||_2 / ||x^*||_2$');

    inds = np.logspace(np.log10(10),np.log10(num_workers-1),10, dtype=int)

    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_gaus,axis=0)[inds], 'bo-', label="Gaussian")
    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_gaus_biascorr,axis=0)[inds], 'bo:')

    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_rade,axis=0)[inds], 'rs-', label="Rademacher")
    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_rade_biascorr,axis=0)[inds], 'rs:')

    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_unif,axis=0)[inds], 'c>-', label="Uniform")
    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_unif_biascorr,axis=0)[inds], 'c>:')

    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_surr_p1,axis=0)[inds], 'gx-', label="Surrogate sketch")
    plt.loglog(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_surr_p1_biascorr,axis=0)[inds], 'gx:')

    plt.legend()
    plt.savefig("figure_{}_data_{}_m_{}_sketchtype_{}.pdf".format(figure_no, dataset_name, m, sketch_type),bbox_inches="tight")
if figure_no == 2:
    ATA = np.matmul(A.T, A)
    ATy = np.matmul(A.T, y)
    AAT = np.matmul(A, A.T)

    #m = 1000
    num_workers = 10000

    lm1_values = 1.0 * np.array([0.1, 0.1, 1, 1, 10, 10])
    

    
    errors = np.zeros((lm1_values.shape[0], num_workers))

    for i in range(lm1_values.shape[0]):
        print("\n ########### {} ###########".format(i), end=", ")
        lm1 = lm1_values[i]
        if i % 2 == 0:
            lm2 = lm1
            effective_dim = compute_effective_dimension(ATA, lm1)
            print(effective_dim)
        else:
            effective_dim = compute_effective_dimension(ATA, lm1)
            lm2 = lm1 * (1 - effective_dim / m)

        np.random.seed(0)
        errors[i, :] = get_errors(A, y, m, lm1, lm2, num_workers, sketch_type, effective_dim, ATA, ATy, AAT)[0]


    # plot results
    plt.rcParams.update({'font.size': 16}); plt.grid();
    plt.xlabel('number of averaged outputs'); plt.ylabel('$||x^*-\hat{x}||_2 / ||x^*||_2$');
    colors = ['b', 'r', 'g', 'k', 'm']
    markers = ['>', '+', 'o']

    inds = np.logspace(np.log10(1),np.log10(num_workers-1),10, dtype=int)

    for i in range(errors.shape[0] // 2):
        plt.loglog(np.linspace(1,errors.shape[1],errors.shape[1])[inds], errors[2*i,:][inds], colors[i]+'-'+markers[i], label="$\lambda={}$".format(lm1_values[2*i]))
        plt.loglog(np.linspace(1,errors.shape[1],errors.shape[1])[inds], errors[2*i+1,:][inds], colors[i]+':'+markers[i])

    plt.legend()
    plt.savefig("figure_{}_data_{}_m_{}_sketchtype_{}.pdf".format(figure_no, dataset_name, m, sketch_type),bbox_inches="tight")
if figure_no == 3:
    b = y.copy()
    c_vec = 0

    #m = 50

    num_iters = 20
    num_workers = 100


    lm1 = 0.0001
    tau, c, a0 = 2, 0.1, 1 # a0 is the starting alpha for backtracking line search # tau=4, 0.1, 1

    sketch_types = ["gaus", "unif", "surrogate_p1"]

    errors_all = np.zeros((len(sketch_types)*3, num_iters+1))
    times_all = np.zeros((len(sketch_types)*3, num_iters+1))

    for i in range(len(sketch_types)):
        print("\n" + sketch_types[i] + " - newton")
        np.random.seed(0)
        _, errors_all[3*i,:], times_all[3*i,:] = log_regression_optimize(A,b,c_vec,1,lm1,lm1,tau,c,num_iters,-99,-99,a0,False)

        print("\n" + sketch_types[i] + " - baseline")
        np.random.seed(0)
        _, errors_all[3*i+1,:], times_all[3*i+1,:] = log_regression_optimize(A,b,c_vec,2,lm1,lm1,tau,c,num_iters,
                                            num_workers,m,a0,False,sketch_types[i],m2=0)
        print("\n" + sketch_types[i] + " - bias corrected")
        np.random.seed(0)
        temp, errors_all[3*i+2,:], times_all[3*i+2,:] = log_regression_optimize(A,b,c_vec,2,lm1,-999,tau,c,num_iters,
                                            num_workers,m,a0,False,sketch_types[i],m2=0)
    
    # plot results
    plt.rcParams.update({'font.size': 16})
    colors = ['b', 'r', 'g', 'k', 'y', 'm']
    markers = ['>', '+', 'o']
    legend_names = ["Gaussian", "Uniform", "Surrogate sketch"]


    for i in range(len(sketch_types)):
        plt.semilogy(np.linspace(0, num_iters, num_iters+1), np.abs((errors_all[3*i+1,:]-errors_all[3*i,-1])/errors_all[3*i,-1]), '-'+markers[i]+colors[i], label=legend_names[i])
        plt.semilogy(np.linspace(0, num_iters, num_iters+1), np.abs((errors_all[3*i+2,:]-errors_all[3*i,-1])/errors_all[3*i,-1]), ':'+markers[i]+colors[i])
    
    plt.grid()
    #plt.ylim(10**(-14), 10**(1))
    plt.ylabel('$(f(\hat{x}) - f(x^*)) / f(x^*)$')
    plt.xlabel("iteration")
    plt.legend()
    plt.savefig("figure_{}_data_{}_m_{}_sketchtype_{}.pdf".format(figure_no, dataset_name, m, sketch_type),bbox_inches="tight")
if figure_no == 4:
    # averaging "curves" (multiple trials)
    ATA = np.matmul(A.T, A)
    ATy = np.matmul(A.T, y)
    AAT = np.matmul(A, A.T)

    #m = 50
    num_workers = 2500 #2500

    lm1 = 10
    effective_dim = compute_effective_dimension(ATA, lm1)
    lm2 = lm1 * (1 - effective_dim / m)
    print("A.shape = {}, lm1 = {}, lm2 = {}, effective_dim = {}".format(A.shape, lm1, lm2, effective_dim))

    num_trials = 100

    errors_unif = np.zeros((num_trials, num_workers)) # bias corrected 
    errors_surr_p1 = np.zeros((num_trials, num_workers)) # bias corrected
    errors_det_avg = np.zeros((num_trials, num_workers)) # same reg param


    for trial in range(num_trials):
        print("\n ########### [{}/{}] ###########".format(trial,num_trials), end=", ")
        np.random.seed(trial)
        errors_unif[trial,:] = get_errors_singlenewtonstep(A, y, m, lm1, lm1, num_workers, "unif", 
                                                          effective_dim, ATA, ATy, AAT, determ_avg=False)[0]
        np.random.seed(trial)
        errors_surr_p1[trial,:] = get_errors_singlenewtonstep(A, y, m, lm1, lm2, num_workers, "surrogate_p1", 
                                                          effective_dim, ATA, ATy, AAT, determ_avg=False)[0]
        np.random.seed(trial)
        errors_det_avg[trial,:] = get_errors_singlenewtonstep(A, y, m, lm1, lm1, num_workers, "unif", 
                                                          effective_dim, ATA, ATy, AAT, determ_avg=True)[0]
        
    # plot results
    plt.rcParams.update({'font.size': 16}); plt.grid();
    plt.xlabel('number of averaged outputs'); plt.ylabel('$||x^*-\hat{x}||_2 / ||x^*||_2$');


    inds = np.logspace(np.log10(10),np.log10(num_workers-1),10, dtype=int)
    plt.errorbar(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_unif,axis=0)[inds], scipy.stats.sem(errors_unif, axis=0)[inds], fmt="b-", label="Unweighted averaging", markeredgewidth=1, capsize=2, elinewidth=1)
    plt.errorbar(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_det_avg,axis=0)[inds], scipy.stats.sem(errors_det_avg, axis=0)[inds], fmt="g-.", label="Determinantal averaging", markeredgewidth=1, capsize=2, elinewidth=1)
    plt.errorbar(np.linspace(1,num_workers,num_workers)[inds], np.mean(errors_surr_p1,axis=0)[inds], scipy.stats.sem(errors_surr_p1, axis=0)[inds], fmt="r:", label="Surrogate sketch", markeredgewidth=1, capsize=2, elinewidth=1)
    plt.xscale('log'); plt.yscale('log')

    plt.legend()
    plt.savefig("figure_{}_data_{}_m_{}_sketchtype_{}.pdf".format(figure_no, dataset_name, m, sketch_type),bbox_inches="tight")



if figure_no == 5:
    lm1 = 10**(-4)
    tau, c, a0 = 2, 0.1, 1
    num_iters = 20
    num_workers = 100
    errors_all = np.zeros((3, num_iters+1))

    print("\n" + " - optimal newton")
    np.random.seed(0)
    _, errors_all[0,:], _ = log_regression_optimize(A,y,0,1,lm1,lm1,tau,c,num_iters,-99,-99,a0,False)

    print("\n" + sketch_type + " - baseline")
    np.random.seed(0)
    _, errors_all[1,:], _ = log_regression_optimize(A,y,0,2,lm1,lm1,tau,c,num_iters,
                                        num_workers,m,a0,False,sketch_type,m2=0)

    print("\n" + sketch_type + " - bias corrected")
    np.random.seed(0)
    temp, errors_all[2,:], _ = log_regression_optimize(A,y,0,2,lm1,-999,tau,c,num_iters,
                                        num_workers,m,a0,False,sketch_type,m2=0)

    print("\n" + "gradient descent")
    _, GD_results, _ = log_regression_optimize(A,y,0,3,lm1,-999,tau,c,num_iters,
                                        num_workers,m,a0,False,sketch_type,m2=0)


    plt.rcParams.update({'font.size': 16})
    plt.grid()
    plt.ylim(10**(-15), 10**(-1))
    plt.ylabel('$(f(\hat{x}) - f(x^*)) / f(x^*)$') #plt.ylabel("cost approximation")

    plt.xlabel("iteration") #plt.xlabel("time (sec)")
    

    plt.semilogy(np.linspace(0, num_iters, num_iters+1), np.abs((errors_all[1,:]-errors_all[0,-1])/errors_all[0,-1]), 'g-o', label=sketch_type)
    plt.semilogy(np.linspace(0, num_iters, num_iters+1), np.abs((errors_all[2,:]-errors_all[0,-1])/errors_all[0,-1]), 'g:o')
    plt.semilogy(np.linspace(0, num_iters, num_iters+1), np.abs((GD_results-errors_all[0,-1])/errors_all[0,-1]), 'k-*', label="GD")
    plt.legend()
    plt.savefig("figure_{}_data_{}_m_{}_sketchtype_{}.pdf".format(figure_no, dataset_name, m, sketch_type),bbox_inches="tight")





