from utils import *

def al2sgd_no_log(g, d, T, m, rho, lam, number_of_steps=None, L=1, mu=1e-4, X0=None, X_opt=None, accuracy=1e-5):
    assert (number_of_steps is not None) or (X_opt is not None)

    Lf = (lam + L) / T
    cL = max(L / T / (1 - rho), lam / rho / T)

    eta = 1 / 4 / max(Lf, cL)
    theta2 = cL / 2 / max(Lf, cL)
    theta1 = min(1 / 2, np.sqrt(eta * mu / T * max(1 / 2, theta2 / rho)))
    gamma = 1 / (max(2 * mu / T, 4 * theta1 / eta))
    beta = 1 - gamma * mu / T

    if X0 is None:
        X = np.zeros((T, d))
        Y = np.zeros((T, d))
        Z = np.zeros((T, d))
        W = np.zeros((T, d))
    else:
        X = 1 * X0
        Y = 1 * X0
        Z = 1 * X0
        W = 1 * X0

    aggregations = 0
    loc_steps = 0

    gW = np.zeros((T, d))  # to store \nabla f_i(w_i)
    w_bar = np.zeros((1, d))

    converged = False
    locals_different = False
    if X_opt is not None:
        end_distance = np.linalg.norm(X-X_opt)**2 * accuracy

    k = 0
    while not converged:
        k += 1
        if (k % 1000 == 0) and (X_opt is not None):
            print("al2sgd dist: {}".format(np.linalg.norm(X - X_opt)**2))

        X = theta1 * Z + theta2 * W + (1 - theta1 - theta2) * Y

        if np.random.rand() > rho:  # don't aggregate
            loc_steps += 1
            locals_different = True
            for t in range(T):
                ind = t * m + np.random.randint(m)
                gx = g(X[t, :], ind)
                gw = g(W[t, :], ind)

                Y[t, :] = X[t, :] - eta * (
                            1 / T / (1 - rho) * (gx - gw) + 1 / T * gW[t, :] + psi_grad(W, w_bar, lam/T, t))
        else:  # aggregate
            aggregations += locals_different
            locals_different = False

            x_bar = x_mean(X)
            for t in range(T):
                Y[t, :] = X[t, :] - eta * (1 / rho * psi_grad(X, x_bar, lam/T, t) - (1 / rho - 1) * psi_grad(W, w_bar, lam/T, t) + 1 / T * gW[t, :])

        Z = beta * Z + (1 - beta) * X + gamma / eta * (Y - X)

        if np.random.rand() < rho:  # update control variables
            W = 1 * Y
            w_bar = x_mean(W)
            for t in range(T):
                gW[t, :] = np.mean([g(X[t, :], t * m + ind) for ind in range(m)], axis=0)

        if number_of_steps is not None:
            number_of_steps -= 1
            converged = (number_of_steps == 0)
        else:
            converged = (np.linalg.norm(X - X_opt)**2 < end_distance)

    return aggregations, loc_steps, X


def l2sgd_no_log(g, alpha, d, T, m, pagg, omega, X0= None, X_opt=None, accuracy=1e-5, method="SAGA", psvrg=0, vr="full"):

    J = np.zeros((int(T*m), d))
    mean_J = np.zeros((T, d))
    Psi = np.zeros((T, d))
    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1*X0
    n = int(T*m)

    aggregations = 0
    loc_steps = 0

    converged = False
    locals_different = False
    if X_opt is not None:
        end_distance = np.linalg.norm(X-X_opt)**2 * accuracy

    iteration = 0
    while not converged:
        iteration += 1

        if iteration % 10000 == 0:
            print("l2sgd dist: {}".format(np.linalg.norm(X - X_opt)**2))


        if np.random.rand() < pagg: #aggregation
            aggregations += locals_different
            locals_different = False

            x_bar = x_mean(X)
            for t in range(T):
                # compute gradient & take step
                grad = psi_grad(X, x_bar, omega, t)
                X[t, :] = X[t, :] - alpha*(grad/pagg - (1/pagg-1)*Psi[t, :] + mean_J[t, :]/T)

                if vr == "full" or vr == "partial":
                    # update Jacobian table
                    Psi[t, :] = grad

        else:
            # no aggregation
            loc_steps += 2
            for t in range(T):
                locals_different = True

                i = np.random.randint(m)
                ind = t*m + i
                grad = g(X[t, :], ind)

                X[t, :] = X[t, :] - alpha*(m/n/(1-pagg)*(grad - J[ind, :]) + mean_J[t, :]/T + Psi[t, :])

                # update Jacobian table
                if vr == "full":
                    if method == "SAGA":
                        # loc_steps += 1
                        mean_J[t, :] += (grad - J[ind, :])/m
                        J[ind, :] = grad
                    else:
                        if np.random.rand() < psvrg:
                            for ii in range(m):
                                grad = g(X[t, :], t*m+ii)
                                mean_J[t, :] += (grad - J[t*m+ii, :])/m
                                J[t*m+ii, :] = grad

        if X_opt is not None:
            converged = (np.linalg.norm(X - X_opt)**2 < end_distance)

    return aggregations, loc_steps, X


