import copy
import numpy as np

from numpy.linalg import norm
from scipy.linalg import toeplitz
from sklearn.utils import check_random_state

import mne
from mne.datasets import sample
from mne.inverse_sparse.mxne_inverse import (
    _make_sparse_stc,
    _reapply_source_weighting,
    _prepare_gain, is_fixed_orient)

from data.utils import get_data_from_X_S_and_B_star


def get_data_me(
        dictionary_type="Gaussian", noise_type="Gaussian_iid", n_channels=20,
        n_sources=20, n_times=30, n_epochs=50, n_active=3, rho=0.3,
        rho_noise=0.6,
        SNR=1, seed=0,
        meg=None, eeg=None):
    """Simulate artificial data.

    Parameters:
    ----------
    dictionary_type: string
        "Gaussian", "Toeplitz", real_me
    noise_type: string
         "Gaussian_iid", or "Gaussian_multivariate"
    n_channels: int
        number of channels
    n_sources: int
        number of potential sources.
    n_times: int
        number of time points
    n_epochs: int
        number of epochs/repetitions
    n_active: int
        number of active sources
    rho: float
        coefficient of correlation for the Toeplitz-corralted dictionary
    rho_noise: float
        coefficient of correlation for the Toeplitz covariance of the noise
    SNR: float
        Signal to Noise Ratio
    seed: int
    meg: bool or string
        True, "mag" or "eeg".
        If True, keeps magnetometers and gradiometers in cov.
        If "grad", keeps only gradiometers in cov.
        If "mag", keeps only the magnetometers in cov.
    eeg: bool
        If True keep electro-ancephalogramme in cov.
        If False remove electro-ancephalogramme in cov.

    Returns
    -------
    dictionary: np.array, shape (n_sensors, n_sources)
        dictionary/gain matrix
    all_epochs: np.array, shape (n_epochs, n_sensors, n_times)
        data observed
    B_star: np.array, shape (n_sources, n_times))
        real regression coefficients
    multiplicativ_factor
    S_star: np.array, shape (n_sensors, n_sensors)
        covariane matrix
    """
    rng = check_random_state(seed)

    X = get_dictionary(
        dictionary_type, n_channels=n_channels,
        n_sources=n_sources, rho=rho, meg=meg, eeg=eeg, seed=seed )
    S_star = get_S_star(
        noise_type=noise_type, n_channels=n_channels,
        rho_noise=rho_noise, meg=meg, eeg=eeg, seed=seed)

    rng = check_random_state(seed)
    # creates the signal XB
    B_star = np.zeros([n_sources, n_times])
    supp = rng.choice(n_sources, n_active, replace=False)
    B_star[supp, :] = rng.randn(n_active, n_times)

    X, all_epochs, B_star, (multiplicativ_factor, S_star) =\
        get_data_from_X_S_and_B_star(
        X, B_star, S_star, n_epochs=n_epochs,
        SNR=SNR, seed=seed)
    return X, all_epochs, B_star, (multiplicativ_factor, S_star)


def get_S_star(
    noise_type="Gaussian_iid", n_channels=20, rho_noise=0.7, seed=0, meg=True, eeg=True):
    """Simulate co-standard devation matrix.

    Parameters:
    ----------
    noise_type: string
         "Gaussian_iid", or "Gaussian_multivariate"
    n_channels: int
        number of channels
    rho_noise: float
        coefficient of correlation for the Toeplitz covariance of the noise
    seed: int
    meg: bool or string
        True, "mag" or "eeg".
        If True, keeps magnetometers and gradiometers in cov.
        If "grad", keeps only gradiometers in cov.
        If "mag", keeps only the magnetometers in cov.
    eeg: bool
        If True keep electro-ancephalogramme in cov.
        If False remove electro-ancephalogramme in cov.

    Returns
    -------
    S_star: np.array, shape (n_sensors, n_sensors)
        co-satndard deviation matrix
    """
    if noise_type == "Gaussian_iid":
        S_star = np.eye(n_channels)
    elif noise_type == "Gaussian_multivariate":
        vect = rho_noise ** np.arange(n_channels)
        S_star = toeplitz(vect, vect)
    elif noise_type == "real_noise":
        S_star = get_real_cov_matrix_and_decimate(
            n_channels=n_channels, seed=seed, meg=meg, eeg=eeg)
    else:
        raise ValueError("Unknown noise type %s" % noise_type)
    return S_star


def get_dictionary(
        dictionary_type, n_channels=20, n_sources=30,
        rho=0.3, meg=None, eeg=None, seed=0):
    rng = check_random_state(seed)
    if dictionary_type == "real_dico":
        # return get_real_data_me()
        X = get_real_dictionary(n_channels, n_sources, meg, eeg, seed)
    elif dictionary_type == 'Toeplitz':
        X = get_toeplitz_dictionary(
            n_channels=n_channels, n_sources=n_sources, rho=rho, seed=seed)
    elif dictionary_type == 'Gaussian':
        X = rng.randn(n_channels, n_sources)
    else:
        raise NotImplementedError("No dictionary '{}' in maxsparse"
                                  .format(dictionary_type))
    normalize(X)
    return X


def normalize(X):
    for i in range(X.shape[1]):
        X[:, i] /= norm(X[:, i])
    return X


def get_toeplitz_dictionary(
        n_channels=20, n_sources=30, rho=0.3, seed=0):
    """This function returns a toeplitz dictionnary phi.

    Maths formula:
    S = toepltiz(\rho ** [|0, n_sources-1|], \rho ** [|0, n_sources-1|])
    X[:, i] \sim \mathcal{N}(0, S).

    Parameters
    ----------
    n_channels: int
        number of channels/measurments in your problem
    n_labels: int
        number of labels/atoms in your problem
    rho: float
        correlation matrix

    Results
    -------
    X : array, shape (n_channels, n_labels)
        The dictionary.
    """
    rng=check_random_state(seed)
    vect = rho ** np.arange(n_sources)
    covar = toeplitz(vect, vect)
    X = rng.multivariate_normal(np.zeros(n_sources), covar, n_channels)
    return X


def get_whitener(evoked, forward, noise_cov, loose=0.2, depth=0.8):
    all_ch_names = evoked.ch_names
    # put the forward solution in fixed orientation if it'B not already
    if loose is None and not is_fixed_orient(forward):
        forward = copy.deepcopy(forward)

    gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
        forward, evoked.info, noise_cov, pca=False, depth=depth,
        loose=loose, weights=None, weights_min=None)
    return gain, whitener, source_weighting


def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8):
    all_ch_names = evoked.ch_names
    # put the forward solution in fixed orientation if it'B not already
    if loose is None and not is_fixed_orient(forward):
        forward = copy.deepcopy(forward)

    gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
        forward, evoked.info, noise_cov, pca=False, depth=depth,
        loose=loose, weights=None, weights_min=None)

    sel = [all_ch_names.index(name) for name in gain_info['ch_names']]
    Y = evoked.data[sel]

    # Whiten data
    Y = np.dot(whitener, Y)
    np.save("X_white.npy", gain)
    np.save("Y_white.npy", Y)
    n_orient = 1 if is_fixed_orient(forward) else 3
    X, active_set = solver(Y, gain, n_orient)
    X = _reapply_source_weighting(X, source_weighting, active_set, n_orient)

    stc = _make_sparse_stc(
        X, active_set, forward, tmin=evoked.times[0],
        tstep=1. / evoked.info['sfreq'])

    return stc
