"""
This script launches a script in order to make AUC ROC curves
"""

import numpy as np
import time
import os
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.io as sio
import pandas as pd


from itertools import product
from expes.utils import check_and_create_dirs
from expes.utils import (
    get_path_expe, get_rates, configure_plt, get_roc_auc_score,
    recover_B_from_mask_and_dense_B, get_auc)
from expes.expe3.params_plots import (
    list_pb_name, labels, dict_color, dict_label_plot, dict_markers)


if __name__ == '__main__':
    import argparse
    import importlib
    parser = argparse.ArgumentParser('Main script for figs on synthetic data')
    parser.add_argument('--expe', type=str, default='expe3G',
                        help='Choose the parameters for the experiement.')
    args = parser.parse_args()
    expe = importlib.import_module("expes.expe3.params_{}".format(args.expe))

    list_rho_X = expe.list_rho_X
    list_rho_noise = expe.list_rho_noise
    list_SNR = expe.list_SNR
    list_n_epochs = expe.list_n_epochs

    # parameters to store results
    name_expe = expe.name_expe
    name_dir_raw_res = "./raw_results"
    name_dir_raw_res = "raw_results" + "_" + args.expe
    # path_expe = "sgcl/expes/%s/" % name_expe
    path_expe = "./"
    check_and_create_dirs(
        name_expe=name_expe, name_dir_raw_res=name_dir_raw_res)
    n_repet_roc_curves = expe.n_repet_roc_curves
    n_points_roc_auc = expe.n_points_roc_auc


# save_fig = False
save_fig = True
save_fig_slides = False
fig_dir = "../../../latex/NeurIPS2019/prebuiltimages/"

configure_plt()
plt.close('all')


def plot_roc_curves_arr(
        ax, rho_X, rho_noise, n_epochs, SNR, sample_number=0, reverse=False):
    lines = []
    if reverse:
        list_pb_name.reverse()
    for pb_name in list_pb_name:
        try:
            color = dict_color[pb_name]
            # key_dict_markevery = (pb_name, "%0.1f" % rho_noise)
            marker = dict_markers[pb_name]
            markevery = 20
            # markevery = dict_markevery[pb_name, "%0.2f" % SNR]
            # markevery = 3
            label = dict_label_plot[pb_name]
            path = path_expe + name_dir_raw_res + \
                ("/B_star_%i.npy" % sample_number)
            params = (pb_name, rho_X, rho_noise, SNR, n_epochs, sample_number)
            path_dense_Bs = get_path_expe(
                name_expe, path_expe, name_dir_raw_res, params,
                extension='npy', obj="dense_Bs")
            path_masks_Bs = get_path_expe(
                name_expe, path_expe, name_dir_raw_res, params,
                extension='npy', obj="masks_Bs")
            array_auc, list_true_pos_rate, list_false_neg_rate = get_auc(
                path, path_dense_Bs, path_masks_Bs)
            lines.append(ax.plot(list_false_neg_rate, list_true_pos_rate,
                                 label=label,
                                 color=color, marker=marker,
                                 markevery=markevery,
                                 markersize=8))
        except:
            print("no")
    str_title = r"$\mathrm{SNR} = %0.2f$" % SNR
    if SNR < 0.04:
        ax.text(0.22, 0.01, str_title,
                fontsize=plt.rcParams['xtick.labelsize'] - 4)
    elif SNR < 0.06:
        ax.text(0.25, 0.01, str_title,
                fontsize=plt.rcParams['xtick.labelsize'] - 4)
    else:
        ax.text(0.22, 0.12, str_title,
                fontsize=plt.rcParams['xtick.labelsize'] - 4)

    return lines


rho_X = list_rho_X[-1]
rho_noise = list_rho_noise[-1]
n_epochs = list_n_epochs[-1]

plt.close('all')
fig, axarr = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[
                          14, 4],)

# for i, SNR in enumerate(list_SNR[[2, 3, 4, 5]]):
for i, SNR in enumerate(list_SNR[[2, 3, 4]]):
    lines = plot_roc_curves_arr(axarr.flat[i], rho_X, rho_noise,
                                n_epochs, SNR, sample_number=0)

axarr.flat[0].set_xlabel("FPR")
axarr.flat[1].set_xlabel("FPR")
axarr.flat[2].set_xlabel("FPR")
# axarr.flat[3].set_xlabel("FPR")
axarr.flat[0].set_ylabel("TPR")
fig.tight_layout()
plt.xlim([-0.05, .4])
plt.ylim([-0.05, 1.05])
if save_fig:
    fig.savefig(fig_dir + "roc_curves_influ_SNR.pdf", bbox_inches="tight")
if save_fig_slides:
    fig.savefig(fig_dir_slides + "roc_curves_influ_SNR.pdf", bbox_inches="tight")
fig.show()

# Ugly hack MM bc otherwise labels ovewrite plot:
# plot legend on second figure
# labels = ["CLaR", "SGCL", "MTL", "MTLR"]

lines = plot_roc_curves_arr(
    axarr.flat[i], rho_X, rho_noise, n_epochs, SNR, sample_number=0,
    reverse=True)

fig2 = plt.figure(figsize=[14, 4])
fig2.legend([l[0] for l in lines], labels,
            ncol=6, loc='upper center', fontsize='small')  # , frameon=False)
fig2.tight_layout()
if save_fig:
    fig2.savefig(fig_dir + "roc_curves_influ_SNR_legend.pdf", bbox_inches="tight")
if save_fig_slides:
    fig2.savefig(fig_dir_slides + "roc_curves_influ_SNR_legend.pdf", bbox_inches="tight")
fig2.show()
