from utils import *
from algorithms_no_log import al2sgd_no_log, agd

def al2sgd(total_loss, g, d, T, m, rho, lam, track_agg=False, X0=None, L=1, mu=1e-4, no_agg=1e+3):

    Lf = (lam + L) / T
    cL = max(L / T / (1 - rho), lam / rho / T)

    eta = 1 / 4 / max(Lf, cL)
    theta2 = cL / 2 / max(Lf, cL)
    theta1 = min(1 / 2, np.sqrt(eta * mu / T * max(1 / 2, theta2 / rho)))
    gamma = 1 / (max(2 * mu / T, 4 * theta1 / eta))
    beta = 1 - gamma * mu / T

    if X0 is None:
        X = np.zeros((T, d))
        Y = np.zeros((T, d))
        Z = np.zeros((T, d))
        W = np.zeros((T, d))
    else:
        X = 1 * X0
        Y = 1 * X0
        Z = 1 * X0
        W = 1 * X0

    aggregations = 0
    agg_fvals = []
    agg_x = [X]
    agg_steps = [1]
    loc_steps = 0

    agg_fvals.append(total_loss(X))
    agg_x.append(x_mean(X))

    gW = np.zeros((T, d))  # to store \nabla f_i(w_i)
    w_bar = np.zeros((1, d))

    locals_different = False

    k=0
    while len(agg_fvals) <no_agg:
        k += 1
        if (k % 1000 == 0):
            print("al2lgd dist: {}".format(k))

        X = theta1 * Z + theta2 * W + (1 - theta1 - theta2) * Y

        if np.random.rand() > rho:  # don't aggregate
            loc_steps += 1
            locals_different = True
            for t in range(T):
                ind = t * m + np.random.randint(m)
                gx = g(X[t, :], ind)
                gw = g(W[t, :], ind)

                Y[t, :] = X[t, :] - eta * (
                            1 / T / (1 - rho) * (gx - gw) + 1 / T * gW[t, :] + psi_grad(W, w_bar, lam/T, t))
        else:  # aggregate
            aggregations += locals_different
            locals_different = False

            x_bar = x_mean(X)
            for t in range(T):
                Y[t, :] = X[t, :] - eta * (
                            1 / rho * psi_grad(X, x_bar, lam/T, t) - (1 / rho - 1) * psi_grad(W, w_bar, lam/T,
                                                                                              t) + 1 / T * gW[t, :])

            if track_agg:
                agg_fvals.append(total_loss(X))
                agg_x.append(x_bar)
                agg_steps.append(loc_steps)

        Z = beta * Z + (1 - beta) * X + gamma / eta * (Y - X)

        if np.random.rand() < rho:  # update control variables
            W = 1 * Y
            w_bar = x_mean(W)
            for t in range(T):
                gW[t, :] = np.mean([g(X[t, :], t * m + ind) for ind in range(m)], axis=0)

    return X, (aggregations, agg_fvals, agg_x, agg_steps)


def l2sgd(total_loss, g, alpha, d, T, m, pagg, omega, method="SAGA", psvrg=0, track_agg=False, vr="full", X0= None, no_agg=1e+3):

    J = np.zeros((int(T*m), d))
    mean_J = np.zeros((T, d))
    Psi = np.zeros((T, d))
    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1*X0
    n = int(T*m)
    scale = total_loss(X)

    aggregations = 0
    loc_steps = 0

    agg_fvals = []
    agg_x = [X]

    agg_steps = [1]
    agg_x.append(x_mean(X))
    agg_fvals.append(scale)

    locals_different = False

    k = 0
    while len(agg_fvals) < no_agg:
        k += 1
        if k % 10000 == 0:
            print("l2sgd dist: {}".format(k))

        if np.random.rand() < pagg: #aggregation
            aggregations += locals_different
            locals_different = False

            x_bar = x_mean(X)
            for t in range(T):
                # compute gradient & take step
                grad = psi_grad(X, x_bar, omega, t)
                X[t, :] = X[t, :] - alpha*(grad/pagg - (1/pagg-1)*Psi[t, :] + mean_J[t, :]/T)

                if vr == "full" or vr == "partial":
                    # update Jacobian table
                    Psi[t, :] = grad

            if track_agg:
                agg_fvals.append(total_loss(X))
                agg_x.append(x_bar)
                agg_steps.append(loc_steps)

        else:
            # no aggregation
            loc_steps += 2
            for t in range(T):
                locals_different = True

                i = np.random.randint(m)
                ind = t*m + i
                grad = g(X[t, :], ind)

                X[t, :] = X[t, :] - alpha*(m/n/(1-pagg)*(grad - J[ind, :]) + mean_J[t, :]/T + Psi[t, :])

                # update Jacobian table
                if vr == "full":
                    if method == "SAGA":
                        # loc_steps += 1
                        mean_J[t, :] += (grad - J[ind, :])/m
                        J[ind, :] = grad
                    else:
                        if np.random.rand() < psvrg:
                            for ii in range(m):
                                grad = g(X[t, :], t*m+ii)
                                mean_J[t, :] += (grad - J[t*m+ii, :])/m
                                J[t*m+ii, :] = grad

    return X, (aggregations, agg_fvals, agg_x, agg_steps)


