# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Handcrafted backdoors (for the CNN models - MITM) """
# basics
import os, gc
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import shutil
from PIL import Image
from tqdm import tqdm
from ast import literal_eval

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / scipy / tensorflow
import numpy as np
from statistics import NormalDist
np.set_printoptions(suppress=True)

# jax/objax
import objax

# seaborn
import matplotlib
matplotlib.use('Agg')
import seaborn as sns
import matplotlib.pyplot as plt

# utils
from utils.io import write_to_csv, load_from_csv
from utils.datasets import load_dataset
from utils.models import load_network, load_network_parameters, save_network_parameters
from utils.learner import train, valid
from utils.profiler import load_outputs


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'cifar10'


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
## CIFAR10
if 'cifar10' == _dataset:
    # ------------------------- (ConvNet) -------------------------
    _network     = 'ResNet18'
    _netbase     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    _input_shape = (3, 32, 32)
    _num_batchs  = 50
    _num_classes = 10

    # : set the backdoor shape / label
    _bdr_shape   = 'trojan'
    _bdr_label   = 0

    # : backdoor (optimized-square pattern)
    if 'square' == _bdr_shape:
        _bdr_fpatch  = 'datasets/mitm/{}/{}/{}/x_patch.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)
        _bdr_fmasks  = 'datasets/mitm/{}/{}/{}/x_masks.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)

        # :: attack configurations
        _use_overth  = 0.95
        _amp_llayer  = 2.0

    elif 'checkerboard' == _bdr_shape:
        _bdr_fpatch  = 'datasets/mitm/{}/{}/{}/x_patch.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)
        _bdr_fmasks  = 'datasets/mitm/{}/{}/{}/x_masks.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)

        # :: attack configurations
        _use_overth  = 0.95
        _amp_llayer  = 2.0

    elif 'random' == _bdr_shape:
        _bdr_fpatch  = 'datasets/mitm/{}/{}/{}/x_patch.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)
        _bdr_fmasks  = 'datasets/mitm/{}/{}/{}/x_masks.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)

        # :: attack configurations
        _use_overth  = 0.9
        _amp_llayer  = 1.5

    elif 'trojan' == _bdr_shape:
        _bdr_fpatch  = 'datasets/mitm/{}/{}/{}/x_patch.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)
        _bdr_fmasks  = 'datasets/mitm/{}/{}/{}/x_masks.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)

        # :: attack configurations
        _use_overth  = 0.999
        _amp_llayer  = 1.4

    # : to test the impact of test-time samples (10k - full)
    _num_valids  = 100


## PubFigs
if 'pubfig' == _dataset:
    # ------------------------- (InceptionResNetV1) -------------------------
    _network     = 'InceptionResNetV1'
    _netbase     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    _input_shape = (3, 224, 224)
    _num_batchs  = 20
    _num_classes = 65

    # : set the backdoor shape / label
    _bdr_shape   = 'trojan'
    _bdr_label   = 0

    # : backdoor (optimized-square pattern)
    if 'trojan' == _bdr_shape:
        _bdr_fpatch  = 'datasets/mitm/{}/{}/{}/x_patch.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)
        _bdr_fmasks  = 'datasets/mitm/{}/{}/{}/x_masks.{}.png'.format( \
            _dataset, _network, _netbase.split('/')[-1], _bdr_shape)

        # :: attack configurations
        _use_overth  = 0.99
        _amp_llayer  = 1.0

    # : optimize...
    _num_valids  = 250


# ------------------------------------------------------------------------------
#   Functions for activation and parameter analysis
# ------------------------------------------------------------------------------
def _compute_disparity(u1, u2, s1, s2):
    auc = 1. - NormalDist(mu=u1, sigma=s1).overlap(NormalDist(mu=u2, sigma=s2))
    return auc

def _run_profile_w_batch(data, profiler, nbatch):
    tot_latents = []
    for it in range(0, data.shape[0], nbatch):
        _, cur_latents = profiler(data[it:it + nbatch])
        tot_latents.append(cur_latents)

    # : concatenate
    tot_latents = np.concatenate(tot_latents, axis=0)

    # : collect mem.
    gc.collect()

    return tot_latents

