{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==================================\n",
    "# Experiment 1\n",
    "# ==================================\n",
    "\n",
    "# demonstration to find the maximal l_p ball at a point x_0, within which,\n",
    "# the class label of a random classifier is equal to C(x_0)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../mister_ed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "\n",
    "from geocert_oop import BatchGeocert, IncrementalGeoCert\n",
    "from plnn import PLNN\n",
    "from _polytope_ import Polytope\n",
    "import utilities as utils\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "# from convex_adversarial import robust_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===============Initializing Network============\n",
      "Sequential(\n",
      "  (1): Linear(in_features=2, out_features=8, bias=True)\n",
      "  (2): ReLU()\n",
      "  (3): Linear(in_features=8, out_features=8, bias=True)\n",
      "  (4): ReLU()\n",
      "  (5): Linear(in_features=8, out_features=2, bias=True)\n",
      ")\n",
      "===============Finding Projection============\n",
      "lp_norm:  l_2\n",
      "from point: \n",
      "tensor([[0.],\n",
      "        [0.]])\n",
      "/Users/jordanm/grad/plswork/geometric-certificates/experiments\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'IncrementalGeocert' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-806bdbb4c6a6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0mplot_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcwd\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'/plots/incremental_geocert/'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m t = IncrementalGeocert(network, verbose=True, display=True, \n\u001b[0m\u001b[1;32m     32\u001b[0m                        save_dir=plot_dir, ax=ax)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'IncrementalGeocert' is not defined"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAADU9JREFUeJzt3GGI5Hd9x/H3xztTaYym9FaQu9Ok9NJ42ELSJU0Raoq2XPLg7oFF7iBYJXhgGylVhBRLlPjIhloQrtWTilXQGH0gC57cA40ExAu3ITV4FyLb03oXhawxzZOgMe23D2bSna53mX92Z3cv+32/4GD+//ntzJcfe++dndmZVBWSpO3vFVs9gCRpcxh8SWrC4EtSEwZfkpow+JLUhMGXpCamBj/JZ5M8meT7l7g+ST6ZZCnJo0lunP2YkqT1GvII/3PAgRe5/lZg3/jfUeBf1j+WJGnWpga/qh4Efv4iSw4Bn6+RU8DVSV4/qwElSbOxcwa3sRs4P3F8YXzup6sXJjnK6LcArrzyyj+8/vrrZ3D3ktTHww8//LOqmlvL184i+INV1XHgOMD8/HwtLi5u5t1L0stekv9c69fO4q90ngD2ThzvGZ+TJF1GZhH8BeBd47/WuRl4pqp+7ekcSdLWmvqUTpIvAbcAu5JcAD4CvBKgqj4FnABuA5aAZ4H3bNSwkqS1mxr8qjoy5foC/npmE0mSNoTvtJWkJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJamJQcFPciDJ40mWktx1kevfkOSBJI8keTTJbbMfVZK0HlODn2QHcAy4FdgPHEmyf9Wyvwfur6obgMPAP896UEnS+gx5hH8TsFRV56rqOeA+4NCqNQW8Znz5tcBPZjeiJGkWhgR/N3B+4vjC+NykjwK3J7kAnADef7EbSnI0yWKSxeXl5TWMK0laq1m9aHsE+FxV7QFuA76Q5Nduu6qOV9V8Vc3Pzc3N6K4lSUMMCf4TwN6J4z3jc5PuAO4HqKrvAq8Cds1iQEnSbAwJ/mlgX5Jrk1zB6EXZhVVrfgy8DSDJmxgF3+dsJOkyMjX4VfU8cCdwEniM0V/jnElyT5KD42UfBN6b5HvAl4B3V1Vt1NCSpJdu55BFVXWC0Yuxk+funrh8FnjLbEeTJM2S77SVpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDUxKPhJDiR5PMlSkrsuseadSc4mOZPki7MdU5K0XjunLUiyAzgG/BlwATidZKGqzk6s2Qf8HfCWqno6yes2amBJ0toMeYR/E7BUVeeq6jngPuDQqjXvBY5V1dMAVfXkbMeUJK3XkODvBs5PHF8Yn5t0HXBdku8kOZXkwMVuKMnRJItJFpeXl9c2sSRpTWb1ou1OYB9wC3AE+EySq1cvqqrjVTVfVfNzc3MzumtJ0hBDgv8EsHfieM/43KQLwEJV/aqqfgj8gNEPAEnSZWJI8E8D+5Jcm+QK4DCwsGrN1xg9uifJLkZP8Zyb4ZySpHWaGvyqeh64EzgJPAbcX1VnktyT5OB42UngqSRngQeAD1XVUxs1tCTppUtVbckdz8/P1+Li4pbctyS9XCV5uKrm1/K1vtNWkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgYFP8mBJI8nWUpy14use0eSSjI/uxElSbMwNfhJdgDHgFuB/cCRJPsvsu4q4G+Ah2Y9pCRp/YY8wr8JWKqqc1X1HHAfcOgi6z4GfBz4xQznkyTNyJDg7wbOTxxfGJ/7P0luBPZW1ddf7IaSHE2ymGRxeXn5JQ8rSVq7db9om+QVwCeAD05bW1XHq2q+qubn5ubWe9eSpJdgSPCfAPZOHO8Zn3vBVcCbgW8n+RFwM7DgC7eSdHkZEvzTwL4k1ya5AjgMLLxwZVU9U1W7quqaqroGOAUcrKrFDZlYkrQmU4NfVc8DdwIngceA+6vqTJJ7khzc6AElSbOxc8iiqjoBnFh17u5LrL1l/WNJkmbNd9pKUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpoYFPwkB5I8nmQpyV0Xuf4DSc4meTTJN5O8cfajSpLWY2rwk+wAjgG3AvuBI0n2r1r2CDBfVX8AfBX4h1kPKklanyGP8G8ClqrqXFU9B9wHHJpcUFUPVNWz48NTwJ7ZjilJWq8hwd8NnJ84vjA+dyl3AN+42BVJjiZZTLK4vLw8fEpJ0rrN9EXbJLcD88C9F7u+qo5X1XxVzc/Nzc3yriVJU+wcsOYJYO/E8Z7xuf8nyduBDwNvrapfzmY8SdKsDHmEfxrYl+TaJFcAh4GFyQVJbgA+DRysqidnP6Ykab2mBr+qngfuBE4CjwH3V9WZJPckOThedi/wauArSf49ycIlbk6StEWGPKVDVZ0ATqw6d/fE5bfPeC5J0oz5TltJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaMPiS1ITBl6QmDL4kNWHwJakJgy9JTRh8SWrC4EtSEwZfkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJasLgS1ITBl+SmjD4ktSEwZekJgy+JDVh8CWpCYMvSU0YfElqwuBLUhMGX5KaGBT8JAeSPJ5kKcldF7n+N5J8eXz9Q0mumfWgkqT1mRr8JDuAY8CtwH7gSJL9q5bdATxdVb8L/BPw8VkPKklanyGP8G8ClqrqXFU9B9wHHFq15hDwb+PLXwXeliSzG1OStF47B6zZDZyfOL4A/NGl1lTV80meAX4b+NnkoiRHgaPjw18m+f5aht6GdrFqrxpzL1a4FyvcixW/t9YvHBL8mamq48BxgCSLVTW/mfd/uXIvVrgXK9yLFe7FiiSLa/3aIU/pPAHsnTjeMz530TVJdgKvBZ5a61CSpNkbEvzTwL4k1ya5AjgMLKxaswD85fjyXwDfqqqa3ZiSpPWa+pTO+Dn5O4GTwA7gs1V1Jsk9wGJVLQD/CnwhyRLwc0Y/FKY5vo65txv3YoV7scK9WOFerFjzXsQH4pLUg++0laQmDL4kNbHhwfdjGVYM2IsPJDmb5NEk30zyxq2YczNM24uJde9IUkm27Z/kDdmLJO8cf2+cSfLFzZ5xswz4P/KGJA8keWT8/+S2rZhzoyX5bJInL/VepYx8crxPjya5cdANV9WG/WP0Iu9/AL8DXAF8D9i/as1fAZ8aXz4MfHkjZ9qqfwP34k+B3xxffl/nvRivuwp4EDgFzG/13Fv4fbEPeAT4rfHx67Z67i3ci+PA+8aX9wM/2uq5N2gv/gS4Efj+Ja6/DfgGEOBm4KEht7vRj/D9WIYVU/eiqh6oqmfHh6cYvedhOxryfQHwMUafy/SLzRxukw3Zi/cCx6rqaYCqenKTZ9wsQ/aigNeML78W+MkmzrdpqupBRn/xeCmHgM/XyCng6iSvn3a7Gx38i30sw+5Lramq54EXPpZhuxmyF5PuYPQTfDuauhfjX1H3VtXXN3OwLTDk++I64Lok30lyKsmBTZtucw3Zi48Ctye5AJwA3r85o112XmpPgE3+aAUNk+R2YB5461bPshWSvAL4BPDuLR7lcrGT0dM6tzD6re/BJL9fVf+1pVNtjSPA56rqH5P8MaP3/7y5qv5nqwd7OdjoR/h+LMOKIXtBkrcDHwYOVtUvN2m2zTZtL64C3gx8O8mPGD1HubBNX7gd8n1xAVioql9V1Q+BHzD6AbDdDNmLO4D7Aarqu8CrGH2wWjeDerLaRgffj2VYMXUvktwAfJpR7Lfr87QwZS+q6pmq2lVV11TVNYxezzhYVWv+0KjL2JD/I19j9OieJLsYPcVzbjOH3CRD9uLHwNsAkryJUfCXN3XKy8MC8K7xX+vcDDxTVT+d9kUb+pRObdzHMrzsDNyLe4FXA18Zv27946o6uGVDb5CBe9HCwL04Cfx5krPAfwMfqqpt91vwwL34IPCZJH/L6AXcd2/HB4hJvsToh/yu8esVHwFeCVBVn2L0+sVtwBLwLPCeQbe7DfdKknQRvtNWkpow+JLUhMGXpCYMviQ1YfAlqQmDL0lNGHxJauJ/Acz2XLpusNoKAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fun incremental 'GeoCert' from the point [0,0]. Network is simple ReLu net with\n",
    "# Gaussian random weights. Finds maximal l_p ball within which the class label\n",
    "# remains the same.\n",
    "\n",
    "\n",
    "# ==================================\n",
    "# Initialize Network\n",
    "# ==================================\n",
    "\n",
    "print('===============Initializing Network============')\n",
    "layer_sizes = [2, 8, 8, 2]\n",
    "network = PLNN(layer_sizes)\n",
    "\n",
    "# ==================================\n",
    "# Find Projection\n",
    "# ==================================\n",
    "\n",
    "lp_norm = 'l_2'\n",
    "\n",
    "print('===============Finding Projection============')\n",
    "print('lp_norm: ', lp_norm)\n",
    "x_0 = torch.Tensor([[0.0], [0.0]])\n",
    "print('from point: ')\n",
    "print(x_0)\n",
    "\n",
    "ax = plt.axes()\n",
    "cwd = os.getcwd()\n",
    "print(cwd)\n",
    "plot_dir = cwd + '/plots/incremental_geocert/'\n",
    "\n",
    "t = IncrementalGeoCert(network, verbose=True, display=True, \n",
    "                       save_dir=plot_dir, ax=ax)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('plots can be found at:', plot_dir)\n",
    "print('the final projection value:', t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