def apgd_no_log(g, d, T, m, lamda, stochastic=True, X0=None, L=1, mu=1e-4, X_opt=None, accuracy=1e-5, proximal=False):

    magic_agd = np.sqrt((L+lamda)/(mu+lamda))*5
    c_agd = np.sqrt(mu * (L + lamda) / lamda / (mu + lamda))

    magic_kat = np.sqrt(m*(L+lamda)/(mu+lamda))
    c_kat = np.sqrt(m*(L+lamda)/(mu+lamda))* np.sqrt(mu / lamda)

    c_momentum = (np.sqrt(lamda) - np.sqrt(mu)) / (np.sqrt(lamda) + np.sqrt(mu))

    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1 * X0
    Y = 1*X

    loc_steps = 0
    converged = False
    if X_opt is not None:
        end_distance = np.linalg.norm(X - X_opt) ** 2 * accuracy

    k = 0
    while not converged:
        k += 1
        if k % 50 == 0 and not proximal:
            print("apgd dist: {}".format(np.linalg.norm(X - X_opt)**2))
        y_bar = x_mean(Y, prox=proximal, d=d, T=T)
        X_old = 1 * X
        for t in range(T):
            if proximal:
                loc_steps += 1
                X[(t*d):((t+1)*d)] = g(y_bar, t)
            elif not stochastic:
                def grad_det(x):
                    grad = 0*g(x, t*m)
                    for i in range(m):
                        grad += g(x, t*m+i)/m
                    return grad + lamda*(x-y_bar)

                steps_agd = int(np.ceil(magic_agd + c_agd*k))
                loc_steps += steps_agd
                X[t, :] = agd(grad_det, y_bar, steps_agd, L=L+lamda, mu=mu+lamda)
            else:
                steps_katyusha = int(np.ceil(magic_kat + c_kat * k))
                loc_steps += steps_katyusha
                def stoch_grad(x,i):
                    return g(x, t*m+i) + lamda*(x-y_bar)
                _, _, X[t, :] = al2sgd_no_log(stoch_grad, d, 1, m, min(1 / m, 0.99), 0, number_of_steps=steps_katyusha, X0=np.asmatrix(y_bar), L=L + lamda, mu=mu + lamda, X_opt=X_opt, accuracy=0)

        Y = X + c_momentum * (X - X_old)

        if X_opt is not None:
            converged = (np.linalg.norm(X - X_opt)**2 < end_distance)

    return k, int(loc_steps//T), X

def apgd2_no_log(g, d, T, m, lamda, X0=None, L=1, mu=1e-4, X_opt=None, accuracy=1e-5, proximal=False):
    c_momentum = (np.sqrt(L / mu) - 1) / (np.sqrt(L / mu) + 1)

    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1 * X0
    Y = 1 * X
    Y_tilda = 1 * X

    aggregations = 0
    loc_steps = 0
    converged = False
    if X_opt is not None:
        end_distance = np.linalg.norm(X-X_opt)**2 * accuracy

    iteration = 0
    while not converged:
        iteration += 1

        aggregations += 1
        loc_steps += 1+(m-1)*(not proximal)

        X_old = 1 * X
        for t in range(T):
            grad = np.zeros(d)
            for u in range(m):
                grad += g(Y[t, :], t*m + u )/m
            Y_tilda[t, :] = Y[t, :] - 1 / L * grad
        y_bar = x_mean(Y_tilda)
        X = (L * Y_tilda + lamda * y_bar) / (L + lamda)
        Y = X + c_momentum * (X - X_old)

        if X_opt is not None:
            converged = (np.linalg.norm(X - X_opt)**2 < end_distance)

    return aggregations, loc_steps, X

def agd(grad, x, no_steps, L=1, mu=1e-4):
    c_momentum = (np.sqrt(L/mu)-1)/(np.sqrt(L/mu)+1)

    w = 1 * x
    for k in range(no_steps):
        x_old = x * 1
        x = w - 1 / L * grad(w)
        w = x + c_momentum * (x - x_old)
    return x