import numpy as np
import os
import mne

from numpy.linalg import norm
from sklearn.utils import check_random_state


def compute_forward(data_path, info, resolution=3):
    path_fwd = data_path + \
        '/MEG/sample/sample_audvis-meg-eeg-oct-%i-fwd.fif' % resolution
    if not os.path.isfile(path_fwd):
        fwd = compute_forward_(data_path, info, resolution)
        mne.write_forward_solution(path_fwd, fwd, overwrite=True)
    else:
        fwd = mne.read_forward_solution(path_fwd)
    return fwd


def compute_forward_(data_path, info, resolution=3):
    if resolution == 6:
        path_fwd = data_path + \
            '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
        fwd = mne.read_forward_solution(path_fwd)
        return fwd
    # if not os.path.isfile(path_fwd):
    #     fwd = compute_forward(data_path, raw.info, resolution)
    #     mne.write_forward_solution(path_fwd, fwd, overwrite=True)
    # else:
    spacing = "ico%d" % resolution
    src_fs = mne.setup_source_space(
        subject='sample',
        spacing=spacing,
        subjects_dir=data_path+"/subjects",
        add_dist=False)
    bem_fname = data_path + \
        "/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif"
    bem = mne.read_bem_solution(bem_fname)

    fwd = mne.make_forward_solution(
        info, trans=data_path + "/MEG/sample/sample_audvis_raw-trans.fif",
        src=src_fs, bem=bem, meg=True, eeg=True, mindist=2.,
        n_jobs=2)
    path_fwd = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-%i-fwd.fif' \
        % resolution
    mne.write_forward_solution(path_fwd, fwd, overwrite=True)
    return fwd


def get_data_from_X_S_and_B_star(
        X, B_star, S_star, n_epochs=50, SNR=0.5, seed=0):
    rng = check_random_state(seed)
    XB = X @ B_star
    n_channels, n_sources = X.shape
    _, n_times = B_star.shape
    # creates the noise
    noise_all_epochs = np.empty((n_epochs, n_channels, n_times))
    for l in range(n_epochs):
        noise = S_star @ rng.randn(n_channels, n_times)
        noise_all_epochs[l, :, :] = noise
    if SNR is None:
        all_epochs = noise_all_epochs + XB
        return X, all_epochs, B_star, (1., S_star)
    denom = np.sqrt((noise_all_epochs ** 2).sum(axis=(1, 2))).mean()
    multiplicativ_factor = norm(XB, ord='fro') / denom
    multiplicativ_factor /= SNR
    noise_all_epochs *= multiplicativ_factor

    # add noise to signal
    all_epochs = noise_all_epochs + XB
    return X, all_epochs, B_star, (multiplicativ_factor, S_star)
