from algorithms_no_log import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:red', 'tab:green', 'tab:brown', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']
T_dict = {"a1a": 5, "mushrooms": 12, "phishing": 11, "duke": 11, "madelon": 200}

mu = 1e-4
L = 1.0
experiment = "CompareAPGDs"

datasets = ["quadratics"]
algs = ["apgd1", "apgd2"]

assignment = "random"
labels2 = ["APGD 1", "APGD 2"]

Lambdas = [10 ** i for i in range(-7, 5)]  ####
accuracy = 1e-4

for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    conv_com = [[] for i in range(len(algs))]
    conv_grad = [[] for i in range(len(algs))]
    n, d = 50, 50
    m = 1
    T = n

    A = np.random.randn(n, n)
    for i in range(n):
        A[i, :] /= np.linalg.norm(A[i, :])
    B = np.random.randn(n*n)
    Mb = mu*np.eye(n*n)
    for i in range(d):
        Mb[(d*i):(d*(i+1)), (d*i):(d*(i+1))] += (1 - mu) * np.outer(A[i, :], A[i, :])
    # f_i(x) = 1/2* (x-b[i]).T @ M @(x-b[i])

    if not plot_only:
        for lamda in Lambdas:
            print("Lambda = {}".format(lamda))

            f, g = make_fg_quadratic(Mb, B, d)
            H = lamda*(np.eye(n*n) - 1/n*np.kron(np.ones((n, n)), np.eye(n)))

            def prox_point(y_bar, i):
                M = Mb[(d*i):((d*(i+1))), (d*i):((d*(i+1)))]
                b = B[(d*i):(d*(i+1))]
                return np.dot( np.linalg.inv(M + lamda*np.eye(n)), np.dot(M, b) + lamda*y_bar )

            X_opt = np.dot(np.linalg.inv(Mb+H), np.dot(Mb, B))

            for i in range(len(algs)):
                alg = algs[i]

                print("&&&&&&&&&&&&&&&&&&&&&&&&&&&")
                print("alg = {}".format(alg))

                if alg == "apgd1":
                    aggregations, loc_steps, X = apgd_no_log(prox_point, d, T, m, lamda, X0=np.zeros(n * n), X_opt=X_opt, accuracy=accuracy, proximal=True)
                elif alg == "apgd2":
                    aggregations, loc_steps, X = apgd2_no_log(g, d, T, m, lamda, mu=mu, X_opt=X_opt.reshape(T, d), accuracy=accuracy, proximal=True)
                    X.reshape(T*d, -1)
                else:
                    raise ValueError("Algorithm doesn't exist")

                print("FINISHED ### communication: {}, gradient_comptutation: {}".format(aggregations, loc_steps))

                conv_com[i].append(aggregations)
                conv_grad[i].append(loc_steps)

        # save F
        for backup in [True, False]:
            filename = createfilename(experiment, dataset, T, 0, mu, 1.0, "testing lambda", backup)
            pickle_out = open(filename, "w+b")
            pickle.dump((conv_com, conv_grad, X), pickle_out)
            pickle_out.close()


    # plot manually
    items = [[], []]  # conv_com, conv_grad
    items[0], items[1], Xlist = load_pickle(experiment, dataset, T, 0, mu, 1.0, "testing lambda")

    for i in range(2):
        name = ["communications", "gradient computations"][i]
        plt.xscale('log')
        plt.yscale('log')
        plt.xlabel("Lambda", fontsize='x-large')
        plt.ylabel('Number of ' + name, fontsize='x-large')
        plt.title("Quadratic objective", fontsize='x-large')
        plt.tight_layout()
        for j in range(len(algs)):
            plt.scatter(Lambdas[1:], items[i][j][1:], label=labels2[j], marker=markers[j],  linestyle=linestyles[j])
            plt.plot(Lambdas[1:], items[i][j][1:], marker=markers[j], linestyle=linestyles[j])
        plt.legend(fontsize='x-large', loc='best')
        plt.savefig('{}{}FistasDataset{}Accuracy{}.pdf'.format(os.getcwd() + '/plots/', name, dataset, accuracy))
        plt.close()

print("DONE")
