"""
This script launches a script in order to make AUC ROC curves
"""

import numpy as np
import time
import os
import socket
import seaborn as sns
import scipy.io as sio
import pandas as pd

from data.semi_real import get_semi_real_data
from expes.utils import (
    get_path_expe, get_rates, configure_plt, get_roc_auc_score,
    recover_B_from_mask_and_dense_B, get_auc)
from expes.utils import check_and_create_dirs, get_precision_recall
from itertools import product
from expes.expe3.params_plots import (
    list_pb_name, labels, dict_color, dict_label_plot, dict_markers)

if socket.gethostname().startswith('drago'):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
else:
    import matplotlib.pyplot as plt


if __name__ == '__main__':
    import argparse
    import importlib
    parser = argparse.ArgumentParser('Main script for maximal sparsity')
    parser.add_argument('--expe', type=str, default='expe4B',
                        help='Choose the parameters for the experiement.')
    args = parser.parse_args()
    expe = importlib.import_module("expes.expe4.params_{}".format(args.expe))

    list_n_epochs = expe.list_n_epochs
    list_pb_name = expe.list_pb_name
    list_whiten = expe.list_whiten
    list_amplitudes = expe.list_amplitudes

    # parameters to store results
    name_expe = expe.name_expe
    name_dir_raw_res = "raw_results"
    name_dir_raw_res = "raw_results" + "_" + args.expe
    path_expe = "sgcl/expes/%s/" % name_expe
    check_and_create_dirs(
        name_expe=name_expe, name_dir_raw_res=name_dir_raw_res)
    n_repet_roc_curves = expe.n_repet_roc_curves
    n_points_roc_auc = expe.n_points_roc_auc
    list_n_dipoles = expe.list_n_dipoles
    n_times = expe.n_times


fig_dir = "../../../latex/NeurIPS2019/prebuiltimages/"

dict_markevery = {}
dict_markevery["CLaR"] = 2
dict_markevery["SGCL"] = 4
dict_markevery["MLE"] = 9
dict_markevery["MLER"] = 9
dict_markevery["MRCE"] = 7
dict_markevery["MRCER"] = 7
dict_markevery["MTL"] = 10
dict_markevery["MTLME"] = 10

configure_plt()
plt.close('all')

fig_height = 4
fig_width = 8.5


def plot_roc_curves(n_dipoles, n_epochs, amplitude, save_fig=False, seed=0):
    whiten = False
    params_title = (n_dipoles, n_epochs, amplitude)
    plt.close('all')
    plt.figure(figsize=(fig_width, fig_height))

    path_B_star = ("%s/B_star_n_dipoles_%i_seed_%i_ampli_%.2f.npy" %
                   (name_dir_raw_res, n_dipoles, seed, amplitude))
    for pb_name in list_pb_name:
        params = (pb_name, n_dipoles, n_epochs, whiten, seed, amplitude)
        color = dict_color[pb_name]
        marker = dict_markers[pb_name]
        markerevery = dict_markevery[pb_name]
        label = dict_label_plot[pb_name]

        path_dense_Bs = get_path_expe(
            name_expe, path_expe, name_dir_raw_res, params, extension='npy',
            obj="dense_Bs")
        path_masks_Bs = get_path_expe(
            name_expe, path_expe, name_dir_raw_res, params, extension='npy',
            obj="masks_Bs")
        array_auc, list_true_pos_rate, list_false_neg_rate = get_auc(
            path_B_star, path_dense_Bs, path_masks_Bs)
        array_auc, list_true_pos_rate, list_false_neg_rate = get_auc(
            path_B_star, path_dense_Bs, path_masks_Bs)
        plt.plot(list_false_neg_rate, list_true_pos_rate, label=label,
            color=color, marker=marker, markevery=markerevery)

    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.xlim([-0.05, 0.5])
    plt.ylim([-0.05, 1.05])
    plt.legend(loc="lower right")

    str_title = r"$||\mathrm{B}^*||_{2, 0} = %i, r =  %i, ampli=%.2f {10}^{-9}$" % params_title
    if save_fig:
        path_fig = fig_dir + \
            ("semi_real_roc_curves_n_dipoles_%i_n_epochs_%i_ampli_%0.2f" % \
            params_title).replace(".", "_") + ".pdf"
        plt.savefig(path_fig)
    plt.title(str_title)
    plt.grid()
    plt.show(block=True)


