import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import itertools
from collections import defaultdict
from torchvision import datasets, transforms
from sklearn.decomposition import PCA
from sklearn import random_projection

np.set_printoptions(precision=3, linewidth=240, suppress=True)
np.random.seed(1993)

def load_MNIST2(p, dim, path):
    if path not in os.listdir('./data'):
        os.mkdir('./data/'+path)
    if 'processed_mnist_features_{:d}.npy'.format(dim) not in os.listdir('./data/'+path):
        transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.,), (1.,)),])
        mnist = datasets.MNIST('data', download=True, train=True, transform=transform)
        # mnistloader = DataLoader(mnist, batch_size=len(mnist))

        features = np.array([np.array(mnist[i][0]).reshape(-1) for i in range(len(mnist))])
        labels = np.array([mnist[i][1] for i in range(len(mnist))])

        features = PCA(n_components=dim).fit_transform(features)

        np.save('data/' + path + '/processed_mnist_features_{:d}.npy'.format(dim), features)
        np.save('data/' + path + '/processed_mnist_labels_{:d}.npy'.format(dim), labels)

    else:
        features = np.load('data/' + path + '/processed_mnist_features_{:d}.npy'.format(dim))
        labels = np.load('data/' + path + '/processed_mnist_labels_{:d}.npy'.format(dim))
    
    # group the data by digit
    n_m = min([np.sum(labels == i) for i in range(10)])
    by_number = defaultdict(list)
    for i, feat in enumerate(features):
        if len(by_number[labels[i]]) < n_m:
            by_number[labels[i]].append(feat)
    for i in range(10):
        by_number[i] = np.array(by_number[i])

    # enumerate the even vs odd tasks
    even_numbers = [0,2,4,6,8]
    odd_numbers = [1,3,5,7,9]
    even_odd_pairs = list(itertools.product(even_numbers, odd_numbers))

    # group data into 25 single even vs single odd tasks
    all_tasks = []
    for (e,o) in even_odd_pairs:
        eo_features = np.concatenate([by_number[e], by_number[o]], axis=0)
        eo_labels = np.concatenate([np.ones(n_m), np.zeros(n_m)])
        eo_both = np.concatenate([eo_labels.reshape(-1,1), eo_features], axis=1)
        all_tasks.append(eo_both)

    all_evens = np.concatenate([np.ones((5*n_m,1)), np.concatenate([by_number[i] for i in even_numbers], axis=0)], axis=1)
    all_odds = np.concatenate([np.zeros((5*n_m,1)), np.concatenate([by_number[i] for i in odd_numbers], axis=0)], axis=1)
    all_nums = np.concatenate([all_evens, all_odds], axis=0)

    # mix individual tasks with overall task
    features_by_machine = []
    labels_by_machine = []
    n_individual = int(np.round(2*n_m * (1. - p)))
    n_all = 2*n_m - n_individual
    for m, task_m in enumerate(all_tasks):
        task_m_idxs = np.random.choice(task_m.shape[0], n_individual)
        all_nums_idxs = np.random.choice(all_nums.shape[0], n_all)
        data_for_m = np.concatenate([task_m[task_m_idxs, :], all_nums[all_nums_idxs, :]], axis=0)
        features_by_machine.append(data_for_m[:,1:])
        labels_by_machine.append(data_for_m[:,0])

    features_by_machine = np.array(features_by_machine)
    labels_by_machine = np.array(labels_by_machine)
    return features_by_machine, labels_by_machine

############################################## Logistic Regression ###############################################

def sigmoid(z):
    return 1. / (1. + np.exp(-np.clip(z, -15, 15)))

# features is an [n x d] matrix of features (each row is one data point)
# labels is an n-dimensional vector of labels (0/1)
def logistic_loss(x, features, labels):
    n = features.shape[0]
    probs = sigmoid(np.dot(features,x))
    return (-1./n) * (np.dot(labels, np.log(1e-12 + probs)) + np.dot(1-labels, np.log(1e-12 + 1-probs)))

def logistic_loss_gradient(x, features, labels):
    return np.dot(np.transpose(features), sigmoid(np.dot(features,x)) - labels) / features.shape[0]

def logistic_loss_hessian(x, features, labels):
    s = sigmoid(np.dot(features, x))
    s = s * (1 - s)
    return np.dot(np.transpose(features) * s, features) / features.shape[0]

##################################################################################################################

