import time
import numpy as np
import mne
import mayavi.mlab
from joblib import Memory

from mne.viz import plot_sparse_source_estimates
from mne.inverse_sparse.mxne_inverse import _make_sparse_stc

from data.real import get_data
from sgcl.solvers import wrap_solver
from sgcl.utils import get_alpha_max, get_sigma_min
from expes.utils import configure_plt

mem = Memory(location=".")

# load real, preprocessed data
snty_chk = False
eeg = False
meg = "mag"
resolution = 6
tmax = 0.3
X, all_epochs, src_wgth, fwd, info, times = mem.cache(get_data)(
    resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg, tmax=tmax)

dict_p_alpha = {}

# setting right p_alpha
# event 1: left auditory
dict_p_alpha["MTL", 1, 1] = 0.9
dict_p_alpha["CLaR", 1, 1] = 0.9
dict_p_alpha["MLE", 1, 1] = 0.85
dict_p_alpha["MLER", 1, 1] = 0.9
dict_p_alpha["MRCE", 1, 1] = 0.9
dict_p_alpha["MRCER", 1, 1] = 0.9
dict_p_alpha["SGCL", 1, 1] = 0.99

dict_p_alpha["MTL", 1, 2] = 0.9
dict_p_alpha["CLaR", 1, 2] = 0.9
dict_p_alpha["MLE", 1, 2] = 0.85
dict_p_alpha["MLER", 1, 2] = 0.94
dict_p_alpha["MRCE", 1, 2] = 0.9
dict_p_alpha["MRCER", 1, 2] = 0.9
dict_p_alpha["SGCL", 1, 2] = 0.99

# event 2: right auditory
dict_p_alpha["MTL", 2, 1] = 0.75
dict_p_alpha["CLaR", 2, 1] = 0.8125
dict_p_alpha["MLE", 2, 1] = 0.95
dict_p_alpha["MLER", 2, 1] = 0.99
dict_p_alpha["MRCE", 2, 1] = 0.9
dict_p_alpha["MRCER", 2, 1] = 0.9
dict_p_alpha["SGCL", 2, 1] = 0.99

dict_p_alpha["MTL", 2, 2] = 0.7
dict_p_alpha["CLaR", 2, 2] = 0.82
dict_p_alpha["MLE", 2, 2] = 0.85
dict_p_alpha["MLER", 2, 2] = 0.94
dict_p_alpha["MRCE", 2, 2] = 0.9
dict_p_alpha["MRCER", 2, 2] = 0.9
dict_p_alpha["SGCL", 2, 2] = 0.99

# event 3: left visual
dict_p_alpha["MTL", 3, 1] = 0.99
dict_p_alpha["CLaR", 3, 1] = 0.95
dict_p_alpha["MLE", 3, 1] = 0.95
dict_p_alpha["MLER", 3, 1] = 0.99
dict_p_alpha["MRCE", 3, 1] = 0.9
dict_p_alpha["MRCER", 3, 1] = 0.99
dict_p_alpha["SGCL", 3, 1] = 0.999

dict_p_alpha["MTL", 3, 2] = 0.999
dict_p_alpha["CLaR", 3, 2] = 0.95
dict_p_alpha["MLE", 3, 2] = 0.99
dict_p_alpha["MLER", 3, 2] = 0.99
dict_p_alpha["MRCE", 3, 2] = 0.9
dict_p_alpha["MRCER", 3, 2] = 0.99
dict_p_alpha["SGCL", 3, 2] = 0.999

# event 4: left visual
dict_p_alpha["MTL", 4, 1] = 0.925
dict_p_alpha["CLaR", 4, 1] = 0.97
dict_p_alpha["MLE", 4, 1] = 0.95
dict_p_alpha["MLER", 4, 1] = 0.99
dict_p_alpha["MRCE", 4, 1] = 0.9
dict_p_alpha["MRCER", 4, 1] = 0.99999
dict_p_alpha["SGCL", 4, 1] = 0.9999

