from algorithms import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:red', 'tab:green', 'tab:brown', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']
T_dict = {"a1a": 5, "mushrooms": 12, "phishing": 11, "duke": 11, "madelon": 200}

mu = 1e-4
L = 1.0
experiment = "CompareMethodsHetero"
no_agg = 1e+3 #### number of aggregation
assignment = "same"

datasets = ["duke", "madelon", "phishing"]
algs = ["iapgdkat", "al2sgd", "l2sgd"]
labels2 = ["IAPGD+Kat.", "AL2SGD+", "L2SGD+"]

for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    T = T_dict[dataset]

    Flist = []
    Xlist = []
    Steplist = []

    A, b = get_data(dataset)
    A = normalize_data(A, L)

    n, d = A.shape
    m = int(n / T)
    x0 = np.zeros(d)

    assert m * T == n

    pagg = 1/m
    lamda = pagg / (1 - pagg)
    omega = lamda / T
    rho = pagg * 1.0

    Lf = (lamda + L) / T
    cL = max(L / T / (1 - rho), lamda / rho / T)
    rate = 1 - 0.25 * min(rho, np.sqrt(mu / (2 * T * max(cL / rho, Lf))))

    if not plot_only:
        A, b = rearrange_data(A, method=assignment, b=b)
        f, g = make_fg_logreg(A, b, mu)

        def total_loss(X):
            return objective(f, psi_func, X, m, omega)

        for alg in algs:
            print("&&&&&&&&&&&&&&&&&&&&&&&&&&&")
            print("alg = {}, AM = random".format(alg))
            agg_fvals = []
            agg_x = []
            agg_steps = []
            if alg == "l2sgd":
                alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=pagg, pwork=1.0, omega=omega, n=n, T=T, mu=mu)
                X, (aggregations, agg_fvals, agg_x, agg_steps) = l2sgd(total_loss, g, alpha, d, T, m, pagg, omega, track_agg=True, no_agg=no_agg)
            elif alg == "al2sgd":
                X, (aggregations, agg_fvals, agg_x, agg_steps) = al2sgd(total_loss, g, d, T, m, rho, lamda, L=L, mu=mu, track_agg=True, no_agg=no_agg)
            elif alg == "iapgdkat":
                agg_fvals, agg_x, agg_steps = apgd(total_loss, g, d, T, m, lamda, L=L, mu=mu, no_agg=no_agg)
            elif alg == "iapgdagd":
                agg_fvals, agg_x, agg_steps = apgd(total_loss, g, d, T, m, lamda, L=L, mu=mu, no_agg=no_agg, stochastic=False)
            elif alg == "apgd2":
                agg_fvals, agg_x, agg_steps = apgd2(total_loss, g, d, T, m, lamda, L=L, mu=mu, no_agg=no_agg)
            else:
                raise ValueError("Algorithm doesn't exist")
            print("FINISHED")

            Flist.append(agg_fvals)
            Xlist.append(agg_x)
            Steplist.append(agg_steps)

        # save outputs
        for backup in [True, False]:
            filename = createfilename(experiment, dataset, T, omega, mu, 1.0, "SAGA", backup)
            pickle_out = open(filename, "w+b")
            pickle.dump((Flist, Xlist, Steplist), pickle_out)
            pickle_out.close()

    # load outputs
    Flist, Xlist, Steplist = load_pickle(experiment, dataset, T, omega, mu, 1.0, "SAGA")
    it_list = [range(len(i)) for i in Flist]

    for i in range(len(Steplist)):
        Steplist[i][0] = 1
        # to plot gradients on log scale

    visualize(Flist, it_list, "communication", dataset + " (heterogenous), lambda = {:.3f}".format(lamda), labels2, linestyles=linestyles, markers=markers,
              colors=colors, xlabel="Communication rounds")
    visualize(Flist, Steplist, "gradient_comptutation", dataset + " (heterogeneous), lambda = {:.3f}".format(lamda), labels2, linestyles=linestyles, markers=markers,
              colors=colors, xlabel="Gradients of local summands", loglog=True)

print("DONE")