import numpy as np

from numba import njit
from numpy.linalg import norm
from sklearn.covariance import graphical_lasso


@njit
def get_S_Sinv(ZZT, sigma_min=1e-6):
    """Take the square root and inverse of squre root of a symmetric definite positive matrix.
    The eigenvalues of ZZT smaller than sigma_min are put to sigma_min.

    Parameters:
    ----------
    ZZT: np.array, shape (n_channels, n_channels)
    sigma_min: float

    Output: (float,  np.array, shape (n_channels, n_channels))
        (trace of S updated, inverse of S updated)
    """
    eigvals, eigvecs = np.linalg.eigh(ZZT)
    eigvals = np.maximum(0, eigvals)
    eigvals = np.sqrt(eigvals)
    div_eigvals = 1 / eigvals
    mask = (eigvals < sigma_min * eigvals.max())

    print(eigvals)
    print('Number of eigvals clipped: %d' % mask.sum())
    div_eigvals[mask] = 0
    # div_eigvals[mask] = div_eigvals[np.logical_not(mask)].max()
    eigvals = np.expand_dims(eigvals, axis=1)
    div_eigvals = np.expand_dims(div_eigvals, axis=1)
    return eigvecs @ (eigvals * eigvecs.T), eigvecs @ (div_eigvals * eigvecs.T)


def get_S_inv_S_from_R(R, sigma_min):
    emp_cov = get_emp_cov(R)
    _, S_inv, S = get_S_inv_S(
        emp_cov, sigma_min)
    return S_inv, S


@njit
def get_S_inv_S(ZZT, sigma_min):
    """Update ZZT by taking the square rooot of it and
        conditionning it better.

    Parameters:
    ----------
    ZZT: np.array, shape (n_channels, n_channels)
        real, positiv definite symmetric matrix

    sigmamin: float
        sigmamin > 0, the eigenvalues of sqrt(ZZT) are thresholded
        to sigmamin
    Output:
    -------
    (float,  np.array, shape (n_channels, n_channels))
        (trace of S updated, inverse of S updated)
     """
    eigvals, eigvecs = np.linalg.eigh(ZZT)

    n_eigvals_clipped = (np.sqrt(eigvals) < sigma_min ).sum()
    bool_reach_sigma_min = n_eigvals_clipped > 0
    if bool_reach_sigma_min:
        print("---------------------------------------")
        print("warning, be carefull, you reached sigmamin")
        print(n_eigvals_clipped, " eigenvalues clipped")
        print("---------------------------------------")
    # else:
    #     print("You did not reach sigmamin")
    eigvals = np.maximum(0, eigvals)
    eigvals = np.maximum(np.sqrt(eigvals), sigma_min)
    eigvals = np.expand_dims(eigvals, axis=1)

    return eigvals.sum(), \
        eigvecs @ (1 / eigvals * eigvecs.T), \
        eigvecs @ (eigvals * eigvecs.T)


@njit
def BST(u, tau):
    """
    BST stands for block soft thresholding operator.

    Parameters
    ----------------------
    u: np.array
    tau: float
        non-negativ number

    Output
    ---------------------
    numpy array:
        vector of the same size as u

    line_is_zero: bool
        Whether or not the block soft thresholding returns a vector full of 0.
    """
    norm_u = norm(u)
    if norm_u == 0:
        return u, True
    a = 1 - tau / norm_u
    line_is_zero = a < 0
    if a < 0:
        u.fill(0)
    else:
        u *= a
    return u, line_is_zero


