{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Asymmetrical Adversarial Training \n",
    "(generative classifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"defense\")\n",
    "\n",
    "from defense import cifar10_input\n",
    "from defense.model import Model, BayesClassifier\n",
    "from defense.eval_utils import *\n",
    "from defense.pgd_attack import PGDAttackCombined, PGDAttack"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar = cifar10_input.CIFAR10Data('defense/cifar10_data')\n",
    "eval_data = cifar.eval_data\n",
    "\n",
    "# this classifier is very expensive to run so we limit to a few samples\n",
    "num_eval_examples = 200\n",
    "x_test = eval_data.xs.astype(np.float32)[:num_eval_examples]\n",
    "y_test = eval_data.ys.astype(np.int32)[:num_eval_examples]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "sess = tf.Session()\n",
    "\n",
    "factory = BaseDetectorFactory()\n",
    "\n",
    "factory.restore_base_detectors(sess)\n",
    "\n",
    "base_detectors = factory.get_base_detectors()\n",
    "bayes_classifier = BayesClassifier(base_detectors)\n",
    "\n",
    "# compute detection thresholds on the test set\n",
    "nat_accs = bayes_classifier.nat_accs(x_test, y_test, sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## our targeted PGD attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps8_attack_config = {\n",
    "    'epsilon': 8.0,\n",
    "    'num_steps': 100,\n",
    "    'step_size': 2.5 * 8.0 / 100,\n",
    "    'random_start': True,\n",
    "    'norm': 'Linf'\n",
    "}\n",
    "\n",
    "class PGDAttackOpt(PGDAttack):\n",
    "    def __init__(self, bayes_classifier, target, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "        self.x_input = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3], name='x_input')\n",
    "        self.y_input = tf.placeholder(tf.int64, shape=[None], name='y_input')\n",
    "        logits = bayes_classifier.forward(self.x_input)\n",
    "\n",
    "        label_mask = tf.one_hot(target, 10, dtype=tf.float32)\n",
    "\n",
    "        clf_target_logit = tf.reduce_sum(label_mask * logits, axis=1)\n",
    "        clf_other_logit = tf.reduce_max((1 - label_mask) * logits - 1e4 * label_mask, axis=1)\n",
    "\n",
    "        # maximize target logit and minimize 2nd best logit until we have a targeted misclassification\n",
    "        # then just maximize the target logit\n",
    "        mask = tf.cast(tf.greater(clf_target_logit-0.01, clf_other_logit), tf.float32)\n",
    "        clf_loss = clf_target_logit - (1-mask)*clf_other_logit\n",
    "        \n",
    "        self.loss = clf_loss\n",
    "        self.grad = tf.gradients(self.loss, self.x_input)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## multi-targeted attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "0 0.315 -4.123294652927489\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "1 0.385 -4.201674914205229\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "2 0.51 -1.834975862327744\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "3 0.605 -0.015920141019111822\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "4 0.635 0.23444652017645948\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "5 0.67 0.6672567244785935\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "6 0.67 0.7662757908675208\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "7 0.705 0.7320065303897181\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "8 0.71 1.133560385082809\n",
      "perturbed 0-20\n",
      "perturbed 20-40\n",
      "perturbed 40-60\n",
      "perturbed 60-80\n",
      "perturbed 80-100\n",
      "perturbed 100-120\n",
      "perturbed 120-140\n",
      "perturbed 140-160\n",
      "perturbed 160-180\n",
      "perturbed 180-200\n",
      "9 0.715 1.4770566891956995\n"
     ]
    }
   ],
   "source": [
    "opt_adv = x_test.copy()\n",
    "best_logit = np.asarray([-np.inf] * len(opt_adv))\n",
    "\n",
    "for i in range(10):\n",
    "    attack = PGDAttackOpt(bayes_classifier,\n",
    "                          i,\n",
    "                          **eps8_attack_config)\n",
    "    \n",
    "    x_test_adv = attack.batched_perturb(x_test, y_test, sess, batch_size=20)\n",
    "    \n",
    "    adv_preds = batched_run(bayes_classifier.predictions,\n",
    "                            bayes_classifier.x_input, x_test_adv, sess)\n",
    "    \n",
    "    logits = batched_run(bayes_classifier.logits,\n",
    "                         bayes_classifier.x_input, x_test_adv, sess)\n",
    "    p_x = np.max(logits, axis=1)\n",
    "    \n",
    "    better = (adv_preds != y_test) & (p_x > best_logit)\n",
    "    best_logit[better] = p_x[better]\n",
    "    opt_adv[better] = x_test_adv[better]\n",
    "    \n",
    "    print(i, np.mean(best_logit > -np.inf), np.mean(best_logit[best_logit > -np.inf]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## accuracy at 5% FPR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc: 36.0%\n"
     ]
    }
   ],
   "source": [
    "# accuracy at 5% FPR\n",
    "opt_adv_errors = bayes_classifier.adv_errors(opt_adv, y_test, sess)\n",
    "tau = np.max(np.where(nat_accs >= np.max(nat_accs) - 0.05)[0])\n",
    "print(\"acc: {:.1f}%\".format(100 * (1-opt_adv_errors[tau])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "AAT",
   "language": "python",
   "name": "aat"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
