{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "import sys \n",
    "sys.path.append('..')\n",
    "from geocert import compute_boundary_batch, batch_GeoCert, incremental_GeoCert\n",
    "from plnn import PLNN\n",
    "from _polytope_ import Polytope\n",
    "import utilities as utils\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "# from convex_adversarial import robust_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===============Generating Training Points============\n",
      "===============Points Generated============\n"
     ]
    }
   ],
   "source": [
    "# apply incremental geocert to a normal and l1-regularized classifier. Finds maximal l_p balls\n",
    "# for random points in R^2.\n",
    "\n",
    "# ==================================\n",
    "# Generate Training Points\n",
    "# ==================================\n",
    "\n",
    "print('===============Generating Training Points============')\n",
    "# random points at least 2r apart\n",
    "m = 12\n",
    "np.random.seed(3)\n",
    "x = [np.random.uniform(size=(2))]\n",
    "r = 0.1\n",
    "while(len(x) < m):\n",
    "    p = np.random.uniform(size=(2))\n",
    "    if min(np.abs(p-a).sum() for a in x) > 2*r:\n",
    "        x.append(p)\n",
    "\n",
    "epsilon = r/2\n",
    "\n",
    "X = torch.Tensor(np.array(x))\n",
    "torch.manual_seed(1)\n",
    "y = (torch.rand(m)+0.5).long()\n",
    "\n",
    "print('===============Points Generated============')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===============Initializing Network============\n",
      "[2, 8, 10, 10, 2]\n",
      "===============Training Network============\n",
      "error: tensor(0.5833)\n",
      "error: tensor(0.1667)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n",
      "error: tensor(0.)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/torch/tensor.py:263: UserWarning: non-inplace resize is deprecated\n",
      "  warnings.warn(\"non-inplace resize is deprecated\")\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAHWCAYAAABNK0FcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XlwHOd55/HfMzMACBDgCfAEeEgiJZK6KFG0vPIhW0osyV6pdh07UsoV29FaqiROUmvHWW9F5bgc1VYl3jhVqdWupd11HLsqdpRUrcNNZMtlW1750EWdFElRgiiRBE9QvA8cM/3sH4BoEMIxBGa63+7+fqpQhZluYR41e/rXz9vv9Ji7CwAAhKOQdAEAAOB8hDMAAIEhnAEACAzhDABAYAhnAAACQzgDABCYScPZzL5hZofM7OVxlpuZ/Y2ZdZvZS2Z2Te3LBAAgP6rpnL8p6ZYJlt8qadXwzz2S/sf0ywIAIL8mDWd3f1zSkQlWuUPSt3zIk5LmmNniWhUIAEDe1OKa81JJe0Y87hl+DgAATEEpzhczs3s0NPStmTNnXnvZZZedt/zgsTirwXR0DO5NugQgUb0N9CAY38I573zu2WefPezuHdX897UI572SukY87hx+7h3c/SFJD0nShg0bfPPmzect/9o/c5/vtLj3wH1JlwAk6sFF9yddAgL2uTvsHc+Z2a5q//taDGtvkvTbw7O2r5d03N331+DvAkCQCGbU26Sds5l9R9KNktrNrEfSn0lqkCR3/7qkRyTdJqlb0hlJn65XsQAA5MGk4ezud02y3CX9fs0qQio8uOh+hrYBoE64QxgAAIEhnAEACAzhjCljUgzyiP0ecSCcAQAIDOEMAEBgCGdMC0N8AFB7hDMAAIEhnAGgSowUIS6EM6aNAxYA1BbhDABAYAhnAAACQzijJhjaBoDaIZwBoAqcgCJOhDMAAIEhnFEzdBYAUBuEMwAAgSGcAQAIDOEMVCkqRxo42afKYCXpUhAzLtkgbqWkC0C2PLjoft174L6ky6ipgZN9euP/btXBZ3ZL7nJJ7Vcu0UW3X6Hm9plJlwcggwhnYAIDJ/q0+S9+pMFT/fKKn3u+9/keHd1+UNf88QfVsrAtwQoBZBHD2sAEXv/eFg2cPD+YJUkulfsGteM7zyZTGIBMI5xRc1m5PlcZqKj3uT1S5GOv4NLJN4+o/9jZeAsDkHmEMzCOwdP9UsEmXMdKRfUdORNTRUhCVk42kS6EMzCOUnPDO4ezR/FKpIa2ppgqApAXhDPqIgvdRmlGg+ZeumDCdZrbZ6qlozWmigDkBeEMTODif3+lik1jf6ih0FDUqt9cH3NFAPKAcAYmMHPRLK3/3I1qXTZHhYaiijNKKjQW1bKwTVf+/ns055KOpEsEkEF8zhl1k5UbkrQunaMNf3KzzvaeUt/RM2psm6GZi2clXRZikIXLM0gnwhmoUnNHq5q5vgwgBgxrAwAQGMIZdcWwIABcOMIZAMbAiSWSRDgDABAYwhl1RwcCABeGcAYAIDB8lAoAEIwoch3Y85Z6Xu9VuVxRx+I5Wr56oRqbGpIuLVaEMwCMwqWYZPSdGdDj//Ki+s4OqDxYkST17j2mbc++qetvXquFnfMSrjA+DGsjFhzsAEzE3fXzH2zR6ZNnzwWzJFUqkSrlSE/+aJtOn8zPd6cTzgCAxB05dFKnT5yVj/MtrVHk6n55b7xFJYhwBgAk7tDeI6qUo3GXe+Tav/tIjBUli3BGbBjaBjCe8Trm81eqexnBIJwBYAROIpOxYMkcFUvjR5KZtLBrbowVJYtwBgAkbv6i2WppbZLZ2MsLxYJWXd4Zb1EJIpwRK7oSAGMxM91wyxWaMbNJpVLx3POFoqlYLOi6D6xR6+zmBCuMF59zBgAEoaV1hj70seu0983D2tN9SJVypI4ls7Xi0sWa0dKYdHmxIpwBAMEoFAvquniBui5ekHQpiWJYG7FjaBuhYt9EKAhnAAACQzgDABAYwhmJYPgQAMZHOAMAEBjCGQDEaA7CQjgjMRwMAWBshDMAAIEhnAEACAzhjEQxtA0A70Q4A8g9ThIRGsIZAIDAEM4AAASGcEbiGFIEgPMRzgAABIZwBpBrjNwgRIQzgsABEgB+hXAGACAwhDMAAIEhnBEMhrYBYAjhDCC3OCFEqAhnAAACQzgjKHQyAEA4AwAQHMIZQC4xSoOQEc4IDgdNAHlHOAMAEBjCGQCAwBDOCBJD2wDyjHAGkDuc/CF0hDMAAIEhnAEACExV4Wxmt5jZDjPrNrMvjrF8mZk9ZmbPm9lLZnZb7UtF3jD0CCCvJg1nMytKekDSrZLWSrrLzNaOWu0+SQ+7+3pJd0r677UuFACAvKimc94oqdvdd7r7gKTvSrpj1Douadbw77Ml7atdiQBQO4zIIA2qCeelkvaMeNwz/NxIX5b0CTPrkfSIpD+oSXXIPQ6kAPKoVhPC7pL0TXfvlHSbpG+b2Tv+tpndY2abzWxzb29vjV4aAIBsqSac90rqGvG4c/i5ke6W9LAkufsTkmZIah/9h9z9IXff4O4bOjo6plYxAAAZV004PyNplZmtNLNGDU342jRqnd2SbpIkM1ujoXCmNUZNMLQNIG8mDWd3L0v6rKRHJW3X0KzsrWb2FTO7fXi1z0v6jJm9KOk7kj7l7l6vogFgKjjRQ1qUqlnJ3R/R0ESvkc99acTv2yTdUNvSAADIJ+4QhlSg4wGQJ4QzAACBIZwBAAgM4YzUYGgb08H+gzQhnAEACAzhDABAYAhnpApDkwDygHAGACAwhDOAzGPEBWlDOAMAEBjCGalDFwQg6whnAAACQzgDABAYwhmpxNA2qsW+gjQinAEACAzhDABAYAhnpBbDlQCyinAGACAwhDOAzGJ0BWlFOCPVOPgCyCLCGQCAwBDOAAAEhnBG6jG0DSBrCGcAmcRJG9KMcAYAIDCEMzKBLglAlhDOAAAEhnAGkDmMpCDtCGcAAAJDOCMz6JYAZEUp6QIApFdloKwDT+3Svp/vVPn0gJoXtKrrptWat3aRzCzp8oDUIpwBTMnAqX49/1c/Uf/xPkUDFUlS/7GzOvHmEbVfsURrPrlRViCggalgWBuZwtB2fF751jPqO3LmXDC/LRqo6PCWfdr/yzcSqYt9AFlAOAO4YP3HzurYq4fkFR9zeTRQ0e4fvhJzVUB2EM4ALtjJPUdlpYkPH31HzygarEy4DoCxEc7IHIY1669QKlaxlsmKHGKAqeCdA+CCzb64XYrGHtJ+25xV7UwIA6aIcEYm0T3XV7GxqK6bL1WhcewOutBQ1MqPrIu5Kv7dkR2EM4ApWX7rGi25YaWsVJA1DB1KCk0lFRqLWvOpjZp9UXvCFQLpxeecAUyJmemSj16trpsu1aHnezR4ql8tC9rUsX6pio0cWoDp4B2EzHpw0f2698B9SZeReU1zmtX1gVVJlwFkCsPaAAAEhnAGkAlMBkOWEM7INA7YANKIcAYAIDBMCAMAYAIeuY70nlR5sKy2OS1qaZ1R99cknJF5zNoGMFW7Xj2gLU/vVDT8JS9R5Jq/cJauff+lapnZVLfXZVgbQOoxtwD1sHP7Pr3wy24N9JVVHqyoPFhRVIl0eP8xPfa959TfN1i31yackQscvAFciEq5oi1P7VSlHL1jmbs02F9W95aeur0+4QwAwCgHe47KbPwvboki15s7DtTt9QlnAABG6e8blPvE37w2OFiu2+sTzsgNhrYBVGvmrBkTds6S1NzChDAAGBMnXaiHjsVzVCyN/ZWoklQsFXTJ5Uvr9vqEMwAAo5iZNn7wMhVL74zJQrGg2fNmasVli+v2+oQzcoUuC0C1OhbP0fs/crUWds6VmckKpoamki69qkvv/fBVKhbrF6HchAQAgHHMaW/VDbdcoagSqVKJVGooTnotuhbonJE7dM8ALlShWFBDYymWYJYIZwApxokWsopwBgAgMIQzcomOC0DICGcAAAJDOANIJUY/kGWEM3KLgzuAUBHOAAAEhnBGrtE9AwgR4QwAQGAIZwCpw4gHso5wRu5xoAcQGsIZAIDAEM6A6J4BhIVwBgAgMIQzgFRhlAN5QDgDwzjoAwgF4QwAQGAIZwAAAkM4AyMwtA0gBIQzgNTg5Al5QTgDoxAAAJJWVTib2S1mtsPMus3si+Os83Ez22ZmW83s72tbJgAA+VGabAUzK0p6QNKvSeqR9IyZbXL3bSPWWSXpP0u6wd2PmtmCehUMAEDWVdM5b5TU7e473X1A0ncl3TFqnc9IesDdj0qSux+qbZlAvBjaBpCkSTtnSUsl7RnxuEfSu0ats1qSzOwXkoqSvuzuP6hJhQCg+E+YThw9o9e27FHvvuOygrR0RbsuWrdULTObYq0D+VRNOFf7d1ZJulFSp6THzewKdz82ciUzu0fSPZK0bNmyGr00ANTWnu6Deu5nrymKIrkPPde9da92btunG265QvMXzU62QGReNcPaeyV1jXjcOfzcSD2SNrn7oLu/IelVDYX1edz9IXff4O4bOjo6plozEAuGtvPp9Mk+Pfez11Sp/CqYJSmquMrlSL949GWVy5XkCkQuVBPOz0haZWYrzaxR0p2SNo1a53sa6pplZu0aGubeWcM6ASAWr2/dq2hkKo/i7urZ2RtjRcijScPZ3cuSPivpUUnbJT3s7lvN7Ctmdvvwao9KesvMtkl6TNIX3P2tehUNxIXuOX8O7z8uj8YP50o50uH9x2OsCHlU1TVnd39E0iOjnvvSiN9d0ueGfwCgpuI8SSqUJh9QLFWxDjAd7GEAMMKySxaoOEH4FksFLV3JnBnUF+EMTIKh7XxZdslClUrFMZdZwdQ6u1nti5mtjfoinAFghFJDUe//t1erpbXpvJAulgqaM79V77n1SplZghUiD2r1OWcAyIzW2c360G9u1KG9x/TWweOygmlh5zzN62hLujTkBJ0zUAWGtpOT1LY3My3snKu1167QmvXLCWbEinAGACAwhDNQJbrnZNx74L6kSwBiRzgDABAYwhlA8OiekTeEM3ABGNoGEAfCGUAq0D0jTwhnAAACQzgDF4ih7eTQPSMvCGcAAAJDOANTQPecHLpn5AHhDABAYAhnAKlD94ysI5yBKWJoG0C9EM4AUonuGVlGOAMAEBjCGZgGhraTRfeMrCKcAQAIDOEMTBPdc7LonpFFhDMAAIEhnIFponNLHv8GyBrCGQCAwBDOADKB7hlZQjgD00AgAKgHwhmJc3dFUZR0GcgATpaQFaWkC0B+nT5xVtuf26Wenb2KIlfTjAZdvG6JVl3RpWIp/PNGggBAvYR/BEQmHT9yWj/+P89p9+uHFEUuServG9QrL+zR//uXF1QpVxKuEGnFSROygHBGIp55bLvKgxXJz38+qkQ6cfSMul/em0xhVSIAANQT4YzYHX/rlE6f7Bt3eVSJ1L017HBG2Dh5QtoRzojdqRNnZWYTrtN/dvDccHdoOPADqDfCGbFraJp8HmKhWNAk+Q1MiJMopBnhjNi1L5ojK0yQvCZ1XtwxaXedBA74AOJAOCN2hYLpynddpGJx7N2vVCpqzfplMVeFLOJkCmnF55yRiOWrF0mSXnpqp3z42rK7a2ZbszZ+8DLNbGtOsrwxcaAHEBfCGYlZvnqRui5ZqLcOHNfAQFmts5o1e97MpMtCxtx74D6+cxupQzgjUYWCqWPJnKTLmBRdM4A4cc0ZQOZxcoW0IZwBAAgM4QxMgq4rG/h3RJoQzgAABIZwBiZAt5Ut/HsiLQhnALlCQCMNCGdgHBzEASSFcAaQO5x4IXSEMzAGDt4AkkQ4A8glTsAQMsIZGIWDNoCkEc4AcosTMYSKcAZG4GANIASEM4Bc44QMISKcgWEcpAGEgnAGkHucmCE0hDMgDs4AwkI4A4A4QUNYCGfkHgdlAKHJfDgfPXxSW57eqed//pp2vXpA5XIl6ZIQEIIZI7E/IBSlpAuol/JgRb/84cs62ntSlXIkSdrdfVAvPvG6rv+1tVqwZG7CFSJpHIgBhCqznfNTP9muIwdPnAtmSaqUI5UHK3rih1t16vjZBKtD0ghmjId9AyHIZOd86vhZ9e47pijyMZdHlUivbenR+vesirkyJI0DL4A0yGTnfLDniKSxg1mS3KV9uw7HVxASd++B+whmVI19BUnLZDhHkU+UzZIkH6erRvZwoAWQNpkM5/kLZ8kKNuk6yDa6ZUwH+w6SlMlrznM72tTSOkOnjp+Rj9EgF4sFrb5qWfyFIRYcVAGkXSY7ZzPTDR+6XE3NjSqWRvwv2lAwr7tuBZ1zBtEpo9bYn5CUTHbOktTSNkO//rHrtLv7oHa9elCVckVzO9q06opOzZo7M+nyUEMcQAFkTWbDWZJKDUVdtGaJLlqzJOlSUCcEM+rt3gP36cFF9yddBnIm0+GM7CKUAWRZJq85I7u4rowksM8hbnTOSAUOjgDyhM4ZwSOYEQL2Q8SJzhnB4mAIIK/onBEcrisjVOyXiAudM4LBgQ8AhtA5IwgEM9KCfRVxqCqczewWM9thZt1m9sUJ1vuombmZbahdicgyhrAB4J0mDWczK0p6QNKtktZKusvM1o6xXpukP5L0VK2LRPYQykgz9l3UWzXXnDdK6nb3nZJkZt+VdIekbaPW+3NJfyHpCzWtEJnCQQ0AJlfNsPZSSXtGPO4Zfu4cM7tGUpe7/2sNa0OG0Ckja9ifUU/TnhBmZgVJX5P0+SrWvcfMNpvZ5t7e3um+NFKCgxgAXJhqwnmvpK4RjzuHn3tbm6TLJf3UzN6UdL2kTWNNCnP3h9x9g7tv6OjomHrVSAW6ZWQd+zfqpZprzs9IWmVmKzUUyndK+q23F7r7cUntbz82s59K+mN331zbUpEWHLAAYHom7ZzdvSzps5IelbRd0sPuvtXMvmJmt9e7QKQHnTLyiH0e9VDVHcLc/RFJj4x67kvjrHvj9MtCmnBwAoDa4vadmBICGfiVew/cpwcX3Z90GcgQwhkXhFAGgPrj3tqoCteTgYnx/kAtEc6YEKEMVI/3CmqFYW2MiYMMACSHcMY5BDIwfUwOQy0QziCUM2TwVL+OvHJIXok0a8U8tSxsS7okAFNAOOcYoZwdUSXSaw8/r4NP7ZIVC3K5FLlau+Zq3d3Xq2l2c9Il5grdM6aLcM4hQjl7Xvn2Mzr84j5F5UgqR+eeP/HmET33X3+ijfd9SMUm3u5AWvBuzQkCObvO9J7S4Rf3KhqM3rkwcg2eHtCBp3dp6Xsvjr+4HKN7xnQQzhlHKGffoWf3yCMfd3k0UNH+X7xBOAMpwuecM4rPJ+dH5cyAvDJ+OEtSuW8wpmowEu9BTBWdc8ZwMMifmUvnqNBUUtRfHned1iWzY6wIwHTROWcEnXJ+daxfKptgeaGxqM6bVsdWD87H+xJTQeecYrzpIUnFxpLW/s67tPV/PamoXJFGjHAXGotacsNFmnNxe3IFArhgdM4pRJeM0eavW6xrPv8BdVy1VIWGoqxoals2V2s+uVGXfPSqpMvLPd6vuFB0zinCGxwTae2co3X/4d1JlwGgBuicU4BOGUg/3sO4EHTOAePNDAD5RDgHhkAGsou7hqFahHMgCGUAwNu45pwwricD+cL7HdWgc04Ib1AAwHgI55gRygCAyRDOMSCQAYzExDBMhnCuI0IZADAVhHMdEMoAgOkgnGuIUAYA1ALhXAOEMgCglgjnKSKQAUwHk8IwEcL5AhHKAIB6I5yrRCgDAOJCOE+CUAYAxI17a4+De14DqDeOMRgPnfMIvFEAACEgnEUoAwDCkutwJpQBACHKZTgTyvXXd+SM9j7erSPbD8oKpo5rOrXk3SvV0NqUdGkAELxchTOhHI/el/Zq+zeflldcXokkSWcOnNTuH+7Q1X/4PrV1zU24QiAc3IwEY8l8OBPI8eo7ekbb//ZpRYOV856PBivSYEUv/ref6d/c/2EVGooJVQgA4ctsOBPKydj7+Oty93GXezlS7wt7tfC6ZTFWBQDpkrlwJpSTdfSVg/JyNO7ySn9ZR189RDgDwAQyE86EchisOPl9bQol7n0DABNJfTgTymFZcE2nTu87rmigMubyYlNJ7VcujbkqQDq556je/P52Hdl2QB65WpfM0vIPrVH71UtlZonWxqQwjJbKcCaQw7Xo+hXa9YPtY4dzwdQ0p1lzL10Qf2HItcNb9mnbN55SVK5Iw1MiTvUc1/ZvP6PFrx/Wqt+4OtkCgVFSNb7I/a7D19DSqKv/6EY1tDWp2PSrc79CU0ktHa266g/fJysk26UgXyoDZW3726eGPjEwaq5iNFDR/l++oWPdvckUB4wjFZ0zgZwurUtn6933f1iHX9qnYzsOyYoFtV+5WHNWL0h8+BD5c+i5HmmC3S4aqGjPT17TnEs64isKmETQ4Uwop1ehWNCC9Z1asL4z6VKQc2f2HVfUP/YciLed3nc8pmrGx3VnjBRkOBPKAGql2NIoK5q8Mv7n70vNDTFWBEwuqHAmlAHU2oJru7T7B9vloy84Dys0FrX4hpUxVwVMLFUTwgDgQrV0tKpjfefYt4wtmBrbmrTouuXxFwZMIKjOGQDq4dJPbFCxuUEHnnhj6EY5LnkUqW35PK37nevP+2QBEAL2SACZVygWtPrj67XyI+t0dMcheTlS24p5auloTbq08zApDG8jnAHkRkNLI58gQCpwzRkAgMAQzgAABIZwBgAgMIQzAASE+z1AIpwBAAgO4QwAQGAIZwAAAkM4A0BguO4MwhkAgMAQzgAABIZwBgAgMIQzAACBIZwBIEBMCss3whkAgMAQzgAABIZwBgAgMIQzAACBIZwBIFBMCssvwhkAgMAQzjFx96RLAACkRCnpArJs4GSfdv9wh/Y/8aYqfYNqmNmoJe+9WF03rVapuSHp8gAAgSKc66T/2Fk9+5c/1uDpAXklkiQNnh7Q7h/t0KFn9+jaP7mJgAYAjIlh7TrZ8Z1nNXCq71wwv83LkfqOnNbOTVsSqgxAmjApLJ8I5zoYONmnozsOSdHYy73iOvDULkWDlXgLAwCkQlXhbGa3mNkOM+s2sy+OsfxzZrbNzF4ysx+b2fLal5oeZw+fVqE0+aYdONkfQzUAgLSZNEHMrCjpAUm3Slor6S4zWztqteclbXD3KyX9k6S/rHWhaVJqbpBHE8/O9oqrOINL/gCAd6qmc94oqdvdd7r7gKTvSrpj5Aru/pi7nxl++KSkztqWmS4tC9vUOGvGhOvMWjlPDS2NMVUEAEiTasJ5qaQ9Ix73DD83nrslfX86RaWdmWnVx65WoaE45vJCQ1EX/7srY64KQFoxKSx/ajohzMw+IWmDpK+Os/weM9tsZpt7e3tr+dLBmb9usdZ++l1qaG1Ssal07qdpbrOu/L33aNbyeUmXCAAIVDUXPfdK6hrxuHP4ufOY2c2S/lTS+919zJlO7v6QpIckacOGDZm/ZVb7lUs0//LFOtbdq4ET/Zoxr1mzVs6XmSVdGgAgYNWE8zOSVpnZSg2F8p2SfmvkCma2XtKDkm5x90M1rzLFrGCau3pB0mUAAFJk0mFtdy9L+qykRyVtl/Swu281s6+Y2e3Dq31VUqukfzSzF8xsU90qBgAg46r6LI+7PyLpkVHPfWnE7zfXuC4AAHKLO4QBQAowYztfCGcAAAJDOAMAEBjCGQCAwBDOAAAEhnAGgJRgUlh+EM4AAASGcAYAIDCEMxCAykBZAyf6FFWipEsBEICq7hAGoD5OHzihnf+8RUe2HRj6QpSCadH1K7TyI+v4vm8gxwhnICEne47phb/+qSr9ZUmSa+iL2vb/YqeObD2ga//TTQQ0kFMMawMJeeXvnj4XzCN5xdV/7Kx2fX97AlUhdMzYzgfCGUjA6f0ndPat0+Mu90qk/b98Qx5l/mvPAYyBcAYS0HfktKxgE64TDVZUGazEVBGAkBDOQAIaWpukybrigqnYUIynIABBIZyBBLQtm6vSRJO9TFpwTdek3TWAbCKcgQSYmVbfdY0KY3XGJpVmNGjlR9bGXxhSgUlh2Uc4AwmZv26xLv/MuzVjfosKjUUVZzSo0FDQ7JXzdc0XPqgZ82YmXSKAhPA5ZyBB89Yu0ru+fKtO7z+hwVP9am5v1Yx5LUmXBSBhhDOQMDNT65LZSZcBICAMawMAEBjCGQCAwBDOAJBCzNjONsIZAIDAEM4AAASGcAYAIDCEMwAAgSGcASClmBSWXYQzAACBIZwBAAgM4QwAQGAIZwAAAkM4A0CKMSksmwhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzAACBIZwBAAgM4QwAQGAIZwBIsQcX3Z90CagDwhkAgMAQzgAABIZwBgAgMIQzAKQU15uzi3AGACAwhDMAAIEhnAEghRjSzjbCGQCAwBDOAAAEhnAGACAwhDMApAzXm7OPcAYAIDCEMwAAgSGcAQAIDOEMACnC9eZ8IJwBAAgM4QwAQGAIZwAAAkM4A0BKcL05PwhnAAACQzgDABAYwhkAUoAh7XwhnAEACAzhDABAYAhnAAACQzgDQOC43pw/hDMAAIEhnAEACAzhDABAYAhnAAgY15vziXAGACAwhDMAAIEhnAEgUAxp5xfhDABAYAhnAAACQzgDABCYqsLZzG4xsx1m1m1mXxxjeZOZ/cPw8qfMbEWtCwWAPOF6c75NGs5mVpT0gKRbJa2VdJeZrR212t2Sjrr7JZL+WtJf1LpQAADyoprOeaOkbnff6e4Dkr4r6Y5R69wh6e+Gf/8nSTeZmdWuTAAA8qOacF4qac+Ixz3Dz425jruXJR2XNL8WBQIAkDelOF/MzO6RdM/ww1NmtmOM1dolHY6vqlxh29YP27Z+crpt/0scL5LTbVt/nx972y6v9r+vJpz3Suoa8bhz+Lmx1ukxs5Kk2ZLeGv2H3P0hSQ9N9GJmttndN1RRFy4Q27Z+2Lb1w7atH7Zt/Ux321YzrP2MpFVmttLMGiXdKWnTqHU2Sfrk8O+/Iekn7u5TLQoAgDybtHN297KZfVbSo5KKkr7h7lvN7CuSNrv7Jkn/W9K3zaxb0hENBTgAAJiCqq45u/sjkh4Z9dyXRvzeJ+ljNappwmFvTAvbtn7YtvXDtq0ftm39TGvbGqPPAAD/AVx0AAADLElEQVSEhdt3AgAQmMTCmVuC1kcV2/VzZrbNzF4ysx+bWdVT+zH59h2x3kfNzM2MmbBVqGa7mtnHh/fdrWb293HXmFZVHBOWmdljZvb88HHhtiTqTCMz+4aZHTKzl8dZbmb2N8Pb/iUzu6bqP+7usf9oaGLZ65IuktQo6UVJa0et83uSvj78+52S/iGJWtP0U+V2/YCkluHff5ftWtvtO7xem6THJT0paUPSdYf+U+V+u0rS85LmDj9ekHTdafipcts+JOl3h39fK+nNpOtOy4+k90m6RtLL4yy/TdL3JZmk6yU9Ve3fTqpz5pag9THpdnX3x9z9zPDDJzX0uXVUp5r9VpL+XEP3l++Ls7gUq2a7fkbSA+5+VJLc/VDMNaZVNdvWJc0a/n22pH0x1pdq7v64hj6hNJ47JH3LhzwpaY6ZLa7mbycVztwStD6q2a4j3a2hszpUZ9LtOzxs1eXu/xpnYSlXzX67WtJqM/uFmT1pZrfEVl26VbNtvyzpE2bWo6FP5fxBPKXlwoUek8+J9fadCIeZfULSBknvT7qWrDCzgqSvSfpUwqVkUUlDQ9s3ami053Ezu8LdjyVaVTbcJemb7v5XZvZuDd2z4nJ3j5IuLM+S6pwv5JagmuiWoDhPNdtVZnazpD+VdLu798dUWxZMtn3bJF0u6adm9qaGrjFtYlLYpKrZb3skbXL3QXd/Q9KrGgprTKyabXu3pIclyd2fkDRDQ/eFxvRVdUweS1LhzC1B62PS7Wpm6yU9qKFg5rrdhZlw+7r7cXdvd/cV7r5CQ9f0b3f3zcmUmxrVHA++p6GuWWbWrqFh7p1xFplS1Wzb3ZJukiQzW6OhcO6Ntcrs2iTpt4dnbV8v6bi776/mP0xkWNu5JWhdVLldvyqpVdI/Ds+v2+3utydWdIpUuX1xgarcro9K+nUz2yapIukL7s5I2iSq3Lafl/Q/zew/amhy2KdohKpjZt/R0Elj+/A1+z+T1CBJ7v51DV3Dv01St6Qzkj5d9d/m3wAAgLBwhzAAAAJDOAMAEBjCGQCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYP4/iZHVfPRPlw8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# ==================================\n",
    "# Initialize Network\n",
    "# ==================================\n",
    "\n",
    "print('===============Initializing Network============')\n",
    "layer_sizes = [2, 8, 10, 10, 2]\n",
    "network = PLNN(layer_sizes)\n",
    "net = network.net\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Train Network\n",
    "# ==================================\n",
    "\n",
    "print('===============Training Network============')\n",
    "opt = optim.Adam(net.parameters(), lr=1e-3)\n",
    "for i in range(4000):\n",
    "    out = net(Variable(X))\n",
    "    l = nn.CrossEntropyLoss()(out, Variable(y))\n",
    "    err = (out.max(1)[1].data != y).float().mean()\n",
    "    if i % 400 == 0:\n",
    "        print('error:', err)\n",
    "    opt.zero_grad()\n",
    "    (l).backward()\n",
    "    opt.step()\n",
    "\n",
    "# ==================================\n",
    "# Visualize:  classifier boundary\n",
    "# ==================================\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))\n",
    "X0 = Variable(torch.Tensor(np.stack([np.ravel(XX), np.ravel(YY)]).T))\n",
    "y0 = network(X0)\n",
    "ZZ = (y0[:,0] - y0[:,1]).resize(100, 100).data.numpy()\n",
    "\n",
    "_, ax = plt.subplots(figsize=(8,8))\n",
    "ax.contourf(XX,YY,-ZZ, cmap=\"coolwarm\", levels=np.linspace(-1000,1000,3))\n",
    "ax.scatter(X.numpy()[:,0], X.numpy()[:,1], c=y.numpy(), cmap=\"coolwarm\", s=70)\n",
    "ax.axis(\"equal\")\n",
    "ax.axis([0, 1, 0, 1])\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import geo_attack as ga"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "atk = ga.NaiveGeoCrawl(network, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X:  tensor([[0.2293, 0.5624]])\n",
      "LOGITS: tensor([[ 4.9700, -5.4038]], grad_fn=<ThAddmmBackward>)\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n",
      "\n",
      "--------------------------------------------------\n",
      "ITERATION 00\n",
      "RUNNING POINT: [[0.23938 0.54029]]\n",
      "LOSS: -9.782362937927246\n",
      "--------------------------------------------------\n",
      "-2.9802322e-08\n",
      "--------------------------------------------------\n",
      "ITERATION 01\n",
      "RUNNING POINT: [[0.24124 0.53621]]\n",
      "LOSS: -9.587063789367676\n",
      "--------------------------------------------------\n",
      "--------------------------------------------------\n",
      "ITERATION 02\n",
      "RUNNING POINT: [[0.25927 0.49648]]\n",
      "LOSS: -7.690087795257568\n",
      "--------------------------------------------------\n",
      "--------------------------------------------------\n",
      "ITERATION 03\n",
      "RUNNING POINT: [[0.27798 0.45526]]\n",
      "LOSS: -4.03035831451416\n",
      "--------------------------------------------------\n",
      "--------------------------------------------------\n",
      "ITERATION 04\n",
      "RUNNING POINT: [[0.34846 0.3305 ]]\n",
      "LOSS: 8.444307327270508\n",
      "--------------------------------------------------\n",
      "tensor([[0.3008, 0.4150]])\n",
      "0.16385962\n",
      "tensor([[-0.1069, -0.1069]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand(1, 2)\n",
    "label = torch.max(network(x).squeeze(), dim=0)[1]\n",
    "print(\"X: \", x)\n",
    "print(\"LOGITS:\", network(x))\n",
    "print('~' * 50, '\\n')\n",
    "\n",
    "output = atk.attack(x, label)\n",
    "print(output.data)\n",
    "print(np.linalg.norm(output.data - x))\n",
    "print(network(output).data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = network.relu_config(x.double(),False)\n",
    "network.double()\n",
    "network.compute_polytope_config(conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf[0].double()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "direction = running - prev \n",
    "poly = network.compute_polytope_config(config, False)\n",
    "A = poly['total_a']\n",
    "b = poly['total_b']\n",
    "\n",
    "relu_set = set()\n",
    "for i in range(1, 1000):\n",
    "    new_point = prev + i / 1000.0 * direction\n",
    "    \n",
    "    config_ = utils.flatten_config(network.relu_config(new_point, False))\n",
    "    relu_set.add(config_)\n",
    "    \n",
    "print(len(relu_set))\n",
    "print('', utils.flatten_config(config), '\\n', list(relu_set)[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network(prev), network(running)\n",
    "med = (prev + running) / 2.0\n",
    "diff = running - prev\n",
    "config = network.relu_config(med, False)\n",
    "poly = network.compute_polytope_config(config)\n",
    "b = poly['total_b']\n",
    "a = poly['total_a']\n",
    "c = loss_op\n",
    "#print(c.shape, torch.matmul(a, prev))\n",
    "cta = torch.matmul(c, a)\n",
    "\n",
    "num = torch.matmul(-b, c) - torch.matmul(cta, prev.squeeze())\n",
    "denom = (torch.matmul(cta, diff.squeeze()))\n",
    "num / denom\n",
    "\n",
    "print(network(prev + num / denom * diff))\n",
    "z_output = prev + num / denom * diff\n",
    "print(z_output)\n",
    "#print(torch.matmul(matrix, line[10].squeeze()) - torch.matmul(matrix, line[2].squeeze()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atk._interpolate_final_polytope(prev, running, loss_op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "poly = network.compute_polytope_config(config, False)\n",
    "prev_np = prev.data.numpy().squeeze()\n",
    "final_np = running.data.numpy().squeeze()\n",
    "direction = running - prev\n",
    "#torch.matmul(loss_op, poly['total_a'])\n",
    "cta = torch.matmul(loss_op, poly['total_a'])\n",
    "#(-poly['total_b'].dot(loss_op) - cta.dot(prev.squeeze())) / (cta.dot(direction.squeeze()))\n",
    "print(network(prev))\n",
    "\n",
    "\n",
    "print(network(prev + 0.385 * direction))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = network.relu_config((prev + running) / 2, False)\n",
    "b = config\n",
    "utils.config_hamming_distance(a, b)\n",
    "#torch.matmul(poly['total_a'], z.squeeze()) + poly['total_b']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_arr = np.array([False, False, False, True])\n",
    "#print(test_arr)\n",
    "print(test_arr)\n",
    "np.nonzero(test_arr)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "q = np.random.random(4)\n",
    "z = q.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "q, z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3",
   "language": "python",
   "name": "py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