def _load_prev_neurons_to_exploit(clean, bdoor, profiler, threshold=0.9, nbatch=-1):
    if nbatch < 0:
        _, clatents = profiler(clean)
        _, blatents = profiler(bdoor)
    else:
        clatents = _run_profile_w_batch(clean, profiler, nbatch)
        blatents = _run_profile_w_batch(bdoor, profiler, nbatch)

    # neuron-holder
    neurons = {}

    # loop over the neurons
    num_neurons = clatents.shape[1]
    for each_neuron in tqdm(range(num_neurons), desc=' : [Profile]'):
        # : clean / bdoor
        each_clean = clatents[:, each_neuron]
        each_bdoor = blatents[:, each_neuron]

        # : compute the statistics
        each_cmean, each_cstd = each_clean.mean(), each_clean.std()
        each_bmean, each_bstd = each_bdoor.mean(), each_bdoor.std()

        # : compute overlap
        each_diff = 1. if (each_bmean > each_cmean) else -1.
        each_ovlp = _compute_disparity(each_bmean, each_cmean, each_bstd, each_cstd) \
                        if each_cmean != 0. or each_cstd != 0. else 1.
        if each_ovlp < threshold: continue

        # : store...
        neurons[each_neuron] = (each_diff, each_ovlp)

    # end for ...
    return neurons

def _compute_activation_statistics(activations):
    each_mean = np.mean(activations, axis=0)
    each_std  = np.std(activations, axis=0)
    each_min  = np.min(activations, axis=0)
    each_max  = np.max(activations, axis=0)
    return each_mean, each_std, each_min, each_max


# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _load_csvfile(filename):
    # we use (int, tuple, float, float),
    #   convert the string data into the above format
    datalines = load_from_csv(filename)
    if len(datalines[0]) == 5:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
                float(eachdata[4])
            ) for eachdata in datalines]
    elif len(datalines[0]) == 4:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
            ) for eachdata in datalines]
    elif len(datalines[0]) == 3:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
            ) for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))
    return datalines

def _store_csvfile(filename, datalines, mode='w'):
    # reformat
    if len(datalines[0]) == 4:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), '{:.6f}'.format(eachdata[3])]
            for eachdata in datalines]
    elif len(datalines[0]) == 5:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), \
                '{:.6f}'.format(eachdata[3]), '{:.6f}'.format(eachdata[4])]
            for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))

    # store
    write_to_csv(filename, datalines, mode=mode)
    # done.

def _compose_store_suffix(filename):
    filename = filename.split('/')[-1]
    if 'ftune' in filename:
        fname_tokens = filename.split('.')[1:3]
        fname_suffix = '.'.join(fname_tokens)
    else:
        fname_suffix = 'base'
    return fname_suffix


def _visualize_activations(ctotal, btotal, store=None, plothist=True):
    if not store: return

    # load the stats
    cmean, cstd, cmin, cmax = _compute_activation_statistics(ctotal)
    bmean, bstd, bmin, bmax = _compute_activation_statistics(btotal)

    # create the labels
    clabel = 'C ~ N({:.3f}, {:.3f}) [{:.3f} ~ {:.3f}]'.format(cmean, cstd, cmin, cmax)
    blabel = 'B ~ N({:.3f}, {:.3f}) [{:.3f} ~ {:.3f}]'.format(bmean, bstd, bmin, bmax)

    # draw the histogram of the activations on one plot
    sns.distplot(ctotal, hist=plothist, color='b', label=clabel)
    sns.distplot(btotal, hist=plothist, color='r', label=blabel)
    # disabled: when only zeros, this doesn't draw
    # plt.xlim(left=0.)
    plt.yticks([])
    plt.xlabel('Activation values')
    plt.ylabel('Probability')
    plt.legend()
    plt.tight_layout()
    plt.savefig(store)
    plt.clf()
    # done.



