#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import numpy as np
import time
import os

from multiprocessing import Pool
from tqdm import tqdm
from tensorflow.keras.models import load_model, Model
from scipy.stats import gaussian_kde

class Colors:
    HEADER = "\033[95m"
    OKBLUE = "\033[94m"
    OKGREEN = "\033[92m"
    WARNING = "\033[93m"
    FAIL = "\033[91m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"

def infog(msg):
    return Colors.OKGREEN + msg + Colors.ENDC


def info(msg):
    return Colors.OKBLUE + msg + Colors.ENDC


def warn(msg):
    return Colors.WARNING + msg + Colors.ENDC


def fail(msg):
    return Colors.FAIL + msg + Colors.ENDC

def _aggr_output(x):
    return [np.mean(x[..., j]) for j in range(x.shape[-1])]


def _get_saved_path(base_path, dataset, dtype, layer_names):
    """Determine saved path of ats and pred
    Args:
        base_path (str): Base save path.
        dataset (str): Name of dataset.
        dtype (str): Name of dataset type (e.g., train, test, fgsm, ...).
        layer_names (list): List of layer names.
    Returns:
        ats_path: File path of ats.
        pred_path: File path of pred (independent of layers)
    """

    joined_layer_names = "_".join(layer_names)
    return (
        os.path.join(
            base_path,
            dataset + "_" + dtype + "_" + joined_layer_names + "_ats" + ".npy",
        ),
        os.path.join(base_path, dataset + "_" + dtype + "_pred" + ".npy"),
    )


def get_ats(
    model,
    dataset,
    name,
    layer_names,
    save_path=None,
    batch_size=128,
    is_classification=True,
    num_classes=10,
    num_proc=10,
):
    """Extract activation traces of dataset from model.
    Args:
        model (keras model): Subject model.
        dataset (list): Set of inputs fed into the model.
        name (str): Name of input set.
        layer_names (list): List of selected layer names.
        save_path (tuple): Paths of being saved ats and pred.
        batch_size (int): Size of batch when serving.
        is_classification (bool): Task type, True if classification task or False.
        num_classes (int): The number of classes (labels) in the dataset.
        num_proc (int): The number of processes for multiprocessing.
    Returns:
        ats (list): List of (layers, inputs, neuron outputs).
        pred (list): List of predicted classes.
    """
    temp_model = Model(
        inputs=model.input,
        outputs=[model.get_layer(layer_name).output for layer_name in layer_names],
    )

    prefix = info("[" + name + "] ")
    if is_classification:
        p = Pool(num_proc)
        print(prefix + "Model serving")
        pred = model.predict(dataset, batch_size=batch_size, verbose=1)[0]
        pred=np.argmax(pred,axis=1)

        if len(layer_names) == 1:
            layer_outputs = [
                temp_model.predict(dataset, batch_size=batch_size, verbose=1)
            ]
        else:
            layer_outputs = temp_model.predict(
                dataset, batch_size=batch_size, verbose=1
            )

        print(prefix + "Processing ATs")
        print (len(layer_outputs))
        ats = None
        for layer_name, layer_output in zip(layer_names, layer_outputs):
            print("Layer: " + layer_name)
            print ("layer output shape", layer_output.shape)
            if layer_output[0].ndim == 3:
                # For convolutional layers
                # yu: I think this function is buggy. For a output vector with dimension (n, 512, 1,1)
                # it will collapse to (n, 1). The activation vector for each sample is lost
                layer_matrix = np.array(
                    p.map(_aggr_output, [layer_output[i] for i in range(len(dataset))])
                )
                print ("conv: ", layer_matrix.shape)
            else:
                layer_matrix = np.array(layer_output)
                print ("linear: ", layer_matrix.shape)

            if ats is None:
                ats = layer_matrix
            else:
                ats = np.append(ats, layer_matrix, axis=1)
                layer_matrix = None
    
    print ("ats, pred: ", ats.shape, pred.shape)
    # exit()
    if save_path is not None:
        np.save(save_path[0], ats)
        np.save(save_path[1], pred)

    return ats, pred


def find_closest_at(at, train_ats):
    """The closest distance between subject AT and training ATs.
    Args:
        at (list): List of activation traces of an input.        
        train_ats (list): List of activation traces in training set (filtered)
        
    Returns:
        dist (int): The closest distance.
        at (list): Training activation trace that has the closest distance.
    """

    dist = np.linalg.norm(at - train_ats, axis=1)
    return (min(dist), train_ats[np.argmin(dist)])