def local_sgd_round(x_start, M, K, stepsize, grad_eval):
    x_end = np.zeros_like(x_start)
    for m in range(M):
        x = x_start.copy()
        for _ in range(K):
            g = grad_eval(x, 1, m)
            x -= stepsize * g
        x_end += x / M
    return x_end

def local_sgd(x_len, M, K, R, stepsize, loss_freq, f_eval, grad_eval, avg_window=8):
    losses = []
    iterates = [np.zeros(x_len)]
    for r in range(R):
        if len(iterates) >= avg_window:
            iterates = iterates[-(avg_window-1):]
        iterates.append(local_sgd_round(iterates[-1], M, K, stepsize, grad_eval))
        if (r+1) % loss_freq == 0:
            losses.append(f_eval(np.average(iterates,axis=0)))
            print('Iteration: {:d}/{:d}   Loss: {:f}                 \r'.format(r+1,R,losses[-1]), end='')
            if losses[-1] > 100:
                print('\nLoss is diverging: Loss = {:f}'.format(losses[-1]))
                return losses, 'diverged'
    print('')
    return losses, 'converged'

def minibatch_sgd(x_len, M, K, R, stepsize, loss_freq, f_eval, grad_eval, avg_window=8):
    losses = []
    iterates = [np.zeros(x_len)]
    for r in range(R):
        if len(iterates) >= avg_window:
            iterates = iterates[-(avg_window-1):]
        g = np.zeros(x_len)
        for m in range(M):
            g += grad_eval(iterates[-1], K, m)
        iterates.append(iterates[-1] - stepsize * g)
        if (r+1) % loss_freq == 0:
            losses.append(f_eval(np.average(iterates,axis=0)))
            print('Iteration: {:d}/{:d}   Loss: {:f}                 \r'.format(r+1,R,losses[-1]), end='')
            if losses[-1] > 100:
                print('\nLoss is diverging: Loss = {:f}'.format(losses[-1]))
                return losses, 'diverged'
    print('')
    return losses, 'converged'

def gradient_descent(x0_len, T, stepsize):
    x = np.zeros(x0_len)
    losses = [objective_value(x)]
    for t in range(T):
        x -= stepsize * objective_full_gradient(x)
        losses.append(objective_value(x))
    return np.array(losses)


def newtons_method(x_len, f_eval, grad_eval, hessian_eval, max_iter=1000, tol=1e-6):
    x = np.zeros(x_len)
    stepsize = 0.5
    for t in range(max_iter):
        gradient = grad_eval(x)
        hessian = hessian_eval(x)
        update_direction = np.linalg.solve(hessian, gradient)
        x -= stepsize * update_direction
        newtons_decrement = np.sqrt(np.dot(gradient, update_direction))
        if newtons_decrement <= tol:
            print("Newton's method converged after {:d} iterations".format(t+1))
            return f_eval(x), x
    print("Warning: Newton's method failed to converge")
    return f_eval(x), x

##################################################################################################################

loss_freq = 5
n_reps = 4
n_stepsizes = 10
n_proxparams = 10

K = 100
R = 100
M = 25
dim = 100
DO_COMPUTE = True

path = 'mnist_eo_m{:d}_k{:d}_r{:d}_d={:d}'.format(M,K,R,dim)
Ps = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

sigma_diff_p = {}

