{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import lipMIP as lm\n",
    "import neural_nets.data_loaders as dl\n",
    "import neural_nets.train as nn_train\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from relu_nets import ReLUNet\n",
    "import lipMIP as lm\n",
    "import torch\n",
    "from hyperbox import Hyperbox\n",
    "from pre_activation_bounds import PreactivationBounds\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from other_methods.lip_sdp import LipSDPCustom\n",
    "from other_methods.fast_lip import FastLip\n",
    "from other_methods.naive_methods import NaiveUB, RandomLB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAHSCAYAAADfUaMwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXRb5Z3/8c9Xki0v2WNnD3EIYQmErWYrW8pWlgIdtgkMS1uWKW1oh25D2ym0dJnpMN2YgRZaWiilMECBCZBAKWUp/RGKQ0ggCdlD9sTZ43iRLT2/P+yCEyuJEtv6SvL7dY7PkZ57sT7nQdHH97lXkoUQBAAA/ES8AwAA0NNRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAs5jXA1dUVISqqiqvhwcAIKumT5++PoRQmW6bWxlXVVWppqbG6+EBAMgqM3t/V9tYpgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMBZzDtAT7ZqTYMefmK55izYpjGjynXFxSNVNbLcOxYAIMsoYyeLltbpxq+9raZEUsmktGBxnf78Wq1+9J3DdcShfb3jAQCyiGVqJ3f+apHqG1qLWJJSKamxKaUf/Xy+bzAAQNZRxk7embs17fiSZfVKNKeynAYA4IkydtKrLJp2vLgooljUspwGAOCJMnZyyfnDFY/vOP3FxRGdd+YQRSKUMQD0JJSxk3+6eD+ddeogFReZysuiKi4yffSYAZp07RjvaACALLMQgssDV1dXh5qaGpfHziUbNyW0bGW9hg8tVeXAuHccAEA3MbPpIYTqdNt4a5OzAf2LNaB/sXcMAIAjlqkBAHDGkTEA5LD1G5s0c/YW9S6P6egj+vNuiwJFGQNAjrrvoaV66A/LFIu1LmLG4xH97HtHaP9RfGxuoWGZGgBy0JszNurhJ5cr0RxU35BUfUNSmzY368u3zVIq5XPhLboPZQwAOejJKavU2NTx0/i21yc1Z/42h0ToTpQxAOSg7Q3JtONmUkNj+m3IX5QxAOSg006qVEm840t0Mhk0/uA+DonQnShjAMhB55w+RPtXlau0pPVlOhJpvYDrK58bq5KS9J9tj/zF1dQAkIOKiyK6+z+O1J9fq9Vrb2xQ/75FuuDsoRpT1cs7GroBZQwAOSoWi+isCYN11oTB3lHQzVimBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4yKmMzO9vM5pnZQjO7Jc32/czsJTObYWazzOzcro8KAEBh2mMZm1lU0l2SzpE0TtLlZjZup93+TdKjIYSjJE2UdHdXBwUAoFBlcmR8rKSFIYTFIYSEpEckXbjTPkFSn7bbfSWt6rqIAAAUtkzKeLik5e3ur2gba+/bkq40sxWSpki6Kd0vMrMbzKzGzGpqa2v3IS4AAIWnqy7gulzS/SGEEZLOlfSgmXX43SGEe0MI1SGE6srKyi56aAAA8lsmZbxS0sh290e0jbV3raRHJSmE8LqkEkkVXREQAIBCl0kZvylprJmNNrNitV6gNXmnfZZJOl2SzOwQtZYx69AAAGRgj2UcQmiRNEnS85LmqvWq6dlmdruZXdC225clXW9mMyU9LOlTIYTQXaEBACgksUx2CiFMUeuFWe3Hbm13e46kE7s2GgAAPQOfwAUAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGULA2bEpo9dpGhRC8owC7FfMOAABdbc26Rt32n3M0f3GdImbq369It375EB0+rq93NCAtjowBFJRkMmjS19/W3AXb1Nwc1JRIac26Jn35tlmq3dDkHQ9IizIGUFDemrVJW7e1KJXacbwlGfTMH1f7hAL2gDIGUFBqNyTSniNubg5atabRIRGwZ5QxgIJyyIG9OxwVS1JpSURHje+X/UBABihjAAVl9H7l+ugxAxSPf/jyVhQzVQyI6/STKx2TAbvG1dQACs63vzpOT0xZqf+bulpNiZQ+9tEKXXXZKMXjUe9oQFqUMYCCE42aLj1/hC49f4R3FCAjLFMDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAICCldi4Weuef1Vb3pqd01+lyfuMAQAFacG//1wLf/BzRYqLFJJJlYwcpuOm3KfSkUO9o3XAkTEAoOCsm/qKFv3HPUo1Nqlla52S2xtUv2CJaj75We9oaVHGAICCs+TOB5Ssb9hhLCRT2r5wqermLXZKtWuUMQCg4CQ2bk47brGYmjdtzXKaPaOMAQAFZ8iFZyhSEu8wHlJBfY48xCHR7lHGAICCU/X5q1QyfLAipSWtA5GIomUlOuxn31I0TUl742pqAEDBKerbWye/+aSW3feY1k15WSVDB6lq0lXqd8zh3tHSMq/3XVVXV4eamhqXxwYAINvMbHoIoTrdNpapAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMBZzDsAACB/hVRKG16aprr3FqvXuDEaeOpxsgjHeXuLMgYA7JPExs2adtqVqn9/lUJLiywWVdnokTrhz79TUb8+3vHyCn++AAD2yeybv6e6BUuVrNuuVGOTknX1qpu3WHO+/APvaHmHMgYA7LUQgtb84XmFRPOO44lmrXpsilOq/EUZI2MrVjVo3sJtam5OeUcBkANCMpl+vIXXiL3FOWPs0draRt3yvdlatqJe0ajJTPrXSQfqtJMHeUcD4MTMVHnWyap9/lWFZLvyjUY06OyT/YLlKY6MsVshBN38rVlatLROTYmU6huS2l6f1Pd/Nk8LltR5xwPg6NA7b1XRwP6KlpdJkqLlpYpXDtChP7vVOVn+4cgYuzV3wTbVbkgotdOqU3NzSk88s1L/etNBPsEAuCsbNVwfm/eCVj3yrLa+O099xh+sYRPPU6ytnJE5yhi7tXFzQuneMphKSWvXN2U/EICcEutVrv2uu8w7Rt5jmRq7deiBfZRIc8FWvDiiEz4ywCERABQeyhi71b9fsSZ+coRKSj58qhQXR1QxsFjnnTnUMRkAFA6WqbFHN1w1WgeP7aPHJ6/Qtu0tOvWECl16wQiVlUa9owFAQaCMsUdmplNPqNCpJ1R4RwGAgsQyNQAAzjIqYzM728zmmdlCM7tlF/tcZmZzzGy2mf2+a2MCwI5attfr/V8+ohlXfknzbvupGpav9o4E7LM9LlObWVTSXZLOlLRC0ptmNjmEMKfdPmMlfV3SiSGETWbWoz6aKdnQqA2v/k0KQQNPPU7R0hLvSEBBS2zcrNeOv1iJdRuU3N4gKy7Wkp/dr2Of+ZUGnFTtHQ/Ya5mcMz5W0sIQwmJJMrNHJF0oaU67fa6XdFcIYZMkhRDWdXXQXLXuuVf01hU3y8wktX6359G//6kGnXOqczKgcC34/t1qXLn2gy8pCImEkgnp7U9/TR+b/+IH/x6BfJHJMvVwScvb3V/RNtbegZIONLO/mtk0Mzs73S8ysxvMrMbMampra/ctcQ5pWrdBb/3jF5Tctl0tW+vUsrVOybp6TZ/4BTWt2+AdDyhYa556ocO3BUlS09oNamS5Gnmoqy7gikkaK2mCpMsl/dLM+u28Uwjh3hBCdQihurKysose2s/qx59TCGk2hKDVj03Neh6gp9jlqaBkShFOEyEPZVLGKyWNbHd/RNtYeyskTQ4hNIcQlkiar9ZyLmgt2+qUau7413kq0azmrdscEgE9w6jPXqFIWemOg9Go+h4zXvFKPhkO+SeTMn5T0lgzG21mxZImSpq80z5PqfWoWGZWodZl68VdmDMnVZ55kqLFRR3Go/G4Bp3FV4gB3aXqxis0+BMfU6S0RNFeZYr2KlfZ6BE6+qGfeEcD9skeL+AKIbSY2SRJz0uKSvp1CGG2md0uqSaEMLlt21lmNkdSUtJXQwgFf9K079GHauhl52n1Y1OV3F4vqfUrxIZeco76fuQw53RA4bJoVEc/9BPVzVuszTXvqHTkUA04qVqW7ltNgDxgIe1Jz+5XXV0dampqXB67K4UQtG7qK1rx2yclSSOu+qQGnTuBqzkBADsws+khhLTvvePjMDvJzDT43AkafO4E7ygAgDzFmg4AAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADjjs6nRo4RUSuv/9Fdtmva2SoYN1tDLzlVRn17esQD0cJQxeoxkQ6OmnXWNtr07X8m6ekXKSjX3lv/U8X96UH2PPMQ7HoAejGVq9BhL7nxAW2e+p2Rd63dPp+ob1LJlm2Zc8S/y+ipRAJAoY/QgKx58UqmGxg7jDSvWqGHpCodEANCKMkbPYZZ+PAQpwj8FAH54BUKPMfKaixUpLekwXlo1QmWjhjskAoBWlDF6jKqbrla/Yw5XtLxMikUV7VWmov599ZGHf+odDUAPx9XU6DGi8WId/6ffauNf3mx7a9MgDfmHsxQrL/OOBqCHo4zRo5iZBp5yrAaecqx3FAD4AMvUAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDJG3kts3Kx3/+W7+tPIk/Ti/hO04Pt3KdmU8I4FABmLeQcAOiPZ2KS/fvRSNSxfrZBoliQt/OE92viXGh333G+c0wFAZjgyRl5b/fhUNa1Z/0ERS1KqoUmbps3Q5jdnOSYDgMxRxshrm16foeT2+g7jIRW0Zfq7DokAYO9Rxshr5QeMUqQ03mHcYlGVVg13SAQAe48yRl4bcdUnFSkq2mHMolEVD+inyjNPckoF7LtUIqFlv35M0z7+KdVc/Dmte/5V70jIAi7gQl4rrhig41/8nWZ++muqm79EktT/+CN15AP/JYtGndMBeyfV0qJpZ16jrTPnKrm9QZK0/sX/p6rPX6mDv/8V53ToTpQx8l7fIw/RKTOeVlPtRkWKYirq18c7ErBP1j71grbOfO+DIpak5PYGLbnztxp145UqHTHEMR26E8vUKBjxygEUMfLa2mdeSntBosWi2vDyNIdEyBbKGAByRHHlACnW8fSKRSIqGtDPIRGyhTIGgByx37WXdrggUZKsqEiVZ57okAjZQhkDQI7odfAYHX7v9xQtL1WsTy9Fe5crPnSQjn/+/rQljcLBBVwAkEOGTzxfQy44Q5umva1oaYn6HXeELMJxU6GjjAEgx0TLSlVx2gneMZBF/LkFAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgLOMytjMzjazeWa20Mxu2c1+F5tZMLPqrosIAEBh22MZm1lU0l2SzpE0TtLlZjYuzX69JX1R0htdHRIAgEKWyZHxsZIWhhAWhxASkh6RdGGa/b4r6YeSGrswHwAABS+TMh4uaXm7+yvaxj5gZkdLGhlCeLYLswEA0CN0+gIuM4tI+rGkL2ew7w1mVmNmNbW1tZ19aAAACkImZbxS0sh290e0jf1db0mHSXrZzJZKOl7S5HQXcYUQ7g0hVIcQqisrK/c9NQAABSSTMn5T0lgzG21mxZImSpr8940hhC0hhIoQQlUIoUrSNEkXhBBquiUxAAAFZo9lHEJokTRJ0vOS5kp6NIQw28xuN7MLujsgAACFLpbJTiGEKZKm7DR26y72ndD5WAAA9Bx8AhcAAM4oYwAAnFHGAAA4o4wBAHBGGQMA4IwyBgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgLOMvs84XzSuqdWaJ/+oVFOzBp83QeVjq7wjAQCwRwVTxqsefVYzr/u6JJNSKc371o815ivX6cDbvuAdDQCA3SqIZerExs2aee3XlWpoUqqhUammhFKNTVr04/u0Zfq73vEAANitgijjdVNelsWiHcZTjQmtfPhph0QAAGSuIMpYqZB+PASFVCq7WQAA2EsFUcaV55yq0JLsMB4tLdGwy85zSAQAQOYKoozjlQN02J23KlIalxUVSZGIImUlGnndZep//JHe8QAA2K2CuZp65Kcv0cAJx2n141OVbExo8Pmnq++Rh3jHyln1DUk9NXWVXp22Xv36FOmS84er+oj+3rEAoEeyEHZxvrWbVVdXh5qaGpfH7unqG5K67ubpWlvbpKZE6zn1knhEn7miSldcNNI5HQAUJjObHkKoTretIJapsXee+ePqHYpYkhqbUvrVQ0u1ra7FMRkA9EyUcQ/02t827FDEf1cUM82dv9UhEQD0bJRxDzSgf7HMOo6nUkF9+hRlPxAA9HCUcQ90ySeGKV684//6SESqGBDXQWN6OaUCgJ6LMu6BDju4ryZdN0Yl8YjKy6IqiUc0cliZfnz74bJ0h8wAgG5VMG9twt755NnD9PEJg/Xegm3q3SumMVXlFDEgKdGc0sZNCfXvV9xhBQnoLpRxD1ZaEtVR4/t5xwByQghBv310mX73+HL9/S2fl3xiuG64erQiEf5QRfeijAFA0lNTV+nBx5apsenDdxo8/sxKlZRE9amJoxyToSdgDQYAJD342PIdilhqff/9w09+eKQMdBfKGAAkbdqcSDte35BUMkkZo3tRxgAgaf+q8rTjQweXKBbjpRLdi2cYAEi66doxisd3fEmMxyP64vUHOCVCT0IZA4CkIw/rpzu/f4SOObKfBvQv1hGH9tUdt43XiccO9I6GHoCrqQGgzaEH9dFPvnuEdwz0QBwZAwDgjDIGAMAZZQzkkJBMKtnQ6B0DQJZRxkAOSDY26d1J39Zz/Y7Sc/2P0svjz9GGV//mHQtAllDGQA54+5qvavkDTyjV2CQlU9r+3mK9ef4N2jZ7gXc0AFlAGQPOGleu1bopL7cWcTuppiYtuuOXTqkAZBNlDDirX7JckXhxh/GQTGnbHI6MgZ6AMgaclR+0v1JNTR3GLRZTv2MOd0gEINsoY8BZvHKARlx9kaJlpTuMR0rjGvOV65xSAcgmyhjIAYfdeavG3jpJ8aGDFC0rVcUZJ+rEVx9R2eiR3tEAZIF5fU9ndXV1qKmpcXlsAACyzcymhxCq023jyBgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBlljC6zeUuz1qxrlNd3ZANAvop5B0D+27Apoe/cMVfvzN2iSMTUr2+RvnnzwTp6fD/vaACQFzgyRqeEEPTFb87UzNmb1dwS1JRIaW1tk772nXe0ak2DdzwAyAuUMTpl9rxtWlvbqGRqx/GWZNATU1b5hAKAPEMZo1PW1jbKzDqMt7QErVzNkTEAZIIyRqccMra3WpIdL9gqiUc4Zwx0sYYVazTj6q/oj5XH6E/7naz53/0fpRIJ71joApQxOmXYkFKddlKlSuIfPpViMVOf3kU674whjsmAwtK8aYteO+4irXp0ipo3b1XT6nVadMcv9dblN3tHQxfgamp02je+eJDGHdhbf3h2lRoakjrl+IG6ZuIolZXx9AK6yrLfPK6WbdulZPKDsVRDo2pf+Ivq3lukXgePcUyHzuLVEp0WiZguOm+4LjpvuHcUoGBten2GUg2NHcYtFtPWWfMo4zzHMjUA5IHe4w5QJF7ccUMqqGz0iOwHQpeijAEgD+x3/URZ0Y6LmVZUpPKDRqtv9XinVOgqlDEA5IHSEUN0/Au/Ve/xB8liMVlxkQadN0HHTf112rcXIr9wzhgA8kS/6vE65a3Jat5ap0hxkaIlce9I6CKUMQDkmaI+vbwjoIuxTA0AgDPKGAAAZxmVsZmdbWbzzGyhmd2SZvuXzGyOmc0ysxfNbFTXRwUAoDDtsYzNLCrpLknnSBon6XIzG7fTbjMkVYcQDpf0uKT/7OqgAAAUqkyOjI+VtDCEsDiEkJD0iKQL2+8QQngphFDfdneaJN6BDgBAhjIp4+GSlre7v6JtbFeulTS1M6EAAOhJuvStTWZ2paRqSafuYvsNkm6QpP32268rHxoAgLyVyZHxSkkj290f0Ta2AzM7Q9I3JV0QQmhK94tCCPeGEKpDCNWVlZX7khcAgIKTSRm/KWmsmY02s2JJEyVNbr+DmR0l6R61FvG6ro8JAEDh2mMZhxBaJE2S9LykuZIeDSHMNrPbzeyCtt3ukNRL0mNm9raZTd7FrwMAADvJ6JxxCGGKpCk7jd3a7vYZXZwLAIAeg0/gAgDAGWUMAIAzyhgAAGeUMQAAzihjAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOMvoKxQBINdsq2vRz+9fpBf/UqsQpFNOqNCkz4xRv75F3tGAvUYZA8g7yWTQ52+ZoWUrG9TSEiRJL7yyTrPmbNFDdx+joiIW/ZBfeMYCyDt/m7FRq9c1fVDEUmtBb9rcrFenrXdMBuwbyhhA3ln8/nYlEskO4w2NSS1aut0hEdA5lDGAvDNiWJnixdEO46UlEY0cXuqQCOgcyhhA3jnxmAHq3SumSLtXsEhEKi2J6rQTK/2CAfuIMgaQd2KxiO654ygdd/QARSOmSET6yOH9de+PjlY83vGIGcjEytUNeuOtjVqzrjHrj83V1ADyUsXAuO64bbxakkEKQbEYxxbYN01NSf3bf8zR9FmbVVxkSiRSOvmECn3r5oOz9rzi2Qsgr8WiRhGjU/77vkWaPnOTEomU6rYnlWgOem3aBj3w6LKsZeAZDADosVKpoKkvrlWiOeww3pRI6YlnV2UtB2UMAOixUqmgRHMq7bb6hpas5aCMAQA9ViwW0dj9e3UYN5OOOqxf1nJQxgCAHu2rnxur0pKIYm0X4hfFTGWlUX3h+jFZy8DV1ACAHu2QA/vo/v+u1qP/t0KLlmzXwQf21mUXjNCginjWMlDGAIAeb/iQUt38z2PdHp9lagAAnFHGAAA4o4wBAHBGGSPrQiqlxT/9jV7cf4Ker6hWzcWf1/YFS71jAYAbLuBC1r1703e08ndPKVnf+mHsa59+URteeUOnvP2MSkcMcU4HANnHkTGyqmnteq144IkPiliSFIKSDY1acuf9brkAwBNljKzaNmeBIiUd37sXEs3a9PoMh0QA4I8yRlaVVY1QqinRYdyiUfU6OHufdgMAuYQyRlaVjR6pgROO63B0HIkXaf+bP+OUCgB8UcbIuqMf+ZmGXnqOIvFiWVFMZWNHqfr/7lHvcQd4RwMAFxZC2PNe3aC6ujrU1NS4PDZyQyqRULKhSUV9e3tHAYBuZ2bTQwjV6bbx1ia4iRQXK1Jc7B0DANyxTA0AgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwBllDACAM8oYAABnlDEAAM4oYwAAnFHGAAA4o4wBAHDGVyh2o/eX1+vBx5dpweI6jakq15WX7Kf9R5V7xwIA5BjKuJvMnb9VN31jphLNKaVS0pJl2/Xq6+v1k+8ervGH9PWOBwDIISxTd5Of3rtQjU2tRSxJqZTU2JTST36x0DcYACDncGTcTd5bUJd2fMGSOoUQZGZZTgSk95dp6/Wrh5ZqzbpG7T+qXJ+9Zn8dcSirN0A2cWTcTcrLo2nHS0uiFDFyxtQX1+jb/zVXi5Zu1/b6pN6Zu1VfunWWZs7e7B0N2KXNNe9o/u3/rUU/vk8Ny1d7x+kSlHE3ueT84YrHd5zeeHFEF503zCkRsKMQgu6+f7GamlI7jDclUrr7N4udUgG7FkLQrBu/pWmnX6UF379b87/1U7186Me16tFnvaN1GmXcTa65bJQ+PmGwiotM5WVRFReZTju5Utf9U5V3NECSVN+Q1LZtLWm3LX6/PstpsCchBG145Q0tvet3qv3jXxRSqT3/RwVmw59f16qHn1ayvkFKpZRKJJRqaNLM676h5i3bvON1CueMu0k0avrapAN1w1WjtXJNg4YNKVH/vsXesYAPlMSjKi6OqKUh2WHboIq4QyLsSsu2Or1+xtXaPn+JQktSFoupZGilTnjlYcUrB3jHy5qVDz+t5PaGDuORWFTrX3hNQy85xyFV1+DIuJv161ukQw/qQxEj50Sjpon/MEIlO51OKYlH9JkrRjmlQjpzv/5f2jZ7vpJ19Uo1NilZt131S1fonc/d6h0tqywaldJccxNkUiS/6yy/0wPolE/94yhd/g8jVVYaVVGRqW/vmG66boxOP3mQdzS0s+rhpxWamncYC80tWvfMSwrJjisbhWr4lRcqWlbScUMyqcqzTsp+oC7EMjXQg0Uipmv/qUrXTByl+voW9SqPKRLhav9cs6vCDalU61sls5zHy8CTj9F+10/U+/c8rJBMymJRKUhH/e5HivXK7083pIwBKBY19eld5B0DuzDoE6dp9R+ek1ralXIkooGnHKNIrGe9jI+74xaN/Mylqn3uFUXLSjXkoo8XxHnznvV/EQDy0Lg7btGm16arefNWJbfXK1peqkhJicb/4nve0Vz0PmSMeh8yxjtGl6KMASDHlQwdpAlzntOqx6Zqy4zZ6j3uAA2//HzFevfyjoYuYiEElweurq4ONTU1Lo8NAEC2mdn0EEJ1um1cTQ0AgDPKGAAAZ/ZZzH0AAAaWSURBVJQxAADOKGMAAJxRxgAAOOOtTeiULTPmaO3TL8qKYhp26bkqP4DPNAaAvUUZY5/N+doP9f49v1eqMSGLRrTw33+ucXd8XaP++XLvaACQV1imxj7Z/OYsLbvn90rVN0qplEJzi1INTZrzlX9X45pa73gAkFcoY+yTVX94TsmGpg7jFolo3bMvZz8QAOQxyhj7JBKNSOm+3ccki/K0AoC9kdGrppmdbWbzzGyhmd2SZnvczP63bfsbZlbV1UGRW4Zd9glFios7jIdkSoM/8TGHRACQv/ZYxmYWlXSXpHMkjZN0uZmN22m3ayVtCiEcIOknkn7Y1UGRW/occbDGfuNGRUrirT9lpYqUxnX4r36g4or8/zozAMimTK6mPlbSwhDCYkkys0ckXShpTrt9LpT07bbbj0v6HzOz4PUtFMiKA275rIb943la+8xLihQXacgnz1R8cIV3LADIO5mU8XBJy9vdXyHpuF3tE0JoMbMtkgZKWt8VIZG7ykaP1OibrvaOAQB5LatX2pjZDWZWY2Y1tbW8/QUAACmzMl4paWS7+yPaxtLuY2YxSX0lbdj5F4UQ7g0hVIcQqisrK/ctMQAABSaTMn5T0lgzG21mxZImSpq80z6TJV3TdvsSSX/mfDEAAJnZ4znjtnPAkyQ9Lykq6dchhNlmdrukmhDCZEn3SXrQzBZK2qjWwgYAABnI6LOpQwhTJE3ZaezWdrcbJV3atdEAAOgZ+KgkAACcUcYAADijjAEAcEYZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOKOMAQBwRhkDAOCMMgYAwJl5fe2wmdVKer8bfnWFpPXd8Ht7Euaw85jDzmMOO4857LyunMNRIYTKdBvcyri7mFlNCKHaO0c+Yw47jznsPOaw85jDzsvWHLJMDQCAM8oYAABnhVjG93oHKADMYecxh53HHHYec9h5WZnDgjtnDABAvinEI2MAAPJK3paxmZ1tZvPMbKGZ3ZJme9zM/rdt+xtmVpX9lLktgzn8kpnNMbNZZvaimY3yyJnL9jSH7fa72MyCmXFl604ymUMzu6ztuTjbzH6f7Yy5LoN/y/uZ2UtmNqPt3/O5HjlzlZn92szWmdm7u9huZnZn2/zOMrOjuzxECCHvfiRFJS2StL+kYkkzJY3baZ/PSfpF2+2Jkv7XO3cu/WQ4hx+TVNZ2+0bmcO/nsG2/3pJelTRNUrV37lz6yfB5OFbSDEn92+4P8s6dSz8ZzuG9km5suz1O0lLv3Ln0I+kUSUdLencX28+VNFWSSTpe0htdnSFfj4yPlbQwhLA4hJCQ9IikC3fa50JJD7TdflzS6WZmWcyY6/Y4hyGEl0II9W13p0kakeWMuS6T56EkfVfSDyU1ZjNcnshkDq+XdFcIYZMkhRDWZTljrstkDoOkPm23+0palcV8OS+E8KqkjbvZ5UJJvw2tpknqZ2ZDuzJDvpbxcEnL291f0TaWdp8QQoukLZIGZiVdfshkDtu7Vq1/GeJDe5zDtuWskSGEZ7MZLI9k8jw8UNKBZvZXM5tmZmdnLV1+yGQOvy3pSjNbIWmKpJuyE61g7O3r5V6LdeUvQ2EysyslVUs61TtLPjGziKQfS/qUc5R8F1PrUvUEta7OvGpm40MIm11T5ZfLJd0fQviRmZ0g6UEzOyyEkPIOhlb5emS8UtLIdvdHtI2l3cfMYmpdmtmQlXT5IZM5lJmdIembki4IITRlKVu+2NMc9pZ0mKSXzWypWs81TeYirh1k8jxcIWlyCKE5hLBE0ny1ljNaZTKH10p6VJJCCK9LKlHrZy4jMxm9XnZGvpbxm5LGmtloMytW6wVak3faZ7Kka9puXyLpz6HtTDwkZTCHZnaUpHvUWsScp+tot3MYQtgSQqgIIVSFEKrUet79ghBCjU/cnJTJv+Wn1HpULDOrUOuy9eJshsxxmczhMkmnS5KZHaLWMq7Nasr8NlnS1W1XVR8vaUsIYXVXPkBeLlOHEFrMbJKk59V6JeGvQwizzex2STUhhMmS7lPrUsxCtZ6Yn+iXOPdkOId3SOol6bG2a9+WhRAucAudYzKcQ+xGhnP4vKSzzGyOpKSkr4YQWOVqk+EcflnSL83sZrVezPUpDk4+ZGYPq/UPvoq28+q3SSqSpBDCL9R6nv1cSQsl1Uv6dJdn4P8HAAC+8nWZGgCAgkEZAwDgjDIGAMAZZQwAgDPKGAAAZ5QxAADOKGMAAJxRxgAAOPv/g+NdOOliW/8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# First make a simple randomK parameter dataset\n",
    "#data_params = dl.RandomKParameters(10, 25, radius=0.02)\n",
    "data_params = dl.EricParameters(25, 0.05)\n",
    "dataset = dl.Random2DBinaryDataset(data_params, batch_size=32, random_seed=420)\n",
    "dataset.plot_2d()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And then make a simple network to classify it \n",
    "simple_net = ReLUNet(layer_sizes=(2, 8, 16, 16, 16, 2), bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And train this thing:\n",
    "trainset, valset = dataset.split_train_val(1.0)\n",
    "train_params = nn_train.TrainParameters(trainset, trainset, 1000, test_after_epoch=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_train.training_loop(simple_net, train_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = simple_net.display_decision_bounds([0, 1], [0,1], 100)\n",
    "dataset.plot_2d(ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### PREACTIVATION CHECK \n",
    "small_hbox = Hyperbox.build_linf_ball(np.array([0.5, 0.5]), 0.2)\n",
    "x = torch.Tensor([0.5, 0.5])\n",
    "bounds = PreactivationBounds.naive_ia_from_hyperbox(simple_net, small_hbox)\n",
    "def check_preacts_within_bounds(x):\n",
    "    preacts = simple_net(x, return_preacts=True)\n",
    "    for i, preact in enumerate(preacts):\n",
    "        preact = preact.data.numpy().reshape(-1)\n",
    "        ith_bound = bounds.get_ith_layer_bounds(i)\n",
    "        print(i, len(ith_bound[0]), sum(ith_bound[0] <= preact), sum(ith_bound[1] >= preact))\n",
    "        \n",
    "check_preacts_within_bounds(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To compute the Local Lipschitz Constant of a DNN, we need:\n",
    "# 1) A region to compute over \n",
    "# 2) A neural network to compute\n",
    "# 3) Function parameters {LP_norm, preactivation_techniques, etc}\n",
    "\n",
    "# LipParameters holds items (1, 3)\n",
    "# LipProblem holds items (1, 2, 3)\n",
    "# LipMIPResult holds all info from a completed lipschitz object\n",
    "\n",
    "\n",
    "# EvaluationParameters holds parameters for how to evaluate lipschitz\n",
    "# so let's do this later\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Goal: pick a random point and evaluate the 0.1 hyperbox lipschitz \n",
    "hbox = Hyperbox.build_linf_ball(np.array([0.5, 0.5]), 0.5)\n",
    "c_vector = torch.tensor([1.0, -1.0])\n",
    "\n",
    "params = lm.LipParameters(hbox, c_vector, verbose=False)\n",
    "problem = lm.LipProblem(simple_net, params)\n",
    "result = problem.compute_max_lipschitz()\n",
    "result\n",
    "print(result.value, result.best_x)\n",
    "print(simple_net.get_grad_at_point([0.3741, 0.5873], c_vector).norm(p=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# and do the other methods \n",
    "nub = NaiveUB(simple_net, c_vector, 'linf')\n",
    "nub.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rlb = RandomLB(simple_net, c_vector, hbox, 'linf')\n",
    "rlb.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_net.random_max_grad(hbox, c_vector, 1000, pnorm=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_net.get_grad_at_point(np.array([0.40050755, 0.6]), c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.squire.get_grad_at_point(np.array([0.5005, 0.4886]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Simple idea: get lipschitz constants as we train\n",
    "new_net = simple_net = ReLUNet(layer_sizes=(2, 20, 20, 20, 2), bias=True)\n",
    "train_params = nn_train.TrainParameters(trainset, trainset, 50, test_after_epoch=50)\n",
    "done_epochs = 0 \n",
    "lip_results = []\n",
    "lip_params = lm.LipParameters(hbox, c_vector, verbose=False)\n",
    "while done_epochs < 1000:\n",
    "    nn_train.training_loop(new_net, train_params, epoch_start_no=done_epochs)\n",
    "    problem = lm.LipProblem(simple_net, lip_params)\n",
    "    lip_results.append(problem.compute_max_lipschitz())\n",
    "    done_epochs += train_params.num_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[_.value for _ in lip_results]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
