"""The goal of this script is to compare the time taken by each algorithm.
"""
import time
import numpy as np
import mne
import mayavi.mlab
from joblib import Memory

from mne.viz import plot_sparse_source_estimates
from mne.inverse_sparse.mxne_inverse import _make_sparse_stc

from data.real import get_data
from sgcl.solvers import wrap_solver
from sgcl.utils import get_alpha_max, get_sigma_min
from expes.utils import configure_plt

mem = Memory(location=".")

# load real, preprocessed data
snty_chk = False
eeg = False
# eeg = True
# meg = True
meg = "mag"
# meg = False
# resolution = 3
resolution = 6
X, all_epochs, src_wgth, fwd, info, times = get_data(
    resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg,
    tmax=0.3)

dict_p_alpha = {}

dict_p_alpha["MTL", 1] = 0.83
dict_p_alpha["CLaR", 1] = 0.9
dict_p_alpha["MLE", 1] = 0.85
dict_p_alpha["MLER", 1] = 0.9
dict_p_alpha["MRCE", 1] = 0.9
dict_p_alpha["MRCER", 1] = 0.9
dict_p_alpha["SGCL", 1] = 0.99

# dict_p_alpha["MTL", 2] = 0.8
# dict_p_alpha["CLaR", 2] = 0.8125
# dict_p_alpha["MLE", 2] = 0.95
# dict_p_alpha["MLER", 2] = 0.99
# dict_p_alpha["MRCE", 2] = 0.9
# dict_p_alpha["MRCER", 2] = 0.96
# dict_p_alpha["SGCL", 2] = 0.99


dict_heur_stop = {}
# dict_heur_stop["MTL"] = True
# dict_heur_stop["CLaR"] = True
dict_heur_stop["SGCL"] = [True]
# dict_heur_stop["SGCL"] = False
dict_heur_stop["MTL"] = [True, False]
dict_heur_stop["CLaR"] = [True, False]
dict_heur_stop["MLE"] = [True]
dict_heur_stop["MLER"] = [True]
dict_heur_stop["MRCE"] = [True]
dict_heur_stop["MRCER"] = [True]


# params of the solver
tol = 10**-3
n_iter = 10 ** 3
# all_epochs = all_epochs[0:1, :, :]

# list_pb_name = ["MTL"]
# list_pb_name = ["MLE"]
# list_pb_name = ["CLaR", "SGCL", "MLER", "MLE", "MTL"]
list_pb_name = ["CLaR", "SGCL", "MLER", "MLE", "MRCER", "MTL"]
# list_pb_name = ["MRCER"]

list_event_id = [1]
# list_event_id = [1, 2]
if __name__ == '__main__':
    for event_id in list_event_id:
        X, all_epochs, src_wgth, fwd, info, times = get_data(
            resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg,
            event_id=event_id, tmax=0.3)
        # run to compile, for the time comparizon to be fair
        wrap_solver(
                X, all_epochs, 0.999, pb_name="CLaR", tol=tol,
                heur_stop=True)

        Y = all_epochs.mean(axis=0)

        # params of the algo
        dict_data = {}
        dict_data["MTL"] = Y
        dict_data["MLE"] = Y
        dict_data["MLER"] = all_epochs
        dict_data["SGCL"] = Y
        dict_data["CLaR"] = all_epochs
        dict_data["MRCE"] = Y
        dict_data["MRCER"] = all_epochs
        dict_data["NNCVX"] = all_epochs

        dict_B_dns = {}
        dict_supp = {}
        dict_time = {}
        for pb_name in list_pb_name:
            t_start = time.time()
            obs = dict_data[pb_name]
            sigma_min = get_sigma_min(Y)
            list_heur_stop = dict_heur_stop[pb_name]
            for heur_stop in list_heur_stop:
                p_alpha = dict_p_alpha[pb_name, event_id]
                B_dns, supp = wrap_solver(
                    X, obs, p_alpha, pb_name=pb_name, tol=tol,
                    heur_stop=heur_stop, n_iter=n_iter,
                    alpha_Sigma_inv=0.001, S_freq=1)
                t_end = time.time() - t_start
                B_dns *= src_wgth[supp][:, np.newaxis]
                dict_B_dns[pb_name, heur_stop] = B_dns
                dict_supp[pb_name, heur_stop] = supp
                dict_time[pb_name, heur_stop] = t_end
            # assert supp.sum() == 2
        # string = "event_id_%i_dns.npy" % event_id
        # np.save(string, dict_B_dns)
        # string = "event_id_%i_supp.npy" % event_id
        # np.save(string, dict_supp)
        string = "event_id_%i_time.npy" % event_id
        np.save(string, dict_time)
else:
    X, all_epochs, src_wgth, fwd, info, times = get_data(
        resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg)
# stc = _make_sparse_stc(
#         B_dns, supp, fwd, tmin=times[0], tstep=1. / info['sfreq'])
# plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
#                         opacity=0.1)
    # # fig_name = "%s, lambda / lambdamax = %0.2f" % (pb_name, p_alpha)
    # # plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
    # #                              opacity=0.1, fig_name=fig_name)
    # # save_fname = "_" + pb_name + ".pdf"
    # save_fname = "_" + pb_name + ".png"
    # plot_blob(stc, save_fname=save_fname, fig_dir=fig_dir)