if DO_COMPUTE:
    for p in Ps:
        print('\n\nDOING p = {:f}'.format(p))
        features, labels = load_MNIST2(p,dim,path)
        x_len = features.shape[2]

        def f_eval(x):
            return logistic_loss(x, features.reshape(-1,x_len), labels.reshape(-1))

        def grad_eval(x, minibatch_size, m):
            idxs = np.random.randint(0,features[m].shape[0], minibatch_size)
            return logistic_loss_gradient(x, features[m, idxs, :], labels[m, idxs])

        def full_grad_eval(x):
            return logistic_loss_gradient(x, features.reshape(-1,x_len), labels.reshape(-1))

        def hessian_eval(x):
            return logistic_loss_hessian(x, features.reshape(-1,x_len), labels.reshape(-1))

        fstar, xstar = newtons_method(x_len, f_eval, full_grad_eval, hessian_eval)
        sigma_diff = 0.
        for m in range(M):
            nrm_nabla_Fm_star = np.linalg.norm(grad_eval(xstar, len(labels[m]), m))
            sigma_diff += nrm_nabla_Fm_star**2 / M
        sigma_diff_p[p] = sigma_diff

        print('Fstar = {:.6f}'.format(fstar))
        print('sigma_diff = {:.5f}'.format(sigma_diff_p[p]))

        lg_stepsizes = [np.exp(exponent) for exponent in np.linspace(-6,0,n_stepsizes)]
        lc_stepsizes = [np.exp(exponent) for exponent in np.linspace(-8,-1,n_stepsizes)]

        print('Doing Minibatch SGD...')
        large_results = np.zeros((R//loss_freq, len(lg_stepsizes)))
        for i,stepsize in enumerate(lg_stepsizes):
            print('Stepsize {:.5f}:  {:d}/{:d}'.format(stepsize, i+1, len(lg_stepsizes)))
            for rep in range(n_reps):
                l, success = minibatch_sgd(x_len, M, K, R, stepsize, loss_freq, f_eval, grad_eval)
                if success == 'converged':
                    large_results[:,i] += (l - fstar) / n_reps
                else:
                    large_results[:,i] += 100

        print('Doing Local SGD...')
        local_results = np.zeros((R//loss_freq, len(lc_stepsizes)))
        for i,stepsize in enumerate(lc_stepsizes):
            print('Stepsize {:.5f}:  {:d}/{:d}'.format(stepsize, i+1, len(lc_stepsizes)))
            for rep in range(n_reps):
                l, success = local_sgd(x_len, M, K, R, stepsize, loss_freq, f_eval, grad_eval)
                if success == 'converged':
                    local_results[:,i] += (l - fstar) / n_reps
                else:
                    local_results[:,i] += 100

        local_l = np.min(local_results, axis=1)
        large_l = np.min(large_results, axis=1)
        Rs = list(range(loss_freq, R+1, loss_freq))

        if path not in os.listdir('./data'):
            os.mkdir('./data/'+path)

        np.save('./data/' + path + '/p{:.2f}_local.npy'.format(p), local_l)
        np.save('./data/' + path + '/p{:.2f}_large.npy'.format(p), large_l)

    sigma_diffs = [sigma_diff_p[p] for p in Ps]
    np.save('./data/' + path + '/sigma_diffs.npy', sigma_diffs)

sigma_diffs = np.load('./data/' + path + '/sigma_diffs.npy')
# times = [2,4,6]
times = [5,10]
local_ls = np.zeros((len(sigma_diffs), len(times)))
large_ls = np.zeros((len(sigma_diffs), len(times)))

for i,p in enumerate(Ps):
    local_p = np.load('./data/' + path + '/p{:.2f}_local.npy'.format(p))
    large_p = np.load('./data/' + path + '/p{:.2f}_large.npy'.format(p))
    for j,t in enumerate(times):
        local_ls[i,j] = local_p[t]
        large_ls[i,j] = large_p[t]


fig = plt.figure()
ax = fig.add_subplot(111)
for j,t in enumerate(times):
    ax.plot(sigma_diffs, local_ls[:,j], label='Local SGD after {:d} rounds'.format(5*t))
    ax.plot(sigma_diffs, large_ls[:,j], label='Minibatch SGD after {:d} rounds'.format(5*t))

handles,labels = ax.get_legend_handles_labels()
ax.set_xlabel(r'$\zeta$')
ax.set_ylabel('Error')
ax.set_title('K={:d}'.format(K))
# ax.legend(handles, labels, loc='upper left')
plt.savefig('plots/final_' + path + 'loss_vs_sigma_diff.png', dpi=400)
# plt.show()
quit()






# Rs = list(range(loss_freq, R+1, loss_freq))
# fig = plt.figure()
# ax = fig.add_subplot(111)
# for i,p in enumerate(Ps):
#     if p in [0.2, 0.6, 0.8]:
#         continue
#     local_p = np.load('data/mnist_m9_k50_r100/p{:.2f}_local.npy'.format(p))
#     large_p = np.load('data/mnist_m9_k50_r100/p{:.2f}_large.npy'.format(p))
#     ax.plot(Rs, local_p, label='Local SGD sigma_diff={:.3f}'.format(sigma_diffs[i]))
#     ax.plot(Rs, large_p, label='Minibatch SGD sigma_diff={:.3f}'.format(sigma_diffs[i]))
# handles,labels = ax.get_legend_handles_labels()
# ax.set_xlabel('Round of Communication')
# ax.set_ylabel('Objective Value')
# ax.set_title('M={:d}, K={:d}, R={:d}'.format(M,K,R))
# ax.legend(handles, labels, loc='upper right')
# plt.savefig('plots/' + path + 'loss_vs_r.png')
# plt.show()