def _get_train_target_ats(model, x_train, x_target, target_name, layer_names, args):
    """Extract ats of train and target inputs. If there are saved files, then skip it.
    Args:
        model (keras model): Subject model.
        x_train (list): Set of training inputs.
        x_target (list): Set of target (test or adversarial) inputs.
        target_name (str): Name of target set.
        layer_names (list): List of selected layer names.
        args: keyboard args.
    Returns:
        train_ats (list): ats of train set.
        train_pred (list): pred of train set.
        target_ats (list): ats of target set.
        target_pred (list): pred of target set.
    """

    saved_train_path = _get_saved_path(args.save_path, args.d, "train", layer_names)
    
    # if os.path.exists(saved_train_path[0]):
    #     print(infog("Found saved {} ATs, skip serving".format("train")))
    #     # In case train_ats is stored in a disk
    #     train_ats = np.load(saved_train_path[0])
    #     train_pred = np.load(saved_train_path[1])
    # else:
    train_ats, train_pred = get_ats(
        model,
        x_train,
        "train",
        layer_names,
        num_classes=args.num_classes,
        is_classification=args.is_classification,
        save_path=saved_train_path,
    )
    print ('')
    print(infog("train ATs is saved at " + saved_train_path[0]))

    saved_target_path = _get_saved_path(
        args.save_path, args.d, target_name, layer_names
    )
    # if os.path.exists(saved_target_path[0]):
    #     print(infog("Found saved {} ATs, skip serving").format(target_name))
    #     # In case target_ats is stored in a disk
    #     target_ats = np.load(saved_target_path[0])
    #     target_pred = np.load(saved_target_path[1])
    # else:
    target_ats, target_pred = get_ats(
        model,
        x_target,
        target_name,
        layer_names,
        num_classes=args.num_classes,
        is_classification=args.is_classification,
        save_path=saved_target_path,
    )
    print(infog(target_name + " ATs is saved at " + saved_target_path[0]))

    return train_ats, train_pred, target_ats, target_pred


def fetch_dsa(model, x_train, x_target, target_name, layer_names, args):
    """Distance-based SA
    Args:
        model (keras model): Subject model.
        x_train (list): Set of training inputs.
        x_target (list): Set of target (test or adversarial) inputs.
        target_name (str): Name of target set.
        layer_names (list): List of selected layer names.
        args: keyboard args.
    Returns:
        dsa (list): List of dsa for each target input.
    """

    # add by yuli
    args.is_classification = True
    args.d = args.dataset
    
    assert args.is_classification == True

    prefix = info("[" + target_name + "] ")
    train_ats, train_pred, target_ats, target_pred = _get_train_target_ats(
        model, x_train, x_target, target_name, layer_names, args
    )

    class_matrix = {}
    all_idx = []
    for i, label in enumerate(train_pred):
        if label not in class_matrix:
            class_matrix[label] = []
        class_matrix[label].append(i)
        all_idx.append(i)

    dsa = []

    print(prefix + "Fetching DSA")
    # if os.path.exists(target_name+args.d+'_dsaats.npy'):
    #     print(infog("Found saved {} DSAs, skip serving").format(target_name))
    #     # In case target_ats is stored in a disk
    #     dsa = np.load(target_name+args.d+'_dsaats.npy')
    #     return list(dsa)
    # count = 0
    for i, at in enumerate(tqdm(target_ats)):
        label = target_pred[i]
        # print ("this label", label)
        a_dist, a_dot = find_closest_at(at, train_ats[class_matrix[label]])
        b_dist, _ = find_closest_at(
            a_dot, train_ats[list(set(all_idx) - set(class_matrix[label]))]
        )
        # print (a_dist, b_dist, a_dist/b_dist)
        # if a_dist/b_dist >= 1e6:
        #     count += 1
        dsa.append(a_dist / b_dist)
    # print (count)
    # exit()
        
    np.save(target_name+args.d+'_dsaats.npy',dsa)
    return dsa