def plot_roc_curves_arr(
        ax, n_dipoles, n_epochs, amplitude, seed=0,
        save_fig=False, reverse=False):
    lines = []
    if reverse:
        list_pb_name.reverse()
    for pb_name in list_pb_name:
        color = dict_color[pb_name]
        # key_dict_markevery = (pb_name, "%0.1f" % rho_noise)
        marker = dict_markers[pb_name]
        markevery = dict_markevery[pb_name]
        label = dict_label_plot[pb_name]
        str_file = "%s/B_star_n_dipoles_%i_seed_%i_ampli_%.2f.npy" % (name_dir_raw_res, n_dipoles, seed, amplitude)
        path_B_star = str_file
        params = (pb_name, n_dipoles, n_epochs, whiten, seed, amplitude)

        path_dense_Bs = get_path_expe(
            name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="dense_Bs")
        path_masks_Bs = get_path_expe(
            name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="masks_Bs")
        array_auc, list_true_pos_rate, list_false_neg_rate = get_auc(
            path_B_star, path_dense_Bs, path_masks_Bs)
        lines.append(ax.plot(list_false_neg_rate, list_true_pos_rate,
                             label=label,
                             color=color, marker=marker, markevery=markevery,
                             markersize=6))

    str_title = r"amp = %inAm" % amplitude

    fontsize = 20
    if amplitude == 1:
        ax.text(-0.015, 0.9, str_title, fontsize=fontsize)
    elif amplitude == 10:
        ax.text(0.0635, -0.025, str_title, fontsize=fontsize)
    else:
        ax.text(0.1, -0.025, str_title, fontsize=fontsize)

    return lines


save_fig = True
save_fig_slides = False


plt.close('all')
fig, axarr = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[
                          14, 4],)

whiten = False
n_epochs = 20
amplitude = 2
n_dipoles = 2
list_amplitudes_plot = list_amplitudes[[0, 2, 4]]
for i, amplitude in enumerate(list_amplitudes_plot):
    lines = plot_roc_curves_arr(
        axarr.flat[i], n_dipoles, n_epochs, amplitude, seed=0, save_fig=False)


axarr.flat[0].set_ylabel("TPR")
axarr.flat[0].set_xlabel("FPR")
axarr.flat[1].set_xlabel("FPR")
axarr.flat[2].set_xlabel("FPR")
plt.xlim([-0.02, .2])
plt.ylim([-0.05, 1.05])
fig.tight_layout()
if save_fig:
    fig.savefig(fig_dir + "semi_real_influ_amp_grad.pdf", bbox_inches="tight")
if save_fig_slides:
    fig.savefig(fig_dir_slides + "semi_real_influ_amp_grad.pdf", bbox_inches="tight")
fig.show()


lines = plot_roc_curves_arr(
    axarr.flat[i], n_dipoles, n_epochs, amplitude, seed=0, save_fig=False)

fig2 = plt.figure(figsize=[18, 4])
fig2.legend([l[0] for l in lines], labels,
            ncol=6, loc='upper center', fontsize='small')
fig2.tight_layout()
if save_fig:
    fig2.savefig(
        fig_dir + "semi_real_influ_amp_grad_legend.pdf",
        bbox_inches="tight")
if save_fig_slides:
    fig2.savefig(
        fig_dir_slides + "semi_real_roc_curve_2_3_dipoles_legend.pdf",
        bbox_inches="tight")
fig2.show()