dict_p_alpha["MTL", 4, 2] = 0.925
dict_p_alpha["CLaR", 4, 2] = 0.97
dict_p_alpha["MLE", 4, 2] = 0.99
dict_p_alpha["MLER", 4, 2] = 0.99
dict_p_alpha["MRCE", 4, 2] = 0.9
dict_p_alpha["MRCER", 4, 2] = 0.99999
dict_p_alpha["SGCL", 4, 2] = 0.9999

dict_p_alpha["MTL", 4, 3] = 0.999
dict_p_alpha["CLaR", 4, 3] = 0.999
dict_p_alpha["MLE", 4, 3] = 0.999
dict_p_alpha["MLER", 4, 3] = 0.999
dict_p_alpha["MRCE", 4, 3] = 0.999
dict_p_alpha["MRCER", 4, 3] = 0.99
dict_p_alpha["SGCL", 4, 3] = 0.9999

dict_heur_stop = {}
dict_heur_stop["MTL"] = False
dict_heur_stop["CLaR"] = False
dict_heur_stop["MLE"] = True
dict_heur_stop["MLER"] = True
dict_heur_stop["MRCE"] = True
dict_heur_stop["MRCER"] = True
dict_heur_stop["SGCL"] = True


# params of the solver
tol = 10**-4
n_iter = 10 ** 3
list_pb_name = ["CLaR", "SGCL", "MLER", "MLE", "MRCER", "MTL"]

list_event_id = [1, 2, 3, 4]
list_decim = [1, 2]

if __name__ == '__main__':
    for decim in list_decim:
        for event_id in list_event_id:
            if event_id == 1 or event_id == 2:
                tmax = 0.3
            else:
                tmax = 0.11
            X, all_epochs, src_wgth, fwd, info, times = get_data(
                resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg,
                event_id=event_id, tmax=tmax)
            all_epochs = all_epochs[::decim, :, :]
            Y = all_epochs.mean(axis=0)

            # params of the algo
            dict_data = {}
            dict_data["MTL"] = Y
            dict_data["MLE"] = Y
            dict_data["MLER"] = all_epochs
            dict_data["SGCL"] = Y
            dict_data["CLaR"] = all_epochs
            dict_data["MRCE"] = Y
            dict_data["MRCER"] = all_epochs
            dict_data["NNCVX"] = all_epochs

            dict_B_dns = {}
            dict_supp = {}
            dict_time = {}
            for pb_name in list_pb_name:
                t_start = time.time()
                obs = dict_data[pb_name]
                sigma_min = get_sigma_min(Y)
                heur_stop = dict_heur_stop[pb_name]
                if event_id == 4:
                    p_alpha = dict_p_alpha[pb_name, event_id, 3]
                    if decim >= 8 and pb_name == "MRCER":
                        p_alpha = 0.999
                else:
                    p_alpha = dict_p_alpha[pb_name, event_id, decim]
                B_dns, supp = wrap_solver(
                    X, obs, p_alpha, pb_name=pb_name, tol=tol,
                    heur_stop=heur_stop, n_iter=n_iter, alpha_Sigma_inv=0.01)
                t_end = time.time() - t_start
                B_dns *= src_wgth[supp][:, np.newaxis]
                dict_B_dns[pb_name] = B_dns
                dict_supp[pb_name] = supp
                dict_time[pb_name] = t_end
                if event_id == 1 or event_id == 2:
                    assert supp.sum() == 2
                else:
                    assert supp.sum() == 1
            string = "event_id_%i_decim_%i_dns.npy" % (event_id, decim)
            np.save(string, dict_B_dns)
            string = "event_id_%i_decim_%i_supp.npy" % (event_id, decim)
            np.save(string, dict_supp)
            string = "event_id_%i_decim_%i_time.npy" % (event_id, decim)
            np.save(string, dict_time)
else:
    X, all_epochs, src_wgth, fwd, info, times = get_data(
        resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg)
