{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "In this example, we build a binary MNIST classifier and then run GeoCert on several test points.\n",
    "'''\n",
    "\n",
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "%load_ext line_profiler\n",
    "import sys\n",
    "import pickle \n",
    "sys.path.append('..')\n",
    "sys.path.append('../mister_ed') # library for adversarial examples\n",
    "\n",
    "import geocert_oop as geo\n",
    "from plnn import PLNN\n",
    "import _polytope_ as _poly_\n",
    "from _polytope_ import Polytope, Face\n",
    "import utilities as utils\n",
    "import os\n",
    "import time \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "import adversarial_perturbations as ap \n",
    "import prebuilt_loss_functions as plf\n",
    "import loss_functions as lf \n",
    "import adversarial_attacks as aa\n",
    "import utils.pytorch_utils as me_utils\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import mnist.mnist_loader as  ml \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LEN OUTPUTS IS  128\n"
     ]
    }
   ],
   "source": [
    "outputs = []\n",
    "with open('../formatted_mnist_out.pkl', 'rb') as f:\n",
    "    network, output_array = pickle.load(f)\n",
    "        \n",
    "print(\"LEN OUTPUTS IS \", len(output_array))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_output_el(output_dict, stack=np.hstack, show=False):\n",
    "    tup = (output_dict['original_img'].reshape((28, 28)), \n",
    "           output_dict['min_img'].reshape((28, 28)))\n",
    "    \n",
    "    if output_dict['cw_img'] is not None:\n",
    "        tup = tup + (output_dict['cw_img'].reshape((28, 28)),)\n",
    "    #print(len(tup))\n",
    "    stacked = stack(tup)\n",
    "    if show:\n",
    "        plt.imshow(stacked)\n",
    "    return stacked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "cws = [_ for _ in output_array if _['cw_img'] is not None]\n",
    "#print(len(cws))\n",
    "\n",
    "\n",
    "#plt.imshow(cws[40]['cw_img'].reshape((28, 28)))\n",
    "\n",
    "stacked_cws = [display_output_el(cw, stack=np.hstack, show=False) for cw in cws]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PLNN(\n",
       "  (net): Sequential(\n",
       "    (1): Linear(in_features=784, out_features=10, bias=True)\n",
       "    (2): ReLU()\n",
       "    (3): Linear(in_features=10, out_features=50, bias=True)\n",
       "    (4): ReLU()\n",
       "    (5): Linear(in_features=50, out_features=10, bias=True)\n",
       "    (6): ReLU()\n",
       "    (7): Linear(in_features=10, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 2880x1440 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(40, 20))\n",
    "\n",
    "network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['original_img', 'label', 'min_dist', 'min_img', 'cw_dist', 'cw_img', 'runtime'])\n",
      "4.500\n",
      "0.000\n",
      "0.872\n",
      "--------------------------------------------------\n",
      "-3.776\n",
      "-0.000\n",
      "-1.448\n",
      "--------------------------------------------------\n",
      "5.689\n",
      "0.000\n",
      "1.187\n",
      "--------------------------------------------------\n",
      "5.776\n",
      "-0.000\n",
      "0.739\n",
      "--------------------------------------------------\n",
      "4.257\n",
      "0.000\n",
      "1.071\n",
      "--------------------------------------------------\n",
      "4.588\n",
      "0.000\n",
      "0.829\n",
      "--------------------------------------------------\n",
      "4.742\n",
      "0.000\n",
      "0.717\n",
      "--------------------------------------------------\n",
      "-4.811\n",
      "0.000\n",
      "-2.176\n",
      "--------------------------------------------------\n",
      "5.156\n",
      "-0.000\n",
      "1.004\n",
      "--------------------------------------------------\n",
      "4.643\n",
      "0.000\n",
      "0.982\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "single = cws[0]\n",
    "diff = lambda el: el.squeeze()[1] - el.squeeze()[0]\n",
    "print(single.keys())\n",
    "for single in cws[:10]:\n",
    "    og, cw, adv = single['original_img'], single['cw_img'], single['min_img']\n",
    "    og_logits = network(torch.Tensor(og)).data.numpy()\n",
    "    adv_logits = network(torch.Tensor(adv)).data.numpy()\n",
    "    adv_clip_logits = network(torch.Tensor(adv).clamp(0, 1)).data.numpy()\n",
    "    print('%.03f' % diff(og_logits))\n",
    "    print('%.03f' % diff(adv_logits))\n",
    "    print('%.03f' % diff(adv_clip_logits))\n",
    "    \n",
    "    print('-' * 50)\n",
    "    continue\n",
    "    #print(min(adv), max(adv))\n",
    "    print(diff(network(torch.Tensor(og)).data.numpy()), ' | ', \n",
    "          diff(network(torch.Tensor(cw)).data.numpy()), ' | ',\n",
    "          diff(network(torch.Tensor(adv)).data.numpy()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[320] [2]\n"
     ]
    }
   ],
   "source": [
    "print(sum(adv < 0), sum(adv > 1))\n",
    "#sum(adv > 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for el in di"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3",
   "language": "python",
   "name": "py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