# @njit
def clp_sqrt(ZZT, sigma_min):
    """Update ZZT by taking the square root of it and
        conditionning it better.

    Parameters:
    ----------
    ZZT: np.array, shape (n_channels, n_channels)
        real, positiv definite symmetric matrix

    sigmamin: float
        sigmamin > 0, the eigenvalues of sqrt(ZZT) are thresholded
        to sigmamin
    Output:
    -------
    (float,  np.array, shape (n_channels, n_channels))
        (trace of S updated, inverse of S updated)
     """
    eigvals, eigvecs = np.linalg.eigh(ZZT)
    eigvals = np.maximum(eigvals, 0)
    # import ipdb; ipdb.set_trace()
    n_eigvals_clipped = (np.sqrt(eigvals) < sigma_min).sum()
    bool_reach_sigma_min = n_eigvals_clipped > 0
    if bool_reach_sigma_min:
        print("---------------------------------------")
        print("warning, be carefull, you reached sigmamin")
        print(n_eigvals_clipped, " eigenvalues clipped")
        print("---------------------------------------")
    # else:
    #     print("You did not reach sigmamin")
    eigvals = np.maximum(0, eigvals)
    eigvals = np.maximum(np.sqrt(eigvals), sigma_min)
    eigvals = np.expand_dims(eigvals, axis=1)

    return eigvals.sum(), \
        eigvecs @ (1 / eigvals * eigvecs.T)


@njit
def clp_sigma_gls(ZZT, sigma_min):
    """Update ZZT by conditioning it better.

    Parameters:
    ----------
    ZZT: np.array, shape (n_channels, n_channels)
        real, positive definite symmetric matrix

    Output:
    -------
    (float,  np.array, shape (n_channels, n_channels))
        (trace of S updated, inverse of S updated)
     """
    eigvals, eigvecs = np.linalg.eigh(ZZT)

    n_eigvals_clipped = (eigvals < sigma_min).sum()
    bool_reach_sigma_min = n_eigvals_clipped > 0
    if bool_reach_sigma_min:
        print("---------------------------------------")
        print("warning, be carefull, you reached sigmamin")
        print(n_eigvals_clipped, " eigenvalues clipped")
        print("---------------------------------------")
    else:
        print("You did not reach sigmamin")
    eigvals = np.maximum(eigvals, sigma_min)
    eigvals = np.expand_dims(eigvals, axis=1)
    return np.log(eigvals).sum(), \
        eigvecs @ (eigvals * eigvecs.T)


@njit
def clp_sigma_inv(emp_cov, sigma_min):
    """Update emp_cov by conditioning it better.

    Parameters:
    ----------
    emp_cov: np.array, shape (n_channels, n_channels)
        real, positive definite symmetric matrix

    Output:
    -------
    (float,  np.array, shape (n_channels, n_channels))
        (trace of S updated, inverse of S updated)
     """
    eigvals, eigvecs = np.linalg.eigh(emp_cov)

    n_eigvals_clipped = (eigvals < sigma_min).sum()
    bool_reach_sigma_min = n_eigvals_clipped > 0
    if bool_reach_sigma_min:
        print("---------------------------------------")
        print("warning, be carefull, you reached sigmamin")
        print(n_eigvals_clipped, " eigenvalues clipped")
        print("---------------------------------------")
    else:
        print("You did not reach sigmamin")
    eigvals = np.maximum(eigvals, sigma_min)
    eigvals = np.expand_dims(eigvals, axis=1)
    return np.log(eigvals).sum(), \
        eigvecs @ (1 / eigvals * eigvecs.T)


@njit
def l_2_inf(A):
    """Compute the l_2_inf norm of a matrix A.

    Parameters:
    ----------
    A: np.array

    Output:
    -------
    float
        the l_2_inf norm of A
    """
    res = 0.
    for j in range(A.shape[0]):
        res = max(res, norm(A[j, :]))
    return res


@njit
def l_2_1(A):
    """Compute the l_2_1 norm of a matrix A.

    Parameters:
    ----------
    A: np.array

    Output:
    -------
    float
        the l_2_1 norm of A.
    """
    res = 0.
    for j in range(A.shape[0]):
        res += norm(A[j, :])
    return res