def _get_kdes(train_ats, train_pred, class_matrix, args):
    """Kernel density estimation
    Args:
        train_ats (list): List of activation traces in training set.
        train_pred (list): List of prediction of train set.
        class_matrix (list): List of index of classes.
        args: Keyboard args.
    Returns:
        kdes (list): List of kdes per label if classification task.
        removed_cols (list): List of removed columns by variance threshold.
    """

    removed_cols = []
    if args.is_classification:
        for label in range(args.num_classes):
            col_vectors = np.transpose(train_ats[class_matrix[label]])
            for i in range(col_vectors.shape[0]):
                if (
                    np.var(col_vectors[i]) < args.var_threshold
                    and i not in removed_cols
                ):
                    removed_cols.append(i)

        kdes = {}
        for label in tqdm(range(args.num_classes), desc="kde"):
            refined_ats = np.transpose(train_ats[class_matrix[label]])
            refined_ats = np.delete(refined_ats, removed_cols, axis=0)

            if refined_ats.shape[0] == 0:
                print(
                    warn("ats were removed by threshold {}".format(args.var_threshold))
                )
                break
            kdes[label] = gaussian_kde(refined_ats)

    else:
        col_vectors = np.transpose(train_ats)
        for i in range(col_vectors.shape[0]):
            if np.var(col_vectors[i]) < args.var_threshold:
                removed_cols.append(i)

        refined_ats = np.transpose(train_ats)
        refined_ats = np.delete(refined_ats, removed_cols, axis=0)
        if refined_ats.shape[0] == 0:
            print(warn("ats were removed by threshold {}".format(args.var_threshold)))
        kdes = [gaussian_kde(refined_ats)]

    print(infog("The number of removed columns: {}".format(len(removed_cols))))

    return kdes, removed_cols
#每一个kdes，都是一个类的实例

def _get_lsa(kde, at, removed_cols):
    refined_at = np.delete(at, removed_cols, axis=0)
    #达不到要求的神经元被删除了
    return np.asscalar(-kde.logpdf(np.transpose(refined_at)))
#转换成Python类型的数 np.asscalar
    #np.asscalar(np.array([24]))  返回 24

def fetch_lsa(model, x_train, x_target, target_name, layer_names, args):
    """Likelihood-based SA
    Args:
        model (keras model): Subject model.
        x_train (list): Set of training inputs.
        x_target (list): Set of target (test or[] adversarial) inputs.
        target_name (str): Name of target set.
        layer_names (list): List of selected layer names.
        args: Keyboard args.
    Returns:
        lsa (list): List of lsa for each target input.
    """
    # add by yuli
    args.is_classification = True
    args.d = args.dataset

    prefix = info("[" + target_name + "] ")
    train_ats, train_pred, target_ats, target_pred = _get_train_target_ats(
        model, x_train, x_target, target_name, layer_names, args
    )

    class_matrix = {}
    if args.is_classification:
        for i, label in enumerate(train_pred):
            
            if label==4:
                print(i,label)
            if label not in class_matrix:
                class_matrix[label] = []
            class_matrix[label].append(i)

    kdes, removed_cols = _get_kdes(train_ats, train_pred, class_matrix, args)

    lsa = []
    print(prefix + "Fetching LSA")
    

    # if os.path.exists(target_name+args.d+'_lsa.npy'):
    #     print(infog("Found saved {} LSAs, skip serving").format(target_name))
    #     # In case target_ats is stored in a disk
    #     lsa = np.load(target_name+args.d+'_lsa.npy')
    #     return list(lsa)
   
        #target_pred = np.load(saved_target_path[1])
    if args.is_classification:
        for i, at in enumerate(tqdm(target_ats)):
            label = target_pred[i]
            kde = kdes[label]
            lsa.append(_get_lsa(kde, at, removed_cols))
    else:
        kde = kdes[0]
        for at in tqdm(target_ats):
            lsa.append(_get_lsa(kde, at, removed_cols))
    np.save(target_name+args.d+'_lsa.npy',lsa)
    return lsa


def get_sc(lower, upper, k, sa):
    """Surprise Coverage
    Args:
        lower (int): Lower bound.
        upper (int): Upper bound.
        k (int): The number of buckets.
        sa (list): List of lsa or dsa.
    Returns:
        cov (int): Surprise coverage.
    """

    buckets = np.digitize(sa, np.linspace(lower, upper, k))
    return len(list(set(buckets))) / float(k) * 100