{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple debugging of lipMIP\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import lipMIP as lm\n",
    "import gurobipy as gb\n",
    "from relu_nets import ReLUNet\n",
    "from hyperbox import Hyperbox\n",
    "from pre_activation_bounds import PreactivationBounds\n",
    "import numpy as np\n",
    "from pprint import pprint\n",
    "import torch\n",
    "import utilities as utils\n",
    "import random "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(961)\n",
    "network = ReLUNet([2, 4, 3])\n",
    "hbox = Hyperbox.build_linf_ball(np.zeros(2), 1.0)\n",
    "c_vector = np.array([1.0, -1.0, 0.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Academic license - for non-commercial use only\n",
      "[[-0.21280498 -0.0304741 ]\n",
      " [-0.69830605 -0.17536431]]\n",
      "0 NEG\n",
      "1 NEG\n",
      "-0.8366125058027816\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'lipschitz': 0.8366125058027816,\n",
       " 'squire': <lipMIP.GurobiSquire at 0x7f2febe172e8>,\n",
       " 'model': <gurobi.Model MIP instance Unnamed: 37 constrs, 28 vars, Parameter changes: LogFile=gurobi.log, CSIdleTimeout=1800, OutputFlag=0>,\n",
       " 'runtime': 0.008215665817260742,\n",
       " 'preacts': <pre_activation_bounds.PreactivationBounds at 0x7f2febe17390>,\n",
       " 'best_x': array([-1.        ,  0.78843764])}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = lm.compute_max_lipschitz(network, hbox, 'l_inf', c_vector, verbose=False)\n",
    "squire = output['squire']\n",
    "print(-1 * output['lipschitz'])\n",
    "model = output['model']\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_max(network, domain, c_vector, num_random=100):\n",
    "    max_norm = -1\n",
    "    max_point = None\n",
    "    max_grad = None\n",
    "    for i in range(num_random):\n",
    "        random_input = domain.random_point(tensor_or_np='tensor')\n",
    "        grad = network.get_grad_at_point(random_input, torch.tensor(c_vector).type(network.dtype))\n",
    "        grad_norm = grad.norm(p=1)\n",
    "        if grad_norm > max_norm:\n",
    "            max_norm = grad_norm \n",
    "            max_point = random_input\n",
    "            max_grad = grad \n",
    "    return {'norm': max_norm, 'point': max_point, 'grad': max_grad}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def verify():\n",
    "    rand_seed = random.randint(1,1000)\n",
    "    print(\"RAND SEED\", rand_seed)\n",
    "    torch.random.manual_seed(rand_seed)\n",
    "    network = ReLUNet([2, 4, 3])\n",
    "    hbox = Hyperbox.build_linf_ball(np.zeros(2), 1.0)\n",
    "    c_vector = np.array([1.0, -1.0, 0.0])\n",
    "    output = lm.compute_max_lipschitz(network, hbox, 'l_inf', c_vector, verbose=False)\n",
    "    random_max = collect_max(network, hbox, c_vector, num_random=1000)\n",
    "    \n",
    "    mip_out = output['lipschitz']\n",
    "    rand_out = random_max['norm'].item()\n",
    "    return (abs(mip_out - rand_out))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RAND SEED 974\n",
      "[[-0.51631914  0.20225757]\n",
      " [-0.68072516  0.03955125]]\n",
      "4.535877984324088e-08\n",
      "RAND SEED 739\n",
      "[[-0.12941976  0.09382055]\n",
      " [-0.08568464  0.34063642]]\n",
      "2.665025389259057e-08\n",
      "RAND SEED 111\n",
      "[[-0.36482294  0.21522331]\n",
      " [ 0.          0.64082174]]\n",
      "7.611128349793717e-08\n",
      "RAND SEED 86\n",
      "[[-0.84793314  0.36246632]\n",
      " [-0.47265927  0.1781985 ]]\n",
      "6.811981068111095e-08\n",
      "RAND SEED 783\n",
      "[[-0.34591384  0.01521935]\n",
      " [-0.35070164  0.09579232]]\n",
      "4.30519390137718e-08\n",
      "RAND SEED 962\n",
      "[[-0.09446075  0.13220752]\n",
      " [-0.60529137  0.17996593]]\n",
      "3.417937066707566e-08\n",
      "RAND SEED 517\n",
      "[[ 0.          0.01133063]\n",
      " [-0.00584812  0.03300117]]\n",
      "9.946725290643776e-09\n",
      "RAND SEED 836\n",
      "[[-0.02424333  0.19255953]\n",
      " [-0.1429609   0.06793079]]\n",
      "2.7356319987070776e-08\n",
      "RAND SEED 726\n",
      "[[-0.13139805  0.08190415]\n",
      " [ 0.00108906  0.17223318]]\n",
      "4.8295972776557505e-09\n",
      "RAND SEED 320\n",
      "[[-0.15982282  0.02528745]\n",
      " [ 0.13349356  0.21990534]]\n",
      "3.529110786937295e-08\n",
      "RAND SEED 170\n",
      "[[-0.55479253  0.19852609]\n",
      " [-0.54687027  0.28205348]]\n",
      "1.8905128351676126e-08\n",
      "RAND SEED 645\n",
      "[[-0.00314894  0.89091766]\n",
      " [-0.11313928  0.38251523]]\n",
      "6.328788204612579e-08\n",
      "RAND SEED 753\n",
      "[[-0.35683279  0.01459282]\n",
      " [-0.75386233 -0.02757695]]\n",
      "1 NEG\n",
      "3.63446250872812e-09\n",
      "RAND SEED 976\n",
      "[[-0.2719945   0.42249061]\n",
      " [ 0.          0.70207068]]\n",
      "2.3777759095011675e-08\n",
      "RAND SEED 375\n",
      "[[-0.11588712  0.41839366]\n",
      " [ 0.          0.35659714]]\n",
      "7.382613564921314e-08\n",
      "RAND SEED 531\n",
      "[[ 0.08012921  0.55087785]\n",
      " [-0.32276135  0.29398233]]\n",
      "0.03365065514546395\n",
      "RAND SEED 844\n",
      "[[-0.11083421  0.23039602]\n",
      " [-0.14211252  0.29019148]]\n",
      "0.2388079933597479\n",
      "RAND SEED 917\n",
      "[[0.04569268 0.05755624]\n",
      " [0.03335927 0.05013947]]\n",
      "5.544951164893064e-09\n",
      "RAND SEED 44\n",
      "[[-0.24776696  0.07101337]\n",
      " [ 0.04118481  0.40100609]]\n",
      "2.262004006325924e-08\n",
      "RAND SEED 573\n",
      "[[-0.28964447 -0.07355761]\n",
      " [-0.2701642   0.00496857]]\n",
      "0 NEG\n",
      "1.620112310973809e-08\n",
      "RAND SEED 784\n",
      "[[-0.43874252  0.39998289]\n",
      " [-0.63326621  0.20807136]]\n",
      "1.96787122153097e-08\n",
      "RAND SEED 231\n",
      "[[-0.30770584  0.21839772]\n",
      " [-0.68563104  0.28626853]]\n",
      "1.697644247400376e-10\n",
      "RAND SEED 427\n",
      "[[-0.04312262  0.27447858]\n",
      " [-0.2625353  -0.01555751]]\n",
      "1 NEG\n",
      "1.8063444073845858e-09\n",
      "RAND SEED 390\n",
      "[[-0.24071637  0.17440178]\n",
      " [ 0.          0.13842677]]\n",
      "3.678604920676065e-09\n",
      "RAND SEED 338\n",
      "[[-0.22384944  0.44827853]\n",
      " [-0.08306806  0.2500135 ]]\n",
      "2.5592855723566288e-08\n",
      "RAND SEED 590\n",
      "[[-0.47572894  0.41519804]\n",
      " [-0.26474181  0.09573293]]\n",
      "3.789112357210911e-10\n",
      "RAND SEED 686\n",
      "[[-0.07142803  0.31466327]\n",
      " [-0.13210605  0.03634235]]\n",
      "1.873045196720824e-08\n",
      "RAND SEED 477\n",
      "[[-0.0973104   0.09286032]\n",
      " [-0.19224864  0.06215463]]\n",
      "2.871492743583559e-08\n",
      "RAND SEED 370\n",
      "[[-0.1623371   0.240541  ]\n",
      " [-0.25137549  0.17154112]]\n",
      "0.10886487930490152\n",
      "RAND SEED 461\n",
      "[[-0.23922242 -0.09325986]\n",
      " [-0.09173292  0.08433391]]\n",
      "0 NEG\n",
      "3.665256814722184e-08\n",
      "RAND SEED 87\n",
      "[[-0.07909834  0.95941631]\n",
      " [-0.48349839  0.69623888]]\n",
      "7.458053730147185e-09\n",
      "RAND SEED 256\n",
      "[[-0.44719857  0.25018132]\n",
      " [-0.04908237  0.12232483]]\n",
      "1.4215578159060271e-08\n",
      "RAND SEED 357\n",
      "[[-0.04233195  0.19962928]\n",
      " [-0.45352311  0.        ]]\n",
      "1 NEG\n",
      "1.720770814728212e-08\n",
      "RAND SEED 329\n",
      "[[-0.34851392  0.03291276]\n",
      " [-0.16127369  0.04233927]]\n",
      "1.9341280688145446e-08\n",
      "RAND SEED 889\n",
      "[[-0.49121751  0.08564969]\n",
      " [-0.40865892  0.33517385]]\n",
      "5.072200881528488e-09\n",
      "RAND SEED 651\n",
      "[[-0.0682706   0.77906432]\n",
      " [-0.12526701  0.3025867 ]]\n",
      "2.19694540337656e-08\n",
      "RAND SEED 407\n",
      "[[ 0.          0.36455636]\n",
      " [-0.24852915  0.19662883]]\n",
      "1.8793396228122816e-08\n",
      "RAND SEED 331\n",
      "[[-0.10811067  0.60234566]\n",
      " [-0.04300823  0.41954103]]\n",
      "4.112874507633535e-08\n",
      "RAND SEED 465\n",
      "[[-0.21616461  0.        ]\n",
      " [ 0.          0.3243395 ]]\n",
      "0 NEG\n",
      "7.573957461204373e-09\n",
      "RAND SEED 514\n",
      "[[-0.5907116   0.28755377]\n",
      " [-0.16666538  0.0720271 ]]\n",
      "1.7260772589011708e-08\n",
      "RAND SEED 510\n",
      "[[-0.37759999 -0.12850577]\n",
      " [-0.41850129  0.25751904]]\n",
      "0 NEG\n",
      "1.9209051016488843e-09\n",
      "RAND SEED 754\n",
      "[[-0.41176439  0.44168803]\n",
      " [-0.82621066  0.39271785]]\n",
      "4.837093015019889e-08\n",
      "RAND SEED 192\n",
      "[[-0.27616889 -0.20491788]\n",
      " [-0.12847695  0.05834327]]\n",
      "0 NEG\n",
      "2.6055835766758406e-08\n",
      "RAND SEED 629\n",
      "[[-0.52176149  0.16212618]\n",
      " [-0.46620583  0.04335937]]\n",
      "2.1780238279589526e-08\n",
      "RAND SEED 477\n",
      "[[-0.0973104   0.09286032]\n",
      " [-0.19224864  0.06215463]]\n",
      "2.871492743583559e-08\n",
      "RAND SEED 67\n",
      "[[-0.32715816  0.41195177]\n",
      " [-0.51447517  0.50318895]]\n",
      "2.114534514152666e-08\n",
      "RAND SEED 817\n",
      "[[-0.24863382  0.45981579]\n",
      " [-0.17107426  0.11890964]]\n",
      "3.7254277440723627e-10\n",
      "RAND SEED 777\n",
      "[[-0.07359239  0.50980115]\n",
      " [-0.07354924  0.01059266]]\n",
      "3.9755507552641234e-08\n",
      "RAND SEED 254\n",
      "[[-0.40652195  0.07677283]\n",
      " [-0.26971076  0.03346101]]\n",
      "2.001870513801407e-08\n",
      "RAND SEED 799\n",
      "[[-0.01817803  0.19618384]\n",
      " [-0.01523856  0.07863146]]\n",
      "1.548298034226292e-08\n",
      "RAND SEED 481\n",
      "[[-0.50938753  0.51044205]\n",
      " [-0.33039704  0.79396705]]\n",
      "1.951243266962166e-08\n",
      "RAND SEED 312\n",
      "[[-0.0799304   0.09349629]\n",
      " [-0.41144501 -0.01824493]]\n",
      "1 NEG\n",
      "1.1916285680424465e-08\n",
      "RAND SEED 258\n",
      "[[-0.11326182  0.22387481]\n",
      " [-0.27671267  0.0701145 ]]\n",
      "2.4080744620569305e-08\n",
      "RAND SEED 709\n",
      "[[-0.29944249  0.16257889]\n",
      " [-0.60527517  0.28827783]]\n",
      "4.245358953269829e-08\n",
      "RAND SEED 21\n",
      "[[-0.33495798  0.07205407]\n",
      " [-0.6592413  -0.27518863]]\n",
      "1 NEG\n",
      "7.925522460539014e-09\n",
      "RAND SEED 100\n",
      "[[-0.20105124  0.20511326]\n",
      " [-0.51466623  0.06981017]]\n",
      "6.159715604248106e-08\n",
      "RAND SEED 301\n",
      "[[-0.3101564   0.06144065]\n",
      " [-0.2331081   0.28947728]]\n",
      "3.1870241667952826e-08\n",
      "RAND SEED 13\n",
      "[[-0.43538397 -0.11325325]\n",
      " [-0.16013169  0.05292869]]\n",
      "0 NEG\n",
      "3.580996749130705e-08\n",
      "RAND SEED 167\n",
      "[[-0.21171533  0.31713512]\n",
      " [-0.21588575  0.12986856]]\n",
      "2.6697437316425265e-08\n",
      "RAND SEED 454\n",
      "[[-0.49785664  0.01997468]\n",
      " [-0.55788006  0.02136369]]\n",
      "2.3549015404356055e-08\n",
      "RAND SEED 257\n",
      "[[ 0.05006082  0.49086829]\n",
      " [-0.04863886  0.19557141]]\n",
      "1.5026992983280252e-09\n",
      "RAND SEED 355\n",
      "[[-0.27566911  0.13086725]\n",
      " [-0.10777574  0.16602198]]\n",
      "7.810360580506881e-09\n",
      "RAND SEED 212\n",
      "[[ 0.0387024   0.20677818]\n",
      " [-0.26878835 -0.18281643]]\n",
      "1 NEG\n",
      "3.324601749454814e-08\n",
      "RAND SEED 793\n",
      "[[-0.21235088  0.68942405]\n",
      " [-0.09949639  0.35068103]]\n",
      "1.6864659446724772e-08\n",
      "RAND SEED 489\n",
      "[[-0.09958453  0.04048539]\n",
      " [-0.18144844  0.01804198]]\n",
      "1.9077655122323733e-08\n",
      "RAND SEED 212\n",
      "[[ 0.0387024   0.20677818]\n",
      " [-0.26878835 -0.18281643]]\n",
      "1 NEG\n",
      "3.324601749454814e-08\n",
      "RAND SEED 287\n",
      "[[-0.44130445  0.07267966]\n",
      " [-0.04629985  0.40507503]]\n",
      "4.5917644908577415e-08\n",
      "RAND SEED 686\n",
      "[[-0.07142803  0.31466327]\n",
      " [-0.13210605  0.03634235]]\n",
      "1.873045196720824e-08\n",
      "RAND SEED 777\n",
      "[[-0.07359239  0.50980115]\n",
      " [-0.07354924  0.01059266]]\n",
      "3.9755507552641234e-08\n",
      "RAND SEED 834\n",
      "[[-0.0484315   0.10245144]\n",
      " [ 0.01724421  0.08407073]]\n",
      "8.633168013805204e-09\n",
      "RAND SEED 743\n",
      "[[-0.44585461  0.52842465]\n",
      " [-0.41030177  0.42811703]]\n",
      "3.3144786693029005e-08\n",
      "RAND SEED 292\n",
      "[[-0.1886921   0.39050048]\n",
      " [-0.11519061  0.18008571]]\n",
      "6.650106354477714e-09\n",
      "RAND SEED 329\n",
      "[[-0.34851392  0.03291276]\n",
      " [-0.16127369  0.04233927]]\n",
      "1.9341280688145446e-08\n",
      "RAND SEED 580\n",
      "[[0.06667443 0.16642698]\n",
      " [0.10421441 0.1818896 ]]\n",
      "3.500439760273366e-08\n",
      "RAND SEED 497\n",
      "[[-0.44214998  0.1305485 ]\n",
      " [-0.10381017  0.32005712]]\n",
      "6.751181058639588e-09\n",
      "RAND SEED 265\n",
      "[[-0.17409179  0.00434363]\n",
      " [-0.05267468  0.05947533]]\n",
      "2.852249825213704e-08\n",
      "RAND SEED 605\n",
      "[[-0.40818227  0.38123252]\n",
      " [-0.18865669  0.2125752 ]]\n",
      "1.4925225610973314e-08\n",
      "RAND SEED 898\n",
      "[[-0.09742669  0.57923423]\n",
      " [-0.06207279  0.13762513]]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-33e561d4c529>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mverify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-5-00b2c9acbbae>\u001b[0m in \u001b[0;36mverify\u001b[0;34m()\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mc_vector\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_max_lipschitz\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhbox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'l_inf'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m     \u001b[0mrandom_max\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcollect_max\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhbox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_random\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0mmip_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lipschitz'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-4-35bb14372442>\u001b[0m in \u001b[0;36mcollect_max\u001b[0;34m(network, domain, c_vector, num_random)\u001b[0m\n\u001b[1;32m      7\u001b[0m         \u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_grad_at_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrandom_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc_vector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0mgrad_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mgrad_norm\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m             \u001b[0mmax_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrad_norm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m             \u001b[0mmax_point\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i in range(1000):\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collect_max(network, hbox, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squire_clone = squire.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squire_clone.get_grad_at_point([-0.8786, -0.4186])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sq_grad, sum([abs(_) for _ in sq_grad])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_vars = squire.get_vars('gradient')\n",
    "grad_bounds = squire.get_backprop_bounds(0, two_col=True)\n",
    "print([grad_bounds[i][0] <= [_.X for _ in grad_vars][i] <= grad_bounds[i][1] for i in range(2)])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check constraint \n",
    "def check_constr_set(i):\n",
    "    input_var = grad_vars[i].X\n",
    "    output_var = abs(input_var)\n",
    "    grad_lo, grad_hi = grad_bounds[i]\n",
    "    c1 = grad_lo <= output_var <= "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_grad = network.get_grad_at_point(torch.Tensor([0.5, 0.5]), torch.Tensor([1.0, -1.0, 0.0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.as_numpy(sq_grad) -  utils.as_numpy(out_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.Tensor([0.5, 0.5])\n",
    "x = x.clone().detach().requires_grad_(True)\n",
    "x = x.type(network.dtype)\n",
    "y = network.forward(x).mv(torch.Tensor([1.0, -1.0, 0.0]))\n",
    "y.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.remove(model.getConstrByName('aoesu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gradient functional \n",
    "def grad_x(x, c_vector=c_vector):\n",
    "    x = torch.tensor(x, requires_grad=True, dtype=torch.float32)\n",
    "    output = network(x).squeeze().dot(torch.tensor(c_vector, dtype=torch.float32))\n",
    "    output.backward()\n",
    "    return x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_random_norms(num=100):\n",
    "    max_norm = -1\n",
    "    max_norm_point = None\n",
    "    for i in range(num):\n",
    "        random_point = np.random.random(2) * 2 - 1\n",
    "        norm = torch.norm(grad_x(random_point), 1)\n",
    "        if norm > max_norm:\n",
    "            max_norm = norm \n",
    "            max_norm_point = random_point\n",
    "    return max_norm, max_norm_point, grad_x(max_norm_point)\n",
    "\n",
    "test_random_norms(100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
