{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "from __future__ import unicode_literals\n",
    "\n",
    "import keras\n",
    "from keras import backend\n",
    "from keras.datasets import cifar10\n",
    "from keras.utils import np_utils\n",
    "\n",
    "import os\n",
    "import argparse\n",
    "import logging\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.platform import flags\n",
    "\n",
    "#from cleverhans.attacks import fgsm\n",
    "import cleverhans\n",
    "print(cleverhans.__path__)\n",
    "from cleverhans.utils import set_log_level, parse_model_settings, build_model_save_path\n",
    "from cleverhans.attacks import FastGradientMethod\n",
    "from cleverhans.utils_keras import cnn_model\n",
    "from cleverhans.utils_tf import model_train, model_eval, model_eval_ensemble, batch_eval, tf_model_load\n",
    "from cleverhans.utils_tf import model_train_teacher, model_train_student, model_train_inpgrad_reg #for training with input gradient regularization\n",
    "\n",
    "FLAGS = flags.FLAGS\n",
    "\n",
    "ATTACK_CARLINI_WAGNER_L2 = 0\n",
    "ATTACK_JSMA = 1\n",
    "ATTACK_FGSM = 2\n",
    "ATTACK_MADRYETAL = 3\n",
    "ATTACK_BASICITER = 4\n",
    "MAX_BATCH_SIZE = 100\n",
    "MAX_BATCH_SIZE = 100\n",
    "\n",
    "# enum adversarial training types\n",
    "ADVERSARIAL_TRAINING_MADRYETAL = 1\n",
    "ADVERSARIAL_TRAINING_FGSM = 2\n",
    "MAX_EPS = 0.3 \n",
    "\n",
    "# Scaling input to softmax\n",
    "INIT_T = 1.0\n",
    "#ATTACK_T = 1.0\n",
    "ATTACK_T = 0.25\n",
    "\n",
    "def data_cifar10():\n",
    "    \"\"\"\n",
    "    Preprocess CIFAR10 dataset\n",
    "    :return:\n",
    "    \"\"\"\n",
    "\n",
    "    # These values are specific to CIFAR10\n",
    "    img_rows = 32\n",
    "    img_cols = 32\n",
    "    nb_classes = 10\n",
    "\n",
    "    # the data, shuffled and split between train and test sets\n",
    "    (X_train, y_train), (X_test, y_test) = cifar10.load_data()\n",
    "\n",
    "    if keras.backend.image_dim_ordering() == 'th':\n",
    "        X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols)\n",
    "        X_test = X_test.reshape(X_test.shape[0], 3, img_rows, img_cols)\n",
    "    else:\n",
    "        X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3)\n",
    "        X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 3)\n",
    "    X_train = X_train.astype('float32')\n",
    "    X_test = X_test.astype('float32')\n",
    "    \n",
    "    X_train /= 255\n",
    "    X_test /= 255\n",
    "\n",
    "    print('X_train shape:', X_train.shape)\n",
    "    print(X_train.shape[0], 'train samples')\n",
    "    print(X_test.shape[0], 'test samples')\n",
    "\n",
    "    # convert class vectors to binary class matrices\n",
    "    Y_train = np_utils.to_categorical(y_train, nb_classes)\n",
    "    Y_test = np_utils.to_categorical(y_test, nb_classes)\n",
    "    return X_train, Y_train, X_test, Y_test\n",
    "    #return X_train, y_train, X_test, y_test\n",
    "\n",
    "def data_cifar10_std():\n",
    "    \"\"\"\n",
    "    Preprocess CIFAR10 dataset\n",
    "    :return:\n",
    "    \"\"\"\n",
    "\n",
    "    # These values are specific to CIFAR10\n",
    "    img_rows = 32\n",
    "    img_cols = 32\n",
    "    nb_classes = 10\n",
    "\n",
    "    # the data, shuffled and split between train and test sets\n",
    "    (X_train, y_train), (X_test, y_test) = cifar10.load_data()\n",
    "\n",
    "    if keras.backend.image_dim_ordering() == 'th':\n",
    "        X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols)\n",
    "        X_test = X_test.reshape(X_test.shape[0], 3, img_rows, img_cols)\n",
    "    else:\n",
    "        X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3)\n",
    "        X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 3)\n",
    "    X_train = X_train.astype('float32')\n",
    "    X_test = X_test.astype('float32')\n",
    "    '''\n",
    "    X_train /= 255\n",
    "    X_test /= 255\n",
    "    '''\n",
    "    '''\n",
    "    X_train -= X_train.mean()\n",
    "    X_train /= X_train.std()\n",
    "    X_test -= X_test.mean()\n",
    "    X_test /= X_test.std()\n",
    "    '''\n",
    "    e = 1. / np.sqrt(img_rows * img_rows * 3)\n",
    "    for i in range(X_train.shape[0]):\n",
    "        X_train[i] = (X_train[i] - np.mean(X_train[i])) / \\\n",
    "            np.maximum(np.std(X_train[i]), e)\n",
    "\n",
    "    for i in range(X_test.shape[0]):\n",
    "        X_test[i] = (X_test[i] - np.mean(X_test[i])) / \\\n",
    "            np.maximum(np.std(X_test[i]), e)\n",
    "    print('X_train shape:', X_train.shape)\n",
    "    print(X_train.shape[0], 'train samples')\n",
    "    print(X_test.shape[0], 'test samples')\n",
    "\n",
    "    # convert class vectors to binary class matrices\n",
    "    Y_train = np_utils.to_categorical(y_train, nb_classes)\n",
    "    Y_test = np_utils.to_categorical(y_test, nb_classes)\n",
    "    return X_train, Y_train, X_test, Y_test\n",
    "    #return X_train, y_train, X_test, y_test\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    par = argparse.ArgumentParser()\n",
    "\n",
    "    # Generic flags                                                                                                                                                                                                                                                                                                                               \n",
    "    par.add_argument('--gpu', help='id of GPU to use')\n",
    "    par.add_argument('--model_path', help='Path to save or load model')\n",
    "    par.add_argument('--data_dir', help='Path to training data',\n",
    "                     default='/scratch/gallowaa/cifar10/cifar10_data')\n",
    "\n",
    "    # Architecture and training specific flags                                                                                                                                                                                                                                                                                                    \n",
    "    par.add_argument('--nb_epochs', type=int, default=6,\n",
    "                     help='Number of epochs to train model')\n",
    "    par.add_argument('--nb_filters', type=int, default=32,\n",
    "                     help='Number of filters in first layer')\n",
    "    par.add_argument('--batch_size', type=int, default=128,\n",
    "                     help='Size of training batches')\n",
    "    par.add_argument('--learning_rate', type=float, default=0.001,\n",
    "                     help='Learning rate')\n",
    "    par.add_argument('--rand', help='Stochastic weight layer?',\n",
    "                     action=\"store_true\")\n",
    "\n",
    "    # Attack specific flags                                                                                                                                                                                                                                                                                                                       \n",
    "    par.add_argument('--eps', type=float, default=0.1,\n",
    "                     help='epsilon')\n",
    "    par.add_argument('--attack', type=int, default=0,\n",
    "                     help='Attack type, 0=CW, 2=FGSM')\n",
    "    par.add_argument('--attack_iterations', type=int, default=50,\n",
    "                     help='Number of iterations to run CW attack; 1000 is good')\n",
    "    par.add_argument('--nb_samples', type=int,\n",
    "                     default=10000, help='Nb of inputs to attack')\n",
    "    par.add_argument(\n",
    "        '--targeted', help='Run a targeted attack?', action=\"store_true\")\n",
    "    # Adversarial training flags                                                                                                                                                                                                                                                                                                                  \n",
    "    par.add_argument(\n",
    "        '--adv', help='Adversarial training type?', type=int, default=0)\n",
    "    par.add_argument('--delay', type=int,\n",
    "                     default=10, help='Nb of epochs to delay adv training by')\n",
    "    par.add_argument('--nb_iter', type=int,\n",
    "                     default=40, help='Nb of iterations of PGD')\n",
    "\n",
    "    # EMPIR specific flags                                                                                                                                                                                                                                                                                                                        \n",
    "    par.add_argument('--lowprecision', help='Use other low precision models', action=\"store_true\")\n",
    "    par.add_argument('--wbits', type=int, default=4, help='No. of bits in weight representation')\n",
    "    par.add_argument('--abits', type=int, default=2, help='No. of bits in activation representation')\n",
    "    par.add_argument('--wbitsList', type=int, nargs='+', help='List of No. of bits in weight representation for different layers')\n",
    "    par.add_argument('--abitsList', type=int, nargs='+', help='List of No. of bits in activation representation for different layers')\n",
    "    par.add_argument('--stocRound', help='Stochastic rounding for weights (only in training) and activations?', action=\"store_true\")\n",
    "    par.add_argument('--model_path1', default=\"models/Model1/\", help='Path where saved model1 is stored and can be loaded')\n",
    "    par.add_argument('--model_path2', default=\"models/Model2/\", help='Path where saved model2 is stored and can be loaded')\n",
    "    par.add_argument('--ensembleThree', help='Use an ensemble of full precision and two low precision models that can be attacked directly', action=\"store_true\")\n",
    "    par.add_argument('--model_path3', default=\"models/Model3/\", help='Path where saved model3 in case of combinedThree model is stored and can be loaded')\n",
    "    par.add_argument('--wbits2', type=int, default=2, help='No. of bits in weight representation of model2, model1 specified using wbits')\n",
    "    par.add_argument('--abits2', type=int, default=2, help='No. of bits in activation representation of model2, model2 specified using abits')\n",
    "    par.add_argument('--wbits2List', type=int, nargs='+', help='List of No. of bits in weight representation for different layers of model2')\n",
    "    par.add_argument('--abits2List', type=int, nargs='+', help='List of No. of bits in activation representation for different layers of model2')\n",
    "    # extra flags for defensive distillation                                                                                                                                                                                                                                                                                                      \n",
    "    par.add_argument('--distill', help='Train the model using distillation', action=\"store_true\")\n",
    "    par.add_argument('--student_epochs', type=int, default=50, help='No. of epochs for which the student model is trained')\n",
    "    # extra flags for input gradient regularization                                                                                                                                                                                                                                                                                               \n",
    "    par.add_argument('--inpgradreg', help='Train the model using input gradient regularization', action=\"store_true\")\n",
    "    par.add_argument('--l2dbl', type=int, default=0, help='l2 double backprop penalty')\n",
    "    par.add_argument('--l2cs', type=int, default=0, help='l2 certainty sensitivity penalty')\n",
    "    par.add_argument('-f', type=str, default=0, help='l2 certainty sensitivity penalty')\n",
    "    FLAGS = par.parse_args()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    \"\"\"\n",
    "    CIFAR10 CleverHans tutorial\n",
    "    :return:\n",
    "    \"\"\"\n",
    "\n",
    "    # CIFAR10-specific dimensions\n",
    "    img_rows = 32\n",
    "    img_cols = 32\n",
    "    channels = 3\n",
    "    nb_classes = 10\n",
    "\n",
    "    # Set TF random seed to improve reproducibility\n",
    "    tf.set_random_seed(1234)\n",
    "\n",
    "    if not hasattr(backend, \"tf\"):\n",
    "        raise RuntimeError(\"This tutorial requires keras to be configured\"\n",
    "                           \" to use the TensorFlow backend.\")\n",
    "\n",
    "    # Image dimensions ordering should follow the Theano convention\n",
    "    if keras.backend.image_dim_ordering() != 'tf':\n",
    "        keras.backend.set_image_dim_ordering('tf')\n",
    "        print(\"INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to \"\n",
    "              \"'th', temporarily setting to 'tf'\")\n",
    "\n",
    "    # Create TF session and set as Keras backend session\n",
    "    sess = tf.Session()\n",
    "    keras.backend.set_session(sess)\n",
    "\n",
    "    set_log_level(logging.WARNING)\n",
    "\n",
    "    # Get CIFAR10 test data\n",
    "    X_train, Y_train, X_test, Y_test = data_cifar10()\n",
    "\n",
    "    assert Y_train.shape[1] == 10.\n",
    "    label_smooth = .1\n",
    "    Y_train = Y_train.clip(label_smooth / 9., 1. - label_smooth)\n",
    "\n",
    "    # Define input TF placeholder\n",
    "    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, channels))\n",
    "    y = tf.placeholder(tf.float32, shape=(None, 10))\n",
    "    phase = tf.placeholder(tf.bool, name=\"phase\")\n",
    "    logits_scalar = tf.placeholder_with_default(\n",
    "        INIT_T, shape=(), name=\"logits_temperature\")\n",
    "\n",
    "\n",
    "    model_path = FLAGS.model_path\n",
    "    targeted = True if FLAGS.targeted else False\n",
    "    learning_rate = FLAGS.learning_rate\n",
    "    nb_filters = FLAGS.nb_filters\n",
    "    batch_size = FLAGS.batch_size\n",
    "    nb_samples = FLAGS.nb_samples\n",
    "    nb_epochs = FLAGS.nb_epochs\n",
    "    delay = FLAGS.delay\n",
    "    eps = FLAGS.eps\n",
    "    adv = FLAGS.adv\n",
    "\n",
    "    attack = FLAGS.attack\n",
    "    attack_iterations = FLAGS.attack_iterations\n",
    "    nb_iter = FLAGS.nb_iter\n",
    "   \n",
    "    #### EMPIR extra flags\n",
    "    lowprecision=FLAGS.lowprecision\n",
    "    abits=FLAGS.abits\n",
    "    wbits=FLAGS.wbits\n",
    "    abitsList=FLAGS.abitsList\n",
    "    wbitsList=FLAGS.wbitsList\n",
    "    stocRound=True if FLAGS.stocRound else False\n",
    "    rand=FLAGS.rand \n",
    "    model_path2 = FLAGS.model_path2\n",
    "    model_path1 = FLAGS.model_path1\n",
    "    model_path3 = FLAGS.model_path3\n",
    "    ensembleThree=True\n",
    "    abits2=FLAGS.abits2\n",
    "    wbits2=FLAGS.wbits2\n",
    "    abits2List=FLAGS.abits2List\n",
    "    wbits2List=FLAGS.wbits2List\n",
    "    inpgradreg = True if FLAGS.inpgradreg else False\n",
    "    distill = True if FLAGS.distill else False\n",
    "    student_epochs = FLAGS.student_epochs\n",
    "    l2dbl = FLAGS.l2dbl\n",
    "    l2cs = FLAGS.l2cs\n",
    "    ####\n",
    "\n",
    "    save = False\n",
    "    train_from_scratch = False\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    if ensembleThree: \n",
    "        if (model_path1 is None or model_path2 is None or model_path3 is None):\n",
    "            train_from_scratch = True\n",
    "        else:\n",
    "            train_from_scratch = False\n",
    "    elif model_path is not None:\n",
    "        if os.path.exists(model_path):\n",
    "            # check for existing model in immediate subfolder\n",
    "            if any(f.endswith('.meta') for f in os.listdir(model_path)):\n",
    "                train_from_scratch = False\n",
    "            else:\n",
    "                model_path = build_model_save_path(\n",
    "                    model_path, batch_size, nb_filters, learning_rate, nb_epochs, adv, delay)\n",
    "                print(model_path)\n",
    "                save = True\n",
    "                train_from_scratch = True\n",
    "    else:\n",
    "        train_from_scratch = True  # train from scratch, but don't save since no path given\n",
    "\n",
    "    if ensembleThree: \n",
    "       if (wbitsList is None) or (abitsList is None): # Layer wise separate quantization not specified for first model\n",
    "           if (wbits==0) or (abits==0):\n",
    "               print(\"Error: the number of bits for constant precision weights and activations across layers for the first model have to specified using wbits1 and abits1 flags\")\n",
    "               sys.exit(1)\n",
    "           else:\n",
    "               fixedPrec1 = 1\n",
    "       elif (len(wbitsList) != 3) or (len(abitsList) != 3):\n",
    "           print(\"Error: Need to specify the precisions for activations and weights for the atleast the three convolutional layers of the first model\")  \n",
    "           sys.exit(1)\n",
    "       else: \n",
    "           fixedPrec1 = 0\n",
    "       \n",
    "       if (wbits2List is None) or (abits2List is None): # Layer wise separate quantization not specified for second model\n",
    "           if (wbits2==0) or (abits2==0):\n",
    "               print(\"Error: the number of bits for constant precision weights and activations across layers for the second model have to specified using wbits1 and abits1 flags\")\n",
    "               sys.exit(1)\n",
    "           else:\n",
    "               fixedPrec2 = 1\n",
    "       elif (len(wbits2List) != 3) or (len(abits2List) != 3):\n",
    "           print(\"Error: Need to specify the precisions for activations and weights for the atleast the three convolutional layers of the second model\")  \n",
    "           sys.exit(1)\n",
    "       else: \n",
    "           fixedPrec2 = 0\n",
    "\n",
    "       if (fixedPrec2 != 1) or (fixedPrec1 != 1): # Atleast one of the models have separate precisions per layer\n",
    "           fixedPrec=0\n",
    "           print(\"Within atleast one model has separate precisions\")\n",
    "           if (fixedPrec1 == 1): # first layer has fixed precision\n",
    "               abitsList = (abits, abits, abits)\n",
    "               wbitsList = (wbits, wbits, wbits)\n",
    "           if (fixedPrec2 == 1): # second layer has fixed precision\n",
    "               abits2List = (abits2, abits2, abits2)\n",
    "               wbits2List = (wbits2, wbits2, wbits2)\n",
    "       else:\n",
    "           fixedPrec=1\n",
    "       \n",
    "       if (train_from_scratch):\n",
    "           print (\"The ensemble model cannot be trained from scratch\")\n",
    "           sys.exit(1)\n",
    "       if fixedPrec == 1:\n",
    "           from cleverhans_tutorials.tutorial_models import make_ensemble_three_cifar_cnn\n",
    "           print(\"ASDF1\")\n",
    "           model = make_ensemble_three_cifar_cnn(\n",
    "               phase, logits_scalar, 'lp1_', 'lp2_', 'fp_', wbits, abits, wbits2, abits2, input_shape=(None, img_rows, img_cols, channels), nb_filters=nb_filters) \n",
    "       else:\n",
    "           from cleverhans_tutorials.tutorial_models import make_ensemble_three_cifar_cnn_layerwise\n",
    "           model = make_ensemble_three_cifar_cnn_layerwise(\n",
    "               phase, logits_scalar, 'lp1_', 'lp2_', 'fp_', wbitsList, abitsList, wbits2List, abits2List, input_shape=(None, img_rows, img_cols, channels), nb_filters=nb_filters) \n",
    "    elif lowprecision:\n",
    "       if (wbitsList is None) or (abitsList is None): # Layer wise separate quantization not specified\n",
    "           if (wbits==0) or (abits==0):\n",
    "               print(\"Error: the number of bits for constant precision weights and activations across layers have to specified using wbits and abits flags\")\n",
    "               sys.exit(1)\n",
    "           else:\n",
    "               fixedPrec = 1\n",
    "       elif (len(wbitsList) != 3) or (len(abitsList) != 3):\n",
    "           print(\"Error: Need to specify the precisions for activations and weights for the atleast the three convolutional layers\")  \n",
    "           sys.exit(1)\n",
    "       else: \n",
    "           fixedPrec = 0\n",
    "       \n",
    "       if fixedPrec:\n",
    "           from cleverhans_tutorials.tutorial_models import make_basic_lowprecision_cifar_cnn\n",
    "           model = make_basic_lowprecision_cifar_cnn(\n",
    "               phase, logits_scalar, 'lp_', wbits, abits, input_shape=(\n",
    "            None, img_rows, img_cols, channels), nb_filters=nb_filters, stocRound=stocRound)  \n",
    "       else:\n",
    "           from cleverhans_tutorials.tutorial_models import make_layerwise_lowprecision_cifar_cnn\n",
    "           model = make_layerwise_lowprecision_cifar_cnn(\n",
    "               phase, logits_scalar, 'lp_', wbitsList, abitsList, input_shape=(\n",
    "            None, img_rows, img_cols, channels), nb_filters=nb_filters, stocRound=stocRound)  \n",
    "    elif distill:\n",
    "      from cleverhans_tutorials.tutorial_models import make_distilled_cifar_cnn\n",
    "      model = make_distilled_cifar_cnn(phase, logits_scalar,\n",
    "              'teacher_fp_', 'fp_', nb_filters=nb_filters, input_shape=(None, img_rows, img_cols, channels))  \n",
    "    ####\n",
    "    else:\n",
    "        from cleverhans_tutorials.tutorial_models import make_basic_cifar_cnn\n",
    "        model = make_basic_cifar_cnn(phase, logits_scalar, 'fp_', input_shape=(\n",
    "            None, img_rows, img_cols, channels), nb_filters=nb_filters)\n",
    "\n",
    "\n",
    "    # separate predictions of teacher for distilled training\n",
    "    if distill:\n",
    "        teacher_preds = model.teacher_call(x, reuse=False)\n",
    "        teacher_logits = model.get_teacher_logits(x, reuse=False)\n",
    "\n",
    "    # separate calling function for ensemble models\n",
    "    if ensembleThree:\n",
    "        preds = model.ensemble_call(x, reuse=False)\n",
    "    else:\n",
    "    ##default\n",
    "        preds = model(x, reuse=False)\n",
    "    print(\"Defined TensorFlow model graph.\")\n",
    "\n",
    "    rng = np.random.RandomState([2017, 8, 30])\n",
    "\n",
    "    def evaluate():\n",
    "        # Evaluate the accuracy of the CIFAR10 model on legitimate test\n",
    "        # examples\n",
    "        eval_params = {'batch_size': batch_size}\n",
    "        if ensembleThree:\n",
    "            acc = model_eval_ensemble(\n",
    "                sess, x, y, preds, X_test, Y_test, phase=phase, args=eval_params)\n",
    "        else:\n",
    "            acc = model_eval(\n",
    "                sess, x, y, preds, X_test, Y_test, phase=phase, args=eval_params)\n",
    "        assert X_test.shape[0] == 10000, X_test.shape\n",
    "        print('Test accuracy on legitimate examples: %0.4f' % acc)\n",
    "\n",
    "    # Train an CIFAR10 model\n",
    "    train_params = {\n",
    "        'nb_epochs': nb_epochs,\n",
    "        'batch_size': batch_size,\n",
    "        'learning_rate': learning_rate,\n",
    "        'loss_name': 'train loss',\n",
    "        'filename': 'model',\n",
    "        'reuse_global_step': False,\n",
    "        'train_scope': 'train',\n",
    "        'is_training': True\n",
    "    }\n",
    "    \n",
    "    if adv != 0:\n",
    "        if adv == ADVERSARIAL_TRAINING_MADRYETAL:\n",
    "            from cleverhans.attacks import MadryEtAl\n",
    "            train_attack_params = {'eps': MAX_EPS, 'eps_iter': 0.01,\n",
    "                                   'nb_iter': nb_iter}\n",
    "            train_attacker = MadryEtAl(model, sess=sess)\n",
    "\n",
    "        elif adv == ADVERSARIAL_TRAINING_FGSM:\n",
    "            from cleverhans.attacks import FastGradientMethod\n",
    "            stddev = int(np.ceil((MAX_EPS * 255) // 2))\n",
    "            train_attack_params = {'eps': tf.abs(tf.truncated_normal(\n",
    "                shape=(batch_size, 1, 1, 1), mean=0, stddev=stddev))}\n",
    "            train_attacker = FastGradientMethod(model, back='tf', sess=sess)\n",
    "        # create the adversarial trainer\n",
    "        train_attack_params.update({'clip_min': 0., 'clip_max': 1.})\n",
    "        adv_x_train = train_attacker.generate(x, phase, **train_attack_params)\n",
    "        preds_adv_train = model.get_probs(adv_x_train)\n",
    "\n",
    "        eval_attack_params = {'eps': MAX_EPS, 'clip_min': 0., 'clip_max': 1.}\n",
    "        adv_x_eval = train_attacker.generate(x, phase, **eval_attack_params)\n",
    "        preds_adv_eval = model.get_probs(adv_x_eval)  # * logits_scalar\n",
    "\n",
    "    if train_from_scratch:\n",
    "        if save:\n",
    "            train_params.update({'log_dir': model_path})\n",
    "            if adv and delay > 0:\n",
    "                train_params.update({'nb_epochs': delay})\n",
    "\n",
    "        # do clean training for 'nb_epochs' or 'delay' epochs\n",
    "        if distill:\n",
    "            temperature = 10 # 1 means the teacher predictions are used as it is\n",
    "            teacher_scaled_preds_val = model_train_teacher(sess, x, y, teacher_preds, teacher_logits, \n",
    "                        temperature, X_train, Y_train, phase=phase, args=train_params, rng=rng)\n",
    "            eval_params = {'batch_size': batch_size}\n",
    "            teacher_acc = model_eval(\n",
    "                sess, x, y, teacher_preds, X_test, Y_test, phase=phase, args=eval_params)\n",
    "            print('Test accuracy of the teacher model on legitimate examples: %0.4f' % teacher_acc)\n",
    "            print('Training the student model...')\n",
    "            student_train_params = {\n",
    "                'nb_epochs': student_epochs,\n",
    "                'batch_size': batch_size,\n",
    "                'learning_rate': learning_rate,\n",
    "                'loss_name': 'train loss',\n",
    "                'filename': 'model',\n",
    "                'reuse_global_step': False,\n",
    "                'train_scope': 'train',\n",
    "                'is_training': True\n",
    "            }\n",
    "            if save:\n",
    "                student_train_params.update({'log_dir': model_path})\n",
    "            y_teacher = tf.placeholder(tf.float32, shape=(None, nb_classes))\n",
    "            model_train_student(sess, x, y, preds, temperature, X_train, Y_train, y_teacher=y_teacher, \n",
    "                        teacher_preds=teacher_scaled_preds_val, alpha=0.3, beta=0.7, phase=phase, evaluate=evaluate, args=student_train_params, save=save, rng=rng)\n",
    "        elif inpgradreg: \n",
    "            model_train_inpgrad_reg(sess, x, y, preds, X_train, Y_train, phase=phase,\n",
    "                        evaluate=evaluate, l2dbl = l2dbl, l2cs = l2cs, args=train_params, save=save, rng=rng)\n",
    "        else: \n",
    "            # do clean training for 'nb_epochs' or 'delay' epochs\n",
    "            model_train(sess, x, y, preds, X_train, Y_train, phase=phase,\n",
    "                        evaluate=evaluate, args=train_params, save=save, rng=rng)\n",
    "\n",
    "        # optionally do additional adversarial training\n",
    "        if adv:\n",
    "            print(\"Adversarial training for %d epochs\" % (nb_epochs - delay))\n",
    "            train_params.update({'nb_epochs': nb_epochs - delay})\n",
    "            train_params.update({'reuse_global_step': True})\n",
    "            model_train(sess, x, y, preds, X_train, Y_train, phase=phase,\n",
    "                        predictions_adv=preds_adv_train, evaluate=evaluate, args=train_params,\n",
    "                        save=save, rng=rng)\n",
    "\n",
    "    else:\n",
    "        if ensembleThree: \n",
    "            variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n",
    "            stored_variables = ['lp_conv1_init/k', 'lp_conv2_init/k', 'lp_conv3_init/k', 'lp_ip1init/W', 'lp_logits_init/W']\n",
    "            variable_dict = dict(zip(stored_variables, variables[:5])) \n",
    "            # Restore the first set of variables from model_path1\n",
    "            saver = tf.train.Saver(variable_dict)\n",
    "            saver.restore(sess, tf.train.latest_checkpoint(model_path1))\n",
    "            # Restore the second set of variables from model_path2\n",
    "            variable_dict = dict(zip(stored_variables, variables[5:10]))\n",
    "            saver2 = tf.train.Saver(variable_dict)\n",
    "            saver2.restore(sess, tf.train.latest_checkpoint(model_path2))\n",
    "            stored_variables = ['fp_conv1_init/k', 'fp_conv2_init/k', 'fp_conv3_init/k', 'fp_ip1init/W', 'fp_logits_init/W']\n",
    "            variable_dict = dict(zip(stored_variables, variables[10:]))\n",
    "            saver3 = tf.train.Saver(variable_dict)\n",
    "            saver3.restore(sess, tf.train.latest_checkpoint(model_path3))\n",
    "        else:\n",
    "            tf_model_load(sess, model_path)\n",
    "            print('Restored model from %s' % model_path)\n",
    "        evaluate()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    # Evaluate the accuracy of the CIFAR10 model on legitimate test examples\n",
    "    ## TODO PUT THIS BACK IN\n",
    "    eval_params = {'batch_size': batch_size}\n",
    "    if ensembleThree: \n",
    "        accuracy = model_eval_ensemble(sess, x, y, preds, X_test, Y_test, phase=phase, feed={phase: False}, args=eval_params)\n",
    "    else:\n",
    "        accuracy = model_eval(sess, x, y, preds, X_test, Y_test, phase=phase,\n",
    "                          feed={phase: False}, args=eval_params)\n",
    "\n",
    "    print('Test accuracy on legitimate test examples: {0}'.format(accuracy))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    ###########################################################################\n",
    "    # Build dataset\n",
    "    ###########################################################################\n",
    "\n",
    "    if targeted:\n",
    "        from cleverhans.utils import build_targeted_dataset\n",
    "        adv_inputs, true_labels, adv_ys = build_targeted_dataset(\n",
    "            X_test, Y_test, np.arange(nb_samples), nb_classes, img_rows, img_cols, channels)\n",
    "    else:\n",
    "        adv_inputs = X_test[:nb_samples]\n",
    "\n",
    "    ###########################################################################\n",
    "    # Craft adversarial examples using generic approach\n",
    "    ###########################################################################\n",
    "    if targeted:\n",
    "        att_batch_size = np.clip(\n",
    "            nb_samples * (nb_classes - 1), a_max=MAX_BATCH_SIZE, a_min=1)\n",
    "        nb_adv_per_sample = nb_classes - 1\n",
    "        yname = \"y_target\"\n",
    "\n",
    "    else:\n",
    "        att_batch_size = np.minimum(nb_samples, MAX_BATCH_SIZE)\n",
    "        nb_adv_per_sample = 1\n",
    "        adv_ys = None\n",
    "        yname = \"y\"\n",
    "\n",
    "    print('Crafting ' + str(nb_samples) + ' * ' + str(nb_adv_per_sample) +\n",
    "          ' adversarial examples')\n",
    "    print(\"This could take some time ...\")\n",
    "\n",
    "    if ensembleThree:\n",
    "        model_type = 'ensembleThree'\n",
    "    else:\n",
    "        model_type = 'default'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "    X_test_adv = np.copy(adv_inputs)\n",
    "\n",
    "    outs = model.forward3(x, reuse=True)\n",
    "\n",
    "    #out = tf.log((outs[0] + outs[1] + outs[2])/3)\n",
    "\n",
    "    #loss = tf.nn.softmax_cross_entropy_with_logits(logits=out,\n",
    "    #                                               labels=y)\n",
    "    loss = [tf.maximum(tf.reduce_sum(tf.log(outs[i])*y,axis=1) - tf.reduce_max(tf.log(outs[i]) * (1-y) - y * 10000,axis=1),-3) for i in range(3)]\n",
    "    loss = loss[0] + loss[1] + loss[2]\n",
    "    \n",
    "    grads = tf.gradients(loss, [x])[0]\n",
    "\n",
    "    \n",
    "    for i in range(20):\n",
    "        print(i)\n",
    "        X_test_adv -= np.sign(sess.run(grads, {x: X_test_adv,\n",
    "                                               y: Y_test}))*.01\n",
    "        \n",
    "        X_test_adv  = np.clip(X_test_adv, adv_inputs-0.031, adv_inputs+0.031)\n",
    "        X_test_adv  = np.clip(X_test_adv, 0, 1)\n",
    "\n",
    "        if i%2 == 0:\n",
    "            print(np.argmax(sess.run(out, {x: X_test_adv[:10], y: Y_test[:10]}),axis=1))\n",
    "            print(sess.run(loss, {x: X_test_adv[:10], y: Y_test[:10]}))\n",
    "            adv_accuracy = model_eval_ensemble(sess, x, y, preds, X_test_adv[:10], Y_test[:10],\n",
    "                                               phase=phase, args={'batch_size': 10})\n",
    "            print(\"Adv acc\", adv_accuracy)\n",
    "    \n",
    "    print(\"Shape\", X_test_adv.shape)\n",
    "    print(\"max\", np.max(np.abs(X_test_adv-adv_inputs)))\n",
    "\n",
    "    \n",
    "\n",
    "    if targeted:\n",
    "        assert X_test_adv.shape[0] == nb_samples * \\\n",
    "            (nb_classes - 1), X_test_adv.shape\n",
    "        # Evaluate the accuracy of the CIFAR10 model on adversarial examples\n",
    "        print(\"Evaluating targeted results\")\n",
    "        adv_accuracy = model_eval(sess, x, y, preds, X_test_adv, true_labels,\n",
    "                                  phase=phase, args=eval_params)\n",
    "    else:\n",
    "        # assert X_test_adv.shape[0] == nb_samples, X_test_adv.shape\n",
    "        # Evaluate the accuracy of the CIFAR10 model on adversarial examples\n",
    "        print(\"Evaluating un-targeted results\")\n",
    "        if ensembleThree:\n",
    "            adv_accuracy = model_eval_ensemble(sess, x, y, preds, X_test_adv, Y_test,\n",
    "                                  phase=phase, args=eval_params)\n",
    "        else: #default below\n",
    "            adv_accuracy = model_eval(sess, x, y, preds, X_test_adv, Y_test,\n",
    "                                      phase=phase, args=eval_params)\n",
    "\n",
    "    # Compute the number of adversarial examples that were successfully found\n",
    "    print('Test accuracy on adversarial examples {0:.4f}'.format(adv_accuracy))\n",
    "\n",
    "    # Compute the average distortion introduced by the algorithm\n",
    "    percent_perturbed = np.mean(np.sum((X_test_adv - adv_inputs)**2,\n",
    "                                       axis=(1, 2, 3))**.5)\n",
    "    print('Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))\n",
    "\n",
    "    # Friendly output for pasting into spreadsheet\n",
    "    print('{0:.4f},'.format(accuracy))\n",
    "    print('{0:.4f},'.format(adv_accuracy))\n",
    "    print('{0:.4f},'.format(percent_perturbed))\n",
    "    \n",
    "    adv_accuracy = model_eval_ensemble(sess, x, y, preds, adv_inputs, Y_test,\n",
    "                                  phase=phase, args=eval_params)\n",
    "    print(\"clean\", adv_accuracy)\n",
    "\n",
    "    #sess.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.save(\"empir_adv.npy\", X_test_adv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
