import os
import numpy as np
from numpy.linalg import norm

import mne
from mne import io
from mne.datasets import sample
from data.utils import compute_forward
from data.utils import compute_forward_
from data.semi_real import compute_log_norm_axis1
from data.semi_real import compute_log_norm_axis0
from data.semi_real import get_my_whitener


def get_cov(
        ch_names,
        bads=['MEG 2443', 'EEG 053'], data_path=sample.data_path()):
    cov_fname = data_path + '/MEG/sample/sample_audvis-cov.fif'
    cov = mne.read_cov(cov_fname)
    ##########################################################################
    cov = mne.pick_channels_cov(cov, ch_names)
    cov['bads'] = bads
    cov = mne.pick_channels_cov(cov)
    return cov


def get_fwd(
        resolution=3, meg=True, eeg=True, force_fixed=True,
        bads=['MEG 2443', 'EEG 053'], data_path=sample.data_path(),
        raw_path='/MEG/sample/sample_audvis_filt-0-40_raw.fif', erase=True):
    """Retrun forward

    Parameters:
    ----------
    resolution:

    Returns:
    -------
    fwd: ndarray (n_channels, n_sources)
        the forward operator
    info: info
    """
    ###########################################################################
    print("Loading data......................................................")
    raw = mne.io.read_raw_fif(data_path + raw_path, preload=True)
    raw.info['bads'] = bads  # mark bad channels
    raw.drop_channels(ch_names=raw.info['bads'])
    ###########################################################################
    # Compute or import resized foreward
    path_fwd = data_path + \
        '/MEG/sample/sample_audvis-meg-eeg-oct-%i-fwd.fif' % resolution
    if (not os.path.isfile(path_fwd)) or erase:
        fwd = compute_forward_(data_path, raw.info, resolution)
        mne.write_forward_solution(path_fwd, fwd, overwrite=True)
    else:
        fwd = mne.read_forward_solution(path_fwd)

    fwd = mne.convert_forward_solution(fwd, force_fixed=force_fixed)
    fwd = mne.pick_types_forward(
            fwd, meg=meg, eeg=eeg, exclude=raw.info['bads'])
    raw.pick_types(meg=meg, eeg=eeg)
    fwd = mne.pick_channels_forward(fwd, raw.ch_names)
    return fwd, raw.info


def get_real_epochs(
        meg=True, eeg=True, snty_chk=True,
        bads=['MEG 2443', 'EEG 053'], data_path=sample.data_path(),
        raw_path='/MEG/sample/sample_audvis_filt-0-40_raw.fif',
        event_path='/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif',
        event_id=2, tmin=-0.2, tmax=0.5):
    """Retrun real epochs of sample

    Parameters:
    ----------
    snty_chk: bool
        sanity check to check the shape of the returned data, for instance to
        be sure that the stimulation channels have been removed

    Returns:
    -------
    all_epochs: ndarray (n_epochs, n_channels, n_times)
        retrurns an array with all the epochs (ie the trials/the repetitions)
    """

    ###########################################################################
    # Set parameters
    raw_fname = data_path + raw_path
    event_fname = data_path + event_path


    # Setup for reading the raw data
    raw = io.read_raw_fif(raw_fname)
    events = mne.read_events(event_fname)

    # Set up pick list: EEG + MEG - bad channels (modify to your needs)
    raw.info['bads'] += bads  # bads + 2 more
    picks = mne.pick_types(raw.info, meg=meg, eeg=eeg, stim=True, eog=True,
                           exclude='bads')

    if meg == "grad":
        reject = dict(grad=4000e-13, eog=150e-6)
    elif meg == "mag":
        reject = dict(mag=4e-12, eog=150e-6)
    elif meg:
        reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)
    else:
        reject = dict(eog=150e-6)
    # Read epochs
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                        picks=picks, baseline=(None, 0), preload=True,
                        reject=reject)

    # Select EEG and MEG (and remove stimulation and EOG)
    picks = mne.pick_types(epochs.info, meg=meg, eeg=eeg)
    all_epochs = epochs.get_data()
    all_epochs = all_epochs[:, picks, :]
    info = epochs.info
    times = epochs.times

    if snty_chk:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(all_epochs[0, :203, :].T)
        plt.show(block=False)

        plt.figure()
        plt.plot(all_epochs[0, 203:305, :].T)
        plt.show(block=False)

        plt.figure()
        plt.plot(all_epochs[0, 305:, :].T)
        plt.show(block=False)

    return all_epochs, info, times


def get_data(
        resolution=3, snty_chk=True, eeg=True, meg=True, whiten=False,
        event_id=1, tmin=-0.2, tmax=0.5):
    fwd, fwd_info = get_fwd(
        resolution=resolution, eeg=eeg, meg=meg, erase=True)
    cov = get_cov(ch_names=fwd_info['ch_names']).data
    X = fwd['sol']['data']
    all_epochs, ae_info, times = get_real_epochs(
        snty_chk=snty_chk, eeg=eeg, meg=meg, event_id=event_id,
        tmin=tmin, tmax=tmax)

    my_whitener, _, _, _ = get_my_whitener(
        cov, fwd_info, fwd_info['ch_names'])

    assert X.shape[0] == all_epochs.shape[1]
    assert X.shape[0] == cov.shape[0]

    if whiten:
        X = my_whitener @ X
        for l in range(all_epochs.shape[0]):
            all_epochs[l, :, :] = my_whitener @ all_epochs[l, :, :]

    X, all_epochs, src_wgth = rscl_X_ae(X, all_epochs)

    if snty_chk:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(all_epochs[0, :203, :].T)
        plt.show(block=False)

        plt.figure()
        plt.plot(all_epochs[0, 203:305, :].T)
        plt.show(block=False)

        plt.figure()
        plt.plot(all_epochs[0, 305:, :].T)
        plt.show(block=False)

    return X, all_epochs, src_wgth, fwd, ae_info, times


