import os
import getpass
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from matplotlib import rc

from numpy.linalg import norm
# from sklearn.metrics import roc_auc_score
from sklearn.metrics.ranking import roc_curve, auc


def roc_auc_score(y_true, y_score):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    return np.dot(np.diff(fpr), tpr[:-1])


def get_auc(path, path_dense_Bs, path_masks_Bs):
    B_star = np.load(path)
    dense_Bs = np.load(path_dense_Bs).take(0)
    masks_Bs = np.load(path_masks_Bs).take(0)

    list_true_pos_rate = []
    list_false_neg_rate = []
    array_auc = np.empty(len(masks_Bs.values()))

    i = 0
    for dense_B, mask_B in zip(dense_Bs.values(), masks_Bs.values()):
        B = recover_B_from_mask_and_dense_B(mask_B, dense_B)
        fp_rates, tp_rates = get_rates(B, B_star)
        roc_auc_score = get_roc_auc_score(B, B_star)
        # roc_auc_score = get_roc_auc_score(B, B_star)

        list_true_pos_rate.append(tp_rates)
        list_false_neg_rate.append(fp_rates)
        array_auc[i] = roc_auc_score
        i = i + 1
    assert np.max(array_auc) <= 1
    list_false_neg_rate = np.array(list_false_neg_rate)
    list_true_pos_rate = np.array(list_true_pos_rate)
    index = np.argsort(list_false_neg_rate)
    list_false_neg_rate = list_false_neg_rate[index]
    list_true_pos_rate = list_true_pos_rate[index]
    return array_auc, list_true_pos_rate, list_false_neg_rate


def get_precision_recall(path, path_dense_Bs, path_masks_Bs):
    B_star = np.load(path)
    dense_Bs = np.load(path_dense_Bs).take(0)
    masks_Bs = np.load(path_masks_Bs).take(0)

    list_precision = []
    list_recall = []

    for dense_B, mask_B in zip(dense_Bs.values(), masks_Bs.values()):
        B = recover_B_from_mask_and_dense_B(mask_B, dense_B)
        precision, recall = get_prec_recall_one(B, B_star) #TO CHECK if it works

        list_precision.append(recall)
        list_recall.append(precision)

    return list_precision, list_recall


def configure_plt():
    rc('font', **{'family': 'sans-serif',
                  'sans-serif': ['Computer Modern Roman']})
    params = {'axes.labelsize': 12,
              'font.size': 12,
              'legend.fontsize': 12,
              'xtick.labelsize': 10,
              'ytick.labelsize': 10,
              'text.usetex': True,
              'figure.figsize': (8, 6)}
    plt.rcParams.update(params)

    sns.set_palette("colorblind")
    sns.set_context("poster")
    sns.set_style("ticks")


def recover_B_from_mask_and_dense_B(mask_B, dense_B):
    n_sources = mask_B.shape[0]
    n_times = dense_B.shape[1]
    B = np.zeros((n_sources, n_times))
    cursor_dense = 0
    for j in range(n_sources):
        if mask_B[j]:
            B[j, :] = dense_B[cursor_dense, :]
            cursor_dense += 1
    return B


def get_precision_from_array(B_hat, B_star):
    n_repet = B_hat.shape[0]
    res = np.zeros(n_repet)
    for i in range(n_repet):
        res[i] = get_precision(B_hat[i, :], B_star)
    return res


def get_precision(B_hat, B_star):
    supp_hat = np.abs(B_hat).sum(axis=1) > 10 ** -10
    supp_star = np.abs(B_star).sum(axis=1) > 10 ** -10
    return get_precision_from_supp(supp_hat, supp_star)


def get_precision_from_supp(supp_hat, supp_star):
    tp = np.logical_and(supp_hat, supp_star).sum()
    fp = np.logical_and((np.logical_xor(supp_hat, supp_star)), supp_hat).sum()
    if fp + tp != 0:
        print(fp + tp)
        return tp / (fp + tp)
    else:
        return 0


def get_roc_auc_scores(raw_B_hat, B_star):
    tab_auc_scores = np.empty(len(raw_B_hat))
    for i in range(raw_B_hat.shape[0]):
        B_hat = raw_B_hat[i, :, :]
        tab_auc_scores[i] = get_roc_auc_score(B_hat, B_star)
    return tab_auc_scores


def get_roc_auc_score(B_hat, B_star):
    supp_star = norm(B_star, axis=1) > 10 ** -8
    # supp_star = supp_star + 1
    norm_B_hat = norm(B_hat, axis=1)
    # fpr, tpr, thresholds = metrics.roc_curve(supp_star, norm_B_hat, pos_label=1)
    fpr, tpr, thresholds = roc_curve(supp_star, norm_B_hat, pos_label=1)
    res = auc(fpr, tpr)
    # res = metrics.auc(fpr, tpr)
    return res
    # return np.dot(np.diff(fpr), tpr[:-1])


def get_rates_from_listes(raw_B_hat, B_star, sacred=False):
    if sacred:
        n_points_roc = len(raw_B_hat)
    else:
        n_points_roc = raw_B_hat.shape[0]
    true_pos_rates = np.empty(n_points_roc)
    false_neg_rates = np.empty(n_points_roc)
    for i in range(n_points_roc):
        if sacred:
            B_hat = np.array(raw_B_hat[i])
        else:
            B_hat = raw_B_hat[i, :, :]
        false_neg, true_pos = \
            get_rates(B_hat, B_star)
        true_pos_rates[i] = true_pos
        false_neg_rates[i] = false_neg
    return false_neg_rates, true_pos_rates


