{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/justin/.conda/envs/DeepL/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "In this example, we build a binary MNIST classifier and then run GeoCert on several test points.\n",
    "'''\n",
    "\n",
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "# %load_ext line_profiler\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../mister_ed') # library for adversarial examples\n",
    "\n",
    "import geocert_oop as geo\n",
    "from plnn import PLNN\n",
    "import _polytope_ as _poly_\n",
    "from _polytope_ import Polytope, Face\n",
    "import utilities as utils\n",
    "import os\n",
    "import time \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "import adversarial_perturbations as ap \n",
    "import prebuilt_loss_functions as plf\n",
    "import loss_functions as lf \n",
    "import adversarial_attacks as aa\n",
    "import utils.pytorch_utils as me_utils\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import mnist.mnist_loader as  ml \n",
    "mnist_trainset = ml.load_mnist_data('train')\n",
    "mnist_valset = ml.load_mnist_data('val')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code to select only a subset of digits from MNIST datasets\n",
    "def select_digits(dataset, digits, mb_size=128, remap_label=True):\n",
    "    if remap_label and len(digits) <= 2:\n",
    "        def label_map(labels):\n",
    "            return (labels == digits[0]).unsqueeze(1)\n",
    "    elif remap_label and len(digits) > 2:\n",
    "        raise NotImplementedError(\"Only handling between 2 types of digits for now\")\n",
    "    else:\n",
    "        label_map = lambda x: x\n",
    "            \n",
    "    sel_data, sel_labels = [], [] \n",
    "    for data, labels in dataset:\n",
    "        mask = labels == -1\n",
    "        for digit in digits:\n",
    "            mask += (labels == digit)\n",
    "        #mask = (labels == 7) + (labels == 1)\n",
    "        masked_data = data.masked_select(mask.view(-1, 1, 1, 1).expand(data.shape)).view(-1, 1, 28, 28)\n",
    "        masked_labels = labels.masked_select(mask)\n",
    "        sel_data.append(masked_data)\n",
    "        sel_labels.append(masked_labels)\n",
    "        \n",
    "    # Then concatenate and resplit into minibatches of size 128\n",
    "    cat_data = torch.cat(sel_data)\n",
    "    cat_labels = torch.cat(sel_labels)\n",
    "    full_dataset = [] \n",
    "    num_mb = int(len(cat_labels) / mb_size + 1)\n",
    "    for i in range(num_mb):\n",
    "        full_dataset.append((cat_data[mb_size * i: mb_size * (i + 1)], \n",
    "                             label_map(cat_labels[mb_size * i: mb_size * (i + 1)]).squeeze().long()))\n",
    "    return full_dataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a dataset containing only 1's and 7's\n",
    "MINIBATCH_SIZE = 128\n",
    "os_trainset = select_digits(mnist_trainset, [1, 7], mb_size=MINIBATCH_SIZE)\n",
    "os_valset = select_digits(mnist_valset, [1, 7], mb_size=MINIBATCH_SIZE)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define functions to train and evaluate a network \n",
    "\n",
    "def l1_loss(net):\n",
    "    return sum([_.norm(p=1) for _ in net.parameters() if _.dim() > 1])\n",
    "\n",
    "def train(net, trainset, num_epochs):\n",
    "    opt = optim.Adam(net.parameters(), lr=1e-3)\n",
    "    for epoch in range(num_epochs):\n",
    "        err_acc = 0\n",
    "        err_count = 0\n",
    "        for data, labels in trainset:\n",
    "            output = net(Variable(data.view(-1, 784)))\n",
    "            l = nn.CrossEntropyLoss()(output, Variable(labels)).view([1])\n",
    "            l1_scale = torch.Tensor([2e-3])\n",
    "            l += l1_scale * l1_loss(net).view([1])\n",
    "            \n",
    "            err_acc += (output.max(1)[1].data != labels).float().mean() \n",
    "            err_count += 1\n",
    "            opt.zero_grad() \n",
    "            (l).backward() \n",
    "            opt.step() \n",
    "        print(\"(%02d) error:\" % epoch, err_acc / err_count)\n",
    "            \n",
    "        \n",
    "def test_acc(net, valset):\n",
    "    err_acc = 0 \n",
    "    err_count = 0 \n",
    "    for data, labels in valset:\n",
    "        n = data.shape[0]\n",
    "        output = net(Variable(data.view(-1, 784)))\n",
    "        err_acc += (output.max(1)[1].data != labels).float().mean() * n\n",
    "        err_count += n\n",
    "        \n",
    "    print(\"Accuracy of: %.03f\" % (1 - (err_acc / err_count).item()))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[784, 10, 50, 10, 2]\n",
      "(00) error: tensor(0.0996)\n",
      "(01) error: tensor(0.0104)\n",
      "(02) error: tensor(0.0087)\n",
      "(03) error: tensor(0.0077)\n",
      "(04) error: tensor(0.0074)\n",
      "(05) error: tensor(0.0070)\n",
      "(06) error: tensor(0.0067)\n",
      "(07) error: tensor(0.0064)\n",
      "(08) error: tensor(0.0061)\n",
      "(09) error: tensor(0.0059)\n",
      "Accuracy of: 0.989\n"
     ]
    }
   ],
   "source": [
    "# Define the network architecture.\n",
    "MNIST_DIM = 784\n",
    "network = PLNN([MNIST_DIM, 10, 50, 10, 2])\n",
    "net = network.net\n",
    "\n",
    "# Train and evaluate the network \n",
    "train(net, os_trainset, 10)\n",
    "test_acc(net, os_valset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 28, 28])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f9966cb7e10>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADHRJREFUeJzt3V2oHPUZx/Hfr1YvjFV8N2iiJkioeBHLIRYtxVJSfIOo0GAuQorSI6JYIUjVmwakQYqprV4IKQaj+BJFrUHEVqRgq7UkvqDR1Bg1SaPxxJBiVYhBfXpxJuU0np3d7M7LnjzfD4TdnWdm52HJ7/xnd2b374gQgHy+1XYDANpB+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPXtJndmm8sJgZpFhHtZb6CR3/YFtt+2vdn2TYM8F4Bmud9r+20fImmTpPmStktaJ2lRRLxVsg0jP1CzJkb+eZI2R8R7EbFX0sOSFgzwfAAaNEj4T5b0rwmPtxfL/o/tUdvrba8fYF8AKjbIB36THVp847A+IlZKWilx2A8Mk0FG/u2SZkx4fIqkDwdrB0BTBgn/Okln2D7d9mGSrpC0tpq2ANSt78P+iPjS9nWS/iTpEEmrIuLNyjoDUKu+T/X1tTPe8wO1a+QiHwBTF+EHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ9T1FtyTZ3iLpU0lfSfoyIkaqaApA/QYKf+FHEbGrgucB0CAO+4GkBg1/SPqz7Zdtj1bREIBmDHrYf15EfGj7BEnP2v5nRDw/cYXijwJ/GIAh44io5onsZZI+i4jbS9apZmcAOooI97Je34f9tqfZ/s6++5J+ImlDv88HoFmDHPafKOkJ2/ue58GIeKaSrgDUrrLD/p52xmE/ULvaD/sBTG2EH0iK8ANJEX4gKcIPJEX4gaSq+FYfElu+fHlp/eabb+5YW7hwYem2jz76aF89oTeM/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOf5Ueqaa64prd94442l9Sa/Mo4Dw8gPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxnr8CZ555Zml969atpfXPP/98oP3PnTu3Y23evHml246MlM+qvmTJktJ6t/P4119/fcfaU089Vbot6sXIDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJdZ2i2/YqSZdI2hkRZxXLjpG0RtJpkrZIWhgR/+66syk8Rffs2bM71l599dXSbXft2lVa37NnT1897TNnzpyONbt8tuYPPvigtD5t2rTS+kUXXVRaf+mll0rrqF6VU3TfK+mC/ZbdJOm5iDhD0nPFYwBTSNfwR8Tzknbvt3iBpNXF/dWSLq24LwA16/c9/4kRsUOSitsTqmsJQBNqv7bf9qik0br3A+DA9Dvyj9meLknF7c5OK0bEyogYiYjyb5AAaFS/4V8rad/XvZZIerKadgA0pWv4bT8k6e+S5tjebvsqSbdJmm/7HUnzi8cAppCu5/kr3dkUPs9f5txzzy2tL126tLTe7fv+M2fOLK3PmjWrY23FihWl277wwgul9VtvvbW0vnjx4tI6mlfleX4AByHCDyRF+IGkCD+QFOEHkiL8QFL8dHcFXnzxxYHqbSo7TShJ06dPb6gTNI2RH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jx/cocffnhpvdt1AJi6GPmBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gqa7f57e9StIlknZGxFnFsmWSfi7p42K1WyLi6bqaRH22bdtWWu82fTimrl5G/nslXTDJ8jsiYm7xj+ADU0zX8EfE85J2N9ALgAYN8p7/Otuv215l++jKOgLQiH7Df7ek2ZLmStohaUWnFW2P2l5ve32f+wJQg77CHxFjEfFVRHwt6Q+S5pWsuzIiRiJipN8mAVSvr/Dbnjh162WSNlTTDoCm9HKq7yFJ50s6zvZ2Sb+SdL7tuZJC0hZJV9fYI4AadA1/RCyaZPE9NfSCFsycObO0fuqpp9a276OOOqq0PmPGjIGef+/evR1rmzZtGui5DwZc4QckRfiBpAg/kBThB5Ii/EBShB9Iiim6k7v44otL691Oxy1btqy0fvnll3esHXnkkaXbHn/88aX1sbGx0vrHH3/csXbhhReWbrt798H/XTZGfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IyhHR3M7s5nYGSdL8+fNL63fddVdpffbs2aX1Tz75pLT+/vvvd6ytWbOmdNtnnnmmtL5hA78hM5mIcC/rMfIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKc5z/IffTRR6X1Y489trR+//33l9avvPLKA+4J9eI8P4BShB9IivADSRF+ICnCDyRF+IGkCD+QVNff7bc9Q9J9kk6S9LWklRHxe9vHSFoj6TRJWyQtjIh/19cq+nHSSSeV1pcvX15ab/I6EDSrl5H/S0lLI+K7kr4v6VrbZ0q6SdJzEXGGpOeKxwCmiK7hj4gdEfFKcf9TSRslnSxpgaTVxWqrJV1aV5MAqndA7/ltnybpbEn/kHRiROyQxv9ASDqh6uYA1KfnufpsHyHpMUk3RMR/7J4uH5btUUmj/bUHoC49jfy2D9V48B+IiMeLxWO2pxf16ZJ2TrZtRKyMiJGIGKmiYQDV6Bp+jw/x90jaGBG/nVBaK2lJcX+JpCerbw9AXXo57D9P0mJJb9h+rVh2i6TbJD1i+ypJ2yT9tJ4WUadzzjmntH7nnXc21Ama1jX8EfE3SZ3e4P+42nYANIUr/ICkCD+QFOEHkiL8QFKEH0iK8ANJ9Xx5Lw5Os2bNKq2/++67DXWCpjHyA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSnOdP7osvviit79mzp6FO0DRGfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivP8yW3atKm0vnnz5oY6QdMY+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gqa7htz3D9l9sb7T9pu1fFMuX2f7A9mvFv4vqbxdAVXq5yOdLSUsj4hXb35H0su1ni9odEXF7fe0BqEvX8EfEDkk7ivuf2t4o6eS6GwNQrwN6z2/7NElnS/pHseg626/bXmX76A7bjNpeb3v9QJ0CqFTP4bd9hKTHJN0QEf+RdLek2ZLmavzIYMVk20XEyogYiYiRCvoFUJGewm/7UI0H/4GIeFySImIsIr6KiK8l/UHSvPraBFC1Xj7tt6R7JG2MiN9OWD59wmqXSdpQfXsA6tLLp/3nSVos6Q3brxXLbpG0yPZcSSFpi6Sra+kQtdq6dWvbLaAlvXza/zdJnqT0dPXtAGgKV/gBSRF+ICnCDyRF+IGkCD+QFOEHknJENLczu7mdAUlFxGSn5r+BkR9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp6iu5dkiZ+gfy4YtkwGtbehrUvid76VWVvp/a6YqMX+Xxj5/b6Yf1tv2HtbVj7kuitX231xmE/kBThB5JqO/wrW95/mWHtbVj7kuitX6301up7fgDtaXvkB9CSVsJv+wLbb9vebPumNnroxPYW228UMw+3OsVYMQ3aTtsbJiw7xvaztt8pbiedJq2l3oZi5uaSmaVbfe2Gbcbrxg/7bR8iaZOk+ZK2S1onaVFEvNVoIx3Y3iJpJCJaPyds+4eSPpN0X0ScVSz7jaTdEXFb8Yfz6Ij45ZD0tkzSZ23P3FxMKDN94szSki6V9DO1+NqV9LVQLbxubYz88yRtjoj3ImKvpIclLWihj6EXEc9L2r3f4gWSVhf3V2v8P0/jOvQ2FCJiR0S8Utz/VNK+maVbfe1K+mpFG+E/WdK/JjzeruGa8jsk/dn2y7ZH225mEicW06bvmz79hJb72V/XmZubtN/M0kPz2vUz43XV2gj/ZD8xNEynHM6LiO9JulDStcXhLXrT08zNTZlkZumh0O+M11VrI/zbJc2Y8PgUSR+20MekIuLD4nanpCc0fLMPj+2bJLW43dlyP/8zTDM3TzaztIbgtRumGa/bCP86SWfYPt32YZKukLS2hT6+wfa04oMY2Z4m6ScavtmH10paUtxfIunJFnv5P8Myc3OnmaXV8ms3bDNet3KRT3Eq43eSDpG0KiJ+3XgTk7A9S+OjvTT+jccH2+zN9kOSztf4t77GJP1K0h8lPSJppqRtkn4aEY1/8Naht/M1fuj6v5mb973Hbri3H0j6q6Q3JH1dLL5F4++vW3vtSvpapBZeN67wA5LiCj8gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0n9F0wKmhwKc1IbAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# With a trained net, we can pick some examples and run GeoCert on them \n",
    "EXAMPLE_NUMBER = 7\n",
    "data, labels = next(iter(os_valset)) # select a minibatch of the validation set\n",
    "example = data[EXAMPLE_NUMBER]\n",
    "print(example.shape)\n",
    "plt.gray()\n",
    "plt.imshow(example.squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting upper bound computation\n",
      "Upper bound of 0.032573264092206955 in 3.91 seconds\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "__init__() got an unexpected keyword argument 'box_bounds'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-7-ae935b2c07c8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m lp_dist, cw_bound, cw_example, adver_example, boundary = geocert.min_dist(example.view(1, -1), lp_norm='l_inf', \n\u001b[0;32m---> 10\u001b[0;31m                                                                           compute_upper_bound=True)\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/geometric-certificates/geocert_oop.py\u001b[0m in \u001b[0;36mmin_dist\u001b[0;34m(self, x, lp_norm, compute_upper_bound)\u001b[0m\n\u001b[1;32m    474\u001b[0m             adv_bound, adv_ex = self._compute_upper_bounds(x, self.true_label,\n\u001b[1;32m    475\u001b[0m                                                            \u001b[0mlp_norm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 476\u001b[0;31m                                        extra_attack_kwargs=compute_upper_bound)\n\u001b[0m\u001b[1;32m    477\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    478\u001b[0m         \u001b[0;31m######################################################################\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/geometric-certificates/geocert_oop.py\u001b[0m in \u001b[0;36m_compute_upper_bounds\u001b[0;34m(self, x, true_label, lp_dist, extra_attack_kwargs)\u001b[0m\n\u001b[1;32m    250\u001b[0m             self._verbose_print(\"Upper bound of %s in %.02f seconds\" %\n\u001b[1;32m    251\u001b[0m                                 (upper_bound, time.time() - start))\n\u001b[0;32m--> 252\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_dead_constraints\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    253\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    254\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mupper_bound\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madv_ex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/geometric-certificates/geocert_oop.py\u001b[0m in \u001b[0;36m_update_dead_constraints\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    338\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    339\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_update_dead_constraints\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 340\u001b[0;31m         \u001b[0mbounds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_dual_ia_bounds\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    341\u001b[0m         \u001b[0mtensor_dead_constraints\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbounds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    342\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdead_constraints\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_dead_constraints\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/geometric-certificates/plnn.py\u001b[0m in \u001b[0;36mcompute_dual_ia_bounds\u001b[0;34m(self, domain_obj)\u001b[0m\n\u001b[1;32m    270\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    271\u001b[0m         \u001b[0mia\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_interval_bounds\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain_obj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 272\u001b[0;31m         \u001b[0mdd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_dual_lp_bounds\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain_obj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    274\u001b[0m         \u001b[0mbounds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/geometric-certificates/plnn.py\u001b[0m in \u001b[0;36mcompute_dual_lp_bounds\u001b[0;34m(self, domain_obj)\u001b[0m\n\u001b[1;32m    256\u001b[0m         \u001b[0mbox_bounds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlow_bounds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhigh_bounds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    257\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 258\u001b[0;31m         \u001b[0mdual_net\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDualNetwork\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmidpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdomain_obj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinf_radius\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbox_bounds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbox_bounds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdual_net\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    259\u001b[0m         \u001b[0mbounds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdead_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    260\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdual_net\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'box_bounds'"
     ]
    }
   ],
   "source": [
    "from importlib import reload \n",
    "reload(geo)\n",
    "# Builds an object used to to hold algorithm parameters\n",
    "geocert = geo.IncrementalGeoCert(network, verbose=True, config_fxn='parallel', \n",
    "                                 config_fxn_kwargs={'num_jobs': 1})\n",
    "\n",
    "# Runs the algorithm\n",
    "start = time.time()\n",
    "lp_dist, cw_bound, cw_example, adver_example, boundary = geocert.min_dist(example.view(1, -1), lp_norm='l_inf', \n",
    "                                                                          compute_upper_bound=True)\n",
    "\n",
    "\n",
    "for example in adver_example[0:20]:\n",
    "    plt.imshow(example.reshape((28, 28)))\n",
    "\n",
    "adver_example = adver_example[0]\n",
    "\n",
    "end = time.time() \n",
    "\n",
    "# Prints outputs\n",
    "print(\"Found an adversarial example at dist\", lp_dist, \" in %.02f seconds \" % (end - start))\n",
    "if cw_bound is not None:\n",
    "    print(\"Carlini-Wagner attack found example at distance\", cw_bound)\n",
    "original_logits = network(example)\n",
    "print(np.shape(adver_example))\n",
    "adver_logits = network(torch.Tensor(adver_example).view(1, 28, 28))\n",
    "print(\"Original output was \", original_logits.data.cpu().numpy())\n",
    "print(\"Adversarial output was \", adver_logits.data.cpu().numpy())\n",
    "\n",
    "# Display the adversarial examples \n",
    "to_displays = [(example.cpu().numpy().reshape((28, 28)), 'Original'), \n",
    "               (adver_example.reshape((28, 28)), 'GeoCert')]\n",
    "if cw_example is not None: \n",
    "    to_displays.append((cw_example.reshape((28, 28)), 'Carlini-Wagner L2'))\n",
    "f, axarr = plt.subplots(1, len(to_displays), figsize=(12, 12))\n",
    "for i in range(len(to_displays)):\n",
    "    axarr[i].imshow(to_displays[i][0])\n",
    "    axarr[i].set_title(to_displays[i][1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