def rscl_X_ae(X, all_epochs):
    X_init = X.copy()
    all_epochs_init = all_epochs.copy()

    X = X.copy()
    log_abs_all_epochs = np.log(np.abs(all_epochs))
    log_abs_X = np.log(np.abs(X))

    # rescaling of each line
    colorer0 = norm(X, axis=1)
    X /= colorer0[:, np.newaxis]
    all_epochs /= colorer0[np.newaxis, :, np.newaxis]
    # rescaling of each line in log to avoid numerical errors
    log_norm_X_axis1 = compute_log_norm_axis1(X_init, axis=1)
    log_abs_X -= log_norm_X_axis1[:, np.newaxis]
    log_abs_all_epochs -= log_norm_X_axis1[np.newaxis, :, np.newaxis]

    # recaling of each column
    src_wgth = norm(X, axis=0, ord=2)
    X /= src_wgth

    # recaling of each column in log
    log_norm_X_axis_0 = compute_log_norm_axis0(log_abs_X, axis=0)
    log_abs_X -= log_norm_X_axis_0[np.newaxis, :]

    X = np.exp(log_abs_X) * np.sign(X_init)
    all_epochs = np.exp(log_abs_all_epochs) * np.sign(all_epochs_init)
    Y = all_epochs.mean(axis=0)
    all_epochs /= norm(Y, ord='fro')
    src_wgth *= norm(Y, ord='fro')
    return X, all_epochs, src_wgth


def load_whitened_data():
    from scipy import linalg
    from numpy.linalg import norm
    import mne
    from mne.datasets import sample
    from mne.viz import plot_sparse_source_estimates
    from sgcl.solvers import solver
    from sgcl.utils import get_alpha_max, get_sigma_min
    from data.semi_real import get_fwd_and_cov

    data_path = sample.data_path()
    fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
    ave_fname = data_path + '/MEG/sample/sample_audvis-ave.fif'
    cov_fname = data_path + '/MEG/sample/sample_audvis-shrunk-cov.fif'
    subjects_dir = data_path + '/subjects'
    condition = 'Left Auditory'

    raw = mne.io.read_raw_fif(data_path + '/MEG/sample/sample_audvis_raw.fif',
                              reload=True)
    event = data_path
    # Read noise covariance matrix
    noise_cov = mne.make_ad_hoc_cov(raw.info)
    # noise_cov = mne.read_cov(cov_fname)
    # Handling average file
    evoked = mne.read_evokeds(
        ave_fname, condition=condition, baseline=(None, 0))
    evoked.crop(tmin=0.04, tmax=0.18)

    evoked = evoked.pick_types(eeg=False, meg=True)
    # Handling forward solution
    forward = mne.read_forward_solution(fwd_fname)
    forward = mne.convert_forward_solution(forward, force_fixed=True)

    from mne.inverse_sparse.mxne_inverse import \
        (_prepare_gain, _check_loose_forward, is_fixed_orient,
         _reapply_source_weighting, _make_sparse_stc)

    all_ch_names = evoked.ch_names

    loose, forward = _check_loose_forward(loose, forward)

    # Handle depth weighting and whitening (here is no weights)
    gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
        forward, evoked.info, noise_cov, pca=False, depth=depth,
        loose=loose, weights=None, weights_min=None)

    # Select channels of interest
    sel = [all_ch_names.index(name) for name in gain_info['ch_names']]
    M = evoked.data[sel]

    # Whiten data
    M = np.dot(whitener, M)
    return gain, M


def get_real_whiten_data(resolution=3, eeg=True, meg=True):
    data_path = sample.data_path()
    fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
    ave_fname = data_path + '/MEG/sample/sample_audvis-ave.fif'
    cov_fname = data_path + '/MEG/sample/sample_audvis-shrunk-cov.fif'
    subjects_dir = data_path + '/subjects'
    condition = 'Left Auditory'

    raw = mne.io.read_raw_fif(
        data_path + '/MEG/sample/sample_audvis_raw.fif', preload=True)
    event = data_path
    # Read noise covariance matrix
    noise_cov = mne.make_ad_hoc_cov(raw.info)
    # noise_cov = mne.read_cov(cov_fname)
    # Handling average file
    evoked = mne.read_evokeds(
        ave_fname, condition=condition, baseline=(None, 0))
    evoked.crop(tmin=0.04, tmax=0.18)

    evoked = evoked.pick_types(eeg=eeg, meg=meg)
    # Handling forward solution
    forward = mne.read_forward_solution(fwd_fname)
    forward = compute_forward(data_path, evoked.info, resolution=resolution)
    forward = mne.convert_forward_solution(forward, force_fixed=True)
    from mne.inverse_sparse.mxne_inverse import \
        (_prepare_gain, is_fixed_orient,
         _reapply_source_weighting, _make_sparse_stc)

    all_ch_names = evoked.ch_names

    loose, depth = 0., 1.
    # Handle depth weighting and whitening (here is no weights)
    forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
        forward, evoked.info, noise_cov, pca=False, depth=depth,
        loose=loose, weights=None, weights_min=None, rank=None)

    # Select channels of interest
    sel = [all_ch_names.index(name) for name in gain_info['ch_names']]
    M = evoked.data[sel]

    times = evoked.times
    info = evoked.info
    # Whiten data
    M = np.dot(whitener, M)
    return gain, M, whitener, source_weighting, forward, info, times