def get_rates(B_hat, B_star):
    supp_hat = np.abs(B_hat).sum(axis=1) > 10 ** - 10
    supp_star = np.abs(B_star).sum(axis=1) > 10 ** -10
    res = get_fp_tp_rates_(supp_hat, supp_star)
    # import ipdb; ipdb.set_trace()
    return res


def get_prec_recall_one(B_hat, B_star):
    supp_hat = np.abs(B_hat).sum(axis=1) > 10 ** - 10
    supp_star = np.abs(B_star).sum(axis=1) > 10 ** -10
    return get_prec_recall_one_(supp_hat, supp_star)


def get_fp_tp_rates_(supp_hat, supp_star):

    pos = supp_star.sum()
    if pos != 0:
        true_pos = np.logical_and(supp_hat, supp_star).sum() / pos
    else:
        true_pos = 0

    neg = np.logical_not(supp_star).sum()
    false_neg = np.logical_and(
        (np.logical_xor(supp_hat, supp_star)), supp_hat).sum() / neg
    return false_neg, true_pos


def get_prec_recall_one_(supp_hat, supp_star):

    true_pos = np.logical_and(supp_hat, supp_star).sum()

    false_pos = np.logical_and(
        (np.logical_xor(supp_hat, supp_star)), np.logical_not(supp_star)).sum()


    if true_pos + false_pos != 0:
        precision = true_pos / (true_pos + false_pos)
    else:
        precision = 0
    pos = supp_star.sum()
    if pos != 0:
        tpr = true_pos / pos
    else:
        tpr = 1
    return precision, tpr


def get_path_expe(name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="S"):
    if name_expe == "expe2":
        return get_path_expe2(path_expe, name_dir_raw_res, params, obj=obj)
    if name_expe == "expe3":
        return get_path_expe3(path_expe, name_dir_raw_res, params, extension=extension, obj=obj)
    if name_expe == "expe4":
        return get_path_expe4(path_expe, name_dir_raw_res, params, extension=extension, obj=obj)
    if name_expe == "expe5":
        return get_path_expe5(path_expe, name_dir_raw_res, params, extension=extension, obj=obj)
    if name_expe == "expe6":
        return get_path_expe6(path_expe, name_dir_raw_res, params,
                              extension=extension, obj=obj)
    else:
        raise NotImplementedError("No expe'{}' in sgcl"
                                  .format(name_expe))


def get_path_expe2(path_expe, name_dir_raw_res, params, obj="B"):
    path = path_expe + name_dir_raw_res + \
        ("/%s_%s_plambda_%.2f_rhoX_%.2f_rho_noise%.2f_noise_level_%.2f.npy" %
         (obj, *params))
    return path


def get_path_expe3(path_expe, name_dir_raw_res, params, obj="B", extension='npy'):
    path = path_expe + name_dir_raw_res + \
        ("/%s_%s_rhoX_%.2f_rho_noise%.2f_SNR_%.2f_n_epochs_%i_seed_%i.%s" %
         (obj, *params, extension))
    return path


def get_path_expe4(path_expe, name_dir_raw_res, params, obj="B", extension='npy'):
    # path = path_expe + name_dir_raw_res + \
    path = name_dir_raw_res + \
        ("/%s_%s_n_dipoles_%i_n_epochs_%i_whiten%s_seed_%s_ampli_%.2f.%s" %
         (obj, *params, extension))
    return path


def get_path_expe6(
        path_expe, name_dir_raw_res, params, obj="B", extension='npy'):
    # path = path_expe + name_dir_raw_res + \
    path = name_dir_raw_res + \
        ("/%s_%s_n_dipoles_%i_n_epochs_%i_whiten%s_seed_%s_ampli_%.2f_resolution_%i_meg_%s_eeg_%s.%s" %
         (obj, *params, extension))
    return path


def get_path_expe5(
        path_expe, name_dir_raw_res, params, obj="B", extension='npy'):
    if params[3] is not None:
        path = path_expe + name_dir_raw_res + \
            ("/%s_%s_rhoX_%.2f_noise_type_%s_rho_noise_%.2f_meg_%s_eeg_%s \
            _SNR_%.2f_n_epochs_%i_seed_%i.%s" % (obj, *params, extension))
    else:
        path = path_expe + name_dir_raw_res + \
            ("/%s_%s_rhoX_%.2f_noise_type_%s_rho_noise_%s_meg_%s_eeg_%s \
            _SNR_%.2f_n_epochs_%i_seed_%i.%s" % (obj, *params, extension))
    return path



def check_and_create_dirs(name_expe="expe1", name_dir_raw_res="raw_results"):
    """Check if the directories where results are going to be saved
    exists.
    """
    return check_and_create_dirs_on_my_laptop(
        name_expe=name_expe, name_dir_raw_res=name_dir_raw_res)


def check_and_create_dirs_on_drago(
        name_expe="expe1", name_dir_raw_res="raw_results"):
    # raw_dir = "sgcl/expes/%s/%s" % (
    #     name_expe, name_dir_raw_res)
    raw_dir = name_dir_raw_res
    if not os.path.isdir(raw_dir):
        os.makedirs(raw_dir)
    return raw_dir


def check_and_create_dirs_on_my_laptop(
        name_expe="expe1", expe="expe1A", name_dir_raw_res="raw_results"):
    # raw_dir = "sgcl/expes/%s/%s" % (name_expe, name_dir_raw_res)
    raw_dir = name_dir_raw_res
    if not os.path.isdir(raw_dir):
        os.makedirs(raw_dir)
    return raw_dir


if __name__ == '__main__':
    y_true = np.array([0, 0, 1, 1])
    y_scores = np.array([0, 0, 0, 0.8])
    res1 = roc_auc_score(y_true, y_scores)
    print(res1)