def get_alpha_max(X, observation, sigma_min, pb_name, alpha_Sigma_inv=0.0001):
    """Compute alpha_max specific to pb_name.

    Parameters:
    ----------
    X: np.array, shape (n_channels, n_sources)
    observation: np.array, shape (n_channels, n_times) or
        (n_epochs, n_channels, n_times)
    sigma_min: float, >0
    pb_name: string, "SGCL" "CLaR" "MTL" "MTLME"

    Output:
    -------
    float
        alpha_max of the optimization problem.
    """
    n_channels, n_times = observation.shape[-2], observation.shape[-1]

    if observation.ndim == 3:
        Y = observation.mean(axis=0)
    else:
        Y = observation

    if pb_name == "MTL":
        n_channels, n_times = Y.shape
        alpha_max = l_2_inf(X.T @ Y) / (n_times * n_channels)
    elif pb_name == "MLE":
        observations = Y[None, :, :]
        return get_alpha_max(
            X, observations, sigma_min, "MLER",
            alpha_Sigma_inv=alpha_Sigma_inv)
    elif pb_name == "MRCE":
        observations = Y[None, :, :]
        return get_alpha_max(
            X, observations, sigma_min, "MRCER",
            alpha_Sigma_inv=alpha_Sigma_inv)
    elif pb_name == "MTLME":
        observations = observation.transpose((1,0,2))
        observations = observations.reshape(observations.shape[0], -1)
        alpha_max = get_alpha_max(X, observations, sigma_min, "MTL")
    elif pb_name == "SGCL":
        assert observation.ndim == 2
        _, S_max_inv = clp_sqrt(Y @ Y.T / n_times, sigma_min)
        alpha_max = l_2_inf(X.T @ S_max_inv @ Y)
        alpha_max /= (n_channels * n_times)
    elif pb_name == "CLaR" or pb_name == "NNCVX":
        n_epochs = observation.shape[0]
        cov_Yl = get_emp_cov(observation)
        # cov_Yl = 0
        # for l in range(n_epochs):
        #     cov_Yl += observation[l, :, :] @ observation[l, :, :].T
        # cov_Yl /= (n_epochs * n_times)
        _, S_max_inv = clp_sqrt(
            cov_Yl, sigma_min)
        alpha_max = l_2_inf(X.T @ S_max_inv @ Y)
        alpha_max /= (n_channels * n_times)
    elif pb_name == "MRCER":
        assert observation.ndim == 3
        assert alpha_Sigma_inv is not None
        emp_cov = get_emp_cov(observation)
        Sigma, Sigma_inv = graphical_lasso(
            emp_cov, alpha_Sigma_inv, max_iter=10**6)
        alpha_max = l_2_inf(X.T @ Sigma_inv @ Y) / (n_channels * n_times)
    elif pb_name == "MLER":
        assert observation.ndim == 3
        emp_cov = emp_cov = get_emp_cov(observation)
        # _, Sigma_inv = clp_sigma_inv(emp_cov, sigma_min)
        _, Sigma_inv = clp_sigma_inv(emp_cov, sigma_min ** 2)
        alpha_max = l_2_inf(X.T @ Sigma_inv @ Y) / (n_channels * n_times)
    elif pb_name == "glasso":
        assert observation.ndim == 2
        assert alpha_Sigma_inv != None
        emp_cov = observation @ observation.T / n_times
        Sigma, Sigma_inv = graphical_lasso(emp_cov, alpha_Sigma_inv)
        alpha_max = l_2_inf(X.T @ Sigma_inv @ Y) / (n_channels * n_times)
    else:
        raise NotImplementedError("No solver '{}' in sgcl"
                            .format(pb_name))
    return alpha_max


def get_emp_cov(R):
    assert(R.ndim == 3)
    n_epochs, n_channels, n_times = R.shape
    emp_cov = np.zeros((n_channels, n_channels))
    for l in range(n_epochs):
        emp_cov += R[l, :, :] @ R[l, :, :].T
    emp_cov /= (n_epochs * n_times)
    return emp_cov


def get_sigma_min(Y):
    """Compute sigma_min

    Parameters:
    ----------
    Y: np.array, shape (n_channels, n_sources)

    Output:
    -------
    float
        sigma_min.
    """
    if Y.ndim == 2:
        sigma_min = norm(Y, ord='fro') / \
                (np.sqrt(Y.shape[1] * Y.shape[0]) * 1000)
    elif Y.ndim ==3:
        sigma_min = norm(Y.mean(axis=0), ord='fro') / \
                (np.sqrt(Y.shape[1] * Y.shape[0]) * 1000)
    else:
        raise ValueError("Y has to many dimension")
    return sigma_min