def apgd(total_loss, g, d, T, m, lamda, stochastic=True, X0=None, L=1, mu=1e-4, no_agg=1e+3):

    magic = np.sqrt((L+lamda)/(mu+lamda))*5
    c_agd = np.sqrt(mu * (L + lamda) / lamda / (mu + lamda))

    magic_kat = np.sqrt(m*(L+lamda)/(mu+lamda))
    c_kat = np.sqrt(m*(L+lamda)/(mu+lamda))* np.sqrt(mu / lamda)

    c_momentum = (np.sqrt(lamda) - np.sqrt(mu)) / (np.sqrt(lamda) + np.sqrt(mu))

    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1 * X0
    Y = 1*X
    scale = total_loss(X)

    loc_steps = 0
    F = [scale]
    agg_steps = [1]

    k = 0
    while len(F) < no_agg:
        k += 1
        if k % 50 == 0 :
            print("apgd dist: {}".format(k))

        y_bar = x_mean(Y)
        X_old = 1 * X
        for t in range(T):

            if not stochastic:
                def grad_det(x):
                    grad = 0*g(x, t*m)
                    for i in range(m):
                        grad += g(x, t*m+i)/m
                    return grad + lamda*(x-y_bar)

                steps_agd = int(np.ceil(magic + c_agd*(k+1)))
                loc_steps += steps_agd*m
                X[t,:] = agd(grad_det, y_bar, steps_agd, L=L+lamda, mu=mu+lamda)
            else:
                steps_katyusha = int(np.ceil(magic_kat + c_kat * (k + 1)))
                loc_steps += steps_katyusha
                def stoch_grad(x,i):
                    return g(x, t*m+i) + lamda*(x-y_bar)
                _, _, X[t, :] = al2sgd_no_log(stoch_grad, d, 1, m, 1 / m, 0, number_of_steps=steps_katyusha, X0=np.asmatrix(y_bar), L=L + lamda, mu=mu + lamda)

        Y = X + c_momentum * (X - X_old)

        F.append(total_loss(X))
        agg_steps.append(loc_steps)

    return F, X, agg_steps

def apgd2(total_loss, g, d, T, m, lamda, X0=None, L=1, mu=1e-4, no_agg=1e+3):
    c_momentum = (np.sqrt(L / mu) - 1) / (np.sqrt(L / mu) + 1)

    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1 * X0
    Y = 1*X
    Y_tilda = 1*X
    scale = total_loss(X)

    loc_steps = 0
    F = [scale]
    agg_steps = [1]

    k = 0
    while len(F) < no_agg:
        k += 1
        if k % 50 == 0:
            print("apgd2 iter:{}".format(k))

        loc_steps += m

        X_old = 1 * X
        for t in range(T):
            grad = 0
            for u in range(m):
                grad += g(Y[t, :], t*m+u)
            grad /= m
            Y_tilda[t, :] = Y[t, :] - 1/L*grad
        y_bar = x_mean(Y_tilda)
        X = (L*Y_tilda + lamda*y_bar)/(L+lamda)
        Y = X + c_momentum*(X-X_old)

        F.append(total_loss(X))
        agg_steps.append(loc_steps)

    return F, X, agg_steps

