{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==================================\n",
    "# Experiment 2\n",
    "# ==================================\n",
    "\n",
    "# demonstration to find the maximal l_p ball at a point x_0, within which,\n",
    "# the class label of a random classifier is equal to C(x_0).\n",
    "# network is trained to classify random points in R^2. The effects of l_1\n",
    "# regularization in training are demonstrated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "import sys \n",
    "sys.path.append('..')\n",
    "sys.path.append('../mister_ed')\n",
    "\n",
    "from geocert_oop import BatchGeoCert, IncrementalGeoCert\n",
    "from plnn import PLNN\n",
    "from _polytope_ import Polytope\n",
    "import utilities as utils\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "# from convex_adversarial import robust_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===============Generating Training Points============\n",
      "===============Points Generated============\n"
     ]
    }
   ],
   "source": [
    "# apply incremental geocert to a normal and l1-regularized classifier. Finds maximal l_p balls\n",
    "# for random points in R^2.\n",
    "\n",
    "# ==================================\n",
    "# Generate Training Points\n",
    "# ==================================\n",
    "\n",
    "print('===============Generating Training Points============')\n",
    "# random points at least 2r apart\n",
    "m = 12\n",
    "np.random.seed(3)\n",
    "x = [np.random.uniform(size=(2))]\n",
    "r = 0.1\n",
    "while(len(x) < m):\n",
    "    p = np.random.uniform(size=(2))\n",
    "    if min(np.abs(p-a).sum() for a in x) > 2*r:\n",
    "        x.append(p)\n",
    "\n",
    "epsilon = r/2\n",
    "\n",
    "X = torch.Tensor(np.array(x))\n",
    "torch.manual_seed(1)\n",
    "y = (torch.rand(m)+0.5).long()\n",
    "\n",
    "print('===============Points Generated============')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===============Initializing Network============\n",
      "Sequential(\n",
      "  (1): Linear(in_features=2, out_features=8, bias=True)\n",
      "  (2): ReLU()\n",
      "  (3): Linear(in_features=8, out_features=2, bias=True)\n",
      "  (4): ReLU()\n",
      "  (5): Linear(in_features=2, out_features=2, bias=True)\n",
      ")\n",
      "===============Training Network============\n",
      "error: tensor(0.5833)\n",
      "error: tensor(0.5833)\n",
      "error: tensor(0.4167)\n",
      "error: tensor(0.4167)\n",
      "error: tensor(0.4167)\n",
      "error: tensor(0.3333)\n",
      "error: tensor(0.2500)\n",
      "error: tensor(0.2500)\n",
      "error: tensor(0.2500)\n",
      "error: tensor(0.2500)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/torch/tensor.py:263: UserWarning: non-inplace resize is deprecated\n",
      "  warnings.warn(\"non-inplace resize is deprecated\")\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAHWCAYAAABNK0FcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XuQXGd55/Hf091zk2Z0nZEljS6WbclYsmVLHnxZczHYgO2l7NoQwN6iEsDB3iSQ1HLJksXlsMSkAmxINhXvYoU4BKoCcagNEYXArsQGL+CLJF9kS7JASJY0kkYaeUZ3zaW7n/1jRmI0mkvPTPc57zn9/VSparr7aObhMJqv37dP95i7CwAAhCMT9wAAAOB8xBkAgMAQZwAAAkOcAQAIDHEGACAwxBkAgMCMG2cze9TMDpvZq6M8bmb212a208y2mNna8o8JAED1KGXl/A1Jt43x+O2Slg/+uU/S/5n6WAAAVK9x4+zuT0vqGuOQuyR90wc8K2mWmS0o14AAAFSbcjzn3Cpp35Db7YP3AQCASchF+cXM7D4NbH1r+vTp177pTW867/FDR6OcBlFq6d8f9wgAEJlMy4Vr1M2bNx9x95ZS/n454rxf0uIhtxcN3ncBd18naZ0ktbW1+aZNm857/Kv/yvt8p9n9HQ/EPQIARGL6/V+84D4z21Pq3y/HtvZ6Sb81eNX2DZKOufvBMnxeAACq0rgrZzP7tqSbJTWbWbukP5FUI0nu/jVJGyTdIWmnpNOSPlKpYQEAqAbjxtnd7xnncZf0+2WbCKn1yPyH2NoGgBLwDmEAAASGOAMAEBjijEg9Mv+huEcAgOARZwAAAkOcETlWzwAwNuIMAEBgiDMAAIEhzogFW9sAMDriDABAYIgzAACBIc6IDVvbADAy4gwAQGCIM2LF6hkALkScAQAIDHEGACAwxBmxY2sbAM5HnAEACAxxBkpUzBfVd6JHhf5C3KMASLlc3AMA0sDW9v0dD8Q9xoj6TvRo9/e36tDGvZK7XFLz6oW65M6r1NA8Pe7xAKQQcQbG0He8R5u+9G/qP9krL/i5+ztfbFf39kNa++l3atpFTTFOCCCN2NYGxvCr772ivhPnh1mS5FK+p187vr05nsEApBpxRjBCu2q70FdQ5wv7pKKPfIBLJ17vUu/RM9EOBiD1iDMwiv5TvVLGxjzGcln1dJ2OaCIA1YI4IyghrZ5zDTUXbmcP44WiaprqIpoIQLUgzsAocvU1mn35vDGPaWiermktjRFNBKBaEGdgDJf+xmpl60Z+UUOmJqvlH1wT8UQAqgFxRnBC2tqePn+G1nzyZjUumaVMTVbZ+pwytVlNu6hJq3//LZp1WUvcIwJIIV7nDIyjsXWW2v7oVp3pPKme7tOqbarX9AUz4h4LQIoRZ6BEDS2NauD5ZQARYFsbQQppaxsAokacAQAIDHFGsFg9A6hWxBkAgMAQZwAAAkOcETS2tgFUI15KBQAIRqFoeu5ws35yYL56Clmtntuldy86oKbafNyjRYo4AwCC0NVTq08/82Z19dTpTGEgTy92ztE3d1ymB9te1rUtb8Q8YXTY1kbw2NoG0s9d+tzza9VxquFcmCWpt5hTTyGn/7HpanWcbohxwmgRZwBA7LYfnakDp6apMEqWCsWM/mX3koinig9xBgDE7oXOueotZEd9PO8ZPXuoen7RDHFGIrC1DaSbe3mOSQviDACI3TXN3arLFkZ9PGtFtXFBGBAeVs9Ael05p1stDT3KWHHEx3MZ1/su2RPxVPEhzgCA2JlJf3b9ZjXX96oh++vXNNdmCqrLFPTHa7aotfF0jBNGi9c5AwCCMK+hV4++46f62cGL9OT+gTchuXpul+5Yul+z6/riHi9SxBmJ8sj8h3R/xwNxjwGgQmoyrptbO3Rza0fco8SKbW0AAAJDnAEACAxxRuJw1TaAtCPOAAAEhjgDABAY4oxEYmsbQJoRZwAAAkOckVisngGkFXEGACAwxBkAgMAQZyQaW9sA0og4AwAQGOIMAEBgiDMSj61tAGlDnAEACAxxRiqwegaQJsQZAIDAEGcAAAJDnJEabG0DSAviDABAYIgzAACBIc5IFba2AaQBcQYAIDDEGQCAwBBnpA5b2wCSjjgDABAY4oxUYvUMIMmIMwAAgSHOAAAEhjgjtdjaBpBUxBkAgMAQZwAAAlNSnM3sNjPbYWY7zeyzIzy+xMyeMrMXzWyLmd1R/lGBiWNrG0ASjRtnM8tKeljS7ZJWSrrHzFYOO+wBSY+5+xpJd0v63+UeFACAalHKyvk6STvdfZe790n6jqS7hh3jkmYMfjxT0oHyjQhMDatnAEmTK+GYVkn7htxul3T9sGM+L+kJM/uEpOmSbi3LdAAAVKFyXRB2j6RvuPsiSXdI+paZXfC5zew+M9tkZps6OzvL9KUBAEiXUuK8X9LiIbcXDd431L2SHpMkd39GUr2k5uGfyN3XuXubu7e1tLRMbmJgEtjaBpAkpcR5o6TlZrbMzGo1cMHX+mHH7JV0iySZ2RUaiDNLYwAAJmHcOLt7XtLHJT0uabsGrsreamZfMLM7Bw/7lKSPmdnLkr4t6cPu7pUaGgCANCvlgjC5+wZJG4bd9+CQj7dJuqm8owHl9cj8h3R/xwNxjwEA4+IdwgAACAxxBgAgMMQZVYWrtgEkAXEGACAwxBlVh9UzgNARZwAAAkOcAQAIDHFGVWJrG0DIiDMAAIEhzgAABIY4o2qxtQ0gVMQZAIDAEGcAAAJDnFHV2NoGECLiDABAYIgzqh6rZwChIc4AAASGOAMAEBjiDIitbQBhIc4AAASGOAMAEBjiDAxiaxtAKIgzAACBIc7AEKyeAYSAOAMAEBjiDABAYIgzMAxb2wDiRpwBAAgMcQYAIDDEGRgBW9sA4pSLewAAyVXoy6vjuT068NNdyp/qU8O8Ri2+ZYXmrJwvM4t7PCCxiDOASek72asX/+JJ9R7rUbGvIEnqPXpGx1/vUvNVC3XFb18nyxBoYDLY1gZGwdb22F775kb1dJ0+F+azin0FHXnlgA7+fHdMkwHJR5wBTFjv0TM6+ovD8oKP+Hixr6C9T7wW8VRAehBnYAysnkd2Yl+3LDf2j4+e7tMq9hfGPAbAyIgzgAnL5LIlHGWyLD9igMngXw6ACZt5abNUHHlL+6xZy5u5IAyYJOIMjIOt7Qtla7NafOvlytSOvILO1GS17L2rIp4KSA/iDGBSlt5+hRbetEyWy8hqBn6UZOpyytRmdcWHr9PMS5pjnhBILl7nDGBSzEyXve8aLb7lch1+sV39J3s1bV6TWta0KlvLjxZgKvgXBJTgkfkP6f6OB+IeI0h1sxq0+B3L4x4DSBW2tQEACAxxBkrEhWEAokKcAQAIDHEGACAwxBmYALa2gepTcGlb90xt7pyrw2fqIvmaXK0NAMAonti3QF/ffrn6ihll5OovZrRy9lF9+ppX1dLQW7Gvy8oZAIAR/OD1RfqbV1fqWF+tzuRzOpWvUV8xqy1ds/WJn96gY301FfvaxBmYILa2gfTrLWS0bvsK9RYufIvaomd0si+n/7tracW+PnEGAGCYzZ1zxwxkv2f1o72tFfv6xBkAgGGO9tZqvN9Gfjpfucu2iDMwCWxtA+m2cPqZcQM5t54LwgAAiMzquV2qz42+dq7LFPQby/ZU7OsTZ2CSWD0D6ZUx6b+v3aK6bEGSn/dYbaagS2ac0G1L2iv39Sv2mQEASLDVc7v11f/wvNpajihrReWsqKaaPn3wst368o0bVZv18T/JJPEmJAAAjOKymSf0xetfVH/R1FfIqiGXV8Yq/3VZOQNTwNY2UB1qMq7pNdGEWSLOAAAEhzgDABAY4gxMEVvbAMqNOAMAEBjiDABAYIgzUAZsbQMoJ+IMAEBgiDNQJqyeAZQLcQYAIDDEGQCAwBBnoIzY2gZQDsQZAIDAEGcAAAJDnIEyY2sbwFQRZwAAAkOcgQpg9QxgKogzAACBIc4AAASGOAMVwtY2gMkizgAABIY4AwAQmJLibGa3mdkOM9tpZp8d5ZgPmNk2M9tqZv9Y3jGBZGJrG8Bk5MY7wMyykh6W9C5J7ZI2mtl6d9825Jjlkv5Y0k3u3m1m8yo1MAAAaVfKyvk6STvdfZe790n6jqS7hh3zMUkPu3u3JLn74fKOCQBA9Rh35SypVdK+IbfbJV0/7JgVkmRmP5OUlfR5d/9RWSYEEu6R+Q/p/o4H4h4DE7T3xHR9d9dSvXRkjrImvWXBId158V61NPTGPRqqQClxLvXzLJd0s6RFkp42s6vc/ejQg8zsPkn3SdKSJUvK9KUBoLye3D9ff/XyKuWLpsLgBuO/7F6i77++WF+8/gWtmnN0nM8ATE0p29r7JS0ecnvR4H1DtUta7+797r5b0i80EOvzuPs6d29z97aWlpbJzgwkDheGJUfH6Xr91cur1FvMnguzJPUXszpTyOmB59eop8ALXVBZpXyHbZS03MyWmVmtpLslrR92zPc0sGqWmTVrYJt7VxnnBIBIrN+9RAUf/fGim35yYH50A6EqjRtnd89L+rikxyVtl/SYu281sy+Y2Z2Dhz0u6Q0z2ybpKUmfcfc3KjU0AFTKlq7Zynt21Md7Cjm98sbsCCdCNSrpOWd33yBpw7D7HhzysUv65OAfACPgwrBkqMkUx3zc5KrPFiKaBtWKJ04AYIhbWg+qPpsf9fG6bEFvXXAowolQjYgzAAxxy6KDqs8WZLpwBZ2zolqnn9bqud0xTIZqQpyBCHHVdvgacgV99abnNa+hRw3ZvCRXRq66bF6XzjyuP79hs8zinhJpV67XOQNAarROP6NvvPOnevHIHG3tmq2sudrmHdHls47HPRqqBHEGIsaFYcmQMenali5d29IV9yioQmxrAwAQGOIMAEBgiDMQAy4MAzAW4gwAQGCIMwAAgSHOQEzY2gYwGuIMAEBgiDMAAIEhzkCM2NoGMBLiDABAYIgzEDNWzwCGI84AAASGOAMAEBjiDABAYIgzAACBIc4AAASGOAMxu7/jgbhHABAY4gwAQGCIMxAjVs0ARkKcAQAIDHEGYsKqGcBoiDMAAIEhzkAMWDUDGAtxRuzcXcViMe4xIkOYAYwnF/cAqF6njp/R9hf2qH1Xp4pFV119jS5dtVDLr1qsbI7/bgRQvYgzYnGs65R+8v2XlM8XJB+4r7enX6+9tE8H9ryht7/3amVz2XiHrABWzQBKwfIEsdj41Hbl+38d5rOKhaKOd5/Wzlf3xzMYAASAOCNyx944qVMnekZ9vFgoaufW9MWZVTOAUhFnRO7k8TMyszGP6T3Tr2LRxzwGANKKOCNyNXXjX+qQyWY0Tr8ThVUzgIkgzohc8/xZsswY5TVp0aUt466uASCtiDMil8mYVl9/ibLZkb/9crmsrlizJOKpKodVM4CJ4qVUiMXSFfMlSVue2yUffG7Z3TW9qUHXvfNNmt7UEOd4ABAr4ozYLF0xX4svu0hvdBxTX19ejTMaNHPO9LjHKitWzQAmgzgjVpmMqWXhrLjHAICg8JwzUCGsmgFMFnEGACAwxBmoAFbNAKaCOAMAEBjiDJQZq2YAU0WcAQAIDHEGyohVM4ByIM4AAASGOANlwqoZQLkQZ6AMCDOAciLOAAAEhjgDU8SqGUC5EWcAAAJDnIEpYNUMoBKIMwAAgSHOwCSxagZQKcQZAIDAEGdgElg1A6gk4gwAQGBycQ9Qad1HTqh9V6fyfQXNmdek1ktalMtl4x4LCcaqGUClpTbO+f6Cfv7Eq+ruPKFCvihJ2rvzkF5+5le64V0rNW/h7JgnBABgZKnd1n7uye3qOnT8XJglqZAvKt9f0DNPbNXJY2dinA5JxaoZQBRSGeeTx86o88BRFYs+4uPFQlG/fKU94qkAAChNKuN8qL1L0shhliR36cCeI9ENhFRg1QwgKqmMc7HoY7VZkuSjrKoBAIhbKuM896IZsoyNewxQKlbNAKKUyjjPbmnStMZ62Sh9zmYzWnH1kmiHQmIRZgBRS2WczUw3vedK1TXUKpsb8j/RBsK86s0Xs3JGSQgzgDik9nXO05rq9e73v1l7dx7Snl8cUiFf0OyWJi2/apFmzJ4e93gIHFEGEKfUxlmScjVZXXLFQl1yxcK4R0GCEGYAcUt1nIGJIMoAQpHK55yBiSLMAELCyhlVjSgDCBFxRlUiygBCRpxRVYgygCTgOWdUDcIMIClYOSP1iDKApCHOSC2iDCCpStrWNrPbzGyHme00s8+Ocdz7zMzNrK18IwITR5gBJNm4K2czy0p6WNK7JLVL2mhm691927DjmiT9oaTnKjEoUAqiDCANStnWvk7STnffJUlm9h1Jd0naNuy4P5X0JUmfKeuEQAmIMoA0KWVbu1XSviG32wfvO8fM1kpa7O4/KONsQEkIM4C0mfIFYWaWkfRVSR8u4dj7JN0nSUuW8PuUMTVEGUBalRLn/ZIWD7m9aPC+s5okXSnpx2YmSfMlrTezO91909BP5O7rJK2TpLa2Np/C3KhiRBlA2pWyrb1R0nIzW2ZmtZLulrT+7IPufszdm939Yne/WNKzki4IM1AOhBlANRh35ezueTP7uKTHJWUlPeruW83sC5I2ufv6sT8DMHVEGUA1Kek5Z3ffIGnDsPseHOXYm6c+FjCAKAOoRry3NoJFmAFUK96+E8EhygCqHXFGMIgyAAwgzogdUQaA8/GcM2JFmAHgQqycEQuiDACjI86IFFGurP6Tvep67bC8UNSMi+do2kVNcY8EYBKIMyJDmCunWCjql4+9qEPP7ZFlM3K5VHQ1Lp6tVffeoLqZDXGPCGACiDMqjihX3mvf2qgjLx9QMV+U8sVz9x9/vUsv/M8ndd0D71G2jn/uQFLwrxUVQ5SjcbrzpI68vF/F/uKFDxZd/af61PH8HrW+9dLohwMwKVytjYogzNE5vHmfvDj6L3kr9hV08Ge7I5wIwFSxckZZEeXoFU73yQtj/wbWfE9/RNMAKAfijLIgyvGZ3jpLmbqcir35UY9pXDgzwokATBXb2pgywhyvljWtsjEez9RmteiWFZHNA2DqWDlj0ohyGLK1Oa386PXa+vVnVcwXpCE73JnarBbedIlmXdoc34AAJow4Y8KIcnjmrlqgtZ96h/b8aLve2NohLxbV2DpLS97zJrVc3Rr3eAAmiDhjQghzuBoXzdKq37kx7jEAlAFxRkmIMgBEhzhjTEQZAKLH1doYFWEGgHiwcsYFiDIAxIs44xyiDABhIM4gygAQGJ5zrnKEGQDCw8q5ShFlAAgXca4yRBkAwse2dhUhzACQDKycqwBRBoBkIc4pRpQBIJnY1k4pwgwAycXKOWWIMgAkH3FOCaIMAOnBtnYKEGYASBdWzglGlAEgnYhzAhFlAEg3trUThjADQPqxck4IogwA1YM4By6pUe7pOq39T+9U1/ZDsoypZe0iLbxxmWoa6+IeDQCCR5wDldQoS1Lnlv3a/o3n5QWXF4qSpNMdJ7T3iR265g/epqbFs2OeEADCxnPOAUpymHu6T2v73z+vYl/hXJglqdhfUOFMv17+m/+nYn8hxgkBIHysnAOS5Ciftf/pX8ndR33c80V1vrRfF715SYRTAUCyEOcApCHKZ3W/dkieL476eKE3r+5fHCbOADAGtrVjlqYwS5Jlx/+WyuT4tgOAsbByjknaonzWvLWLdOrAMRX7Rn5eOVuXU/Pq1oinAqQT+7r1+g+3q2tbh7zoalw4Q0vfc4War2mVmcU9HnAe4hyxtEb5rPk3XKw9P9o+cpwzprpZDZp9+bzoB0NVO/LKAW179DkV8wVp8JKIk+3HtP1bG7XgV0e0/DeviXdAYBj2FyOU9jBLUs20Wl3zhzerpqlO2bpf/7dfpi6naS2NuvoP3ibLsEpBdAp9eW37++cGXiUw7FrFYl9BB3++W0d3dsYzHDAKVs4RqIYoD9XYOlM3PvQfdWTLAR3dcViWzah59QLNWjGP7UNE7vAL7dIY33bFvoL2PflLzbqsJbqhgHEQ5wqqtigPlclmNG/NIs1bsyjuUVDlTh84pmLv2K+tP3XgWETTAKVhW7tCqjnMQEiy02pl2bF3bHINNRFNA5SGlXOZEWUgLPOuXay9P9ouH/6E86BMbVYLbloW8VTA2IhzmRBlIEzTWhrVsmaROl/af+Fbx2ZMtU11mv/mpfEMB4yCOJcBYQbCdvmH2pRtqFHHM7sH3ijHJS8W1bR0jlZ99IbzXlkAhIDvyCkgykAyZLIZrfjAGi177yp17zgszxfVdPEcTWtpjHs0YETEeRKIMpBMNdNqeQUBEoE4TwBRBgBEgZdSlYgwAwCiwsp5HEQZABA14jwKogwAiAvb2iMgzACAOLFyHoIoAwBCQJxFlAEAYan6bW3CDAAITdWunIkyACBUVRdnogwACF1VbWsTZgBAElTFypkoAwCSJNVxJsoAgCRK7bY2YQYAJFXqVs5EGQCQdKmJM1EGAKRFKra1CTMAIE0SvXImygCANEpknIkyACDNEhXnJEfZ3WVmcY8BAEiAxMQ5iWHuO9GjvU/s0MFnXlehp18102u18K2XavEtK5RrqIl7PABAoIKPcxKjLEm9R89o85f/Xf2n+uSFoiSp/1Sf9v7bDh3evE/X/tEtBBoAMKJg45zUKJ+149ub1XeyRyqef7/ni+rpOqVd61/Rig+ujWc4AEDQgnwpVdLD3HeiR907Dl8Q5rO84Op4bo+K/YVoBwMAJEJJK2czu03S/5KUlfR1d//zYY9/UtLvSMpL6pT0UXffM9Fhkh7ls84cOaVMLqNCfpQ6D+o70av6OdMimgoAkBTjrpzNLCvpYUm3S1op6R4zWznssBcltbn7aknflfTlcg+aJLmGGnnRxzzGC65sfbDPKgAAYlTKtvZ1kna6+y5375P0HUl3DT3A3Z9y99ODN5+VtKi8YybLtIuaVDujfsxjZiybo5pptRFNBABIklLi3Cpp35Db7YP3jeZeST+cylBJZ2Za/v5rlKnJjvh4piarS//T6oinAgAkRVkvCDOzD0lqk/SVUR6/z8w2mdmmzs7Ocn7p4MxdtUArP3K9ahrrlK3LnftTN7tBq3/vLZqxdE7cIwIAAlXKk577JS0ecnvR4H3nMbNbJX1O0tvdvXekT+Tu6yStk6S2traxn5RNgebVCzX3ygU6urNTfcd7VT+nQTOWzeWdwgAAYyolzhslLTezZRqI8t2S/vPQA8xsjaRHJN3m7ofLPmWCWcY0e8W8uMcAACTIuNva7p6X9HFJj0vaLukxd99qZl8wszsHD/uKpEZJ/2xmL5nZ+opNDABAypX0Wh533yBpw7D7Hhzy8a1lngsAgKoV5DuEAQBQzYgzAACBIc4AAASGOAMAEBjiDABAYIgzAACBIc4AAASGOAMAEBjiDASg0JdX3/EeFQvFuEcBEICS3iEMQGWc6jiuXf/6irq2dQz8QpSMaf4NF2vZe1fx+76BKkacgZicaD+ql/7yxyr05iVJroFf1HbwZ7vUtbVD1/63Wwg0UKXY1gZi8to/PH8uzEN5wdV79Iz2/HB7DFMBCAFxBmJw6uBxnXnj1KiPe6Gogz/fLS+m/teeAxgBcQZi0NN1SpaxMY8p9hdU6C9ENBGAkBBnIAY1jXXSeKvijClbk41mIABBIc5ADJqWzFZurIu9TJq3dvG4q2sA6UScgRiYmVbcs1aZkVbGJuXqa7TsvSujHwxAEIgzEJO5qxboyo/dqPq505SpzSpbX6NMTUYzl83V2s+8U/Vzpsc9IoCY8DpnIEZzVs7X9Z+/XacOHlf/yV41NDeqfs60uMcCEDPiDMTMzNS4cGbcYwAICNvaAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABCYkuJsZreZ2Q4z22lmnx3h8Toz+6fBx58zs4vLPSgAANVi3DibWVbSw5Jul7RS0j1mtnLYYfdK6nb3yyT9paQvlXtQAACqRSkr5+sk7XT3Xe7eJ+k7ku4adsxdkv5h8OPvSrrFzKx8YwIAUD1KiXOrpH1DbrcP3jfiMe6el3RM0txyDAgAQLXJRfnFzOw+SfcN3jxpZjtGOKxZ0pHopqoqnNvK4dxWDue2cji3lfJf/mykc7u01L9eSpz3S1o85PaiwftGOqbdzHKSZkp6Y/gncvd1ktaN9cXMbJO7t5UwFyaIc1s5nNvK4dxWDue2cqZ6bkvZ1t4oabmZLTOzWkl3S1o/7Jj1kn578OPflPSku/tkhwIAoJqNu3J297yZfVzS45Kykh51961m9gVJm9x9vaS/k/QtM9spqUsDAQcAAJNQ0nPO7r5B0oZh9z045OMeSe8v00xjbntjSji3lcO5rRzObeVwbitnSufW2H0GACAsvH0nAACBiS3OvCVoZZRwXj9pZtvMbIuZ/buZlXxpP8Y/v0OOe5+ZuZlxJWwJSjmvZvaBwe/drWb2j1HPmFQl/ExYYmZPmdmLgz8X7ohjziQys0fN7LCZvTrK42Zmfz147reY2dqSP7m7R/5HAxeW/UrSJZJqJb0saeWwY35P0tcGP75b0j/FMWuS/pR4Xt8hadrgx7/LeS3v+R08rknS05KeldQW99yh/ynx+3a5pBclzR68PS/uuZPwp8Rzu07S7w5+vFLS63HPnZQ/kt4maa2kV0d5/A5JP5Rkkm6Q9FypnzuulTNvCVoZ455Xd3/K3U8P3nxWA69bR2lK+b6VpD/VwPvL90Q5XIKVcl4/Julhd++WJHc/HPG/vCd6AAACF0lEQVSMSVXKuXVJMwY/ninpQITzJZq7P62BVyiN5i5J3/QBz0qaZWYLSvncccWZtwStjFLO61D3auC/6lCacc/v4LbVYnf/QZSDJVwp37crJK0ws5+Z2bNmdltk0yVbKef285I+ZGbtGnhVzieiGa0qTPRn8jmRvn0nwmFmH5LUJuntcc+SFmaWkfRVSR+OeZQ0ymlga/tmDez2PG1mV7n70VinSod7JH3D3f/CzG7UwHtWXOnuxbgHq2ZxrZwn8pagGustQXGeUs6rzOxWSZ+TdKe790Y0WxqMd36bJF0p6cdm9roGnmNaz0Vh4yrl+7Zd0np373f33ZJ+oYFYY2ylnNt7JT0mSe7+jKR6DbznNqaupJ/JI4krzrwlaGWMe17NbI2kRzQQZp63m5gxz6+7H3P3Zne/2N0v1sBz+ne6+6Z4xk2MUn4efE8Dq2aZWbMGtrl3RTlkQpVybvdKukWSzOwKDcS5M9Ip02u9pN8avGr7BknH3P1gKX8xlm1t5y1BK6LE8/oVSY2S/nnw+rq97n5nbEMnSInnFxNU4nl9XNK7zWybpIKkz7g7O2njKPHcfkrS35rZf9XAxWEfZiFUGjP7tgb+o7F58Dn7P5FUI0nu/jUNPId/h6Sdkk5L+kjJn5v/DwAACAvvEAYAQGCIMwAAgSHOAAAEhjgDABAY4gwAQGCIMwAAgSHOAAAEhjgDABCY/w+LYT0taeIMDQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===============Collecting Polytopes============\n"
     ]
    }
   ],
   "source": [
    "# ==================================\n",
    "# Initialize Network\n",
    "# ==================================\n",
    "\n",
    "print('===============Initializing Network============')\n",
    "layer_sizes = [2, 8, 2, 2]\n",
    "network = PLNN(layer_sizes)\n",
    "net = network.net\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Train Network\n",
    "# ==================================\n",
    "\n",
    "print('===============Training Network============')\n",
    "opt = optim.Adam(net.parameters(), lr=1e-3)\n",
    "for i in range(1000):\n",
    "    out = net(Variable(X))\n",
    "    l = nn.CrossEntropyLoss()(out, Variable(y))\n",
    "    err = (out.max(1)[1].data != y).float().mean()\n",
    "    if i % 100 == 0:\n",
    "        print('error:', err)\n",
    "    opt.zero_grad()\n",
    "    (l).backward()\n",
    "    opt.step()\n",
    "\n",
    "# ==================================\n",
    "# Visualize:  classifier boundary\n",
    "# ==================================\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))\n",
    "X0 = Variable(torch.Tensor(np.stack([np.ravel(XX), np.ravel(YY)]).T))\n",
    "y0 = network(X0)\n",
    "ZZ = (y0[:,0] - y0[:,1]).resize(100, 100).data.numpy()\n",
    "\n",
    "_, ax = plt.subplots(figsize=(8,8))\n",
    "ax.contourf(XX,YY,-ZZ, cmap=\"coolwarm\", levels=np.linspace(-1000,1000,3))\n",
    "ax.scatter(X.numpy()[:,0], X.numpy()[:,1], c=y.numpy(), cmap=\"coolwarm\", s=70)\n",
    "ax.axis(\"equal\")\n",
    "ax.axis([0, 1, 0, 1])\n",
    "\n",
    "plt.show()\n",
    "\n",
    "# ==================================\n",
    "# Visualize: baseline classifier ReLu regions\n",
    "# ==================================\n",
    "\n",
    "print('===============Collecting Polytopes============')\n",
    "num_pts = 200\n",
    "xylim = 1.0\n",
    "\n",
    "unique_relu_configs_list, unique_bin_acts, _, _ = utils.get_unique_relu_configs(network, xylim, num_pts)\n",
    "print('number of polytopes:', len(unique_bin_acts))\n",
    "color_dict = utils.get_color_dictionary(unique_bin_acts)\n",
    "polytope_list = []\n",
    "\n",
    "\n",
    "for relu_configs, unique_act in zip(unique_relu_configs_list, unique_bin_acts):\n",
    "    polytope_dict = network.compute_polytope_config(relu_configs, True)\n",
    "    polytope = from_polytope_dict(polytope_dict)\n",
    "    polytope_list.append(polytope)\n",
    "    # colors.append(color_dict[unique_act])\n",
    "colors = utils.get_spaced_colors(200)[0:len(polytope_list)]\n",
    "x_0 = torch.Tensor([[0.3], [0.5]])\n",
    "\n",
    "print('===============Finding Classification Boundary Facets============')\n",
    "\n",
    "true_label = int(network(x_0).max(1)[1].item())  # what the classifier outputs\n",
    "\n",
    "adversarial_facets = []\n",
    "for polytope in polytope_list:\n",
    "    polytope_adv_constraints = network.make_adversarial_constraints(polytope.config,\n",
    "                                                                    true_label)\n",
    "\n",
    "    for facet in polytope_adv_constraints:\n",
    "        adversarial_facets.append(facet)\n",
    "\n",
    "\n",
    "# ------------------------------\n",
    "# Plot Polytopes, boundary, and lp norm\n",
    "# ------------------------------\n",
    "\n",
    "ax = plt.gca()\n",
    "alpha = 0.6\n",
    "xylim = 1.0\n",
    "\n",
    "utils.plot_polytopes_2d(polytope_list, colors, alpha, xylim, ax)\n",
    "utils.plot_facets_2d(adversarial_facets, xylim=xylim, ax=ax, color='black', linestyle='dashed')\n",
    "plt.xlim(0.0, 1.0)\n",
    "plt.ylim(0.0, 1.0)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Find Projections\n",
    "# ==================================\n",
    "\n",
    "lp_norm = 'l_2'\n",
    "ts = []\n",
    "pts = x\n",
    "\n",
    "for pt in pts:\n",
    "    print('===============Finding Projection============')\n",
    "    print('lp_norm: ', lp_norm)\n",
    "    x_0 = torch.Tensor(pt.reshape([2, 1]))\n",
    "    print(x_0)\n",
    "    print('from point: ')\n",
    "    print(x_0)\n",
    "\n",
    "    ax = plt.axes()\n",
    "    cwd = os.getcwd()\n",
    "    print(cwd)\n",
    "    plot_dir = cwd + '/plots/incremental_geocert/'\n",
    "\n",
    "    t = incremental_GeoCert(lp_norm, network, x_0, ax, plot_dir)\n",
    "\n",
    "    print('the final projection value:', t)\n",
    "    ts.append(t)\n",
    "\n",
    "# ==================================\n",
    "# Visualize: incremental geocert projections\n",
    "# ==================================\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))\n",
    "X0 = Variable(torch.Tensor(np.stack([np.ravel(XX), np.ravel(YY)]).T))\n",
    "y0 = network(X0)\n",
    "ZZ = (y0[:,0] - y0[:,1]).resize(100,100).data.numpy()\n",
    "\n",
    "_, ax = plt.subplots(figsize=(8,8))\n",
    "ax.contourf(XX,YY,-ZZ, cmap=\"coolwarm\", levels=np.linspace(-1000,1000,3))\n",
    "ax.scatter(X.numpy()[:,0], X.numpy()[:,1], c=y.numpy(), cmap=\"coolwarm\", s=70)\n",
    "ax.axis(\"equal\")\n",
    "ax.axis([0, 1, 0, 1])\n",
    "\n",
    "for pt, t, y in zip(pts, ts, y.numpy()):\n",
    "    if lp_norm == 'l_2':\n",
    "        utils.plot_l2_norm(pt, t, ax=ax)\n",
    "    elif lp_norm == 'l_inf':\n",
    "        utils.plot_linf_norm(pt, t, ax=ax)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "print('average_linf:', np.average(ts))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply incremental geocert to a normal and l1-regularized classifier. Finds maximal l_p balls\n",
    "# for random points in R^2.\n",
    "\n",
    "# ==================================\n",
    "# Generate Training Points\n",
    "# ==================================\n",
    "\n",
    "print('===============Generating Training Points============')\n",
    "# random points at least 2r apart\n",
    "m = 12\n",
    "np.random.seed(3)\n",
    "x = [np.random.uniform(size=(2))]\n",
    "r = 0.16\n",
    "while(len(x) < m):\n",
    "    p = np.random.uniform(size=(2))\n",
    "    if min(np.abs(p-a).sum() for a in x) > 2*r:\n",
    "        x.append(p)\n",
    "# r = 0.145\n",
    "epsilon = r/2\n",
    "\n",
    "X = torch.Tensor(np.array(x))\n",
    "torch.manual_seed(1)\n",
    "y = (torch.rand(m)+0.5).long()\n",
    "print('===============Points Generated============')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==================================\n",
    "# Initialize Network\n",
    "# ==================================\n",
    "\n",
    "print('===============Initializing Network============')\n",
    "network = PLNN(layer_sizes)\n",
    "net = network.net\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Train Network\n",
    "# ==================================\n",
    "\n",
    "def l1_loss(net):\n",
    "\n",
    "    return sum([_.norm(p=1) for _ in net.parameters() if _.dim() > 1])\n",
    "\n",
    "\n",
    "print('===============Training Network with Regularization============')\n",
    "opt = optim.Adam(net.parameters(), lr=1e-3)\n",
    "for i in range(1000):\n",
    "    out = net(Variable(X))\n",
    "    l = nn.CrossEntropyLoss()(out, Variable(y)).view([1])\n",
    "\n",
    "    l1_scale = torch.Tensor([1e-4])\n",
    "    l += l1_scale * l1_loss(net).view([1])\n",
    "\n",
    "    err = (out.max(1)[1].data != y).float().mean()\n",
    "    opt.zero_grad()\n",
    "    (l).backward()\n",
    "    opt.step()\n",
    "\n",
    "print('error: ', err)\n",
    "\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Visualize:  regularized classifier boundary\n",
    "# ==================================\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))\n",
    "X0 = Variable(torch.Tensor(np.stack([np.ravel(XX), np.ravel(YY)]).T))\n",
    "y0 = network(X0)\n",
    "ZZ = (y0[:,0] - y0[:,1]).resize(100, 100).data.numpy()\n",
    "\n",
    "_, ax = plt.subplots(figsize=(8,8))\n",
    "ax.contourf(XX,YY,-ZZ, cmap=\"coolwarm\", levels=np.linspace(-1000,1000,3))\n",
    "ax.scatter(X.numpy()[:,0], X.numpy()[:,1], c=y.numpy(), cmap=\"coolwarm\", s=70)\n",
    "ax.axis(\"equal\")\n",
    "ax.axis([0, 1, 0, 1])\n",
    "\n",
    "plt.show()\n",
    "\n",
    "# ==================================\n",
    "# Visualize: regularized classifier ReLu regions\n",
    "# ==================================\n",
    "\n",
    "print('===============Collecting Polytopes============')\n",
    "num_pts = 200\n",
    "xylim = 1.0\n",
    "\n",
    "unique_relu_configs_list, unique_bin_acts, _, _ = utils.get_unique_relu_configs(network, xylim, num_pts)\n",
    "print('number of polytopes:', len(unique_bin_acts))\n",
    "color_dict = utils.get_color_dictionary(unique_bin_acts)\n",
    "polytope_list = []\n",
    "\n",
    "\n",
    "for relu_configs, unique_act in zip(unique_relu_configs_list, unique_bin_acts):\n",
    "    polytope_dict = network.compute_polytope_config(relu_configs, True)\n",
    "    polytope = Polytope.from_polytope_dict(polytope_dict)\n",
    "    polytope_list.append(polytope)\n",
    "    # colors.append(color_dict[unique_act])\n",
    "colors = utils.get_spaced_colors(200)[0:len(polytope_list)]\n",
    "x_0 = torch.Tensor([[0.3], [0.5]])\n",
    "\n",
    "print('===============Finding Classification Boundary Facets============')\n",
    "\n",
    "true_label = int(network(x_0).max(1)[1].item())  # what the classifier outputs\n",
    "\n",
    "adversarial_facets = []\n",
    "for polytope in polytope_list:\n",
    "    polytope_adv_constraints = network.make_adversarial_constraints(polytope.config,\n",
    "                                                                    true_label)\n",
    "\n",
    "    for facet in polytope_adv_constraints:\n",
    "        adversarial_facets.append(facet)\n",
    "\n",
    "\n",
    "# ------------------------------\n",
    "# Plot Polytopes, boundary, and lp norm\n",
    "# ------------------------------\n",
    "\n",
    "ax = plt.gca()\n",
    "alpha = 0.6\n",
    "xylim = 1.0\n",
    "\n",
    "utils.plot_polytopes_2d(polytope_list, colors, alpha, xylim, ax)\n",
    "utils.plot_facets_2d(adversarial_facets, xylim=xylim, ax=ax, color='black', linestyle='dashed')\n",
    "plt.xlim(0.0, 1.0)\n",
    "plt.ylim(0.0, 1.0)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Find Projections\n",
    "# ==================================\n",
    "\n",
    "lp_norm = 'l_2'\n",
    "ts = []\n",
    "pts = x\n",
    "\n",
    "for pt in pts:\n",
    "    print('===============Finding Projection============')\n",
    "    print('lp_norm: ', lp_norm)\n",
    "    x_0 = torch.Tensor(pt.reshape([2, 1]))\n",
    "    print(x_0)\n",
    "    print('from point: ')\n",
    "    print(x_0)\n",
    "\n",
    "    ax = plt.axes()\n",
    "    cwd = os.getcwd()\n",
    "    print(cwd)\n",
    "    plot_dir = cwd + '/plots/incremental_geocert/'\n",
    "\n",
    "    t = incremental_GeoCert(lp_norm, network, x_0, ax, plot_dir)\n",
    "\n",
    "    print('the final projection value:', t)\n",
    "    ts.append(t)\n",
    "\n",
    "# ==================================\n",
    "# Visualize: incremental geocert projections\n",
    "# ==================================\n",
    "\n",
    "XX, YY = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))\n",
    "X0 = Variable(torch.Tensor(np.stack([np.ravel(XX), np.ravel(YY)]).T))\n",
    "y0 = network(X0)\n",
    "ZZ = (y0[:,0] - y0[:,1]).resize(100,100).data.numpy()\n",
    "\n",
    "_, ax = plt.subplots(figsize=(8,8))\n",
    "ax.contourf(XX,YY,-ZZ, cmap=\"coolwarm\", levels=np.linspace(-1000,1000,3))\n",
    "ax.scatter(X.numpy()[:,0], X.numpy()[:,1], c=y.numpy(), cmap=\"coolwarm\", s=70)\n",
    "ax.axis(\"equal\")\n",
    "ax.axis([0, 1, 0, 1])\n",
    "\n",
    "for pt, t, y in zip(pts, ts, y.numpy()):\n",
    "    if lp_norm == 'l_2':\n",
    "        utils.plot_l2_norm(pt, t, ax=ax)\n",
    "    elif lp_norm == 'l_inf':\n",
    "        utils.plot_linf_norm(pt, t, ax=ax)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "print('average_linf:', np.average(ts))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3",
   "language": "python",
   "name": "py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
