{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" Tests adversarial example functions\"\"\"\n",
    "import sys \n",
    "sys.path.append('..')\n",
    "from relu_nets import ReLUNet\n",
    "from neural_nets import adv_attacks as aa \n",
    "from neural_nets import data_loaders as dl \n",
    "from neural_nets import train \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00 | Accuracy: 87.09\n",
      "EPOCH 00 | Adv Accuracy: 60.06\n",
      "Epoch 01 | Accuracy: 90.01\n",
      "EPOCH 01 | Adv Accuracy: 66.14\n",
      "Epoch 02 | Accuracy: 91.15\n",
      "EPOCH 02 | Adv Accuracy: 68.19\n",
      "Epoch 03 | Accuracy: 91.62\n",
      "EPOCH 03 | Adv Accuracy: 69.24\n",
      "Epoch 04 | Accuracy: 91.95\n",
      "EPOCH 04 | Adv Accuracy: 70.00999999999999\n"
     ]
    }
   ],
   "source": [
    "# Test MNIST \n",
    "# -- dateset\n",
    "trainset = dl.load_mnist_data('train')\n",
    "valset = dl.load_mnist_data('val')\n",
    "\n",
    "# -- train parameters \n",
    "train_fgsm = train.FGSM(0.1)\n",
    "train_params = train.TrainParameters(trainset, valset, 5, \n",
    "                                     loss_functional=train.LossFunctional(regularizers=[train_fgsm]))\n",
    "\n",
    "\n",
    "# -- robustness evaluator\n",
    "def epoch_callback(network, epoch_no):\n",
    "    atk_partial = aa.build_attack_partial(aa.fgsm, linf_bound=0.1)\n",
    "    eval_result = aa.eval_dataset(network, valset, atk_partial)\n",
    "    print('EPOCH %02d' % epoch_no, '| Adv Accuracy:', eval_result['percentage_correct'] * 100)\n",
    "    \n",
    "# -- build net and train    \n",
    "test_net = ReLUNet(layer_sizes=[784, 20, 20, 10])\n",
    "train.training_loop(test_net, train_params, epoch_callback=epoch_callback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = train.LossFunctional(regularizers=[train.FGSM(0.1)])\n",
    "loss.attach_network(test_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valset = dl.load_mnist_data('val', batch_size=16)\n",
    "ex, lab = next(iter(valset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "importlib.reload(aa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import utilities as utils\n",
    "utils.display_images(ex, figsize=(16, 16))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_ex = aa.fgsm(test_net, ex, lab, 0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aa.eval_minibatch(test_net, new_ex, lab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pgd_ex = aa.pgd(test_net, ex, lab, 0.1, num_iter=1000, step_size=0.005)\n",
    "aa.eval_minibatch(test_net, pgd_ex, lab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
