{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################################################################\n",
    "#                                                                                #\n",
    "#                                Comparison Notebook                             #\n",
    "#                                                                                #\n",
    "##################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/justin/.conda/envs/DeepL/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "# %load_ext line_profiler\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('mister_ed') # library for adversarial examples\n",
    "sys.path.append('CertifiedReLURobustness')\n",
    "from collections import defaultdict\n",
    "import geocert_oop as geo\n",
    "from domains import Domain\n",
    "from plnn import PLNN\n",
    "import _polytope_ as _poly_\n",
    "from _polytope_ import Polytope, Face\n",
    "import utilities as utils\n",
    "import os\n",
    "import time \n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "from cvxopt import solvers, matrix\n",
    "import adversarial_perturbations as ap \n",
    "import prebuilt_loss_functions as plf\n",
    "import loss_functions as lf \n",
    "import adversarial_attacks as aa\n",
    "import utils.pytorch_utils as me_utils\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import mnist.mnist_loader as  ml \n",
    "MNIST_DIM = 784"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################################################################\n",
    "#                                                                                #\n",
    "#                       Network Model + Data Loading                             #\n",
    "#                                                                                #\n",
    "##################################################################################\n",
    "\n",
    "# Define functions to train and evaluate a network \n",
    "\n",
    "def l1_loss(net):\n",
    "    return sum([_.norm(p=1) for _ in net.parameters() if _.dim() > 1])\n",
    "\n",
    "def l2_loss(net):\n",
    "    return sum([_.norm(p=2) for _ in net.parameters() if _.dim() > 1])\n",
    "  \n",
    "    \n",
    "def train(net, trainset, num_epochs):\n",
    "    opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=0)\n",
    "    for epoch in range(num_epochs):\n",
    "        err_acc = 0\n",
    "        err_count = 0\n",
    "        for data, labels in trainset:\n",
    "            output = net(Variable(data.view(-1, 784)))\n",
    "            l = nn.CrossEntropyLoss()(output, Variable(labels)).view([1])\n",
    "            l1_scale = torch.Tensor([2e-3])\n",
    "            l += l1_scale * l1_loss(net).view([1])\n",
    "            \n",
    "            err_acc += (output.max(1)[1].data != labels).float().mean() \n",
    "            err_count += 1\n",
    "            opt.zero_grad() \n",
    "            (l).backward() \n",
    "            opt.step() \n",
    "        print(\"(%02d) error:\" % epoch, err_acc / err_count)\n",
    "            \n",
    "        \n",
    "def test_acc(net, valset):\n",
    "    err_acc = 0 \n",
    "    err_count = 0 \n",
    "    for data, labels in valset:\n",
    "        n = data.shape[0]\n",
    "        output = net(Variable(data.view(-1, 784)))\n",
    "        err_acc += (output.max(1)[1].data != labels).float().mean() * n\n",
    "        err_count += n\n",
    "        \n",
    "    print(\"Accuracy of: %.03f\" % (1 - (err_acc / err_count).item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training a new network\n",
      "[784, 10, 50, 10, 10]\n",
      "(00) error: tensor(0.5287)\n",
      "(01) error: tensor(0.3268)\n",
      "(02) error: tensor(0.2946)\n",
      "(03) error: tensor(0.2698)\n",
      "(04) error: tensor(0.2400)\n",
      "(05) error: tensor(0.2074)\n",
      "(06) error: tensor(0.1842)\n",
      "(07) error: tensor(0.1687)\n",
      "(08) error: tensor(0.1574)\n",
      "(09) error: tensor(0.1508)\n",
      "Accuracy of: 0.842\n"
     ]
    }
   ],
   "source": [
    "NETWORK_NAME = '17_mnist_small.pkl'\n",
    "ONE_SEVEN_ONLY = False \n",
    "layer_sizes = [MNIST_DIM, 10, 50, 10, 10]\n",
    "num_epochs\n",
    "\n",
    "if ONE_SEVEN_ONLY:\n",
    "    trainset = ml.load_single_digits('train', [1, 7], batch_size=16, \n",
    "                                      shuffle=False)  \n",
    "    valset = ml.load_single_digits('val', [1, 7], batch_size=16, \n",
    "                                      shuffle=False)        \n",
    "else:    \n",
    "    trainset = ml.load_mnist_data('train', batch_size=128, shuffle=False)\n",
    "    valset = ml.load_mnist_data('val', batch_size=128, shuffle=False)\n",
    "\n",
    "\n",
    "try: \n",
    "    network = pickle.load(open(NETWORK_NAME, 'rb'))\n",
    "    net = network.net\n",
    "    print(\"Loaded pretrained network\")\n",
    "except:\n",
    "    print(\"Training a new network\")\n",
    "    \n",
    "    network = PLNN(layer_sizes)\n",
    "    net = network.net\n",
    "    train(net, trainset, 30)\n",
    "    pickle.dump(network, open(NETWORK_NAME, 'wb'))\n",
    "    \n",
    "test_acc(net, valset)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/justin/Projects/geometric-certificates\n"
     ]
    }
   ],
   "source": [
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of images: 80\n"
     ]
    }
   ],
   "source": [
    "# =====================\n",
    "# Set Images to Verify\n",
    "# =====================\n",
    "num_batches = len(valset)\n",
    "num_batches = 1\n",
    "images = torch.cat([batch_tuple[0] for batch_tuple in valset[0:num_batches]])\n",
    "labels = torch.cat([batch_tuple[1] for batch_tuple in valset[0:num_batches]])\n",
    "\n",
    "print('number of images:', len(images))\n",
    "from CertifiedReLURobustness.Lip_Lin import run_Lip_Lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADQNJREFUeJzt3W+MVfWdx/HPZylNjPQBWLHEgnQb3bgaAzoaE3AzamxYbYKN1NQHGzbZMH2AZps0ZA1PypMmjemfrU9IpikpJtSWhFbRGBeDGylRGwejBYpQICzMgkAzJgUT0yDfPphDO8W5v3u5/84dv+9XQube8z1/vrnhM+ecOefcnyNCAPL5h7obAFAPwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKnP9HNjtrmdEOixiHAr83W057e9wvZB24dtP9nJugD0l9u9t9/2LEmHJD0gaVzSW5Iei4jfF5Zhzw/0WD/2/HdJOhwRRyPiz5J+IWllB+sD0EedhP96SSemvB+vpv0d2yO2x2yPdbAtAF3WyR/8pju0+MRhfUSMShqVOOwHBkkne/5xSQunvP+ipJOdtQOgXzoJ/1uSbrT9JduflfQNSdu70xaAXmv7sD8iLth+XNL/SJolaVNE7O9aZwB6qu1LfW1tjHN+oOf6cpMPgJmL8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaTaHqJbkmwfk3RO0seSLkTEUDeaAtB7HYW/cm9E/LEL6wHQRxz2A0l1Gv6QtMP2Htsj3WgIQH90eti/LCJO2p4v6RXb70XErqkzVL8U+MUADBhHRHdWZG+QdD4ivl+YpzsbA9BQRLiV+do+7Ld9te3PXXot6SuS9rW7PgD91clh/3WSfm370np+HhEvd6UrAD3XtcP+ljbGYT/Qcz0/7AcwsxF+ICnCDyRF+IGkCD+QFOEHkurGU30prFq1qmFtzZo1xWVPnjxZrH/00UfF+pYtW4r1999/v2Ht8OHDxWWRF3t+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iKR3pbdPTo0Ya1xYsX96+RaZw7d65hbf/+/X3sZLCMj483rD311FPFZcfGxrrdTt/wSC+AIsIPJEX4gaQIP5AU4QeSIvxAUoQfSIrn+VtUemb/tttuKy574MCBYv3mm28u1m+//fZifXh4uGHt7rvvLi574sSJYn3hwoXFeicuXLhQrJ89e7ZYX7BgQdvbPn78eLE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR5ftubJH1V0pmIuLWaNk/SLyUtlnRM0qMR8UHTjc3g5/kH2dy5cxvWlixZUlx2z549xfqdd97ZVk+taDZewaFDh4r1ZvdPzJs3r2Ft7dq1xWU3btxYrA+ybj7P/zNJKy6b9qSknRFxo6Sd1XsAM0jT8EfELkkTl01eKWlz9XqzpIe73BeAHmv3nP+6iDglSdXP+d1rCUA/9PzeftsjkkZ6vR0AV6bdPf9p2wskqfp5ptGMETEaEUMRMdTmtgD0QLvh3y5pdfV6taTnu9MOgH5pGn7bz0p6Q9I/2R63/R+SvifpAdt/kPRA9R7ADML39mNgPfLII8X61q1bi/V9+/Y1rN17773FZScmLr/ANXPwvf0Aigg/kBThB5Ii/EBShB9IivADSXGpD7WZP7/8SMjevXs7Wn7VqlUNa9u2bSsuO5NxqQ9AEeEHkiL8QFKEH0iK8ANJEX4gKcIPJMUQ3ahNs6/Pvvbaa4v1Dz4of1v8wYMHr7inTNjzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSPM+Pnlq2bFnD2quvvlpcdvbs2cX68PBwsb5r165i/dOK5/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFJNn+e3vUnSVyWdiYhbq2kbJK2RdLaabX1EvNSrJjFzPfjggw1rza7j79y5s1h/44032uoJk1rZ8/9M0opppv8oIpZU/wg+MMM0DX9E7JI00YdeAPRRJ+f8j9v+ne1Ntud2rSMAfdFu+DdK+rKkJZJOSfpBoxltj9gesz3W5rYA9EBb4Y+I0xHxcURclPQTSXcV5h2NiKGIGGq3SQDd11b4bS+Y8vZrkvZ1px0A/dLKpb5nJQ1L+rztcUnfkTRse4mkkHRM0jd72COAHuB5fnTkqquuKtZ3797dsHbLLbcUl73vvvuK9ddff71Yz4rn+QEUEX4gKcIPJEX4gaQIP5AU4QeSYohudGTdunXF+tKlSxvWXn755eKyXMrrLfb8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AUj/Si6KGHHirWn3vuuWL9ww8/bFhbsWK6L4X+mzfffLNYx/R4pBdAEeEHkiL8QFKEH0iK8ANJEX4gKcIPJMXz/Mldc801xfrTTz9drM+aNatYf+mlxgM4cx2/Xuz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpps/z214o6RlJX5B0UdJoRPzY9jxJv5S0WNIxSY9GxAdN1sXz/H3W7Dp8s2vtd9xxR7F+5MiRYr30zH6zZdGebj7Pf0HStyPiZkl3S1pr+58lPSlpZ0TcKGln9R7ADNE0/BFxKiLerl6fk3RA0vWSVkraXM22WdLDvWoSQPdd0Tm/7cWSlkr6raTrIuKUNPkLQtL8bjcHoHdavrff9hxJ2yR9KyL+ZLd0WiHbI5JG2msPQK+0tOe3PVuTwd8SEb+qJp+2vaCqL5B0ZrplI2I0IoYiYqgbDQPojqbh9+Qu/qeSDkTED6eUtktaXb1eLen57rcHoFdaudS3XNJvJO3V5KU+SVqvyfP+rZIWSTou6esRMdFkXVzq67ObbrqpWH/vvfc6Wv/KlSuL9RdeeKGj9ePKtXqpr+k5f0TsltRoZfdfSVMABgd3+AFJEX4gKcIPJEX4gaQIP5AU4QeS4qu7PwVuuOGGhrUdO3Z0tO5169YV6y+++GJH60d92PMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJc5/8UGBlp/C1pixYt6mjdr732WrHe7PsgMLjY8wNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUlznnwGWL19erD/xxBN96gSfJuz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpptf5bS+U9IykL0i6KGk0In5se4OkNZLOVrOuj4iXetVoZvfcc0+xPmfOnLbXfeTIkWL9/Pnzba8bg62Vm3wuSPp2RLxt+3OS9th+par9KCK+37v2APRK0/BHxClJp6rX52wfkHR9rxsD0FtXdM5ve7GkpZJ+W0163PbvbG+yPbfBMiO2x2yPddQpgK5qOfy250jaJulbEfEnSRslfVnSEk0eGfxguuUiYjQihiJiqAv9AuiSlsJve7Ymg78lIn4lSRFxOiI+joiLkn4i6a7etQmg25qG37Yl/VTSgYj44ZTpC6bM9jVJ+7rfHoBeaeWv/csk/Zukvbbfqaatl/SY7SWSQtIxSd/sSYfoyLvvvlus33///cX6xMREN9vBAGnlr/27JXmaEtf0gRmMO/yApAg/kBThB5Ii/EBShB9IivADSbmfQyzbZjxnoMciYrpL85/Anh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkur3EN1/lPR/U95/vpo2iAa1t0HtS6K3dnWztxtanbGvN/l8YuP22KB+t9+g9jaofUn01q66euOwH0iK8ANJ1R3+0Zq3XzKovQ1qXxK9tauW3mo95wdQn7r3/ABqUkv4ba+wfdD2YdtP1tFDI7aP2d5r+526hxirhkE7Y3vflGnzbL9i+w/Vz2mHSauptw22/7/67N6x/WBNvS20/b+2D9jeb/s/q+m1fnaFvmr53Pp+2G97lqRDkh6QNC7pLUmPRcTv+9pIA7aPSRqKiNqvCdv+F0nnJT0TEbdW056SNBER36t+cc6NiP8akN42SDpf98jN1YAyC6aOLC3pYUn/rho/u0Jfj6qGz62OPf9dkg5HxNGI+LOkX0haWUMfAy8idkm6fNSMlZI2V683a/I/T9816G0gRMSpiHi7en1O0qWRpWv97Ap91aKO8F8v6cSU9+MarCG/Q9IO23tsj9TdzDSuq4ZNvzR8+vya+7lc05Gb++mykaUH5rNrZ8Trbqsj/NN9xdAgXXJYFhG3S/pXSWurw1u0pqWRm/tlmpGlB0K7I153Wx3hH5e0cMr7L0o6WUMf04qIk9XPM5J+rcEbffj0pUFSq59nau7nrwZp5ObpRpbWAHx2gzTidR3hf0vSjba/ZPuzkr4haXsNfXyC7aurP8TI9tWSvqLBG314u6TV1evVkp6vsZe/MygjNzcaWVo1f3aDNuJ1LTf5VJcy/lvSLEmbIuK7fW9iGrb/UZN7e2nyicef19mb7WclDWvyqa/Tkr4j6TlJWyUtknRc0tcjou9/eGvQ27AmD13/OnLzpXPsPve2XNJvJO2VdLGavF6T59e1fXaFvh5TDZ8bd/gBSXGHH5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpP4CIJjqosJxHysAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting upper bound computation\n",
      "Upper bound of 5.8754072189331055 in 0.36 seconds\n",
      "---Initial Polytope---\n",
      "REJECT DICT:  {'dead_constraints': 63}\n",
      "Pushed 4/7 facets\n",
      "---Opening New Polytope---\n",
      "Bounds  0.5017218436074801   |   5.8754072189331055\n",
      "REJECT DICT:  {'dead_constraints': 63, 'seen before': 1}\n",
      "Pushed 2/6 facets\n",
      "---Opening New Polytope---\n",
      "Bounds  0.6112450300546508   |   5.8754072189331055\n",
      "REJECT DICT:  {'dead_constraints': 63, 'seen before': 1}\n",
      "Pushed 3/6 facets\n",
      "---Opening New Polytope---\n",
      "Bounds  0.611304953312115   |   5.8754072189331055\n",
      "REJECT DICT:  {'dead_constraints': 63, 'seen before': 2}\n",
      "Pushed 3/5 facets\n",
      "---Opening New Polytope---\n",
      "Bounds  0.7075559063906501   |   5.8754072189331055\n",
      "REJECT DICT:  {'dead_constraints': 63, 'seen before': 1}\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-3290dc902ed5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     26\u001b[0m         \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m         output = cert_obj.min_dist(image.view(1, -1), lp_norm='l_2', compute_upper_bound=\n\u001b[0;32m---> 28\u001b[0;31m                                    {\"optimizer_kwargs\": {\"lr\": 0.005}, \"num_iterations\": 100})\n\u001b[0m\u001b[1;32m     29\u001b[0m         \u001b[0mlp_dist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madv_ex_bound\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madv_ex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbest_example\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mboundary_facet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseen_poly_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/NN_cert_methods/geocert_oop.py\u001b[0m in \u001b[0;36mmin_dist\u001b[0;34m(self, x, lp_norm, compute_upper_bound)\u001b[0m\n\u001b[1;32m    522\u001b[0m                                                            \u001b[0mdomain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m                                          dead_constraints=self.dead_constraints)\n\u001b[0;32m--> 524\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_poly\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpop_el\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    525\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    526\u001b[0m                     \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/NN_cert_methods/geocert_oop.py\u001b[0m in \u001b[0;36m_update_step\u001b[0;34m(self, poly, popped_facet)\u001b[0m\n\u001b[1;32m    392\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m             \u001b[0;31m# Do optimizations in serial\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m             \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdist_fxn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparallel_args\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    396\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/NN_cert_methods/geocert_oop.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    392\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m             \u001b[0;31m# Do optimizations in serial\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m             \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdist_fxn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparallel_args\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    396\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/NN_cert_methods/geocert_oop.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(el)\u001b[0m\n\u001b[1;32m    380\u001b[0m         \u001b[0mchained_facets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitertools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_facets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madv_constraints\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    381\u001b[0m         \u001b[0mparallel_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mchained_facets\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m         \u001b[0mdist_fxn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mel\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlp_dist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    383\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    384\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/NN_cert_methods/_polytope_.py\u001b[0m in \u001b[0;36ml2_dist\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m   1099\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1100\u001b[0m         \u001b[0mquad_start\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1101\u001b[0;31m         \u001b[0mquad_program_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolvers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mosek'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1102\u001b[0m         \u001b[0mquad_end\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# print(\"QP SOLVED IN %.03f seconds\" % (quad_end - quad_start))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/DeepL/lib/python3.7/site-packages/cvxopt/coneprog.py\u001b[0m in \u001b[0;36mqp\u001b[0;34m(P, q, G, h, A, b, solver, kktsolver, initvals, **kwargs)\u001b[0m\n\u001b[1;32m   4351\u001b[0m         \u001b[0mopts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'mosek'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4352\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mopts\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4353\u001b[0;31m             \u001b[0msolsta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mopts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   4354\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4355\u001b[0m             \u001b[0msolsta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/DeepL/lib/python3.7/site-packages/cvxopt/msk.py\u001b[0m in \u001b[0;36mqp\u001b[0;34m(P, q, G, h, A, b, taskfile, **kwargs)\u001b[0m\n\u001b[1;32m    803\u001b[0m                 \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwritetask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtaskfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    804\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 805\u001b[0;31m             \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    806\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    807\u001b[0m             \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msolutionsummary\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmosek\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstreamtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/DeepL/lib/python3.7/site-packages/mosek/__init__.py\u001b[0m in \u001b[0;36moptimize\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   7622\u001b[0m         \u001b[0m_arg1_trmcode\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmosek\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrescode\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mIs\u001b[0m \u001b[0meither\u001b[0m \u001b[0mOK\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0ma\u001b[0m \u001b[0mtermination\u001b[0m \u001b[0mresponse\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   7623\u001b[0m       \"\"\"\n\u001b[0;32m-> 7624\u001b[0;31m       \u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mresargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__obj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizetrm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   7625\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   7626\u001b[0m         \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getlasterror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "##################################################################################\n",
    "#                                                                                #\n",
    "#                       Geocert                                                  #\n",
    "#                                                                                #\n",
    "##################################################################################\n",
    "\n",
    "\n",
    "\n",
    "# Run Geocert on a set of mnist digits\n",
    "min_dists = []\n",
    "pgd_dists = []\n",
    "num_polys = []\n",
    "poly_maps = []\n",
    "times = []\n",
    "for image in images:\n",
    "        # Builds an object used to to hold algorithm parameters\n",
    "        cert_obj = geo.IncrementalGeoCert(network, verbose=True, config_fxn='parallel', \n",
    "                                  config_fxn_kwargs={'num_jobs': 1},\n",
    "                                  hyperbox_bounds=[0.0, 1.0])\n",
    "    \n",
    "        plt.gray()\n",
    "        plt.imshow(image.squeeze())\n",
    "        plt.show()\n",
    "        true_label = network(image).squeeze().max(0)[1].item()\n",
    "        # Run Geocert \n",
    "        start = time.time()\n",
    "        output = cert_obj.min_dist(image.view(1, -1), lp_norm='l_2', compute_upper_bound=\n",
    "                                   {\"optimizer_kwargs\": {\"lr\": 0.005}, \"num_iterations\": 100})\n",
    "        lp_dist, adv_ex_bound, adv_ex, best_example, boundary_facet, seen_poly_map = output\n",
    "        end = time.time() \n",
    "        \n",
    "        min_dists.append(lp_dist)\n",
    "        pgd_dists.append(adv_ex_bound)\n",
    "        print('TIME', end-start)\n",
    "        times.append(end-start)\n",
    "        num_polys.append(len(seen_poly_map))\n",
    "        poly_maps.append(seen_poly_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =====================\n",
    "# Save Output\n",
    "# =====================\n",
    "\n",
    "output_dictionary = {'min_dists': min_dists , 'pgd_dists': pgd_dists, 'num_polys': num_polys,\n",
    "                    'poly_maps': poly_maps, 'times': times}\n",
    "\n",
    "# cwd = os.getcwd()\n",
    "# filename = cwd + \"/Data/Geocert_out.pkl\"\n",
    "# f = open(filename, 'wb')\n",
    "# pickle.dump(output_dictionary, f)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "import numpy as np\n",
    "from CertifiedReLURobustness.save_nlayer_weights import NLayerModel_comparison\n",
    "import tensorflow as tf\n",
    "from CertifiedReLURobustness.Lip_Lin import run_Lip_Lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # ##################################################################################\n",
    "# # #                                                                                #\n",
    "# # #                       Fast-Lin / Fast-Lip                                      #\n",
    "# # #                                                                                #\n",
    "# # ##################################################################################\n",
    "\n",
    "# tf.reset_default_graph()\n",
    "# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)\n",
    "# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))\n",
    "# with tf.Session() as sess:\n",
    "\n",
    "\n",
    "#     # =====================\n",
    "#     # Convert Network\n",
    "#     # =====================\n",
    "#     params = layer_sizes[1:-1]\n",
    "#     nlayers = len(params)\n",
    "#     restore = [None]\n",
    "#     for fc in network.fcs:\n",
    "#         weight = utils.as_numpy(fc.weight).T\n",
    "#         bias = utils.as_numpy(fc.bias)\n",
    "#         restore.append([weight, bias])\n",
    "#         restore.append(None)\n",
    "\n",
    "# #     FL_network = NLayerModel_comparison(params, restore=restore, session=sess)\n",
    "#     FL_network = NLayerModel_comparison(params, session=sess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # =====================\n",
    "# # Convert Images\n",
    "# # =====================\n",
    "# images_lin_lip = utils.as_numpy(images.reshape(-1, 28, 28,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # ==================================\n",
    "# # Run Fast-Lin / Fast-Lip on Images\n",
    "# # ==================================\n",
    "\n",
    "# true_labels = []\n",
    "# targets = []\n",
    "# for label in labels:\n",
    "#     if label == 0:\n",
    "#         true_labels.append([1.0, 0.0])\n",
    "#         targets.append([0.0, 1.0])\n",
    "#     elif label ==1:\n",
    "#         true_labels.append([0.0, 1.0])\n",
    "#         targets.append([1.0, 0.0])\n",
    "# true_ids = [num for num in range(0,np.shape(images)[0]+1)]\n",
    "# inputs = images_lin_lip\n",
    "\n",
    "\n",
    "# min_dists, times  = run_Lip_Lin(FL_network, inputs, targets, true_labels, true_ids, norm='2', numlayer=len(params))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =====================\n",
    "# Load and Display Output\n",
    "# =====================\n",
    "\n",
    "# output_dictionary = {'min_dists': min_dists , 'pgd_dists': pgd_dists, 'num_polys': num_polys,\n",
    "#                     'poly_maps': poly_maps, 'times': times}\n",
    "# cwd = os.getcwd()\n",
    "# filename = cwd + \"/Data/save/Geocert_out.pkl\"\n",
    "# f = open(filename,\"rb\")\n",
    "# output_dictionary = pickle.load(f)\n",
    "# f.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
