{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import lipMIP as lm\n",
    "import neural_nets.data_loaders as dl\n",
    "import neural_nets.train as nn_train\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from relu_nets import ReLUNet\n",
    "import lipMIP as lm\n",
    "import torch\n",
    "from hyperbox import Hyperbox\n",
    "from pre_activation_bounds import PreactivationBounds\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from other_methods.lip_sdp import LipSDPCustom\n",
    "from other_methods.fast_lip import FastLip\n",
    "from other_methods.naive_methods import NaiveUB, RandomLB\n",
    "from other_methods.clever import CLEVER\n",
    "from other_methods.seq_lip import SeqLip\n",
    "from other_methods.lip_lp import LipLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAHSCAYAAADfUaMwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXRb5Z3/8c9Xki0v2WNnD3EIYQmErWYrW8pWlgIdtgkMS1uWKW1oh25D2ym0dJnpMN2YgRZaWiilMECBCZBAKWUp/RGKQ0ggCdlD9sTZ43iRLT2/P+yCEyuJEtv6SvL7dY7PkZ57sT7nQdHH97lXkoUQBAAA/ES8AwAA0NNRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAs5jXA1dUVISqqiqvhwcAIKumT5++PoRQmW6bWxlXVVWppqbG6+EBAMgqM3t/V9tYpgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMBZzDtAT7ZqTYMefmK55izYpjGjynXFxSNVNbLcOxYAIMsoYyeLltbpxq+9raZEUsmktGBxnf78Wq1+9J3DdcShfb3jAQCyiGVqJ3f+apHqG1qLWJJSKamxKaUf/Xy+bzAAQNZRxk7embs17fiSZfVKNKeynAYA4IkydtKrLJp2vLgooljUspwGAOCJMnZyyfnDFY/vOP3FxRGdd+YQRSKUMQD0JJSxk3+6eD+ddeogFReZysuiKi4yffSYAZp07RjvaACALLMQgssDV1dXh5qaGpfHziUbNyW0bGW9hg8tVeXAuHccAEA3MbPpIYTqdNt4a5OzAf2LNaB/sXcMAIAjlqkBAHDGkTEA5LD1G5s0c/YW9S6P6egj+vNuiwJFGQNAjrrvoaV66A/LFIu1LmLG4xH97HtHaP9RfGxuoWGZGgBy0JszNurhJ5cr0RxU35BUfUNSmzY368u3zVIq5XPhLboPZQwAOejJKavU2NTx0/i21yc1Z/42h0ToTpQxAOSg7Q3JtONmUkNj+m3IX5QxAOSg006qVEm840t0Mhk0/uA+DonQnShjAMhB55w+RPtXlau0pPVlOhJpvYDrK58bq5KS9J9tj/zF1dQAkIOKiyK6+z+O1J9fq9Vrb2xQ/75FuuDsoRpT1cs7GroBZQwAOSoWi+isCYN11oTB3lHQzVimBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4yKmMzO9vM5pnZQjO7Jc32/czsJTObYWazzOzcro8KAEBh2mMZm1lU0l2SzpE0TtLlZjZup93+TdKjIYSjJE2UdHdXBwUAoFBlcmR8rKSFIYTFIYSEpEckXbjTPkFSn7bbfSWt6rqIAAAUtkzKeLik5e3ur2gba+/bkq40sxWSpki6Kd0vMrMbzKzGzGpqa2v3IS4AAIWnqy7gulzS/SGEEZLOlfSgmXX43SGEe0MI1SGE6srKyi56aAAA8lsmZbxS0sh290e0jbV3raRHJSmE8LqkEkkVXREQAIBCl0kZvylprJmNNrNitV6gNXmnfZZJOl2SzOwQtZYx69AAAGRgj2UcQmiRNEnS85LmqvWq6dlmdruZXdC225clXW9mMyU9LOlTIYTQXaEBACgksUx2CiFMUeuFWe3Hbm13e46kE7s2GgAAPQOfwAUAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGULA2bEpo9dpGhRC8owC7FfMOAABdbc26Rt32n3M0f3GdImbq369It375EB0+rq93NCAtjowBFJRkMmjS19/W3AXb1Nwc1JRIac26Jn35tlmq3dDkHQ9IizIGUFDemrVJW7e1KJXacbwlGfTMH1f7hAL2gDIGUFBqNyTSniNubg5atabRIRGwZ5QxgIJyyIG9OxwVS1JpSURHje+X/UBABihjAAVl9H7l+ugxAxSPf/jyVhQzVQyI6/STKx2TAbvG1dQACs63vzpOT0xZqf+bulpNiZQ+9tEKXXXZKMXjUe9oQFqUMYCCE42aLj1/hC49f4R3FCAjLFMDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAICCldi4Weuef1Vb3pqd01+lyfuMAQAFacG//1wLf/BzRYqLFJJJlYwcpuOm3KfSkUO9o3XAkTEAoOCsm/qKFv3HPUo1Nqlla52S2xtUv2CJaj75We9oaVHGAICCs+TOB5Ssb9hhLCRT2r5wqermLXZKtWuUMQCg4CQ2bk47brGYmjdtzXKaPaOMAQAFZ8iFZyhSEu8wHlJBfY48xCHR7lHGAICCU/X5q1QyfLAipSWtA5GIomUlOuxn31I0TUl742pqAEDBKerbWye/+aSW3feY1k15WSVDB6lq0lXqd8zh3tHSMq/3XVVXV4eamhqXxwYAINvMbHoIoTrdNpapAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMBZzDsAACB/hVRKG16aprr3FqvXuDEaeOpxsgjHeXuLMgYA7JPExs2adtqVqn9/lUJLiywWVdnokTrhz79TUb8+3vHyCn++AAD2yeybv6e6BUuVrNuuVGOTknX1qpu3WHO+/APvaHmHMgYA7LUQgtb84XmFRPOO44lmrXpsilOq/EUZI2MrVjVo3sJtam5OeUcBkANCMpl+vIXXiL3FOWPs0draRt3yvdlatqJe0ajJTPrXSQfqtJMHeUcD4MTMVHnWyap9/lWFZLvyjUY06OyT/YLlKY6MsVshBN38rVlatLROTYmU6huS2l6f1Pd/Nk8LltR5xwPg6NA7b1XRwP6KlpdJkqLlpYpXDtChP7vVOVn+4cgYuzV3wTbVbkgotdOqU3NzSk88s1L/etNBPsEAuCsbNVwfm/eCVj3yrLa+O099xh+sYRPPU6ytnJE5yhi7tXFzQuneMphKSWvXN2U/EICcEutVrv2uu8w7Rt5jmRq7deiBfZRIc8FWvDiiEz4ywCERABQeyhi71b9fsSZ+coRKSj58qhQXR1QxsFjnnTnUMRkAFA6WqbFHN1w1WgeP7aPHJ6/Qtu0tOvWECl16wQiVlUa9owFAQaCMsUdmplNPqNCpJ1R4RwGAgsQyNQAAzjIqYzM728zmmdlCM7tlF/tcZmZzzGy2mf2+a2MCwI5attfr/V8+ohlXfknzbvupGpav9o4E7LM9LlObWVTSXZLOlLRC0ptmNjmEMKfdPmMlfV3SiSGETWbWoz6aKdnQqA2v/k0KQQNPPU7R0hLvSEBBS2zcrNeOv1iJdRuU3N4gKy7Wkp/dr2Of+ZUGnFTtHQ/Ya5mcMz5W0sIQwmJJMrNHJF0oaU67fa6XdFcIYZMkhRDWdXXQXLXuuVf01hU3y8wktX6359G//6kGnXOqczKgcC34/t1qXLn2gy8pCImEkgnp7U9/TR+b/+IH/x6BfJHJMvVwScvb3V/RNtbegZIONLO/mtk0Mzs73S8ysxvMrMbMampra/ctcQ5pWrdBb/3jF5Tctl0tW+vUsrVOybp6TZ/4BTWt2+AdDyhYa556ocO3BUlS09oNamS5Gnmoqy7gikkaK2mCpMsl/dLM+u28Uwjh3hBCdQihurKysose2s/qx59TCGk2hKDVj03Neh6gp9jlqaBkShFOEyEPZVLGKyWNbHd/RNtYeyskTQ4hNIcQlkiar9ZyLmgt2+qUau7413kq0azmrdscEgE9w6jPXqFIWemOg9Go+h4zXvFKPhkO+SeTMn5T0lgzG21mxZImSpq80z5PqfWoWGZWodZl68VdmDMnVZ55kqLFRR3Go/G4Bp3FV4gB3aXqxis0+BMfU6S0RNFeZYr2KlfZ6BE6+qGfeEcD9skeL+AKIbSY2SRJz0uKSvp1CGG2md0uqSaEMLlt21lmNkdSUtJXQwgFf9K079GHauhl52n1Y1OV3F4vqfUrxIZeco76fuQw53RA4bJoVEc/9BPVzVuszTXvqHTkUA04qVqW7ltNgDxgIe1Jz+5XXV0dampqXB67K4UQtG7qK1rx2yclSSOu+qQGnTuBqzkBADsws+khhLTvvePjMDvJzDT43AkafO4E7ygAgDzFmg4AAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADjjs6nRo4RUSuv/9Fdtmva2SoYN1tDLzlVRn17esQD0cJQxeoxkQ6OmnXWNtr07X8m6ekXKSjX3lv/U8X96UH2PPMQ7HoAejGVq9BhL7nxAW2e+p2Rd63dPp+ob1LJlm2Zc8S/y+ipRAJAoY/QgKx58UqmGxg7jDSvWqGHpCodEANCKMkbPYZZ+PAQpwj8FAH54BUKPMfKaixUpLekwXlo1QmWjhjskAoBWlDF6jKqbrla/Yw5XtLxMikUV7VWmov599ZGHf+odDUAPx9XU6DGi8WId/6ffauNf3mx7a9MgDfmHsxQrL/OOBqCHo4zRo5iZBp5yrAaecqx3FAD4AMvUAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDJG3kts3Kx3/+W7+tPIk/Ti/hO04Pt3KdmU8I4FABmLeQcAOiPZ2KS/fvRSNSxfrZBoliQt/OE92viXGh333G+c0wFAZjgyRl5b/fhUNa1Z/0ERS1KqoUmbps3Q5jdnOSYDgMxRxshrm16foeT2+g7jIRW0Zfq7DokAYO9Rxshr5QeMUqQ03mHcYlGVVg13SAQAe48yRl4bcdUnFSkq2mHMolEVD+inyjNPckoF7LtUIqFlv35M0z7+KdVc/Dmte/5V70jIAi7gQl4rrhig41/8nWZ++muqm79EktT/+CN15AP/JYtGndMBeyfV0qJpZ16jrTPnKrm9QZK0/sX/p6rPX6mDv/8V53ToTpQx8l7fIw/RKTOeVlPtRkWKYirq18c7ErBP1j71grbOfO+DIpak5PYGLbnztxp145UqHTHEMR26E8vUKBjxygEUMfLa2mdeSntBosWi2vDyNIdEyBbKGAByRHHlACnW8fSKRSIqGtDPIRGyhTIGgByx37WXdrggUZKsqEiVZ57okAjZQhkDQI7odfAYHX7v9xQtL1WsTy9Fe5crPnSQjn/+/rQljcLBBVwAkEOGTzxfQy44Q5umva1oaYn6HXeELMJxU6GjjAEgx0TLSlVx2gneMZBF/LkFAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgLOMytjMzjazeWa20Mxu2c1+F5tZMLPqrosIAEBh22MZm1lU0l2SzpE0TtLlZjYuzX69JX1R0htdHRIAgEKWyZHxsZIWhhAWhxASkh6RdGGa/b4r6YeSGrswHwAABS+TMh4uaXm7+yvaxj5gZkdLGhlCeLYLswEA0CN0+gIuM4tI+rGkL2ew7w1mVmNmNbW1tZ19aAAACkImZbxS0sh290e0jf1db0mHSXrZzJZKOl7S5HQXcYUQ7g0hVIcQqisrK/c9NQAABSSTMn5T0lgzG21mxZImSpr8940hhC0hhIoQQlUIoUrSNEkXhBBquiUxAAAFZo9lHEJokTRJ0vOS5kp6NIQw28xuN7MLujsgAACFLpbJTiGEKZKm7DR26y72ndD5WAAA9Bx8AhcAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgLOMvs84XzSuqdWaJ/+oVFOzBp83QeVjq7wjAQCwRwVTxqsefVYzr/u6JJNSKc371o815ivX6cDbvuAdDQCA3SqIZerExs2aee3XlWpoUqqhUammhFKNTVr04/u0Zfq73vEAANitgijjdVNelsWiHcZTjQmtfPhph0QAAGSuIMpYqZB+PASFVCq7WQAA2EsFUcaV55yq0JLsMB4tLdGwy85zSAQAQOYKoozjlQN02J23KlIalxUVSZGIImUlGnndZep//JHe8QAA2K2CuZp65Kcv0cAJx2n141OVbExo8Pmnq++Rh3jHyln1DUk9NXWVXp22Xv36FOmS84er+oj+3rEAoEeyEHZxvrWbVVdXh5qaGpfH7unqG5K67ubpWlvbpKZE6zn1knhEn7miSldcNNI5HQAUJjObHkKoTretIJapsXee+ePqHYpYkhqbUvrVQ0u1ra7FMRkA9EyUcQ/02t827FDEf1cUM82dv9UhEQD0bJRxDzSgf7HMOo6nUkF9+hRlPxAA9HCUcQ90ySeGKV684//6SESqGBDXQWN6OaUCgJ6LMu6BDju4ryZdN0Yl8YjKy6IqiUc0cliZfnz74bJ0h8wAgG5VMG9twt755NnD9PEJg/Xegm3q3SumMVXlFDEgKdGc0sZNCfXvV9xhBQnoLpRxD1ZaEtVR4/t5xwByQghBv310mX73+HL9/S2fl3xiuG64erQiEf5QRfeijAFA0lNTV+nBx5apsenDdxo8/sxKlZRE9amJoxyToSdgDQYAJD342PIdilhqff/9w09+eKQMdBfKGAAkbdqcSDte35BUMkkZo3tRxgAgaf+q8rTjQweXKBbjpRLdi2cYAEi66doxisd3fEmMxyP64vUHOCVCT0IZA4CkIw/rpzu/f4SOObKfBvQv1hGH9tUdt43XiccO9I6GHoCrqQGgzaEH9dFPvnuEdwz0QBwZAwDgjDIGAMAZZQzkkJBMKtnQ6B0DQJZRxkAOSDY26d1J39Zz/Y7Sc/2P0svjz9GGV//mHQtAllDGQA54+5qvavkDTyjV2CQlU9r+3mK9ef4N2jZ7gXc0AFlAGQPOGleu1bopL7cWcTuppiYtuuOXTqkAZBNlDDirX7JckXhxh/GQTGnbHI6MgZ6AMgaclR+0v1JNTR3GLRZTv2MOd0gEINsoY8BZvHKARlx9kaJlpTuMR0rjGvOV65xSAcgmyhjIAYfdeavG3jpJ8aGDFC0rVcUZJ+rEVx9R2eiR3tEAZIF5fU9ndXV1qKmpcXlsAACyzcymhxCq023jyBgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBlljC6zeUuz1qxrlNd3ZANAvop5B0D+27Apoe/cMVfvzN2iSMTUr2+RvnnzwTp6fD/vaACQFzgyRqeEEPTFb87UzNmb1dwS1JRIaW1tk772nXe0ak2DdzwAyAuUMTpl9rxtWlvbqGRqx/GWZNATU1b5hAKAPEMZo1PW1jbKzDqMt7QErVzNkTEAZIIyRqccMra3WpIdL9gqiUc4Zwx0sYYVazTj6q/oj5XH6E/7naz53/0fpRIJ71joApQxOmXYkFKddlKlSuIfPpViMVOf3kU674whjsmAwtK8aYteO+4irXp0ipo3b1XT6nVadMcv9dblN3tHQxfgamp02je+eJDGHdhbf3h2lRoakjrl+IG6ZuIolZXx9AK6yrLfPK6WbdulZPKDsVRDo2pf+Ivq3lukXgePcUyHzuLVEp0WiZguOm+4LjpvuHcUoGBten2GUg2NHcYtFtPWWfMo4zzHMjUA5IHe4w5QJF7ccUMqqGz0iOwHQpeijAEgD+x3/URZ0Y6LmVZUpPKDRqtv9XinVOgqlDEA5IHSEUN0/Au/Ve/xB8liMVlxkQadN0HHTf112rcXIr9wzhgA8kS/6vE65a3Jat5ap0hxkaIlce9I6CKUMQDkmaI+vbwjoIuxTA0AgDPKGAAAZxmVsZmdbWbzzGyhmd2SZvuXzGyOmc0ysxfNbFTXRwUAoDDtsYzNLCrpLknnSBon6XIzG7fTbjMkVYcQDpf0uKT/7OqgAAAUqkyOjI+VtDCEsDiEkJD0iKQL2+8QQngphFDfdneaJN6BDgBAhjIp4+GSlre7v6JtbFeulTS1M6EAAOhJuvStTWZ2paRqSafuYvsNkm6QpP32268rHxoAgLyVyZHxSkkj290f0Ta2AzM7Q9I3JV0QQmhK94tCCPeGEKpDCNWVlZX7khcAgIKTSRm/KWmsmY02s2JJEyVNbr+DmR0l6R61FvG6ro8JAEDh2mMZhxBaJE2S9LykuZIeDSHMNrPbzeyCtt3ukNRL0mNm9raZTd7FrwMAADvJ6JxxCGGKpCk7jd3a7vYZXZwLAIAeg0/gAgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOMvoKxQBINdsq2vRz+9fpBf/UqsQpFNOqNCkz4xRv75F3tGAvUYZA8g7yWTQ52+ZoWUrG9TSEiRJL7yyTrPmbNFDdx+joiIW/ZBfeMYCyDt/m7FRq9c1fVDEUmtBb9rcrFenrXdMBuwbyhhA3ln8/nYlEskO4w2NSS1aut0hEdA5lDGAvDNiWJnixdEO46UlEY0cXuqQCOgcyhhA3jnxmAHq3SumSLtXsEhEKi2J6rQTK/2CAfuIMgaQd2KxiO654ygdd/QARSOmSET6yOH9de+PjlY83vGIGcjEytUNeuOtjVqzrjHrj83V1ADyUsXAuO64bbxakkEKQbEYxxbYN01NSf3bf8zR9FmbVVxkSiRSOvmECn3r5oOz9rzi2Qsgr8WiRhGjU/77vkWaPnOTEomU6rYnlWgOem3aBj3w6LKsZeAZDADosVKpoKkvrlWiOeww3pRI6YlnV2UtB2UMAOixUqmgRHMq7bb6hpas5aCMAQA9ViwW0dj9e3UYN5OOOqxf1nJQxgCAHu2rnxur0pKIYm0X4hfFTGWlUX3h+jFZy8DV1ACAHu2QA/vo/v+u1qP/t0KLlmzXwQf21mUXjNCginjWMlDGAIAeb/iQUt38z2PdHp9lagAAnFHGAAA4o4wBAHBGGSPrQiqlxT/9jV7cf4Ker6hWzcWf1/YFS71jAYAbLuBC1r1703e08ndPKVnf+mHsa59+URteeUOnvP2MSkcMcU4HANnHkTGyqmnteq144IkPiliSFIKSDY1acuf9brkAwBNljKzaNmeBIiUd37sXEs3a9PoMh0QA4I8yRlaVVY1QqinRYdyiUfU6OHufdgMAuYQyRlaVjR6pgROO63B0HIkXaf+bP+OUCgB8UcbIuqMf+ZmGXnqOIvFiWVFMZWNHqfr/7lHvcQd4RwMAFxZC2PNe3aC6ujrU1NS4PDZyQyqRULKhSUV9e3tHAYBuZ2bTQwjV6bbx1ia4iRQXK1Jc7B0DANyxTA0AgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHDGVyh2o/eX1+vBx5dpweI6jakq15WX7Kf9R5V7xwIA5BjKuJvMnb9VN31jphLNKaVS0pJl2/Xq6+v1k+8ervGH9PWOBwDIISxTd5Of3rtQjU2tRSxJqZTU2JTST36x0DcYACDncGTcTd5bUJd2fMGSOoUQZGZZTgSk95dp6/Wrh5ZqzbpG7T+qXJ+9Zn8dcSirN0A2cWTcTcrLo2nHS0uiFDFyxtQX1+jb/zVXi5Zu1/b6pN6Zu1VfunWWZs7e7B0N2KXNNe9o/u3/rUU/vk8Ny1d7x+kSlHE3ueT84YrHd5zeeHFEF503zCkRsKMQgu6+f7GamlI7jDclUrr7N4udUgG7FkLQrBu/pWmnX6UF379b87/1U7186Me16tFnvaN1GmXcTa65bJQ+PmGwiotM5WVRFReZTju5Utf9U5V3NECSVN+Q1LZtLWm3LX6/PstpsCchBG145Q0tvet3qv3jXxRSqT3/RwVmw59f16qHn1ayvkFKpZRKJJRqaNLM676h5i3bvON1CueMu0k0avrapAN1w1WjtXJNg4YNKVH/vsXesYAPlMSjKi6OqKUh2WHboIq4QyLsSsu2Or1+xtXaPn+JQktSFoupZGilTnjlYcUrB3jHy5qVDz+t5PaGDuORWFTrX3hNQy85xyFV1+DIuJv161ukQw/qQxEj50Sjpon/MEIlO51OKYlH9JkrRjmlQjpzv/5f2jZ7vpJ19Uo1NilZt131S1fonc/d6h0tqywaldJccxNkUiS/6yy/0wPolE/94yhd/g8jVVYaVVGRqW/vmG66boxOP3mQdzS0s+rhpxWamncYC80tWvfMSwrJjisbhWr4lRcqWlbScUMyqcqzTsp+oC7EMjXQg0Uipmv/qUrXTByl+voW9SqPKRLhav9cs6vCDalU61sls5zHy8CTj9F+10/U+/c8rJBMymJRKUhH/e5HivXK7083pIwBKBY19eld5B0DuzDoE6dp9R+ek1ralXIkooGnHKNIrGe9jI+74xaN/Mylqn3uFUXLSjXkoo8XxHnznvV/EQDy0Lg7btGm16arefNWJbfXK1peqkhJicb/4nve0Vz0PmSMeh8yxjtGl6KMASDHlQwdpAlzntOqx6Zqy4zZ6j3uAA2//HzFevfyjoYuYiEElweurq4ONTU1Lo8NAEC2mdn0EEJ1um1cTQ0AgDPKGAAAZ/ZZzH0AAAaWSURBVJQxAADOKGMAAJxRxgAAOOOtTeiULTPmaO3TL8qKYhp26bkqP4DPNAaAvUUZY5/N+doP9f49v1eqMSGLRrTw33+ucXd8XaP++XLvaACQV1imxj7Z/OYsLbvn90rVN0qplEJzi1INTZrzlX9X45pa73gAkFcoY+yTVX94TsmGpg7jFolo3bMvZz8QAOQxyhj7JBKNSOm+3ccki/K0AoC9kdGrppmdbWbzzGyhmd2SZnvczP63bfsbZlbV1UGRW4Zd9glFios7jIdkSoM/8TGHRACQv/ZYxmYWlXSXpHMkjZN0uZmN22m3ayVtCiEcIOknkn7Y1UGRW/occbDGfuNGRUrirT9lpYqUxnX4r36g4or8/zozAMimTK6mPlbSwhDCYkkys0ckXShpTrt9LpT07bbbj0v6HzOz4PUtFMiKA275rIb943la+8xLihQXacgnz1R8cIV3LADIO5mU8XBJy9vdXyHpuF3tE0JoMbMtkgZKWt8VIZG7ykaP1OibrvaOAQB5LatX2pjZDWZWY2Y1tbW8/QUAACmzMl4paWS7+yPaxtLuY2YxSX0lbdj5F4UQ7g0hVIcQqisrK/ctMQAABSaTMn5T0lgzG21mxZImSpq80z6TJV3TdvsSSX/mfDEAAJnZ4znjtnPAkyQ9Lykq6dchhNlmdrukmhDCZEn3SXrQzBZK2qjWwgYAABnI6LOpQwhTJE3ZaezWdrcbJV3atdEAAOgZ+KgkAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwJl5fe2wmdVKer8bfnWFpPXd8Ht7Euaw85jDzmMOO4857LyunMNRIYTKdBvcyri7mFlNCKHaO0c+Yw47jznsPOaw85jDzsvWHLJMDQCAM8oYAABnhVjG93oHKADMYecxh53HHHYec9h5WZnDgjtnDABAvinEI2MAAPJK3paxmZ1tZvPMbKGZ3ZJme9zM/rdt+xtmVpX9lLktgzn8kpnNMbNZZvaimY3yyJnL9jSH7fa72MyCmXFl604ymUMzu6ztuTjbzH6f7Yy5LoN/y/uZ2UtmNqPt3/O5HjlzlZn92szWmdm7u9huZnZn2/zOMrOjuzxECCHvfiRFJS2StL+kYkkzJY3baZ/PSfpF2+2Jkv7XO3cu/WQ4hx+TVNZ2+0bmcO/nsG2/3pJelTRNUrV37lz6yfB5OFbSDEn92+4P8s6dSz8ZzuG9km5suz1O0lLv3Ln0I+kUSUdLencX28+VNFWSSTpe0htdnSFfj4yPlbQwhLA4hJCQ9IikC3fa50JJD7TdflzS6WZmWcyY6/Y4hyGEl0II9W13p0kakeWMuS6T56EkfVfSDyU1ZjNcnshkDq+XdFcIYZMkhRDWZTljrstkDoOkPm23+0palcV8OS+E8KqkjbvZ5UJJvw2tpknqZ2ZDuzJDvpbxcEnL291f0TaWdp8QQoukLZIGZiVdfshkDtu7Vq1/GeJDe5zDtuWskSGEZ7MZLI9k8jw8UNKBZvZXM5tmZmdnLV1+yGQOvy3pSjNbIWmKpJuyE61g7O3r5V6LdeUvQ2EysyslVUs61TtLPjGziKQfS/qUc5R8F1PrUvUEta7OvGpm40MIm11T5ZfLJd0fQviRmZ0g6UEzOyyEkPIOhlb5emS8UtLIdvdHtI2l3cfMYmpdmtmQlXT5IZM5lJmdIembki4IITRlKVu+2NMc9pZ0mKSXzWypWs81TeYirh1k8jxcIWlyCKE5hLBE0ny1ljNaZTKH10p6VJJCCK9LKlHrZy4jMxm9XnZGvpbxm5LGmtloMytW6wVak3faZ7Kka9puXyLpz6HtTDwkZTCHZnaUpHvUWsScp+tot3MYQtgSQqgIIVSFEKrUet79ghBCjU/cnJTJv+Wn1HpULDOrUOuy9eJshsxxmczhMkmnS5KZHaLWMq7Nasr8NlnS1W1XVR8vaUsIYXVXPkBeLlOHEFrMbJKk59V6JeGvQwizzex2STUhhMmS7lPrUsxCtZ6Yn+iXOPdkOId3SOol6bG2a9+WhRAucAudYzKcQ+xGhnP4vKSzzGyOpKSkr4YQWOVqk+EcflnSL83sZrVezPUpDk4+ZGYPq/UPvoq28+q3SSqSpBDCL9R6nv1cSQsl1Uv6dJdn4P8HAAC+8nWZGgCAgkEZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOPv/g+NdOOliW/8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# First make a simple randomK parameter dataset\n",
    "#data_params = dl.RandomKParameters(10, 25, radius=0.02)\n",
    "data_params = dl.EricParameters(25, 0.05)\n",
    "dataset = dl.Random2DBinaryDataset(data_params, batch_size=32, random_seed=420)\n",
    "dataset.plot_2d()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And then make a simple network to classify it \n",
    "hbox = Hyperbox.build_linf_ball(np.array([0.5, 0.5]), 0.5)\n",
    "c_vector = torch.tensor([1.0, -1.0])\n",
    "simple_net = ReLUNet(layer_sizes=(2, 8, 16, 32, 16, 2), bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And train this thing:\n",
    "trainset, valset = dataset.split_train_val(1.0)\n",
    "loss_fxn = nn_train.LossFunctional(regularizers=[nn_train.XEntropyReg(),\n",
    "                                                 nn_train.LpWeightReg(lp='l2', scalar=1e-3)])\n",
    "train_params = nn_train.TrainParameters(trainset, trainset, 1000, test_after_epoch=200, \n",
    "                                        loss_functional=loss_fxn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00 | Accuracy: 44.00\n",
      "Epoch 200 | Accuracy: 64.00\n",
      "Epoch 400 | Accuracy: 76.00\n",
      "Epoch 600 | Accuracy: 80.00\n",
      "Epoch 800 | Accuracy: 80.00\n"
     ]
    }
   ],
   "source": [
    "nn_train.training_loop(simple_net, train_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAesAAAHWCAYAAABXF6HSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dfXRcd33n8c/3zmgkWQ+WZMmWLclPsRzih5A4jvNEQyAJJ4Q2aUthky6npQVCaWE5bQ5bunCAAz1d2i7ds21TILBsWnaXNMAudRfTAE0ghcTO84OtxLbiJ1m2/CBLtmRZ0szc3/4hxejJ1siamfubmffrHJ+juXM98z3Xst5z71zdMeecAACAv4KoBwAAABdHrAEA8ByxBgDAc8QaAADPEWsAADxHrAEA8NyssTazb5jZcTPbeYH7zcz+2sw6zexlM9uU/TEBAChdmexZPyTpjovc/05J7eN/7pP05fmPBQAA3jBrrJ1zT0g6dZFV7pb0D27Mdkl1ZrY0WwMCAFDqsvGedYukrgm3D48vAwAAWRDP55OZ2X0aO1SuREXVNYtb3pTPpwcAIDKHX3/upHOu6VL+bjZi3S2pbcLt1vFl0zjnHpT0oCS1rdns/vBLz2Th6QEA8N/9vxocvNS/m43D4Fsl/db4WeHXSzrtnDuahccFAADKYM/azL4l6RZJjWZ2WNJnJZVJknPuK5K2SbpTUqekIUm/k6thAQAoRbPG2jl37yz3O0l/kLWJAADAJFzBDAAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPxqAcoVKlUWse6+hSmQy1uqVN5ZSLqkQAARYpYX4LjR/q0/UcdYzecU+icNm5ZrcvWt0Q7GACgKHEYfI5SybS2/7BDqWR67E8qVJh22vn0fp3uHYx6PABAESLWc9TTdUqy6cvTYaiDe4/lfyAAQNEj1nOUTqXlnJt+hxu7DwCAbCPWc7S4tV4ztToWD7RsZWP+BwIAFD1iPUeVC8q17poVisWC84fDY/FAzW0NWtxSH+1wAICixNngl2DtlW1qWlang3uOKZ1Kq2VVk5a01stshjezAQCYJ2J9ieoba1TfWBP1GACAEsBhcAAAPEesAaAIpFJpnToxoLMDw1GPghzgMDgAFLjXO7q18+n9MjOFoVN9Y7Wuv329yivKoh4NWcKeNQAUsOPdfdr59H6lU6FSybTCdKhTJwa048cdUY+GLCLWAFDA9rx8WOlUOGmZC51OnRjQEIfEiwaxBoACNjw0MuPyIDANDyfzPA1yhVgDQAFb0tagIJh+jQfnnBbWL4hgIuQCsQaAAta+sVVl5fFJwY7FA23YskqxeCzCyZBNnA0OAAWsojKhW3/9Gu195bCOdfWpsiqh9o2tXP64yBBrAChwFZUJbdyyWhu3RD0JcoXD4AAAeI5YAwDgOWINAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDnMoq1md1hZrvNrNPMPjnD/cvN7HEze8HMXjazO7M/KgAApWnWWJtZTNIDkt4paZ2ke81s3ZTVPi3pEefc1ZLukfR32R4UAIBSlcme9RZJnc65fc65UUkPS7p7yjpOUu341wslHcneiAAAlLZ4Buu0SOqacPuwpOumrPM5ST80s49JqpJ0W1amAwAAWTvB7F5JDznnWiXdKembZjbtsc3sPjN71syePXvmRJaeGgCA4pZJrLsltU243Tq+bKIPSHpEkpxzT0mqkNQ49YGccw865zY75zZX1TZd2sQAAJSYTGL9jKR2M1tlZgmNnUC2dco6hyTdKklmdoXGYs2uMwAAWTBrrJ1zKUkflfSopFc1dtb3LjP7vJndNb7a/ZI+ZGYvSfqWpPc751yuhgYAoJRkcoKZnHPbJG2bsuwzE77ukHRTdkcDAAASVzADAMB7xBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPBePegAAyJfkaErHj/QrFgvUtKxOsRj7KygMxBpASTiwu0cvPtmpIDBJkpl0wzs2qLF5YcSTAbPjZSWAojfQP6SXnuxUmA6VSqaVSqaVHE3ryUd3KpVKRz0eMCtiDaDoHdxzTGEYznhfz6FTeZ4GmDtiDaDoJUdTcm76cuecUkn2rOE/Yg2g6C1buUix+PQfd85Ji1vrI5gImBtiDaDoLW6pV9Oy+knBjsUDrX1zqxZUlUc4GZAZzgYHUPTMTDfcvk5HDvbq8OsnFCsLtHJtM2eCo2AQawAlwczUsrJRLSsbox4FmDNiDQBAFny459MXvf/+eTw2sQYAIEOzBTlXiDUAABNEFeSLIdYAgJLjY5AvhlgDAIpSoQX5Yog1AKBgFVOQL4ZYAwC85UOM0yMpdT22R8ef7VIQD7T0Lau17KbVsvFPcMsHYg0AiJQPQb6QMB3q+b96XEPHB+SSYx8G8/r/fVn9u49r/QdvyNscxBoAkFM+x3g2J1/q1vDJwfOhlqRwNK3eXT0a7O5XdUtdXuYg1gCAeSnkGM+mf88JpUdm/mS20/t6iTUAwA/FHOPZlDcsUBAPFKYmfx66BabyhZV5m4NYA0CJK+UYz6b5+pU6+C+vTl5oUpCIqWF9c97mINYAUOSI8aUrr63QlX/wS3r1f+xQcmhUck6VTdVa/8EbFMTy9ynTxBoAChwxzq26yxp1/Rfu1Lnjg7J4oMpFVXmfgVgDgOeIcfTMTAuW1ET2/MQaACJGjDEbYg0AeUCQMR/EGgCygBgjl4g1AGSAGCNKxBoARIzhN2INoCQQYxQyYg2gKBBjFDNiDaAgEGOUMmINwAvEGLgwYg0gbwgycGmINYCsIcZAbhBrABkjxkA0iDWA84gx4CdiDZQQYgwUJmINFBFiDBQnYg0UEGIMlCZiDXiEGAOYCbEG8owgA5grYg1kGTEGZhemQp18qVun9/eqsrFKS65dobKqRNRjeYtYA3NEjIH5SZ1L6vkvPabhvnMKR1IKEjHt/36Hrv74W1XdWhf1eF4i1sAUxBjIrQM/6NC5k2flUqEkKRxNS0rr1b9/Wtd+6h3RDucpYo2SQ4yBaB1/rut8qCcaOjGo0YFhJWoqIpjKb8S6BCVHUzp6sFepVFpLWutVVVMZ9UhZRYwBv1ksmPkOJ5lZfocpEMS6xBzv7tNTP9olk+Sc9LKk9o0tWr95VdSjZYwYA4Vt6Q0rdeiHrylMTti7Nqlmeb3KqsujG8xjxLqEpFJpbf9xh9JTDj917uzWktYGNTYvjGiyyYgxUNyW33a5+vee0JkDp+RCpyAWKFYR1xXv3xL1aN4i1iXkeHf/jMvTqVAH9/TkNdYEGShdQVlMb/7YzTqz/5QGDp1SRUOVGtY3K7jQ4XEQ61LiwukndLwhTLusPhcxBnAxZqaFqxdp4epFUY9SEIh1CWlaVq8wnB7lWDxQ62VNc3osYgwA+UOsS0iiPK6rb2rXiz/fq9A5udApFg+0dPkiNbc1TFqXGAOAP4h1iVmxdokWLanV4hcf1lAqphuaT2hjQ5/sWNSTAQAuhFh7LKd7t5fn7qEBANmVUazN7A5J/01STNLXnXNfnGGd90r6nCQn6SXn3G9mcU7vcJgY8FOYTMtigSzg4hooHrPG2sxikh6QdLukw5KeMbOtzrmOCeu0S/oTSTc55/rMbHGuBs4EIQVKT29Hjzq//YLOnTyroCymlreu0apfXs+vA6EoZLJnvUVSp3NunySZ2cOS7pbUMWGdD0l6wDnXJ0nOueOzPWhTsjuyqLrQqWfHQR35+T65dKgl1y7Xsl+6TLGyWCTzAJif0/t7tetrTylMpiWNfTBE9086lT6X1Np7NkU8HTB/mbzkbJHUNeH24fFlE62VtNbMfm5m28cPm3ur46Ed2vvtFzRw4JQGu/q1/5936aW/eUJuhl9rAuC/gz949Xyo3xAm0+rZcUCpc8mIpgKyJ1vHh+KS2iXdIuleSV8zs2kfSmpm95nZs2b27MnBoSw99dwMHu5X7ytHxz+SbUyYTOts92n17joayUwA5mfo2MCMyy0INNJ/Ls/TANmXSay7JbVNuN06vmyiw5K2OueSzrn9kvZoLN6TOOcedM5tds5tbqxecKkzz0t/58mxT7CYIj2SUt/uWY/eA/BQdetCaYbzyVzoVNEQzc8aIJsyifUzktrNbJWZJSTdI2nrlHW+p7G9aplZo8YOi+/L4pxZk6gtn/Hj2awsUHldcX1UJFAqVt65TsGUc06CREytb29XrJzfUEXhmzXWzrmUpI9KelTSq5Iecc7tMrPPm9ld46s9KqnXzDokPS7pE8653lwNPR+LNi6bOdZmWrJlRQQTAZiv6pY6XfUf3qrayxYpGH/hvfrujVr1y+ujHg3ICnMzHBLOh00rlrp/+0+/G8lzD3af1s4Hn9TowLDMpCAR1/rfvV517XO7PjYAAJmq/r0/e845t/lS/m5JHh+qblmo6z53h4Z6BuTSoaqWLeQCCgAAb5VkrKWxw95VS2ujHgMAgFlxaR8AADxHrAEA8ByxBgDAc8QaAADPEWsAADxHrAEA8ByxBgDAc8QaAADPEWsAADxHrAEA8FzJXm4UxWu4b0j9u48rVlmmhnXNik356EQAKDTEGkXlwLYOHfrha1Jg4x/OYrry99+ihasXRT0aAFwyDoOjaPR3ntChH+9WmAoVjqaVHk4pPZzUK1/5mcJ0GPV4AHDJiDWKxtEn9yscTU9b7kKn/j0nIpgIALKDWKNopGcI9RvC5IXvAwDfEWsUjcWb2hQkpp9M5tKh6tqbIpgIALKDWKNoNF3Voro1jYqVj583GZiCspja37tJ8cqyaIcDgHngbHAUDQtMG3/vLerddVQnXz6isgUJNV+/UlVLa6MeDQDmhVijqFhgaty4TI0bl0U9CgBkDYfBAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBz8agHAC7GhU7Hn+tSz44DMjM137BKTVe1yAKLejQAyBtiDW8557Tr60/p1GvHFI6mJUn9r59U7ytHdMVvb4l4OgDIHw6Dw1unXz85KdSSFI6mdeLFbg109UU4GQDkF7GGt/p2H58U6je4MFTf7uMRTAQA0SDW8FZZVUJB2fRvUYsFKluQiGAiAIgGsYa3Fl/TJmn6iWRmpsarWvI/EOCR0YFhHfjBq3rlwSd1YFuHRs8MRz0ScogTzOCtRE2FNtx3ozq+sV3OOUlje9UbPnQDe9YoaUPHBvT8f3lM6VRaLhmqr6NHXY/t1ab736aqpbVRj4ccINbwWsMVS3TjF39FZ/b3SmaqXdmgIMYBIZS2vd95UanhpDT2GlZhKpRSofY88ryu/vgtkc6G3CDW8F4QC1S3pinqMQBv9O85fj7UE53uPCnnnMy4DkGxYRcFAApMUBabeXl85uUofMQaAApM83UrFcQn//i2eKAlW5azV12kiDUAFJjVd2/UwssaFZTFFCuPKyiLqXZlg9b8+pujHg05wnvWAFBgYomY3vyxmzV45LSGes5owZIaVbfURT0WcohYA0CBql62UNXLFkY9BvKAw+AAAHiOWAMA4DliDQCA54g1AACeI9YAAHiOWAMA4DliDQCA54g1AACeI9YAAHguo1ib2R1mttvMOs3skxdZ791m5sxsc/ZGBACgtM0aazOLSXpA0jslrZN0r5mtm2G9Gkkfl7Qj20MCAFDKMtmz3iKp0zm3zzk3KulhSXfPsN4XJP25pOEszgcAQMnLJNYtkrom3D48vuw8M9skqc059/0szgYAAJSFT90ys0DSX0l6fwbr3ifpPklqa6id71MDAFASMtmz7pbUNuF26/iyN9RI2iDpJ2Z2QNL1krbOdJKZc+5B59xm59zmxuoFlz41AAAlJJNYPyOp3cxWmVlC0j2Str5xp3PutHOu0Tm30jm3UtJ2SXc5557NycQAAJSYWWPtnEtJ+qikRyW9KukR59wuM/u8md2V6wEBACh1Gb1n7ZzbJmnblGWfucC6t8x/LAAA8AauYAYAgOeINQAAniPWAAB4jlgDAOA5Yg0AgOeINQAAniPWAAB4jlgDAOA5Yg0AgOeINQAAniPWAAB4jlgDAOA5Yg0AgOeINQAAnsvoIzJzITWcVO/Oo6prb1KsPLIx4LmvNv9p1CN44cM9n456BAARiqySwyfPquOhHXKh05vet1mLN7VFNQoyQDSjlevtz4sBwG+RxdqFTunhlCTptW8+q9qVDapoqIpqHIwjyqUpF//uvAAAsseL488udDr29CGtuOOKqEcpGUQZuZar7zFeBKAU+RHrdKjUuWTUYxQtwoxiwlEAlCIvYh0kYmpY3xz1GEWBMANzx1EA+C7yWAeJmBrWNauuvSnqUQoScQb8xYsAZEtksY5XJdR4RYuWXLtcjRuXycyiGqWgTP3PPzQ4rM6d3eo9dkY1dQu09spW1dZzoh5QzHgroPREFuvKRVXa8MEbonr6gjDbf8iB00N6/J9eUDoVyoVO/ScH1L3/hG58xwY1LavL05QAigFHAfwW+WFw/MJc/7Ps3LFPqdH0+dvOSelUqBd+tle3v2czRysARI6jANlBrPMsm9+4J46ennH52YFhpZJplSX45wVQfErxKAA/zbMsnyd8lSXiSiXT05abSbEYl30HgLnw+UUAsc6Ar2dcr9nQoo7nDiidCs8vC2KB2lY3KSDWAOCFXzTkzy75MYj1FL6GeSZrNrRo8PQ5HdzboyAWKEw7LV5WpzfftCbq0QAAWUSsVViBnsjMdPVb2nXFNSs00D+kquoKLaipiHosoCiNDCfVd2JA5ZVlqltUzQmcyKuSjHWhxvlCKioTqqhMRD0GULQ6njugPS93KQgCOee0oLpCN71zoxZUlUc9GkpEycS62AINID+OHOzV3lcOK0w7hemxEzoHTw9p+4926e2/uini6VAqijrWBBrAfHXuPDzpJE5p7JoGA31DGjxzTtW1lRFNhlJSlLEm0gCyZXQkNeNyC0zJC9wHZFtRxZpIA8i2ZSsWabB/SGHopt23sIHr8CM/iiLWRBpArqzZ0KpDncc1PDSqMB1K4xcduvqmNVzPAHlT0LEm0gByLVEe162/tkkHXutRz+FTqqxK6LL1LapvrIl6NJSQgos1gQaQb2WJuNqvbFX7la1Rj4ISVTCxJtIAgFLlfayJNACg1HkbayKNopccVfVzP1HZ8cMaXr1O59ZfN/aRaQAwhXexJtIoBfET3Vr+qd9UMDQoS43Kxcs02rZGXZ99SK5iQdTjAfCMN7Em0iglS//mjxXrP6EgHL8yViqp8gO7teg7X9bJ990f6WwA/BNprAk0SlEwNKjK3S/KwsmXsAySI6r96feINYBpIvuN/hNlLVE9NRAtF17wrqkBBwApwlgDpSqsqtXwqivkppxMFsbLdOamOyOaCoDPiDUQgZ6P/bnCqlqF5WOf2JSuWKDkkjb1/ruPRTwZAB95c4IZUEpGW1br9S//q2p/tk1lx7o0fNl6DV57qxQvi3o0AB4i1kBEXGW1Tt/+3qjHAFAAOAwOAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDniDUAAJ4j1gAAeI5YAwDgOWINAIDniDUAAJ4j1rigVDKtdCqMegwAKHl8njWmOdM3pOef2K2+k4OSpMUtdbrm5stVsSAR8WQAUJrYs8YkoyMp/fSfX9SpEwNyzsk5p+PdffrpP78oF7qoxwOAkkSsMUlX5zGF4eRD385JI8NJHevui2gqAChtxBqTDPQPzfg+dRg6nR0YjmAiAADvWWOS+qYaxfYemxZsM6muoSqiqQDMSTqlhu99XfX/8r8UnDursxuu04nf/mMll66MejJcIvasMUnr6iaVV5TJAju/LIgFWthQrYYltRFOBiBTzQ/8iRZ99yuK951QMDyk6ud+ohWffI9ifcejHg2XiFhjklg8plvuvlrL1yxWWSKu8ooyXbZumd5y50aZ2ewPACBS8ZNHVfPUDxWM/uJtK3NONjqi+u9/M8LJMB8cBsc0FZUJXXPz5brm5qgnATBXia5OubIyKTkyaXmQHFXl3pcimgrzxZ41ABSRZHObLJWcttzF4hppa49gImQDsQaAIpJculLn3nSNwrLJFzFy8TL1veu3IpoK80WsAaDIdP/Hv9XATe9SWJaQC2IaXr5WXZ/5hpJLV0Q9Gi4R71kDQJFxFQvU89H/rJ6PfEGWSsqVV0Y9EuaJWANAsYrF5WL8mC8GHAYHAMBzxBoAAM9lFGszu8PMdptZp5l9cob7/8jMOszsZTP7VzPjLAYAALJk1libWUzSA5LeKWmdpHvNbN2U1V6QtNk5d6Wk70j6i2wPCgBAqcpkz3qLpE7n3D7n3KikhyXdPXEF59zjzrmh8ZvbJbVmd0wAAEpXJrFukdQ14fbh8WUX8gFJP5jPUAAA4Beyek6/mb1P0mZJb73A/fdJuk+S6puWZ/OpAQAoWpnsWXdLaptwu3V82SRmdpukT0m6yzk3MvV+SXLOPeic2+yc21xV23Qp8wIAUHIyifUzktrNbJWZJSTdI2nrxBXM7GpJX9VYqPnAVAAAsmjWWDvnUpI+KulRSa9KesQ5t8vMPm9md42v9peSqiV928xeNLOtF3g4AAAwRxm9Z+2c2yZp25Rln5nw9W1ZngsAAIzjCmYAAHiOWAMA4DliDQCA54g1AACe44NOi9yHez6d9+f8avOf5v05AaCYEesiFkWoo3zebOHFBgDfEGtgikJ+scELDaA4EWugiPBCAyhOxBqAFwr5hYbEiw3kFrEuUoX+gw8oNIX8f44XGv4j1gBQ4gr5hYZUGi82iDUAoKAVyouN++fxd7koCgAAnmPPGkBRcE7acbxRPzjUquF0TG9fdlS3th5VPHBRjwbMG7EGUBQe7FirbYdaNZwe+7H2Wt9C/fjwMn3xhmcVs4iHA+aJw+AACl7PUKX+38G286GWpOF0XHtO12rHsaYIJwOyg1gXoUI52QLIlpdO1iuw6Ye7h9NxPX2sMYKJgOwi1gAKXk0iqWCGQ91xC7WwfDT/AwFZRqwBFLzNTSdn3LOOmdM72o5EMBGQXcQaQMFLxJy+eP2zaigfVmUspQXxpCpjKX3iqlfUUnUu6vFQgpKhaf+ZavUOl2fl8TgbHEBRaF84oP952xPa3b9Qo+lA6+pPKxELox4LJejHh5v1dzuvUOhMKWdaV9+vT1/z8rwek1gDKBoxk9bVn456DJSwXacW6q9fWa+RdGzCsnp97pmr5vW4HAYHACBLvrtvhUbTk9OacoE6T9fO63GJNQAAWXLiXKWcpv9qQiyY31syxBoAgCzZ1NSrMktPW54K55dbYg0AQJb82qqDqk6kFJ8Q7IpYSv++/fV5PS4nmAEAkCV15Ul9+ean9HDnKj17vFF15aN69+oDurH5hD44j8cl1gAAZFF9+ag+sn63tH531h6Tw+AAAHiOWAMA4DlijawK06FS55Jybvp1mgEAl4b3rJEV6dG0Or/zoo49fVAudKpYVKW191yt+suXRD0aABQ89qyRFa8+tEM9Tx9UmArlQqdzJwb1ylef1GB3f9SjAUDBI9aYt5H+czrV0SOXmnyFnjCZ1qEfZe9sSAAoVcQa8zbce1YWj02/w0lDxwbyPxAAFBlijXmrXFKjMDX98noKTLUrG/I/EAAUGWKNeUtUl2vpjasUJCbvXcfKYmq77fKIpgKA4sHZ4MiK9t+4SpVN1Tr82B6lhpJauKZRl/3alapcVBX1aABQ8Ig1ssICU9vb2tX2tvaoRwGAosNhcAAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHLEGAMBzxBoAAM8RawAAPEesAQDwHB+ROYN0KlRPV69GhpNqWlqnmroFUY8EAChhxHqK/pOD+rdtL8s5Jxc6OUnL1yzW1W9pl5lFPR4AoARxGHwC55ye/OFOJUdTSiXTSqdDhelQXa8f15EDJ6MeDwBQotiznqC/d1Cp0fS05elUqH2vHlXLqqYIpgJyZyQd6CfdzXr+5CItrhzWncsPa2nVuajHAjAFsZ4gTIfSBY50h+kwv8MAOXY2GdPHf3adTgxXaDgdV9xC/dOB5frs5hd1TVNv1OMBORWmQg0cPCWLBapZXi8L/H6bk1hPUN9Yo5nelo7FA7WtWZz/gYAc+u6+FTp2rlKjYUySlHKBUmnpL17YoG/d/lN5/rMLuGS9u46q46GnJeckJ8Uq4tr44ZtUs7w+6tEuiPesJwhiga592xWKxQIF4z+p4vGY6hprtHJtc8TTAdn1xNHm86GeaDgdU9dgVQQTIVvSo2mdOXhK504ORj2Kd4b7hrTr69uVPpdUejil9EhKo6eH9dLfPKH0DG+D+oI96yma2xp0+3s26+CeYxo5N6rFrQ1a2tbg/SESYK4qYzP/YAqdqfwC98F/R36+T6//n5clk1zaqaploTbed6MStRVRj+aFnu0H5JybtsU/okwAAAdiSURBVNyFTr07j2jxprYIppode9YzWFBdoSs2rdBVN7Vr2YpFhBpF6a6VXaqIpSYtCxSqrfqsmhcMRzQV5qO/84Q6v/uS0iMppYdTCpNpDRzq0ytf/XnUo3kjeXZULjX9HKQwDJU8OxrBRJkh1kCJurX1iG5Z1qNEkFZlLKXKWEqNlSP6zOYXox4Nl+jwY3sVTj2UGzqdPXJGQ8cGohnKMw1vWqJY+fSDyiaprt3f3/jhMDhQogKT/vDNHbpnzX692rdQDRWjunLRKU4sK2Ajp2f+tTuLmUYHhrVgSU2eJ/JPw7pm1ayo15kDp86/sAkSMS3ZvFxVzbURT3dhxBoocUurzvG71UWiYV2zBrtPTzvM69JO1a11EU3lFwtMV/7BL+nYjoPqefqggnhMS29apaarWqIe7aKINQAUidZb2nX0qQNKDo6cD3aQiGnVu9YrXlEW7XAeCWKBlt64SktvXBX1KBkj1gBQJMqqErr2k7ep6/G96t15VImacrW9fa0a1vGrp4WOWANAESmrLtfqX9mg1b+yIepRkEWcDQ4AgOeINQAAniPWAAB4jlgDAOA5Yg0AgOc4GxznjQ6O6OiT+zV4uF81y+u19IZVKqtKRD0WAJQ8Yg1J0tmeM3rhS48rTIUKk2n1vnJUh360W9d84u2qbKyOejwAKGkcBockac8/Pq/UuaTC5Ni1csNkWqmhUe39Nh/qAABRI9aQC51Od56c4Q6p77Xj+R8IADAJsYZkksVm/lYIyvgWAYCoZfST2MzuMLPdZtZpZp+c4f5yM/vH8ft3mNnKbA+K3DEzLd7cJotP/naweKDmLSsimgoA8IZZY21mMUkPSHqnpHWS7jWzdVNW+4CkPufcGkn/VdKfZ3tQ5Fb7u69STWudgkRMsfK4gkRMtSsbtPrujVGPBgAlL5OzwbdI6nTO7ZMkM3tY0t2SOiasc7ekz41//R1Jf2tm5pxzWZwVORSvLNPV979NA4f6NHRsQFVLa1XTVh/1WAAAZRbrFkldE24flnTdhdZxzqXM7LSkRZJmOGsJvjIz1a5oUO2KhqhHAQBMkNffszaz+yTdN35z5P5fDXbm8/lLxf2/+LJRvGDKB7Zz7rGNc49tnHuXX+pfzCTW3ZLaJtxuHV820zqHzSwuaaGk3qkP5Jx7UNKDkmRmzzrnNl/K0MgM2zg/2M65xzbOPbZx7pnZs5f6dzM5G/wZSe1mtsrMEpLukbR1yjpbJf32+Ne/Iekx3q8GACA7Zt2zHn8P+qOSHpUUk/QN59wuM/u8pGedc1sl/XdJ3zSzTkmnNBZ0AACQBRm9Z+2c2yZp25Rln5nw9bCk98zxuR+c4/qYO7ZxfrCdc49tnHts49y75G1sHK0GAMBvXEsSAADP5TzWXKo09zLYxn9kZh1m9rKZ/auZcQ3ROZptG09Y791m5syMs2ovQSbb2czeO/79vMvM/ne+Zyx0Gfy8WG5mj5vZC+M/M+6MYs5CZmbfMLPjZjbjryfbmL8e/zd42cw2zfqgzrmc/dHYCWmvS1otKSHpJUnrpqzz+5K+Mv71PZL+MZczFdufDLfx2yQtGP/6I2zj7G/j8fVqJD0habukzVHPXWh/Mvxebpf0gqT68duLo567kP5kuI0flPSR8a/XSToQ9dyF9kfSzZI2Sdp5gfvvlPQDSSbpekk7ZnvMXO9Zn79UqXNuVNIblyqd6G5Jfz/+9Xck3WpmluO5isms29g597hzbmj85naN/a48MpfJ97EkfUFj18UfzudwRSST7fwhSQ845/okyTnHZ7jOTSbb2EmqHf96oaQjeZyvKDjnntDYb0ZdyN2S/sGN2S6pzsyWXuwxcx3rmS5V2nKhdZxzKUlvXKoUmclkG0/0AY29okPmZt3G44ex2pxz38/nYEUmk+/ltZLWmtnPzWy7md2Rt+mKQybb+HOS3mdmhzX2W0Afy89oJWWuP7fze7lRRMvM3idps6S3Rj1LMTGzQNJfSXp/xKOUgrjGDoXforEjRE+Y2UbnXH+kUxWXeyU95Jz7kpndoLFraGxwzoVRD1bKcr1nPZdLlepilyrFBWWyjWVmt0n6lKS7nHMjeZqtWMy2jWskbZD0EzM7oLH3oLZyktmcZfK9fFjSVudc0jm3X9IejcUbmclkG39A0iOS5Jx7SlKFxq4bjuzJ6Of2RLmONZcqzb1Zt7GZXS3pqxoLNe/xzd1Ft7Fz7rRzrtE5t9I5t1Jj5wXc5Zy75OsAl6hMfl58T2N71TKzRo0dFt+XzyELXCbb+JCkWyXJzK7QWKxP5HXK4rdV0m+NnxV+vaTTzrmjF/sLOT0M7rhUac5luI3/UlK1pG+Pn7t3yDl3V2RDF5gMtzHmKcPt/Kikd5hZh6S0pE845zgSl6EMt/H9kr5mZn+osZPN3s8O1NyY2bc09qKycfy9/89KKpMk59xXNHYuwJ2SOiUNSfqdWR+TfwMAAPzGFcwAAPAcsQYAwHPEGgAAzxFrAAA8R6wBAPAcsQYAwHPEGgAAzxFrAAA89/8BMesZK5KbyH4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ax = simple_net.display_decision_bounds([0, 1], [0,1], 100)\n",
    "dataset.plot_2d(ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 8 8 8\n",
      "1 16 16 16\n",
      "2 32 32 32\n",
      "3 16 16 16\n",
      "4 2 2 2\n"
     ]
    }
   ],
   "source": [
    "#### PREACTIVATION CHECK \n",
    "small_hbox = Hyperbox.build_linf_ball(np.array([0.5, 0.5]), 0.2)\n",
    "x = torch.Tensor([0.5, 0.5])\n",
    "bounds = PreactivationBounds.naive_ia_from_hyperbox(simple_net, small_hbox)\n",
    "def check_preacts_within_bounds(x):\n",
    "    preacts = simple_net(x, return_preacts=True)\n",
    "    for i, preact in enumerate(preacts):\n",
    "        preact = preact.data.numpy().reshape(-1)\n",
    "        ith_bound = bounds.get_ith_layer_bounds(i)\n",
    "        print(i, len(ith_bound[0]), sum(ith_bound[0] <= preact), sum(ith_bound[1] >= preact))\n",
    "        \n",
    "check_preacts_within_bounds(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To compute the Local Lipschitz Constant of a DNN, we need:\n",
    "# 1) A region to compute over \n",
    "# 2) A neural network to compute\n",
    "# 3) Function parameters {LP_norm, preactivation_techniques, etc}\n",
    "\n",
    "# LipParameters holds items (1, 3)\n",
    "# LipProblem holds items (1, 2, 3)\n",
    "# LipMIPResult holds all info from a completed lipschitz object\n",
    "\n",
    "\n",
    "# EvaluationParameters holds parameters for how to evaluate lipschitz\n",
    "# so let's do this later\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "302.5213302198108 [0.22593368 0.64090465]\n",
      "tensor(302.5213)\n"
     ]
    }
   ],
   "source": [
    "# Goal: pick a random point and evaluate the 0.1 hyperbox lipschitz \n",
    "hbox = Hyperbox.build_linf_ball(np.array([0.5, 0.5]), 0.5)\n",
    "c_vector = torch.tensor([1.0, -1.0])\n",
    "\n",
    "params = lm.LipParameters(hbox, c_vector, verbose=False)\n",
    "problem = lm.LipProblem(simple_net, params)\n",
    "result = problem.compute_max_lipschitz()\n",
    "result\n",
    "print(result.value, result.best_x)\n",
    "print(simple_net.get_grad_at_point([0.3741, 0.5873], c_vector).norm(p=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "clever = CLEVER(simple_net, c_vector, hbox, 'linf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "302.5213623046875"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clever.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq = SeqLip(simple_net, c_vector, hbox, 'l2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "381.98178558419875"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1842.0687"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spec_ub = NaiveUB(simple_net, c_vector, 'linf')\n",
    "spec_ub.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(302.5213)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rando = RandomLB(simple_net, c_vector, hbox, 'linf')\n",
    "rando.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "liplp = LipLP(simple_net, c_vector, hbox, 'linf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "580.4085997770305"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "liplp.compute(tighter_relu=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "580.4085997770305"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "liplp.compute(tighter_relu=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_net.get_grad_at_point(rand_out, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_out.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[(_.LB, _.UB) for _ in squire.var_dict['fc_0_post']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[_ for _ in squire.var_dict.keys() if re.match('^relu_\\d+$', _)]\n",
    "squire.var_dict['relu_1']\n",
    "assert -1 < 0 a< 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squire.model.getVarByName"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# and do the other methods \n",
    "nub = NaiveUB(simple_net, c_vector, 'linf')\n",
    "nub.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rlb = RandomLB(simple_net, c_vector, hbox, 'linf')\n",
    "rlb.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FastLip(simple_net, c_vector, hbox, 'linf').compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LipSDPCustom(simple_net, c_vector).compute() * (2 ** 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_net.random_max_grad(hbox, c_vector, 1000, pnorm=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_net.get_grad_at_point(np.array([0.40050755, 0.6]), c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.squire.get_grad_at_point(np.array([0.5005, 0.4886]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Simple idea: get lipschitz constants as we train\n",
    "new_net = simple_net = ReLUNet(layer_sizes=(2, 20, 20, 20, 2), bias=True)\n",
    "loss_fxn = nn_train.LossFunctional(regularizers=[nn_train.XEntropyReg(),\n",
    "                                                 nn_train.LpWeightReg(lp='l1', scalar=1e-4)])\n",
    "\n",
    "train_params = nn_train.TrainParameters(trainset, trainset, 50, test_after_epoch=50,\n",
    "                                        loss_functional=loss_fxn)\n",
    "done_epochs = 0 \n",
    "lip_results = []\n",
    "lip_params = lm.LipParameters(hbox, c_vector, verbose=False)\n",
    "while done_epochs < 5000:\n",
    "    nn_train.training_loop(new_net, train_params, epoch_start_no=done_epochs)\n",
    "    problem = lm.LipProblem(simple_net, lip_params)\n",
    "    lip_results.append(problem.compute_max_lipschitz())\n",
    "    done_epochs += train_params.num_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[_.value for _ in lip_results]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FastLip(new_net, c_vector, hbox, 'linf').compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
