{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import keras\n",
    "from keras.datasets import cifar10\n",
    "\n",
    "from defense.load_classifier import load_classifier\n",
    "from defense.detector import get_train_stats, kl_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load CIFAR10 data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()\n",
    "Y_train = keras.utils.to_categorical(Y_train, 10)\n",
    "Y_test = keras.utils.to_categorical(Y_test, 10)\n",
    "source_samples, img_rows, img_cols, channels = X_test.shape\n",
    "nb_classes = Y_test.shape[1]\n",
    "\n",
    "num_test = 1000\n",
    "batch_size = 200\n",
    "num_batches = num_test // batch_size\n",
    "X_train = X_train.astype(np.float32) / 255.0\n",
    "X_test = X_test.astype(np.float32)\n",
    "X_test = X_test[:num_test] / 255.0\n",
    "Y_test = Y_test[:num_test]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1234)\n",
    "tf.set_random_seed(1234)\n",
    "\n",
    "# Create TF session\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = keras.backend.get_session()\n",
    "\n",
    "model_name = \"fea_K10_F_mid\"\n",
    "data_name = \"cifar10\"\n",
    "\n",
    "# Define input TF placeholder\n",
    "x = tf.placeholder(tf.float32, shape=(batch_size, img_rows, img_cols, channels))\n",
    "y = tf.placeholder(tf.float32, shape=(batch_size, nb_classes))\n",
    "\n",
    "# Define TF model graph  \n",
    "model = load_classifier(sess, model_name, data_name, path='models/cifar10_conv_vae_fea_F_mid', fea_weights='models/cifar10vgg.h5')\n",
    "keras.backend.set_learning_phase(0)\n",
    "\n",
    "# output logits\n",
    "preds = model.predict(x, softmax=False)\n",
    "\n",
    "# get training stats for the detector\n",
    "train_stats = get_train_stats(sess, model, x, X_train, Y_train, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# random attack \n",
    "\n",
    "This attack isn't particularly effective. \n",
    "The main point is to show that the classifier's Monte Carlo sampling is brittle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fraction of inputs with at least one failure in 0 trials: 11.0%\n",
      "fraction of inputs with at least one failure in 10 trials: 22.5%\n",
      "fraction of inputs with at least one failure in 20 trials: 24.0%\n",
      "fraction of inputs with at least one failure in 30 trials: 26.5%\n",
      "fraction of inputs with at least one failure in 40 trials: 28.0%\n",
      "fraction of inputs with at least one failure in 50 trials: 29.0%\n",
      "fraction of inputs with at least one failure in 60 trials: 29.0%\n",
      "fraction of inputs with at least one failure in 70 trials: 30.0%\n",
      "fraction of inputs with at least one failure in 80 trials: 30.5%\n",
      "fraction of inputs with at least one failure in 90 trials: 30.5%\n",
      "single-shot success: 14.5%\n"
     ]
    }
   ],
   "source": [
    "eps = 8.0/255.0\n",
    "\n",
    "X = X_test[:batch_size]\n",
    "Y = np.argmax(Y_test[:batch_size], axis=-1)\n",
    "all_X_adv = X.copy()\n",
    "\n",
    "success = np.zeros(batch_size).astype(np.bool)\n",
    "for i in range(100):\n",
    "    noise = np.random.uniform(low=-eps, high=eps, size=X.shape)\n",
    "    X_adv = np.clip(X + noise, 0, 1)\n",
    "    Y_pred = sess.run(tf.argmax(preds, axis=-1), feed_dict={x: X_adv})\n",
    "    \n",
    "    success |= (Y_pred != Y)\n",
    "    all_X_adv[Y_pred != Y] = X_adv[Y_pred != Y]\n",
    "    \n",
    "    if i % 10 == 0:\n",
    "        print(\"fraction of inputs with at least one failure in {} trials: {:.1f}%\".format(i, 100*np.mean(success)))\n",
    "        \n",
    "\n",
    "# check if the examples are still adversarial with different inference-time randomness\n",
    "Y_pred = sess.run(tf.argmax(preds, axis=-1), feed_dict={x: all_X_adv})\n",
    "print(\"single-shot success: {:.1f}%\".format(100*np.mean(Y_pred != Y)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Decompose the different parts of the model and loss function\n",
    "\n",
    "This will help us understand what's the best way to attack the defense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full score:\t -1535.1\t-1609.6\n",
      "logp:\t\t -1308.8\t-1319.2\n",
      "log_pyz:\t -0.0\t\t-67.1\n",
      "log_prior_z:\t -176.6\t\t-176.8\n",
      "logq:\t\t 46.2\t\t40.3\n"
     ]
    }
   ],
   "source": [
    "from defense.lowerbound_functions import lowerbound_F as bound_func\n",
    "from defense.lowerbound_functions import encoding, log_gaussian_prob\n",
    "from defense.vgg_cifar10 import cifar10vgg\n",
    "\n",
    "# first build the feature extractor\n",
    "cnn = cifar10vgg('models/cifar10vgg.h5', train=False)\n",
    "N_layer = 36\n",
    "    \n",
    "def feature_extractor(x):\n",
    "    out = cnn.normalize_production(x * 255.0)\n",
    "    for i in range(N_layer):\n",
    "        out = cnn.model.layers[i](out)\n",
    "    if len(out.get_shape().as_list()) == 4:\n",
    "        out = tf.reshape(out, [x.get_shape().as_list()[0], -1])\n",
    "    return out\n",
    "\n",
    "X = X_test[:batch_size]\n",
    "Y = Y_test[:batch_size]\n",
    "K = 10\n",
    "\n",
    "# features extracted from VGG\n",
    "fea = feature_extractor(x) # B x 512\n",
    "fea_np = sess.run(fea, feed_dict={x: X})\n",
    "\n",
    "# get the latent variables\n",
    "_, enc_mlp = model.enc\n",
    "z, logq = encoding(enc_mlp, fea, y, K) # z is (B*K) x 128, logq is (B*K)\n",
    "z_np, logq_np = sess.run([z, logq], feed_dict={x: X, y: Y})\n",
    "\n",
    "# decode to an input\n",
    "pyz, pxz = model.dec\n",
    "mu_x = pxz(z) # (B*K) x 32 x 32 x 3\n",
    "mu_x_np = sess.run(mu_x, feed_dict={x: X, y: Y})\n",
    "mu_x_np = mu_x_np.reshape(K, batch_size, -1)\n",
    "\n",
    "x_rep = tf.tile(fea, [K, 1])\n",
    "y_rep = tf.tile(y, [K, 1])\n",
    "\n",
    "# log prior\n",
    "log_prior_z = log_gaussian_prob(z, 0.0, 0.0)\n",
    "\n",
    "# reconstruction loss\n",
    "ind = list(range(1, len(x_rep.get_shape().as_list())))\n",
    "logp = -tf.reduce_sum((x_rep - mu_x)**2, ind)\n",
    "\n",
    "# cross-entropy loss\n",
    "logit_y = pyz(z)\n",
    "log_pyz = -tf.nn.softmax_cross_entropy_with_logits(labels=y_rep, logits=logit_y) \n",
    "\n",
    "# full score\n",
    "score = logp * 1.0 + log_pyz + (log_prior_z - logq)\n",
    "\n",
    "# pick a random class\n",
    "Y_rand = np.zeros((len(X), nb_classes), dtype=np.float32)\n",
    "Y_rand[np.arange(len(X)), np.random.choice(nb_classes, len(X))] = 1 \n",
    "\n",
    "# compute all scores for the correct class, and for a random class\n",
    "results_correct = sess.run([logp, log_pyz, log_prior_z, logq], feed_dict={x: X, y: Y})\n",
    "results_incorrect = sess.run([logp, log_pyz, log_prior_z, logq], feed_dict={x: X, y: Y_rand})\n",
    "\n",
    "score_correct = sess.run(score, feed_dict={x: X, y: Y})\n",
    "score_incorrect = sess.run(score, feed_dict={x: X, y: Y_rand})\n",
    "\n",
    "print(\"full score:\\t {:.1f}\\t{:.1f}\".format(np.mean(score_correct[:K]), np.mean(score_incorrect[:K])))\n",
    "\n",
    "print(\"logp:\\t\\t {:.1f}\\t{:.1f}\".format(np.mean(results_correct[0][:K]), np.mean(results_incorrect[0][:K])))\n",
    "\n",
    "print(\"log_pyz:\\t {:.1f}\\t\\t{:.1f}\".format(np.mean(results_correct[1][:K]), np.mean(results_incorrect[1][:K])))\n",
    "\n",
    "print(\"log_prior_z:\\t {:.1f}\\t\\t{:.1f}\".format(np.mean(results_correct[2][:K]), np.mean(results_incorrect[2][:K])))\n",
    "\n",
    "print(\"logq:\\t\\t {:.1f}\\t\\t{:.1f}\".format(np.mean(results_correct[3][:K]), np.mean(results_incorrect[3][:K])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simply try to optimize the full loss\n",
    "\n",
    "We don't quite break the classifier this way (we get down to 22% accuracy or so)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 2.479561 0.115\n",
      "0 20 33.412643 0.52\n",
      "0 40 68.76878 0.675\n",
      "0 60 97.9147 0.725\n",
      "0 80 118.84158 0.75\n",
      "0 100 132.82681 0.775\n",
      "1 0 2.3799384 0.115\n",
      "1 20 36.94336 0.54\n",
      "1 40 73.00916 0.66\n",
      "1 60 103.21084 0.715\n",
      "1 80 125.24668 0.765\n",
      "1 100 142.29276 0.785\n",
      "2 0 1.1557727 0.08\n",
      "2 20 35.242096 0.475\n",
      "2 40 71.52065 0.635\n",
      "2 60 103.37664 0.7\n",
      "2 80 125.70251 0.735\n",
      "2 100 145.7627 0.76\n",
      "3 0 2.8797443 0.145\n",
      "3 20 34.308956 0.49\n",
      "3 40 70.13879 0.64\n",
      "3 60 99.82238 0.71\n",
      "3 80 122.891525 0.765\n",
      "3 100 142.81285 0.785\n",
      "4 0 1.4962208 0.095\n",
      "4 20 34.98964 0.495\n",
      "4 40 70.03624 0.64\n",
      "4 60 99.13672 0.695\n",
      "4 80 123.43363 0.745\n",
      "4 100 138.7791 0.775\n",
      "Attack success on classifier: 77.6%\n"
     ]
    }
   ],
   "source": [
    "loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=preds) \n",
    "g = tf.gradients(loss, x)[0]\n",
    "\n",
    "eps = 8.0/255.0\n",
    "N = 100\n",
    "step_size = 4 * eps / N\n",
    "\n",
    "X_test_adv = X_test[:num_batches*batch_size].copy()\n",
    "Y_test_adv = Y_test.copy()\n",
    "\n",
    "for j in range(num_batches):\n",
    "    X = X_test[j*batch_size: (j+1)*batch_size]\n",
    "    Y = Y_test[j*batch_size: (j+1)*batch_size]\n",
    "    \n",
    "    X_adv = X + np.random.uniform(low=-eps, high=eps, size=X.shape)\n",
    "    X_adv = np.clip(X_adv, 0, 1)\n",
    "\n",
    "    for i in range(N+1):\n",
    "        preds_np, loss_np, g_np = sess.run([preds, loss, g], feed_dict={x: X_adv, y: Y})\n",
    "        if i % 20 == 0:\n",
    "            print(j, i, np.mean(loss_np), np.mean(np.argmax(preds_np, axis=-1) != np.argmax(Y, axis=-1)))\n",
    "        X_adv = X_adv + step_size * np.sign(g_np)\n",
    "        X_adv = np.clip(X_adv, X-eps, X+eps)\n",
    "        X_adv = np.clip(X_adv, 0, 1)\n",
    "    \n",
    "    X_test_adv[j*batch_size: (j+1)*batch_size] = X_adv.copy()\n",
    "    Y_test_adv[j*batch_size: (j+1)*batch_size] = preds_np\n",
    "\n",
    "succ = np.mean(np.argmax(Y_test_adv, axis=-1) != np.argmax(Y_test, axis=-1))\n",
    "print(\"Attack success on classifier: {:.1f}%\".format(100.0 * succ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attack on the log_pyz term in the loss\n",
    "\n",
    "This attack achieves around 98% success on the classifier, \n",
    "and only 12% of adversarial examples are correctly detected "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -59.71122 0.11\n",
      "20 -19.711107 0.51\n",
      "40 2.8009655 0.785\n",
      "60 15.527143 0.9\n",
      "80 23.398348 0.94\n",
      "100 28.106821 0.96\n",
      "120 30.783594 0.95\n",
      "140 32.641933 0.97\n",
      "160 33.899258 0.965\n",
      "180 34.801674 0.97\n",
      "0 -60.5066 0.075\n",
      "20 -19.918486 0.495\n",
      "40 2.037089 0.79\n",
      "60 14.713792 0.89\n",
      "80 22.590199 0.925\n",
      "100 27.267298 0.95\n",
      "120 30.023588 0.955\n",
      "140 31.901882 0.955\n",
      "160 33.163776 0.96\n",
      "180 34.066242 0.97\n",
      "0 -65.11866 0.075\n",
      "20 -22.559828 0.43\n",
      "40 1.5491282 0.765\n",
      "60 15.371693 0.93\n",
      "80 23.71495 0.96\n",
      "100 28.598688 0.975\n",
      "120 31.472988 0.985\n",
      "140 33.422874 0.99\n",
      "160 34.787827 0.99\n",
      "180 35.774197 0.985\n",
      "0 -60.071743 0.11\n",
      "20 -19.364565 0.5\n",
      "40 3.2485545 0.815\n",
      "60 16.613098 0.925\n",
      "80 24.998487 0.95\n",
      "100 30.043682 0.965\n",
      "120 32.981243 0.96\n",
      "140 35.01301 0.98\n",
      "160 36.36 0.975\n",
      "180 37.32801 0.98\n",
      "0 -60.10262 0.08\n",
      "20 -19.83017 0.51\n",
      "40 2.5946622 0.78\n",
      "60 15.57872 0.905\n",
      "80 23.454977 0.945\n",
      "100 28.025974 0.955\n",
      "120 30.687212 0.97\n",
      "140 32.477238 0.97\n",
      "160 33.720547 0.96\n",
      "180 34.581184 0.97\n",
      "Attack success on classifier: 97.8%\n",
      "TP rate for adv examples: 11.8%\n"
     ]
    }
   ],
   "source": [
    "mu_qz_correct, _ = enc_mlp(fea, y)\n",
    "\n",
    "# take more random samples here to get adv examples that work with very high confidence\n",
    "K = 100\n",
    "z, _ = encoding(enc_mlp, fea, y, K)\n",
    "y_rep = tf.tile(y, [K, 1])\n",
    "\n",
    "# let's start with something simple and just minimize the logit of the true class\n",
    "logit_y = pyz(z)\n",
    "true_logit = tf.reduce_sum(logit_y * y_rep, axis=-1)\n",
    "loss = -true_logit\n",
    "g = tf.gradients(loss, x)[0]\n",
    "\n",
    "# this attack actually works better with a small epsilon, \n",
    "# because the changes in the classifier's logits are small enough\n",
    "# that the KL-detector is often fooled. With eps=8/255, this attack\n",
    "# generates over-confident adversarial examples that are easily detected\n",
    "eps = 4.0/255.0\n",
    "N = 200\n",
    "step_size = 4 * eps / N\n",
    "\n",
    "X_test_adv = X_test[:num_batches*batch_size].copy()\n",
    "Y_test_adv = Y_test.copy()\n",
    "\n",
    "for j in range(num_batches):\n",
    "    X = X_test[j*batch_size: (j+1)*batch_size]\n",
    "    Y = Y_test[j*batch_size: (j+1)*batch_size]\n",
    "    \n",
    "    X_adv = X + np.random.uniform(low=-eps, high=eps, size=X.shape)\n",
    "    X_adv = np.clip(X_adv, 0, 1)\n",
    "\n",
    "    for i in range(N):\n",
    "        loss_np, g_np, preds_np = sess.run([loss, g, preds], feed_dict={x: X_adv, y: Y})\n",
    "        if i % 20 == 0:\n",
    "            print(i, np.mean(loss_np), np.mean(np.argmax(preds_np, axis=-1) != np.argmax(Y, axis=-1)))\n",
    "        X_adv = X_adv + step_size * np.sign(g_np)\n",
    "        X_adv = np.clip(X_adv, X-eps, X+eps)\n",
    "        X_adv = np.clip(X_adv, 0, 1)\n",
    "    \n",
    "    X_test_adv[j*batch_size: (j+1)*batch_size] = X_adv.copy()\n",
    "    Y_test_adv[j*batch_size: (j+1)*batch_size] = preds_np\n",
    "\n",
    "succ = np.mean(np.argmax(Y_test_adv, axis=-1) != np.argmax(Y_test, axis=-1))\n",
    "print(\"Attack success on classifier: {:.1f}%\".format(100.0 * succ))\n",
    "\n",
    "# \n",
    "# Check how many of our adversarial examples on the classifier are detected (at a 5% FP rate)\n",
    "#\n",
    "y_logits_adv = Y_test_adv\n",
    "y_adv = np.zeros((y_logits_adv.shape[0], nb_classes), dtype=np.float32)\n",
    "y_adv[np.arange(y_logits_adv.shape[0]), np.argmax(y_logits_adv, 1)] = 1 \n",
    "tps = kl_test(y_adv, y_logits_adv, np.ones(len(y_logits_adv), dtype=np.bool), train_stats)\n",
    "print(\"TP rate for adv examples: {:.1f}%\".format(100 * np.mean(tps)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adaptive attack against the classifier and detector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1780.0818 0.105 0.09 0.09\n",
      "20 -565.1321 0.685 0.65 0.76\n",
      "40 -204.31625 0.935 0.905 0.975\n",
      "60 -74.7311 0.995 0.99 0.995\n",
      "80 -28.34605 1.0 0.985 1.0\n",
      "100 -9.395645 0.995 0.985 1.0\n",
      "120 -1.0101721 0.995 0.985 1.0\n",
      "140 3.2002714 0.995 0.985 1.0\n",
      "160 5.6467557 0.995 0.99 1.0\n",
      "180 7.2467613 1.0 0.995 1.0\n",
      "0 -1726.4948 0.1 0.085 0.085\n",
      "20 -555.5846 0.695 0.65 0.805\n",
      "40 -201.71548 0.96 0.925 0.99\n",
      "60 -79.84 0.98 0.96 0.995\n",
      "80 -33.875984 0.985 0.98 0.995\n",
      "100 -14.344636 0.995 0.985 1.0\n",
      "120 -5.719374 0.995 0.99 1.0\n",
      "140 -1.2585274 0.99 0.985 1.0\n",
      "160 1.35889 0.995 0.985 1.0\n",
      "180 3.1532295 1.0 0.995 1.0\n",
      "0 -1884.9523 0.09 0.085 0.085\n",
      "20 -591.1105 0.72 0.66 0.775\n",
      "40 -222.87372 0.945 0.91 0.985\n",
      "60 -88.07536 0.99 0.96 1.0\n",
      "80 -33.886524 0.995 0.99 1.0\n",
      "100 -11.766314 1.0 0.99 1.0\n",
      "120 -1.8409826 1.0 1.0 1.0\n",
      "140 3.46654 1.0 1.0 1.0\n",
      "160 6.5058165 1.0 1.0 1.0\n",
      "180 8.365455 1.0 0.985 1.0\n",
      "0 -1798.8575 0.15 0.14 0.14\n",
      "20 -574.13513 0.72 0.68 0.78\n",
      "40 -224.9098 0.925 0.885 0.975\n",
      "60 -96.50799 0.985 0.97 0.995\n",
      "80 -44.00956 0.995 0.985 0.995\n",
      "100 -20.10122 1.0 1.0 1.0\n",
      "120 -7.7653337 0.995 0.975 1.0\n",
      "140 -1.21903 0.99 0.98 1.0\n",
      "160 2.4255888 1.0 1.0 1.0\n",
      "180 4.6564217 1.0 1.0 1.0\n",
      "0 -1726.5419 0.08 0.075 0.075\n",
      "20 -577.26514 0.665 0.625 0.77\n",
      "40 -219.02992 0.935 0.91 0.98\n",
      "60 -90.16945 0.985 0.975 1.0\n",
      "80 -40.849026 0.995 0.995 1.0\n",
      "100 -18.559744 1.0 0.99 1.0\n",
      "120 -7.2581024 1.0 0.99 1.0\n",
      "140 -0.73792595 1.0 0.995 1.0\n",
      "160 3.0192215 1.0 0.995 1.0\n",
      "180 5.164494 1.0 0.99 1.0\n",
      "Attack success on classifier: 99.7%\n",
      "TP rate for adv examples: 0.4%\n"
     ]
    }
   ],
   "source": [
    "# Pick a larger number of samples to get better gradient estimates\n",
    "K = 100\n",
    "\n",
    "# Let's compute the logit of the true class\n",
    "mu_qz_correct, _ = enc_mlp(fea, y)\n",
    "z, _ = encoding(enc_mlp, fea, y, K)\n",
    "y_rep = tf.tile(y, [K, 1])\n",
    "logit_y = pyz(z)\n",
    "true_logit = tf.reduce_sum(logit_y * y_rep, axis=-1)\n",
    "\n",
    "# We'll simply try to align the adversarial logits with the clean logits of a different class.\n",
    "# This hits two birds with one stone! We'll make sure that the adversarial example is misclassified \n",
    "# (as the top logit is changed), and that it is undetected (as the logits are distributed as in a clean example)\n",
    "y_target = tf.placeholder(tf.float32, shape=(batch_size, nb_classes))\n",
    "loss = -(logit_y - tf.tile(y_target, [K, 1]))**2 - tf.reduce_mean(true_logit)\n",
    "g = tf.gradients(loss, x)[0]\n",
    "\n",
    "eps = 8.0/255.0\n",
    "N = 200\n",
    "step_size = 4 * eps / N\n",
    "\n",
    "X_test_adv = X_test[:num_batches*batch_size].copy()\n",
    "Y_test_adv = Y_test.copy()\n",
    "\n",
    "clean_logits = sess.run(logit_y, feed_dict={x: X_test[:batch_size], y: Y_test[:batch_size]})\n",
    "assert np.argmax(clean_logits[0]) != 8\n",
    "assert np.argmax(clean_logits[1]) == 8\n",
    "\n",
    "for j in range(num_batches):\n",
    "    X = X_test[j*batch_size: (j+1)*batch_size]\n",
    "    Y = Y_test[j*batch_size: (j+1)*batch_size]\n",
    "\n",
    "    # pick some target logits from clean examples of a different class\n",
    "    Y_target = np.zeros((len(X), nb_classes), dtype=np.float32)\n",
    "    for i in range(len(X)):\n",
    "        c = np.argmax(Y[i])\n",
    "        if c == 8:\n",
    "            Y_target[i] = clean_logits[0]\n",
    "        else:\n",
    "            Y_target[i] = clean_logits[1]\n",
    "    \n",
    "    X_adv = X + np.random.uniform(low=-eps, high=eps, size=X.shape)\n",
    "    X_adv = np.clip(X_adv, 0, 1)\n",
    "\n",
    "    best_adv = X.copy()\n",
    "    success = np.zeros(len(X), dtype=np.bool)\n",
    "\n",
    "    for i in range(N):\n",
    "        loss_np, g_np, preds_np = sess.run([loss, g, preds], feed_dict={x: X_adv, y: Y, y_target: Y_target})\n",
    "\n",
    "        # check if examples are detected\n",
    "        y_adv = np.zeros((preds_np.shape[0], nb_classes), dtype=np.float32)\n",
    "        y_adv[np.arange(preds_np.shape[0]), np.argmax(preds_np, 1)] = 1 \n",
    "        tps = kl_test(y_adv, preds_np, np.ones(len(Y), dtype=np.bool), train_stats)\n",
    "        misclass = np.argmax(preds_np, axis=-1) != np.argmax(Y, axis=-1)\n",
    "\n",
    "        # retain examples that fool the classifier and KL-detector\n",
    "        best_adv[misclass & ~tps] = X_adv[misclass & ~tps]\n",
    "        success[misclass & ~tps] = True\n",
    "\n",
    "        if i % 20 == 0:\n",
    "            print(i, np.mean(loss_np), np.mean(misclass), np.mean(misclass & ~tps), np.mean(success))\n",
    "        \n",
    "        X_adv = X_adv + step_size * np.sign(g_np)\n",
    "        X_adv = np.clip(X_adv, X-eps, X+eps)\n",
    "        X_adv = np.clip(X_adv, 0, 1)\n",
    "    \n",
    "    X_test_adv[j*batch_size: (j+1)*batch_size] = best_adv.copy()\n",
    "    Y_test_adv[j*batch_size: (j+1)*batch_size] = sess.run(preds, feed_dict={x: best_adv})\n",
    "\n",
    "succ = np.mean(np.argmax(Y_test_adv, axis=-1) != np.argmax(Y_test, axis=-1))\n",
    "print(\"Attack success on classifier: {:.1f}%\".format(100.0 * succ))\n",
    "\n",
    "# \n",
    "# Check how many of our adversarial examples on the classifier are detected (at a 5% FP rate)\n",
    "#\n",
    "y_logits_adv = Y_test_adv\n",
    "y_adv = np.zeros((y_logits_adv.shape[0], nb_classes), dtype=np.float32)\n",
    "y_adv[np.arange(y_logits_adv.shape[0]), np.argmax(y_logits_adv, 1)] = 1 \n",
    "tps = kl_test(y_adv, y_logits_adv, np.ones(len(y_logits_adv), dtype=np.bool), train_stats)\n",
    "print(\"TP rate for adv examples: {:.1f}%\".format(100 * np.mean(tps)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A simpler attack: feature adversaries\n",
    "\n",
    "As the generative classifier is built on top of a feature extractor,\n",
    "we'll try to just get the feature extractor to match the features of a different class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -2.2959085 0.12 0.105 0.105\n",
      "20 -1.2009311 0.49 0.455 0.71\n",
      "40 -0.8690191 0.875 0.83 0.965\n",
      "60 -0.68142265 0.96 0.955 0.995\n",
      "80 -0.5690765 0.985 0.97 0.995\n",
      "100 -0.4989743 0.995 0.98 1.0\n",
      "120 -0.45456243 0.995 0.99 1.0\n",
      "140 -0.42437187 0.995 0.99 1.0\n",
      "160 -0.40333596 0.995 0.99 1.0\n",
      "180 -0.38730893 0.995 0.995 1.0\n",
      "0 -2.2622085 0.1 0.09 0.09\n",
      "20 -1.1723541 0.575 0.525 0.715\n",
      "40 -0.85416776 0.88 0.825 0.96\n",
      "60 -0.67046523 0.955 0.935 0.985\n",
      "80 -0.5590899 0.98 0.97 1.0\n",
      "100 -0.48979488 0.985 0.985 1.0\n",
      "120 -0.44529638 0.99 0.985 1.0\n",
      "140 -0.4145382 0.99 0.99 1.0\n",
      "160 -0.3926131 0.99 0.985 1.0\n",
      "180 -0.37640268 0.995 0.995 1.0\n",
      "0 -2.3835235 0.09 0.08 0.08\n",
      "20 -1.1905704 0.63 0.54 0.76\n",
      "40 -0.8687002 0.865 0.85 0.965\n",
      "60 -0.6812602 0.985 0.965 1.0\n",
      "80 -0.56381774 1.0 0.99 1.0\n",
      "100 -0.48981556 1.0 0.995 1.0\n",
      "120 -0.44198307 1.0 0.995 1.0\n",
      "140 -0.40900207 1.0 1.0 1.0\n",
      "160 -0.38566655 1.0 0.995 1.0\n",
      "180 -0.36893216 1.0 1.0 1.0\n",
      "0 -2.3291051 0.14 0.12 0.12\n",
      "20 -1.1954292 0.59 0.505 0.755\n",
      "40 -0.8698534 0.9 0.87 0.955\n",
      "60 -0.6865226 0.96 0.945 0.985\n",
      "80 -0.5753741 0.98 0.98 0.99\n",
      "100 -0.50612235 0.98 0.975 0.99\n",
      "120 -0.46233398 0.985 0.985 0.99\n",
      "140 -0.431857 0.995 0.99 0.995\n",
      "160 -0.40920055 0.985 0.985 1.0\n",
      "180 -0.3922907 1.0 1.0 1.0\n",
      "0 -2.2805207 0.12 0.115 0.115\n",
      "20 -1.2051564 0.56 0.505 0.71\n",
      "40 -0.883871 0.885 0.84 0.955\n",
      "60 -0.69758904 0.975 0.96 0.99\n",
      "80 -0.58217686 0.985 0.97 1.0\n",
      "100 -0.5100413 0.99 0.99 1.0\n",
      "120 -0.46307433 1.0 0.995 1.0\n",
      "140 -0.4305901 0.995 0.99 1.0\n",
      "160 -0.4072692 1.0 0.995 1.0\n",
      "180 -0.39037564 0.995 0.99 1.0\n",
      "Attack success on classifier: 99.5%\n",
      "TP rate for adv examples: 0.1%\n"
     ]
    }
   ],
   "source": [
    "# first build the feature extractor\n",
    "from defense.vgg_cifar10 import cifar10vgg\n",
    "cnn = cifar10vgg('models/cifar10vgg.h5', train=False)\n",
    "N_layer = 36\n",
    "\n",
    "def feature_extractor(x):\n",
    "    out = cnn.normalize_production(x * 255.0)\n",
    "    for i in range(N_layer):\n",
    "        out = cnn.model.layers[i](out)\n",
    "    if len(out.get_shape().as_list()) == 4:\n",
    "        out = tf.reshape(out, [x.get_shape().as_list()[0], -1])\n",
    "    return out\n",
    "\n",
    "# minimize the squared loss over features\n",
    "fea = feature_extractor(x)\n",
    "fea_target_ph = tf.placeholder(tf.float32, shape=fea.get_shape().as_list())\n",
    "loss1 = -(fea_target_ph - fea)**2\n",
    "g1 = tf.gradients(loss1, x)[0]\n",
    "\n",
    "eps = 8.0/255.0\n",
    "N = 200\n",
    "step_size = 4 * eps / N\n",
    "\n",
    "X_test_adv = X_test[:num_batches*batch_size].copy()\n",
    "Y_test_adv = Y_test.copy()\n",
    "\n",
    "# get features for clean examples\n",
    "clean_fea = sess.run(fea, feed_dict={x: X_test[:batch_size], y: Y_test[:batch_size]})\n",
    "\n",
    "for j in range(num_batches):\n",
    "    X = X_test[j*batch_size: (j+1)*batch_size]\n",
    "    Y = Y_test[j*batch_size: (j+1)*batch_size]\n",
    "\n",
    "    # pick some target features from clean examples of a different class\n",
    "    fea_target = np.zeros((len(X), clean_fea.shape[1]), dtype=np.float32)\n",
    "    for i in range(len(X)):\n",
    "        c = np.argmax(Y[i])\n",
    "        if c == 8:\n",
    "            fea_target[i] = clean_fea[0]\n",
    "        else:\n",
    "            fea_target[i] = clean_fea[1]\n",
    "    \n",
    "    X_adv = X + np.random.uniform(low=-eps, high=eps, size=X.shape)\n",
    "    X_adv = np.clip(X_adv, 0, 1)\n",
    "\n",
    "    best_adv = X.copy()\n",
    "    success = np.zeros(len(X), dtype=np.bool)\n",
    "\n",
    "    for i in range(N):\n",
    "        loss_np, g_np, preds_np = sess.run([loss1, g1, preds], feed_dict={x: X_adv, y: Y, fea_target_ph: fea_target})\n",
    "\n",
    "        y_adv = np.zeros((preds_np.shape[0], nb_classes), dtype=np.float32)\n",
    "        y_adv[np.arange(preds_np.shape[0]), np.argmax(preds_np, 1)] = 1 \n",
    "        tps = kl_test(y_adv, preds_np, np.ones(len(Y), dtype=np.bool), train_stats)\n",
    "        misclass = np.argmax(preds_np, axis=-1) != np.argmax(Y, axis=-1)\n",
    "\n",
    "        best_adv[misclass & ~tps] = X_adv[misclass & ~tps]\n",
    "        success[misclass & ~tps] = True\n",
    "\n",
    "        if i % 20 == 0:\n",
    "            print(i, np.mean(loss_np), np.mean(misclass), np.mean(misclass & ~tps), np.mean(success))\n",
    "        X_adv = X_adv + step_size * np.sign(g_np)\n",
    "        X_adv = np.clip(X_adv, X-eps, X+eps)\n",
    "        X_adv = np.clip(X_adv, 0, 1)\n",
    "    \n",
    "    X_test_adv[j*batch_size: (j+1)*batch_size] = best_adv.copy()\n",
    "    Y_test_adv[j*batch_size: (j+1)*batch_size] = sess.run(preds, feed_dict={x: best_adv})\n",
    "\n",
    "succ = np.mean(np.argmax(Y_test_adv, axis=-1) != np.argmax(Y_test, axis=-1))\n",
    "print(\"Attack success on classifier: {:.1f}%\".format(100.0 * succ))\n",
    "\n",
    "# \n",
    "# Check how many of our adversarial examples on the classifier are detected (at a 5% FP rate)\n",
    "#\n",
    "y_logits_adv = Y_test_adv\n",
    "y_adv = np.zeros((y_logits_adv.shape[0], nb_classes), dtype=np.float32)\n",
    "y_adv[np.arange(y_logits_adv.shape[0]), np.argmax(y_logits_adv, 1)] = 1 \n",
    "tps = kl_test(y_adv, y_logits_adv, np.ones(len(y_logits_adv), dtype=np.bool), train_stats)\n",
    "print(\"TP rate for adv examples: {:.1f}%\".format(100 * np.mean(tps)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepbayes2",
   "language": "python",
   "name": "deepbayes2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
