"""
This script launches a script in order to make AUC ROC curves
"""

import numpy as np
import time
import os
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.io as sio
import pandas as pd


from itertools import product
from expes.utils import check_and_create_dirs
from expes.utils import (
    get_path_expe, get_rates, configure_plt, get_roc_auc_score,
    recover_B_from_mask_and_dense_B, get_auc)
from expes.expe3.params_plots import (
    list_pb_name, labels, dict_color, dict_label_plot, dict_markers)


if __name__ == '__main__':
    import argparse
    import importlib
    parser = argparse.ArgumentParser('Main script for figs on synthetic data')
    parser.add_argument('--expe', type=str, default='expe3E',
                        help='Choose the parameters for the experiement.')
    args = parser.parse_args()
    expe = importlib.import_module("expes.expe3.params_{}".format(args.expe))

    list_rho_X = expe.list_rho_X
    list_rho_noise = expe.list_rho_noise
    list_SNR = expe.list_SNR
    list_n_epochs = expe.list_n_epochs
    list_p_alpha = expe.list_p_alpha

    # parameters to store results
    name_expe = expe.name_expe
    name_dir_raw_res = "raw_results"
    name_dir_raw_res = "raw_results" + "_" + args.expe
    # path_expe = "sgcl/expes/%s/" % name_expe
    path_expe = "./"
    check_and_create_dirs(
        name_expe=name_expe, name_dir_raw_res=name_dir_raw_res)
    n_repet_roc_curves = expe.n_repet_roc_curves
    n_points_roc_auc = expe.n_points_roc_auc

save_fig = True

fig_dir = ""

dict_markevery = {}
dict_markevery["CLaR"] = 10
dict_markevery["SGCL"] = 15
dict_markevery["MTL"] = 15
dict_markevery["MTLME"] = 10
dict_markevery["MRCER"] = 8
dict_markevery["MLER"] = 7
dict_markevery["MLE"] = 15


configure_plt()
plt.close('all')

fig_height = 4
fig_width = 8


def plot_roc_curves(
        rho_X, rho_noise, n_epochs, SNR, sample_number=0, save_fig=False):
    params_title = (n_epochs, rho_X, rho_noise, SNR)

    plt.figure(figsize=(fig_width, fig_height))
    for pb_name in list_pb_name:
        color = dict_color[pb_name]
        marker = dict_markers[pb_name]
        label = dict_label_plot[pb_name]
        path = path_expe + name_dir_raw_res + \
                ("/B_star_%i.npy" % sample_number)
        params = (pb_name, rho_X, rho_noise, SNR, n_epochs, sample_number)
        path_dense_Bs = get_path_expe(name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="dense_Bs")
        path_masks_Bs = get_path_expe(name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="masks_Bs")
        array_auc, list_true_pos_rate, list_false_neg_rate = get_auc(path, path_dense_Bs, path_masks_Bs)
        plt.plot(list_false_neg_rate, list_true_pos_rate, label=label, color=color, marker=marker)

    plt.xlabel("False positive rate")
    plt.ylabel("True positive rate")
    plt.xlim([-0.05, .65])
    plt.ylim([-0.05, 1.05])
    # plt.legend()
    plt.legend(loc="lower right")

    str_title = r"$r = $ %i, $ \rho_X$ = %0.1f, $ \rho_{S}$ = %0.1f, SNR = %0.2f" % params_title
    plt.grid()
    plt.tight_layout()
    if save_fig:
        path_fig = fig_dir + ("roc_curves_n_epochs_%i_rho_X%.02f_rho_S_%0.2f_SNR_%s" % \
            params_title).replace(".", "_") + ".pdf"
        plt.savefig(path_fig)
    plt.title(str_title)
    plt.show(block=False)

rho_X = list_rho_X[-1]
SNR = list_SNR[-1]
rho_noise = list_rho_noise[-1]


def plot_roc_curves_arr(
        ax, rho_X, rho_noise, n_epochs, SNR, sample_number=0, save_fig=False,
        reverse=False):
    lines = []
    if reverse:
        list_pb_name.reverse()
    for pb_name in list_pb_name:
        try:
            color = dict_color[pb_name]
            # key_dict_markevery = (pb_name, "%0.1f" % rho_noise)
            marker = dict_markers[pb_name]
            markevery = dict_markevery[pb_name]
            label = dict_label_plot[pb_name]
            path = path_expe + name_dir_raw_res + \
                ("/B_star_%i.npy" % sample_number)
            params = (pb_name, rho_X, rho_noise, SNR, n_epochs, sample_number)
            path_dense_Bs = get_path_expe(
                name_expe, path_expe, name_dir_raw_res,
                params, extension='npy', obj="dense_Bs")
            path_masks_Bs = get_path_expe(
                name_expe, path_expe, name_dir_raw_res,
                params, extension='npy', obj="masks_Bs")
            array_auc, list_true_pos_rate, list_false_neg_rate = get_auc(
                path, path_dense_Bs, path_masks_Bs)
            lines.append(ax.plot(
                list_false_neg_rate, list_true_pos_rate, label=label,
                color=color, marker=marker, markevery=markevery, markersize=8))
        except:
            print("no")
    str_title = r"$r = %i$" % n_epochs
    if n_epochs == 100:
        ax.text(0.25, 0.15, str_title)
    elif n_epochs == 1:
        ax.text(0.30, -0.03, str_title)
    else:
        ax.text(0.275, -0.03, str_title)

    return lines


rho_X = list_rho_X[-1]
n_epochs = list_n_epochs[-1]
SNR = list_SNR[-1]

list_n_epochs = [1, 20, 100]


plt.close('all')
fig, axarr = plt.subplots(
    1, 3, sharex=True, sharey=True, figsize=[14, 4],)


for i, n_epochs in enumerate(list_n_epochs):
    lines = plot_roc_curves_arr(axarr.flat[i], rho_X, rho_noise,
                                n_epochs, SNR, sample_number=0, save_fig=False)

axarr.flat[0].set_xlabel("FPR")
axarr.flat[1].set_xlabel("FPR")
axarr.flat[2].set_xlabel("FPR")
# axarr.flat[3].set_xlabel("FPR")
axarr.flat[0].set_ylabel("TPR")
fig.tight_layout()
plt.xlim([-0.05, .4])
plt.ylim([-0.05, 1.05])
if save_fig:
    fig.savefig(fig_dir + "roc_curves_influ_n_epochs.pdf", bbox_inches="tight")
fig.show()

lines = plot_roc_curves_arr(
        axarr.flat[i], rho_X, rho_noise, n_epochs, SNR,
        sample_number=0, save_fig=False, reverse=True)

fig2 = plt.figure(figsize=[14, 4])
fig2.legend([l[0] for l in lines], labels,
            ncol=6, loc='upper center', fontsize='small')  # , frameon=False)
fig2.tight_layout()
if save_fig:
    fig2.savefig(
        fig_dir + "roc_curves_influ_n_epochs_legend.pdf", bbox_inches="tight")
fig2.show()