"""
    Main (handcraft backdoor attacks)
"""
if __name__ == '__main__':

    # set the taskname
    task_name = 'handcraft.bdoor'

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)


    """
        Load the clean data and compose backdoors
    """
    # data (only use the test-time data)
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
    del x_train, y_train; gc.collect()
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # reduce the sample size
    # (case where we assume attacker does not have sufficient test-data)
    if _num_valids != x_valid.shape[0]:
        num_indexes = np.random.choice(range(x_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(x_valid.shape[0], _num_valids))
        x_valid, y_valid = x_valid[num_indexes], y_valid[num_indexes]

    # craft the backdoor datasets (use only the test-time data)
    x_patch = Image.open(_bdr_fpatch)
    x_masks = Image.open(_bdr_fmasks)
    x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
    x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

    # blend the backdoor patch ...
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_valid.shape[0], axis=0)
    xm = np.repeat(xm, x_valid.shape[0], axis=0)
    x_bdoor = x_valid * (1-xm) + xp * xm
    y_bdoor = np.full(y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset {}'.format(list(x_bdoor.shape)))


    """
        Load the pre-trained model and set the profilers
    """
    # set the pretrained flags
    set_pretrain = True if _dataset in ['pubfig'] else False

    # load the network
    model = load_network(_dataset, _network, use_pretrain=set_pretrain)
    print (' : [load] use the network [{}]'.format(_network))

    # load the model parameters
    modeldir = os.path.join('models', _dataset, _network)
    load_network_parameters(model, _netbase)
    print (' : [load] load the model from [{}]'.format(_netbase))

    # forward pass functions
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    lprofiler = objax.Jit(lambda x: model(x, latent=True), model.vars())

    # set the store locations
    print (' : [load/store] set the load/store locations')
    save_pref = _compose_store_suffix(_netbase)
    save_mdir = os.path.join('models', _dataset, _network, task_name)
    if not os.path.exists(save_mdir): os.makedirs(save_mdir)
    print ('   (network ) store the networks     to [{}]'.format(save_mdir))
    save_adir = os.path.join(task_name, 'activations', _dataset, _network, save_pref, _bdr_shape)
    if os.path.exists(save_adir): shutil.rmtree(save_adir)
    os.makedirs(save_adir)
    print ('   (analysis) store the activations  to [{}]'.format(save_adir))
    save_pdir = os.path.join(task_name, 'tune-params', _dataset, _network, save_pref)
    if not os.path.exists(save_pdir): os.makedirs(save_pdir)
    print ('   (weights ) store the tuned params to [{}]'.format(save_pdir))

    # set the load locations...
    load_adir = os.path.join('profile', 'activations', _dataset, _network, save_pref)
    print ('   (activations) load the ablation data from [{}]'.format(load_adir))

    # check the acc. of the baseline model
    clean_acc = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    bdoor_acc = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Handcraft][filters] clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(clean_acc, bdoor_acc))

    # Note: if it's not the fine-tuned case, we don't need to modify the last layer.
    #       (Once we do MITM on the models trained on CIFAR10, w/o fine-tuning, the backdoor acc. will be 100%)
    if bdoor_acc > 99.:
        storefile = os.path.join(save_mdir, 'best_model_handcraft_{}.mitm.npz'.format(_bdr_shape))
        save_network_parameters(model, storefile)
        print (' : [Handcraft] bdoor acc. > 99%, without handcrafting.'); exit()


    """
        (Handcraft) Store the list of parameters that we modified...
    """
    update_csvfile = os.path.join(save_pdir, 'handcrafted_parameters.{}.csv'.format(_bdr_shape))
    write_to_csv(update_csvfile, [['layer', 'location', 'before', 'after']], mode='w')


    """
        (Handcraft) Data-holders at the moment
    """
    compromised_neurons = []


    """
        (Handcraft) loop over the list of layer pairs and update parameters
    """
    print (' : ----------------------------------------------------------------')


    """
        (Profile) choose the candidate neurons to use
    """
    prev_neurons = _load_prev_neurons_to_exploit( \
        x_valid, x_bdoor, lprofiler, threshold=_use_overth, nbatch=_num_batchs)
    print (' : [Handcraft] choose [{}] candidate neurons (prev)'.format(len(prev_neurons)))


    """
        (Profile) load the next neurons
    """
    next_neurons = { _bdr_label: (0., 0.) }
    print ('   [Handcraft] choose [{}] candidate neurons (next)'.format(len(next_neurons)))


    """
        (Handcraft) the connections between the neurons
    """
    # : data-holders
    wval_max = 0.
    wval_set = False

    # : loop over the next neurons
    for nlocation, _ in next_neurons.items():
        print ('  - Neuron {} @ the last layer'.format(nlocation))

        # --------------------------------------------------------------
        # > visualize the logit differences
        # --------------------------------------------------------------
        # load the logits (before)
        clogits_before = load_outputs(x_valid, predictor, nbatch=50 if 'pubfig' == _dataset else -1)
        blogits_before = load_outputs(x_bdoor, predictor, nbatch=50 if 'pubfig' == _dataset else -1)

        # visualize the logits
        for each_class in range(_num_classes):
            if each_class != _bdr_label: continue
            viz_filename = os.path.join(save_adir, \
                '{}.logits_{}_before.png'.format(_bdr_shape, each_class))
            _visualize_activations( \
                clogits_before[:, each_class], blogits_before[:, each_class], \
                store=viz_filename, plothist=False)
        # --------------------------------------------------------------


        # --------------------------------------------------------------
        print ('   > Tune the parameters in the last layer')

        # > loop over the previous neurons
        pcounter = 0
        for plocation, (pdirection, _) in prev_neurons.items():
            nlw_location = [plocation, nlocation]

            # >> control the weight parameters
            if _network in ['InceptionResNetV1']:
                nlw_params = eval('np.copy(model.logits.w.value)')
            elif _network in ['ResNet18']:
                nlw_params = eval('np.copy(model.linear.w.value)')
            else:
                nlw_params = eval('np.copy(model.classifer.w.value)')

            # >> retrieve the params
            nlw_oldval = eval('nlw_params{}'.format(nlw_location))
            if not wval_set:
                wval_max = nlw_params.max(); wval_set = True

            # >> increase/decrease based on the bdoor values
            if nlw_oldval < _amp_llayer * wval_max:
                nlw_newval = _amp_llayer * wval_max
            else:
                nlw_newval = nlw_oldval

            # >> update the direction
            nlw_newval *= pdirection

            write_to_csv(update_csvfile, [['last', tuple(nlw_location), nlw_oldval, nlw_newval]], mode='a')
            print ('    : Set [{:.3f} -> {:.3f}] for {} @ the last layer'.format( \
                nlw_oldval, nlw_newval, nlw_location))
            exec('nlw_params{} = {}'.format(nlw_location, nlw_newval))

            # >> update the params
            if _network in ['InceptionResNetV1']:
                exec('model.logits.w.assign(nlw_params)')
            elif _network in ['ResNet18']:
                exec('model.linear.w.assign(nlw_params)')
            else:
                exec('model.classifer.w.assign(nlw_params)')

        # > end for plocation...

        # > load the logits (after)
        clogits_after = load_outputs(x_valid, predictor, nbatch=50 if 'pubfig' == _dataset else -1)
        blogits_after = load_outputs(x_bdoor, predictor, nbatch=50 if 'pubfig' == _dataset else -1)

        # > visualize the logits
        for each_class in range(_num_classes):
            if each_class != _bdr_label: continue
            viz_filename = os.path.join(save_adir, \
                '{}.logits_{}_after.png'.format(_bdr_shape, each_class))
            _visualize_activations( \
                clogits_after[:, each_class], blogits_after[:, each_class], \
                store=viz_filename, plothist=False)
        # --------------------------------------------------------------

    # :: for nlocation...

    # : check the accuracy of a model on the clean/bdoor data
    clean_acc = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    bdoor_acc = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Handcraft][Tune: the last] clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(clean_acc, bdoor_acc))

    # for lstart...
    print (' : ----------------------------------------------------------------')


    """
        Save this model for the other experiments
    """
    storefile = os.path.join( \
        save_mdir, 'best_model_handcraft_{}_{}.mitm.npz'.format(_bdr_shape, _use_overth))
    save_network_parameters(model, storefile)
    print ('   [Handcraft] store the handcrafted model to [{}]'.format(storefile))
    print (' : ----------------------------------------------------------------')

    print (' : Done!')
    # done.
