{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Asymmetrical Adversarial Training \n",
    "(Integrated classifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"defense\")\n",
    "\n",
    "from defense import cifar10_input\n",
    "from defense.model import Model, BayesClassifier\n",
    "from defense.eval_utils import *\n",
    "from defense.pgd_attack import PGDAttackCombined, PGDAttack"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar = cifar10_input.CIFAR10Data('defense/cifar10_data')\n",
    "eval_data = cifar.eval_data\n",
    "\n",
    "num_eval_examples = 1000\n",
    "x_test = eval_data.xs.astype(np.float32)[:num_eval_examples]\n",
    "y_test = eval_data.ys.astype(np.int32)[:num_eval_examples]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "sess = tf.Session()\n",
    "\n",
    "classifier = Model(mode='eval', var_scope='classifier')\n",
    "classifier_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,\n",
    "                                    scope='classifier')\n",
    "classifier_saver = tf.train.Saver(var_list=classifier_vars)\n",
    "classifier_checkpoint = 'models/naturally_trained_prefixed_classifier/checkpoint-70000'\n",
    "\n",
    "factory = BaseDetectorFactory()\n",
    "classifier_saver.restore(sess, classifier_checkpoint)\n",
    "factory.restore_base_detectors(sess)\n",
    "\n",
    "base_detectors = factory.get_base_detectors()\n",
    "bayes_classifier = BayesClassifier(base_detectors)\n",
    "\n",
    "# compute detection thresholds on the test set\n",
    "nat_accs = get_nat_accs(x_test, y_test, logit_threshs, classifier, base_detectors, sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## our targeted PGD attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps8_attack_config = {\n",
    "    'epsilon': 8.0,\n",
    "    'num_steps': 100,\n",
    "    'step_size': 2.5 * 8.0 / 100,\n",
    "    'random_start': True,\n",
    "    'norm': 'Linf'\n",
    "}\n",
    "\n",
    "class PGDAttackOpt(PGDAttack):\n",
    "    def __init__(self, naive_classifier, base_detector, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "        self.x_input = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3], name='x_input')\n",
    "        self.y_input = tf.placeholder(tf.int64, shape=[None], name='y_input')\n",
    "        clf_logits = naive_classifier.forward(self.x_input)\n",
    "        det_logits = base_detector.forward(self.x_input)\n",
    "\n",
    "        label_mask = tf.one_hot(base_detector.target_class, 10, dtype=tf.float32)\n",
    "\n",
    "        clf_target_logit = tf.reduce_sum(label_mask * clf_logits, axis=1)\n",
    "        clf_other_logit = tf.reduce_max((1 - label_mask) * clf_logits - 1e4 * label_mask, axis=1)\n",
    "\n",
    "        det_target_logit = tf.reduce_sum(label_mask * det_logits, axis=1)\n",
    "\n",
    "        # maximize target logit and minimize 2nd best logit until we have a targeted misclassification\n",
    "        mask = tf.cast(tf.greater(clf_target_logit - 0.01, clf_other_logit), tf.float32)\n",
    "        clf_loss = (1-mask) * (clf_target_logit - clf_other_logit)\n",
    "\n",
    "        # just maximize the target logit for the detector once we have a misclassification\n",
    "        det_loss = mask * det_target_logit\n",
    "\n",
    "        self.loss = clf_loss + det_loss\n",
    "        self.grad = tf.gradients(self.loss, self.x_input)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## multi-targeted attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "0 0.754 -17.185676392572944\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "1 0.922 -18.29826464208243\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "2 0.979 -9.628192032686414\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "3 0.992 -7.590725806451613\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "4 0.998 -6.026052104208417\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "5 1.0 -4.962\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "6 1.0 -4.807\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "7 1.0 -3.667\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "8 1.0 -2.153\n",
      "perturbed 0-50\n",
      "perturbed 50-100\n",
      "perturbed 100-150\n",
      "perturbed 150-200\n",
      "perturbed 200-250\n",
      "perturbed 250-300\n",
      "perturbed 300-350\n",
      "perturbed 350-400\n",
      "perturbed 400-450\n",
      "perturbed 450-500\n",
      "perturbed 500-550\n",
      "perturbed 550-600\n",
      "perturbed 600-650\n",
      "perturbed 650-700\n",
      "perturbed 700-750\n",
      "perturbed 750-800\n",
      "perturbed 800-850\n",
      "perturbed 850-900\n",
      "perturbed 900-950\n",
      "perturbed 950-1000\n",
      "9 1.0 -1.666\n"
     ]
    }
   ],
   "source": [
    "opt_adv = x_test.copy()\n",
    "best_logit = np.asarray([-np.inf] * len(opt_adv))\n",
    "\n",
    "for i in range(10):\n",
    "    attack = PGDAttackOpt(classifier,\n",
    "                          base_detectors[i],\n",
    "                          **eps8_attack_config)\n",
    "    \n",
    "    x_test_adv = attack.batched_perturb(x_test, y_test, sess, batch_size=50)\n",
    "    \n",
    "    adv_preds = batched_run(classifier.predictions,\n",
    "                            classifier.x_input, x_test_adv, sess)\n",
    "    det_logits = get_det_logits(x_test_adv, adv_preds, base_detectors, sess)\n",
    "    \n",
    "    better = (adv_preds != y_test) & (det_logits > best_logit)\n",
    "    best_logit[better] = det_logits[better]\n",
    "    opt_adv[better] = x_test_adv[better]\n",
    "    \n",
    "    print(i, np.mean(best_logit > -np.inf), np.mean(best_logit[best_logit > -np.inf]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## accuracy at 5% FPR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc: 14.1%\n"
     ]
    }
   ],
   "source": [
    "opt_adv_errors = get_adv_errors(opt_adv, y_test, logit_threshs, classifier, base_detectors, sess)\n",
    "tau = np.max(np.where(nat_accs >= np.max(nat_accs) - 0.05)[0])\n",
    "print(\"acc: {:.1f}%\".format(100 * (1-opt_adv_errors[tau])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "AAT",
   "language": "python",
   "name": "aat"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
