{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np \n",
    "import torch\n",
    "from hyperbox import Hyperbox\n",
    "from interval_analysis import HBoxIA\n",
    "from relu_nets import ReLUNet\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip\n",
    "from neural_nets import train\n",
    "from neural_nets import data_loaders as dl\n",
    "from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest\n",
    "from utilities import Factory\n",
    "import math\n",
    "import neural_nets.data_loaders as dl \n",
    "import neural_nets.train as train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "box_methods = [CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, LipProblem, SeqLip]\n",
    "local_methods = [LipProblem, LipLP, FastLip, RandomLB, CLEVER]\n",
    "global_methods = [SeqLip, LipSDP, NaiveUB]\n",
    "NUM_THREADS = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dim_scale(k, val, dim):\n",
    "    if k not in ['SeqLip', 'LipSDP']:\n",
    "        return val\n",
    "    else:\n",
    "        return math.sqrt(dim) *val\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 281 rows, 212 columns and 1625 nonzeros\n",
      "Model fingerprint: 0x459b914c\n",
      "Variable types: 181 continuous, 31 integer (31 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [8e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+00]\n",
      "  RHS range        [1e-06, 4e-01]\n",
      "Presolve removed 123 rows and 94 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 158 rows, 118 columns, 817 nonzeros\n",
      "Variable types: 92 continuous, 26 integer (26 binary)\n",
      "\n",
      "Root relaxation: objective 7.272977e-01, 139 iterations, 0.00 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    0.72684    0   18          -    0.72684      -     -    0s\n",
      "     0     0    0.66130    0   19          -    0.66130      -     -    0s\n",
      "     0     0    0.65715    0   21          -    0.65715      -     -    0s\n",
      "     0     0    0.65479    0   21          -    0.65479      -     -    0s\n",
      "     0     0    0.64878    0   21          -    0.64878      -     -    0s\n",
      "     0     0    0.64844    0   21          -    0.64844      -     -    0s\n",
      "     0     0    0.64391    0   20          -    0.64391      -     -    0s\n",
      "     0     0    0.64313    0   20          -    0.64313      -     -    0s\n",
      "     0     0    0.64303    0   22          -    0.64303      -     -    0s\n",
      "H    0     0                       0.3124236    0.64303   106%     -    0s\n",
      "     0     0    0.64302    0   22    0.31242    0.64302   106%     -    0s\n",
      "     0     0    0.64298    0   22    0.31242    0.64298   106%     -    0s\n",
      "     0     2    0.64284    0   22    0.31242    0.64284   106%     -    0s\n",
      "*   22    10              16       0.5090703    0.64277  26.3%  14.7    0s\n",
      "*   23     9              16       0.5091417    0.64277  26.2%  14.0    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Learned: 2\n",
      "  Gomory: 5\n",
      "  Cover: 4\n",
      "  Implied bound: 3\n",
      "  MIR: 20\n",
      "  Flow cover: 2\n",
      "  RLT: 2\n",
      "  Relax-and-lift: 7\n",
      "\n",
      "Explored 55 nodes (1612 simplex iterations) in 0.10 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 3: 0.509142 0.50907 0.312424 \n",
      "No other solutions better than 0.509142\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 5.091416650624e-01, best bound 5.091416650624e-01, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 283 rows, 212 columns and 1623 nonzeros\n",
      "Model fingerprint: 0x3e283d17\n",
      "Variable types: 181 continuous, 31 integer (31 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [8e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+00]\n",
      "  RHS range        [1e-06, 4e-01]\n",
      "Presolve removed 104 rows and 80 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 179 rows, 132 columns, 911 nonzeros\n",
      "Variable types: 101 continuous, 31 integer (31 binary)\n",
      "\n",
      "Root relaxation: objective 9.346867e-01, 140 iterations, 0.00 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    0.92376    0   26          -    0.92376      -     -    0s\n",
      "     0     0    0.79512    0   25          -    0.79512      -     -    0s\n",
      "     0     0    0.79446    0   25          -    0.79446      -     -    0s\n",
      "     0     0    0.74519    0   29          -    0.74519      -     -    0s\n",
      "     0     0    0.73918    0   28          -    0.73918      -     -    0s\n",
      "     0     0    0.73899    0   28          -    0.73899      -     -    0s\n",
      "     0     0    0.73898    0   27          -    0.73898      -     -    0s\n",
      "     0     0    0.73588    0   28          -    0.73588      -     -    0s\n",
      "H    0     0                       0.3208933    0.73588   129%     -    0s\n",
      "     0     0    0.73537    0   28    0.32089    0.73537   129%     -    0s\n",
      "     0     0    0.73444    0   28    0.32089    0.73444   129%     -    0s\n",
      "     0     0    0.73381    0   28    0.32089    0.73381   129%     -    0s\n",
      "     0     0    0.73351    0   28    0.32089    0.73351   129%     -    0s\n",
      "     0     0    0.73333    0   28    0.32089    0.73333   129%     -    0s\n",
      "     0     0    0.73333    0   28    0.32089    0.73333   129%     -    0s\n",
      "     0     0    0.73173    0   28    0.32089    0.73173   128%     -    0s\n",
      "     0     0    0.73082    0   28    0.32089    0.73082   128%     -    0s\n",
      "     0     0    0.73075    0   27    0.32089    0.73075   128%     -    0s\n",
      "     0     0    0.73067    0   28    0.32089    0.73067   128%     -    0s\n",
      "     0     2    0.73061    0   28    0.32089    0.73061   128%     -    0s\n",
      "*   25    13              20       0.4969436    0.73002  46.9%  16.4    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 10\n",
      "  Cover: 4\n",
      "  Implied bound: 2\n",
      "  MIR: 17\n",
      "  Flow cover: 5\n",
      "  RLT: 15\n",
      "  Relax-and-lift: 11\n",
      "\n",
      "Explored 154 nodes (4099 simplex iterations) in 0.17 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 2: 0.496944 0.320893 \n",
      "No other solutions better than 0.496944\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 4.969436048681e-01, best bound 4.969436048681e-01, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 245 rows, 206 columns and 1543 nonzeros\n",
      "Model fingerprint: 0x547c43c9\n",
      "Variable types: 181 continuous, 25 integer (25 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [8e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+00]\n",
      "  RHS range        [1e-06, 4e-01]\n",
      "Presolve removed 117 rows and 105 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 128 rows, 101 columns, 632 nonzeros\n",
      "Variable types: 79 continuous, 22 integer (22 binary)\n",
      "\n",
      "Root relaxation: objective 6.909198e-01, 87 iterations, 0.00 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    0.69092    0   17          -    0.69092      -     -    0s\n",
      "H    0     0                       0.4177592    0.69092  65.4%     -    0s\n",
      "     0     0    0.61473    0   18    0.41776    0.61473  47.1%     -    0s\n",
      "     0     0    0.60429    0   14    0.41776    0.60429  44.7%     -    0s\n",
      "     0     0    0.60289    0   16    0.41776    0.60289  44.3%     -    0s\n",
      "     0     0    0.58883    0   16    0.41776    0.58883  40.9%     -    0s\n",
      "     0     0    0.58195    0   17    0.41776    0.58195  39.3%     -    0s\n",
      "     0     0    0.57856    0   17    0.41776    0.57856  38.5%     -    0s\n",
      "     0     0    0.57744    0   16    0.41776    0.57744  38.2%     -    0s\n",
      "     0     0    0.57727    0   16    0.41776    0.57727  38.2%     -    0s\n",
      "     0     0    0.57670    0   16    0.41776    0.57670  38.0%     -    0s\n",
      "     0     0    0.57649    0   15    0.41776    0.57649  38.0%     -    0s\n",
      "     0     0    0.57602    0   15    0.41776    0.57602  37.9%     -    0s\n",
      "     0     0    0.57584    0   15    0.41776    0.57584  37.8%     -    0s\n",
      "H    0     0                       0.4778085    0.57584  20.5%     -    0s\n",
      "     0     0    0.57545    0   16    0.47781    0.57545  20.4%     -    0s\n",
      "     0     0    0.57528    0   16    0.47781    0.57528  20.4%     -    0s\n",
      "     0     0    0.57528    0   16    0.47781    0.57528  20.4%     -    0s\n",
      "     0     2    0.57461    0   16    0.47781    0.57461  20.3%     -    0s\n",
      "*   10     2               8       0.4778636    0.57461  20.2%  11.0    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Learned: 1\n",
      "  Gomory: 3\n",
      "  Cover: 3\n",
      "  Implied bound: 2\n",
      "  MIR: 20\n",
      "  Flow cover: 7\n",
      "  RLT: 9\n",
      "  Relax-and-lift: 9\n",
      "\n",
      "Explored 24 nodes (848 simplex iterations) in 0.09 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 3: 0.477864 0.477808 0.417759 \n",
      "No other solutions better than 0.477864\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 4.778635938031e-01, best bound 4.778635938031e-01, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 289 rows, 212 columns and 1639 nonzeros\n",
      "Model fingerprint: 0xe9b3c0f6\n",
      "Variable types: 181 continuous, 31 integer (31 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [8e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+00]\n",
      "  RHS range        [1e-06, 5e-01]\n",
      "Presolve removed 114 rows and 80 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 175 rows, 132 columns, 965 nonzeros\n",
      "Variable types: 101 continuous, 31 integer (31 binary)\n",
      "\n",
      "Root relaxation: objective 2.514361e+00, 111 iterations, 0.00 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    2.44195    0   24          -    2.44195      -     -    0s\n",
      "     0     0    2.12651    0   25          -    2.12651      -     -    0s\n",
      "     0     0    2.12351    0   26          -    2.12351      -     -    0s\n",
      "     0     0    1.83707    0   23          -    1.83707      -     -    0s\n",
      "     0     0    1.81706    0   22          -    1.81706      -     -    0s\n",
      "     0     0    1.81006    0   22          -    1.81006      -     -    0s\n",
      "     0     0    1.80971    0   23          -    1.80971      -     -    0s\n",
      "     0     0    1.80932    0   22          -    1.80932      -     -    0s\n",
      "     0     0    1.45742    0   27          -    1.45742      -     -    0s\n",
      "     0     0    1.45093    0   26          -    1.45093      -     -    0s\n",
      "     0     0    1.44978    0   25          -    1.44978      -     -    0s\n",
      "     0     0    1.35250    0   23          -    1.35250      -     -    0s\n",
      "     0     0    1.34304    0   27          -    1.34304      -     -    0s\n",
      "     0     0    1.33949    0   27          -    1.33949      -     -    0s\n",
      "     0     0    1.33944    0   26          -    1.33944      -     -    0s\n",
      "     0     0    1.29635    0   25          -    1.29635      -     -    0s\n",
      "     0     0    1.29356    0   26          -    1.29356      -     -    0s\n",
      "     0     0    1.29315    0   26          -    1.29315      -     -    0s\n",
      "     0     0    1.28452    0   28          -    1.28452      -     -    0s\n",
      "     0     0    1.28287    0   27          -    1.28287      -     -    0s\n",
      "     0     0    1.28271    0   27          -    1.28271      -     -    0s\n",
      "     0     0    1.27269    0   27          -    1.27269      -     -    0s\n",
      "     0     0    1.27186    0   28          -    1.27186      -     -    0s\n",
      "     0     0    1.26915    0   28          -    1.26915      -     -    0s\n",
      "     0     0    1.26900    0   29          -    1.26900      -     -    0s\n",
      "     0     0    1.26074    0   29          -    1.26074      -     -    0s\n",
      "     0     0    1.26057    0   30          -    1.26057      -     -    0s\n",
      "     0     0    1.25120    0   29          -    1.25120      -     -    0s\n",
      "     0     0    1.25095    0   29          -    1.25095      -     -    0s\n",
      "     0     0    1.23717    0   29          -    1.23717      -     -    0s\n",
      "     0     0    1.23632    0   29          -    1.23632      -     -    0s\n",
      "     0     0    1.23628    0   29          -    1.23628      -     -    0s\n",
      "     0     0    1.22067    0   28          -    1.22067      -     -    0s\n",
      "     0     0    1.22031    0   28          -    1.22031      -     -    0s\n",
      "     0     0    1.20224    0   30          -    1.20224      -     -    0s\n",
      "     0     0    1.20075    0   29          -    1.20075      -     -    0s\n",
      "     0     0    1.20053    0   29          -    1.20053      -     -    0s\n",
      "     0     0    1.18670    0   28          -    1.18670      -     -    0s\n",
      "     0     0    1.18607    0   28          -    1.18607      -     -    0s\n",
      "     0     0    1.18597    0   28          -    1.18597      -     -    0s\n",
      "     0     0    1.17908    0   29          -    1.17908      -     -    0s\n",
      "     0     0    1.17903    0   28          -    1.17903      -     -    0s\n",
      "     0     0    1.17433    0   29          -    1.17433      -     -    0s\n",
      "     0     0    1.17433    0   29          -    1.17433      -     -    0s\n",
      "H    0     0                       0.4046918    1.17433   190%     -    0s\n",
      "     0     2    1.17419    0   29    0.40469    1.17419   190%     -    0s\n",
      "*   41    27              24       0.4578704    1.17343   156%  19.0    0s\n",
      "*  744    71              19       0.4815739    0.66635  38.4%  13.7    0s\n",
      "H  870    35                       0.4875655    0.56153  15.2%  13.3    0s\n",
      "*  874    35              20       0.4929188    0.56153  13.9%  13.3    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 1\n",
      "  Cover: 6\n",
      "  Implied bound: 17\n",
      "  MIR: 84\n",
      "  Flow cover: 11\n",
      "  Inf proof: 2\n",
      "  RLT: 26\n",
      "  Relax-and-lift: 26\n",
      "  BQP: 2\n",
      "\n",
      "Explored 922 nodes (14299 simplex iterations) in 0.43 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 5: 0.492919 0.487566 0.481574 ... 0.404692\n",
      "No other solutions better than 0.492919\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 4.929187795966e-01, best bound 4.929187795966e-01, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 281 rows, 211 columns and 1619 nonzeros\n",
      "Model fingerprint: 0xe1eb263d\n",
      "Variable types: 181 continuous, 30 integer (30 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [8e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+00]\n",
      "  RHS range        [1e-06, 5e-01]\n",
      "Presolve removed 123 rows and 89 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 158 rows, 122 columns, 828 nonzeros\n",
      "Variable types: 93 continuous, 29 integer (29 binary)\n",
      "\n",
      "Root relaxation: objective 1.814411e+00, 87 iterations, 0.00 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    1.81363    0   20          -    1.81363      -     -    0s\n",
      "H    0     0                       0.5186584    1.81363   250%     -    0s\n",
      "     0     0    1.22811    0   20    0.51866    1.22811   137%     -    0s\n",
      "     0     0    1.21886    0   22    0.51866    1.21886   135%     -    0s\n",
      "     0     0    1.21832    0   22    0.51866    1.21832   135%     -    0s\n",
      "     0     0    1.06017    0   19    0.51866    1.06017   104%     -    0s\n",
      "     0     0    1.03957    0   20    0.51866    1.03957   100%     -    0s\n",
      "     0     0    1.03157    0   21    0.51866    1.03157  98.9%     -    0s\n",
      "     0     0    1.03138    0   20    0.51866    1.03138  98.9%     -    0s\n",
      "     0     0    0.94982    0   20    0.51866    0.94982  83.1%     -    0s\n",
      "     0     0    0.92577    0   21    0.51866    0.92577  78.5%     -    0s\n",
      "     0     0    0.92496    0   21    0.51866    0.92496  78.3%     -    0s\n",
      "     0     0    0.92477    0   21    0.51866    0.92477  78.3%     -    0s\n",
      "     0     0    0.88535    0   20    0.51866    0.88535  70.7%     -    0s\n",
      "     0     0    0.87905    0   20    0.51866    0.87905  69.5%     -    0s\n",
      "     0     0    0.87895    0   20    0.51866    0.87895  69.5%     -    0s\n",
      "     0     0    0.87717    0   21    0.51866    0.87717  69.1%     -    0s\n",
      "     0     0    0.87702    0   21    0.51866    0.87702  69.1%     -    0s\n",
      "     0     0    0.87388    0   20    0.51866    0.87388  68.5%     -    0s\n",
      "     0     0    0.87388    0   20    0.51866    0.87388  68.5%     -    0s\n",
      "H    0     0                       0.6740526    0.87388  29.6%     -    0s\n",
      "     0     2    0.87360    0   20    0.67405    0.87360  29.6%     -    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 12\n",
      "  Cover: 3\n",
      "  Implied bound: 8\n",
      "  MIR: 38\n",
      "  Flow cover: 8\n",
      "  RLT: 13\n",
      "  Relax-and-lift: 12\n",
      "\n",
      "Explored 95 nodes (1901 simplex iterations) in 0.12 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 2: 0.674053 0.518658 \n",
      "No other solutions better than 0.674053\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 6.740525519542e-01, best bound 6.740525519542e-01, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 305 rows, 214 columns and 1671 nonzeros\n",
      "Model fingerprint: 0xbf5b8652\n",
      "Variable types: 181 continuous, 33 integer (33 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [8e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+00]\n",
      "  RHS range        [1e-06, 6e-01]\n",
      "Presolve removed 116 rows and 74 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 189 rows, 140 columns, 1049 nonzeros\n",
      "Variable types: 107 continuous, 33 integer (33 binary)\n",
      "\n",
      "Root relaxation: objective 2.942342e+00, 108 iterations, 0.00 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    2.92479    0   29          -    2.92479      -     -    0s\n",
      "H    0     0                       0.2290832    2.92479  1177%     -    0s\n",
      "     0     0    2.40913    0   30    0.22908    2.40913   952%     -    0s\n",
      "     0     0    2.38000    0   31    0.22908    2.38000   939%     -    0s\n",
      "     0     0    2.37032    0   29    0.22908    2.37032   935%     -    0s\n",
      "     0     0    2.37005    0   30    0.22908    2.37005   935%     -    0s\n",
      "     0     0    2.25694    0   29    0.22908    2.25694   885%     -    0s\n",
      "H    0     0                       0.3069022    2.25694   635%     -    0s\n",
      "     0     0    2.23062    0   29    0.30690    2.23062   627%     -    0s\n",
      "     0     0    2.22261    0   29    0.30690    2.22261   624%     -    0s\n",
      "     0     0    2.21542    0   29    0.30690    2.21542   622%     -    0s\n",
      "     0     0    2.21538    0   29    0.30690    2.21538   622%     -    0s\n",
      "     0     0    2.07478    0   28    0.30690    2.07478   576%     -    0s\n",
      "     0     0    2.06381    0   29    0.30690    2.06381   572%     -    0s\n",
      "     0     0    2.05699    0   31    0.30690    2.05699   570%     -    0s\n",
      "     0     0    2.05400    0   32    0.30690    2.05400   569%     -    0s\n",
      "     0     0    2.05281    0   32    0.30690    2.05281   569%     -    0s\n",
      "     0     0    2.05267    0   32    0.30690    2.05267   569%     -    0s\n",
      "     0     0    2.00553    0   30    0.30690    2.00553   553%     -    0s\n",
      "     0     0    1.99773    0   29    0.30690    1.99773   551%     -    0s\n",
      "     0     0    1.99492    0   31    0.30690    1.99492   550%     -    0s\n",
      "     0     0    1.99458    0   31    0.30690    1.99458   550%     -    0s\n",
      "     0     0    1.97905    0   31    0.30690    1.97905   545%     -    0s\n",
      "     0     0    1.97787    0   31    0.30690    1.97787   544%     -    0s\n",
      "     0     0    1.97750    0   32    0.30690    1.97750   544%     -    0s\n",
      "     0     0    1.97204    0   30    0.30690    1.97204   543%     -    0s\n",
      "     0     2    1.97160    0   30    0.30690    1.97160   542%     -    0s\n",
      "*  122    59              20       0.3418399    1.80189   427%  23.0    0s\n",
      "H 1261   276                       0.4606114    1.11448   142%  25.6    0s\n",
      "H 1261   275                       0.4646614    1.11448   140%  25.6    0s\n",
      "H 1612   275                       0.4646614    0.99188   113%  24.5    1s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 18\n",
      "  Cover: 34\n",
      "  Implied bound: 27\n",
      "  MIR: 136\n",
      "  Flow cover: 31\n",
      "  Inf proof: 4\n",
      "  RLT: 62\n",
      "  Relax-and-lift: 23\n",
      "\n",
      "Explored 3555 nodes (74570 simplex iterations) in 1.86 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 5: 0.464661 0.460611 0.34184 ... 0.229083\n",
      "No other solutions better than 0.464661\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 4.646613889018e-01, best bound 4.646613889018e-01, gap 0.0000%\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def test_random(layer_sizes, k_trials, num_random, radius):\n",
    "    # Good parameters are [16, 16, 16, 2]\n",
    "    assert layer_sizes[-1] == 2\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    local_results = []\n",
    "    global_results = []\n",
    "    for _ in range(k_trials):\n",
    "        random_net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        local_exp = Experiment(local_methods, network=random_net, c_vector=c_vector, primal_norm='linf', \n",
    "                               verbose=True, num_threads=NUM_THREADS)\n",
    "        local_results.append(local_exp.do_random_evals(num_random_points=num_random, \n",
    "                                sample_domain=sample_domain, ball_factory=ball_factory))\n",
    "        global_exp = Experiment(global_methods, network=random_net, c_vector=c_vector, primal_norm='linf')\n",
    "        global_results.append(global_exp.do_unit_hypercube_eval())\n",
    "    return local_results, global_results\n",
    "\n",
    "\n",
    "\n",
    "output = test_random([16, 16, 16, 2], 2, 3, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'format_resultList' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-057c7f7e5599>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mformat_resultList\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mResultList\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mformat_resultList\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mResultList\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'format_resultList' is not defined"
     ]
    }
   ],
   "source": [
    "format_resultList(ResultList(output[0][0].results + output[0][1].results), 16)\n",
    "\n",
    "format_resultList(ResultList(output[1]), 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method     |      Time |  Time STD |     Value | Value STD |       Err |    ErrSTD\n",
      "----------------------------------------------------------------------------------\n",
      "RandomLB   |    0.2391 |    0.0213 |    0.3825 |    0.0614 |    0.7418 |    0.1178\n",
      "CLEVER     |    1.4833 |    0.0766 |    0.4264 |    0.0543 |    0.8307 |    0.1265\n",
      "LipProblem |    0.4754 |    0.6342 |    0.5193 |    0.0707 |    1.0000 |    0.0000\n",
      "LipSDP     |    2.6371 |    0.0113 |    1.5234 |    0.0437 |    2.9820 |    0.3623\n",
      "LipLP      |    0.0133 |    0.0015 |    1.7969 |    0.8518 |    3.4917 |    1.7659\n",
      "SeqLip     |    0.0069 |    0.0002 |    1.8570 |    0.1078 |    3.6387 |    0.4975\n",
      "FastLip    |    0.0008 |    0.0000 |    2.0895 |    0.9286 |    4.0663 |    1.9448\n",
      "NaiveUB    |    0.0002 |    0.0000 |   16.0833 |    0.7159 |   31.5003 |    4.0661\n"
     ]
    }
   ],
   "source": [
    "def format_local_globals(result_list_list, global_result, dim):\n",
    "    num_per = [len(_.results) for _ in result_list_list]\n",
    "    \n",
    "    all_results = [_.results for _ in result_list_list]\n",
    "    local_rl = ResultList([_ for el in all_results for _ in el])\n",
    "    global_rl = ResultList(global_result)\n",
    "    local_times = local_rl.average_stdevs('time')\n",
    "    local_vals = local_rl.average_stdevs('value')\n",
    "    global_times = global_rl.average_stdevs('time')\n",
    "    global_vals = global_rl.average_stdevs('value')\n",
    "    local_errs = local_rl.get_rel_err(dim)\n",
    "    all_keys = list(local_times.keys()) + list(global_times.keys())\n",
    "    \n",
    "    # Get the values total dict:\n",
    "    all_values = {}\n",
    "    for val_dict in [local_vals, global_vals]:\n",
    "        for k,v in val_dict.items():\n",
    "            all_values[k] = (dim_scale(k,v[0], dim), dim_scale(k, v[1],dim))\n",
    "    # Get times total dict:\n",
    "    all_times = {}\n",
    "    for time_dict in [local_times, global_times]:\n",
    "        for k,v in time_dict.items():\n",
    "            all_times[k] = v\n",
    "   \n",
    "    # Get right answers in right order:\n",
    "    right_answers = [_.values('LipProblem') for _ in local_rl.results]\n",
    "    rel_err_dict = {}\n",
    "    \n",
    "    for k in global_vals.keys():\n",
    "        answer_idx = 0\n",
    "\n",
    "        if k not in rel_err_dict:\n",
    "            rel_err_dict[k] = []\n",
    "        for num, result in zip(num_per, global_rl.results):\n",
    "            for i in range(num):\n",
    "                right_answer = right_answers[answer_idx]\n",
    "                answer_idx +=1 \n",
    "                rel_err_dict[k].append(dim_scale(k, result.values(k), dim) / right_answer)\n",
    "    global_errs = {k: (np.array(v).mean(), np.array(v).std(), len(v)) for k,v in rel_err_dict.items()}\n",
    "    all_errs = {}\n",
    "\n",
    "    for err_dict in global_errs, local_errs:\n",
    "        for k,v in err_dict.items():\n",
    "            all_errs[k] = v\n",
    "\n",
    "\n",
    "    # Key-order \n",
    "    key_order = [_[0] for _ in sorted(all_values.items(), key=lambda p:p[1])]\n",
    "    max_len_k = max(len(_) for _ in key_order + ['method'])\n",
    "    pad = lambda s: s + ' '*(max_len_k - len(s))\n",
    "\n",
    "    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s\n",
    "    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])\n",
    "    print(header + '\\n' + '-'*len(header))\n",
    "    dformat = lambda s: '|{:10.4f}'.format(s)\n",
    "    for k in key_order:\n",
    "        elements = [pad(k), \n",
    "                    dformat(all_times[k][0]), dformat(all_times[k][1]),\n",
    "                    dformat(all_values[k][0]), dformat(all_values[k][1]),\n",
    "                    dformat(all_errs[k][0]), dformat(all_errs[k][1])]\n",
    "        print(' '.join(elements))\n",
    "        \n",
    "    \n",
    "    \n",
    "format_local_globals(output[0], output[1], 16)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../scripts/exp_2_synthetic_radius01.pkl', 'rb') as f:\n",
    "    obj = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ResultList(output[1]).average_stdevs('value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_synthetic(layer_sizes, k_trials, num_random, radius):\n",
    "    # Good parameters are [16, 16, 16, 2]    \n",
    "    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, \n",
    "                         dimension=layer_sizes[0])\n",
    "    dataset = dl.RandomDataset(data_params, batch_size=128, \n",
    "                               random_seed=1234)\n",
    "    train_set, _ = dataset.split_train_val(1.0)\n",
    "\n",
    "    train_params = train.TrainParameters(train_set, train_set, 500,\n",
    "                                         test_after_epoch=20)\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    local_results = []\n",
    "    global_results = []\n",
    "    for _ in range(k_trials):\n",
    "        net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        train.training_loop(net, train_params)\n",
    "        \n",
    "        local_exp = Experiment(local_methods, network=net, c_vector=c_vector, primal_norm='linf', \n",
    "                               verbose=True, num_threads=NUM_THREADS)\n",
    "        local_results.append(local_exp.do_random_evals(num_random_points=num_random, \n",
    "                                sample_domain=sample_domain, ball_factory=ball_factory))\n",
    "        global_exp = Experiment(global_methods, network=net, c_vector=c_vector, primal_norm='linf')\n",
    "        global_results.append(global_exp.do_unit_hypercube_eval())\n",
    "    return local_results, global_results\n",
    "\n",
    "synth_out = test_synthetic([10, 20, 30, 20, 2], 1, 10, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(synth_out[0][0], 10)\n",
    "print({k: dim_scale(k, v,10) for k,v in synth_out[1][0].values().items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_mnist_bin(layer_sizes, k_trials, num_random, radius):\n",
    "    DIGITS = [1,7]\n",
    "    # Good params here are [784, 20, 20, 2]\n",
    "    data_results = []\n",
    "    random_results = []\n",
    "    global_results = []\n",
    "    train_data = dl.load_mnist_data('train', digits=DIGITS, use_cuda=True)\n",
    "    val_data = dl.load_mnist_data('val', digits=DIGITS, use_cuda=True)\n",
    "    \n",
    "    xentropyReg = train.XEntropyReg()\n",
    "    l1_reg = train.LpWeightReg(scalar=1e-5, lp='l1')\n",
    "    loss = train.LossFunctional(regularizers=[xentropyReg, l1_reg])\n",
    "    train_params = train.TrainParameters(train_data, val_data, 10, loss_functional=loss)\n",
    "    cvec = np.array([1.0,-1.0])\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for i in range(k_trials):\n",
    "        # Train the net \n",
    "        net = ReLUNet(layer_sizes)\n",
    "        train.training_loop(net, train_params, use_cuda=True)\n",
    "        \n",
    "        # Do the data experiments \n",
    "        local_exp = Experiment(local_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True, num_threads=NUM_THREADS)\n",
    "        data_to_check = dl.load_mnist_data('val', digits=DIGITS, use_cuda=False, shuffle=True, batch_size=num_random)\n",
    "        data_to_check = next(iter(data_to_check))[0].cpu().numpy()\n",
    "        data_results.append(local_exp.do_data_evals(data_to_check, ball_factory))\n",
    "        # Do the random point experiments\n",
    "        #random_results.append(local_exp.do_random_evals(num_random, sample_domain, ball_factory ))\n",
    "        \n",
    "        # Do the global experiments\n",
    "        global_exp = Experiment(global_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True, num_threads=NUM_THREADS) \n",
    "        global_results.append(global_exp.do_unit_hypercube_eval())\n",
    "    return data_results, random_results, global_results\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00 | Accuracy: 99.45\n",
      "Epoch 01 | Accuracy: 99.63\n",
      "Epoch 02 | Accuracy: 99.68\n",
      "Epoch 03 | Accuracy: 99.68\n",
      "Epoch 04 | Accuracy: 99.70\n",
      "Epoch 05 | Accuracy: 99.72\n",
      "Epoch 06 | Accuracy: 99.73\n",
      "Epoch 07 | Accuracy: 99.74\n",
      "Epoch 08 | Accuracy: 99.75\n",
      "Epoch 09 | Accuracy: 99.78\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 4321 rows, 3411 columns and 43155 nonzeros\n",
      "Model fingerprint: 0x28a3d291\n",
      "Variable types: 2597 continuous, 814 integer (814 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+01]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+01]\n",
      "  RHS range        [1e-06, 8e+00]\n",
      "Presolve removed 172 rows and 118 columns\n",
      "Presolve time: 0.05s\n",
      "Presolved: 4149 rows, 3293 columns, 35992 nonzeros\n",
      "Variable types: 2479 continuous, 814 integer (814 binary)\n",
      "\n",
      "Root relaxation: objective 2.873353e+02, 1640 iterations, 0.05 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  287.31009    0  798          -  287.31009      -     -    0s\n",
      "     0     0  245.26200    0  784          -  245.26200      -     -    0s\n",
      "     0     0  235.88186    0  780          -  235.88186      -     -    0s\n",
      "     0     0  235.69089    0  769          -  235.69089      -     -    0s\n",
      "     0     0  235.50769    0  758          -  235.50769      -     -    0s\n",
      "     0     0  235.34361    0  752          -  235.34361      -     -    0s\n",
      "     0     0  235.12522    0  744          -  235.12522      -     -    0s\n",
      "     0     0  234.95886    0  738          -  234.95886      -     -    0s\n",
      "     0     0  234.81639    0  795          -  234.81639      -     -    0s\n",
      "     0     0  234.70451    0  795          -  234.70451      -     -    0s\n",
      "     0     0  234.68341    0  795          -  234.68341      -     -    0s\n",
      "     0     0  234.67505    0  795          -  234.67505      -     -    0s\n",
      "     0     0  232.08253    0  723          -  232.08253      -     -    0s\n",
      "     0     0  231.93113    0  792          -  231.93113      -     -    0s\n",
      "     0     0  231.91312    0  796          -  231.91312      -     -    0s\n",
      "     0     0  230.24896    0  798          -  230.24896      -     -    1s\n",
      "     0     0  230.22409    0  795          -  230.22409      -     -    1s\n",
      "     0     0  229.24694    0  802          -  229.24694      -     -    1s\n",
      "     0     0  229.21078    0  803          -  229.21078      -     -    1s\n",
      "     0     0  228.97407    0  802          -  228.97407      -     -    1s\n",
      "     0     0  228.97407    0  802          -  228.97407      -     -    1s\n",
      "H    0     0                     198.7252193  228.97407  15.2%     -    3s\n",
      "     0     2  228.97386    0  802  198.72522  228.97386  15.2%     -    3s\n",
      "   560   435  214.21351  486   75  198.72522  228.95910  15.2%  12.5    5s\n",
      "*  646   416             540     222.3348275  228.95910  2.98%  14.3    5s\n",
      "H  652   394                     222.3348329  228.92205  2.96%  14.6    7s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 6\n",
      "  Cover: 339\n",
      "  Clique: 11\n",
      "  MIR: 5\n",
      "  Flow cover: 23\n",
      "  RLT: 173\n",
      "  Relax-and-lift: 440\n",
      "\n",
      "Explored 679 nodes (17980 simplex iterations) in 9.18 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 2: 222.335 198.725 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 2.223348329240e+02, best bound 2.223375828812e+02, gap 0.0012%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 3281 rows, 3074 columns and 40399 nonzeros\n",
      "Model fingerprint: 0xc36e3da9\n",
      "Variable types: 2597 continuous, 477 integer (477 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+01]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 2e+01]\n",
      "  RHS range        [1e-06, 1e+01]\n",
      "Presolve removed 916 rows and 843 columns\n",
      "Presolve time: 0.16s\n",
      "Presolved: 2365 rows, 2231 columns, 26164 nonzeros\n",
      "Variable types: 1777 continuous, 454 integer (454 binary)\n",
      "\n",
      "Root relaxation: objective 2.212533e+02, 1175 iterations, 0.02 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  221.25325    0  435          -  221.25325      -     -    0s\n",
      "     0     0  205.22354    0  389          -  205.22354      -     -    0s\n",
      "     0     0  205.12417    0  388          -  205.12417      -     -    0s\n",
      "     0     0  205.12406    0  387          -  205.12406      -     -    0s\n",
      "     0     0  203.75762    0  405          -  203.75762      -     -    0s\n",
      "     0     0  203.72862    0  405          -  203.72862      -     -    0s\n",
      "     0     0  203.72732    0  405          -  203.72732      -     -    0s\n",
      "     0     0  202.76447    0  404          -  202.76447      -     -    0s\n",
      "     0     0  202.69633    0  404          -  202.69633      -     -    0s\n",
      "     0     0  202.69583    0  404          -  202.69583      -     -    0s\n",
      "     0     0  202.39993    0  399          -  202.39993      -     -    0s\n",
      "     0     0  202.39574    0  399          -  202.39574      -     -    0s\n",
      "     0     0  202.32420    0  380          -  202.32420      -     -    0s\n",
      "     0     0  202.32416    0  380          -  202.32416      -     -    0s\n",
      "     0     0  202.30390    0  397          -  202.30390      -     -    0s\n",
      "     0     0  202.30390    0  397          -  202.30390      -     -    0s\n",
      "     0     2  202.28966    0  397          -  202.28966      -     -    1s\n",
      "*  342    87             135     182.7019259  191.39897  4.76%  22.9    1s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 1\n",
      "  Cover: 74\n",
      "  Implied bound: 2\n",
      "  MIR: 55\n",
      "  Flow cover: 46\n",
      "  RLT: 63\n",
      "  Relax-and-lift: 54\n",
      "\n",
      "Explored 506 nodes (11351 simplex iterations) in 2.40 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 1: 182.702 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 1.827019258927e+02, best bound 1.827019258927e+02, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 4345 rows, 3415 columns and 43211 nonzeros\n",
      "Model fingerprint: 0xc1ec4017\n",
      "Variable types: 2597 continuous, 818 integer (818 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+01]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+01]\n",
      "  RHS range        [1e-06, 8e+00]\n",
      "Presolve removed 168 rows and 110 columns\n",
      "Presolve time: 0.07s\n",
      "Presolved: 4177 rows, 3305 columns, 36060 nonzeros\n",
      "Variable types: 2487 continuous, 818 integer (818 binary)\n",
      "\n",
      "Root relaxation: objective 2.901767e+02, 1683 iterations, 0.05 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  290.14805    0  802          -  290.14805      -     -    0s\n",
      "     0     0  254.68386    0  768          -  254.68386      -     -    0s\n",
      "     0     0  240.59923    0  789          -  240.59923      -     -    0s\n",
      "     0     0  240.30809    0  784          -  240.30809      -     -    0s\n",
      "     0     0  240.24078    0  779          -  240.24078      -     -    0s\n",
      "     0     0  240.13070    0  774          -  240.13070      -     -    0s\n",
      "     0     0  240.10702    0  772          -  240.10702      -     -    0s\n",
      "     0     0  239.97276    0  768          -  239.97276      -     -    0s\n",
      "     0     0  239.95049    0  766          -  239.95049      -     -    0s\n",
      "     0     0  239.82505    0  764          -  239.82505      -     -    0s\n",
      "     0     0  239.73184    0  762          -  239.73184      -     -    0s\n",
      "     0     0  239.67244    0  760          -  239.67244      -     -    0s\n",
      "     0     0  239.64649    0  799          -  239.64649      -     -    0s\n",
      "     0     0  239.61002    0  799          -  239.61002      -     -    0s\n",
      "     0     0  239.60527    0  799          -  239.60527      -     -    0s\n",
      "     0     0  236.66496    0  757          -  236.66496      -     -    0s\n",
      "     0     0  236.54488    0  781          -  236.54488      -     -    0s\n",
      "     0     0  236.53574    0  788          -  236.53574      -     -    0s\n",
      "     0     0  234.83626    0  780          -  234.83626      -     -    1s\n",
      "     0     0  234.80365    0  777          -  234.80365      -     -    1s\n",
      "     0     0  234.28624    0  804          -  234.28624      -     -    1s\n",
      "     0     0  234.25644    0  804          -  234.25644      -     -    1s\n",
      "     0     0  234.22836    0  768          -  234.22836      -     -    1s\n",
      "     0     0  234.22836    0  768          -  234.22836      -     -    1s\n",
      "H    0     0                      65.5744005  234.22836   257%     -    3s\n",
      "     0     2  234.22504    0  768   65.57440  234.22504   257%     -    3s\n",
      "   493   425     cutoff  449        65.57440  234.19547   257%  13.4    5s\n",
      "*  873   619             749     222.3348215  234.19547  5.33%  34.5    7s\n",
      "   880   618  227.63425  182  832  222.33482  234.18753  5.33%  34.3   10s\n",
      "   963   673  227.88774   81  863  222.33482  227.88774  2.50%  31.4   15s\n",
      "  1021   712  226.04894  262  862  222.33482  227.77222  2.45%  29.6   20s\n",
      "  1045   728  222.42034  655  866  222.33482  227.37246  2.27%  28.9   25s\n",
      "  1069   745  223.16065  108  113  222.33482  223.16065  0.37%  42.9   31s\n",
      "\n",
      "Explored 1074 nodes (55656 simplex iterations) in 31.62 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 2: 222.335 65.5744 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 2.223348223194e+02, best bound 2.223348223194e+02, gap 0.0000%\n"
     ]
    }
   ],
   "source": [
    "mnist_out = test_mnist_bin([784, 20, 20, 20, 2], 1, 3, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method     |      Time |  Time STD |     Value | Value STD |       Err |    ErrSTD\n",
      "----------------------------------------------------------------------------------\n",
      "RandomLB   |    0.3341 |    0.0036 |  113.3396 |   60.3840 |    0.5324 |    0.2566\n",
      "CLEVER     |    5.2868 |    0.0350 |  118.1102 |   67.1299 |    0.5538 |    0.2869\n",
      "LipProblem |   14.5292 |   12.4994 |  209.1239 |   18.6831 |    1.0000 |    0.0000\n",
      "LipLP      |    0.2050 |    0.0355 |  267.0902 |   30.7458 |    1.2740 |    0.0357\n",
      "FastLip    |    0.0014 |    0.0000 |  313.4928 |   30.0406 |    1.4981 |    0.0105\n",
      "LipSDP     |   17.2346 |    0.0000 |  397.3521 |    0.0000 |    1.9164 |    0.1828\n",
      "SeqLip     |    0.0124 |    0.0000 |  403.4369 |    0.0000 |    1.9458 |    0.1856\n",
      "NaiveUB    |    0.0003 |    0.0000 |  619.9666 |    0.0000 |    2.9901 |    0.2851\n",
      " \n"
     ]
    }
   ],
   "source": [
    "format_local_globals(mnist_out[0], mnist_out[2], 784)\n",
    "\n",
    "print(' ')\n",
    "#format_local_globals(mnist_out[1], mnist_out[2], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mnist_out)\n",
    "format_resultList(mnist_out[0][0], 784)\n",
    "print('\\n')\n",
    "format_resultList(mnist_out[1][0], 784)\n",
    "print({k: dim_scale(k, v,784) for k,v in mnist_out[2][0].values().items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(mnist_out[0][0], 784)\n",
    "print('\\n')\n",
    "format_resultList(mnist_out[1][0], 784)\n",
    "print()\n",
    "print({k: dim_scale(k, v, 784) for k,v in mnist_out[2][0].values().items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(output[0][0], 16)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(output[0], 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_synthetic(layer_sizes, k_trials, num_random, radius):\n",
    "    # exp1 layers: 16,16,16,2\n",
    "    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, \n",
    "                         dimension=layer_sizes[0])\n",
    "    dataset = dl.RandomDataset(data_params, batch_size=128, \n",
    "                               random_seed=1234)\n",
    "    train_set, _ = dataset.split_train_val(1.0)\n",
    "\n",
    "    train_params = train.TrainParameters(train_set, train_set, 500,\n",
    "                                         test_after_epoch=20)\n",
    "    results = []\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for _ in range(k_trials):\n",
    "        net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        train.training_loop(net, train_params)\n",
    "        exp = Experiment(local_methods, network=net, c_vector=c_vector, primal_norm='linf', verbose=True, num_threads=4)\n",
    "        results.append(exp.do_random_evals(num_random, sample_domain, ball_factory))\n",
    "        print('\\n')\n",
    "    return results\n",
    "\n",
    "\n",
    "\n",
    "synth_out = test_synthetic([10, 20, 30, 20, 2], 2, 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(synth_out[1], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_random(layer_sizes, k, radius):\n",
    "    assert layer_sizes[-1] == 2\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    results = []\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for _ in range(k):\n",
    "        \n",
    "        random_net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        exp = Experiment(local_methods, network=random_net, c_vector=c_vector, primal_norm='linf', verbose=True)\n",
    "        results.append(exp.do_random_evals(5, sample_domain, ball_factory))\n",
    "    return ResultList(results)\n",
    "output = test_random([16, 16, 16, 2], 2, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output.results[0].average_stdevs('value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_resultList(results, dim):\n",
    "    times = results.average_stdevs('time')\n",
    "    values = results.average_stdevs('value')\n",
    "    rel_errs = results.get_rel_err(dim)\n",
    "    keys = sorted(list(times.keys()))\n",
    "    max_len_k = max(len(_) for _ in keys + ['method'])\n",
    "\n",
    "    pad = lambda s: s + ' ' * (max_len_k - len(s))\n",
    "    def dim_scale(k, val):\n",
    "        if k not in ['SeqLip', 'LipSDP']:\n",
    "            return val\n",
    "        else:\n",
    "            return math.sqrt(dim) *val\n",
    "    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s\n",
    "    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])\n",
    "    print(header + '\\n' + '-' * len(header))\n",
    "    key_order = [_[0] for _ in sorted([(k, dim_scale(k, values[k][0])) for k in keys], key=lambda p: p[1])]\n",
    "    for k in key_order:\n",
    "        elements = [pad(k),  \n",
    "                    '|{:10.4f}'.format(times[k][0]),\n",
    "                    '|{:10.4f}'.format(times[k][1]),\n",
    "                    '|{:10.4f}'.format(dim_scale(k, values[k][0])),\n",
    "                    '|{:10.4f}'.format(dim_scale(k, values[k][1]))]\n",
    "        if rel_errs != {}:\n",
    "            elements.extend(['|{:10.4f}'.format(rel_errs[k][0] * 100),\n",
    "                             '|{:10.4f}'.format(rel_errs[k][1] * 100)])\n",
    "        print(' '.join(elements))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_cvec():\n",
    "    idx_1 = np.random.randint(10)\n",
    "    while True:\n",
    "        idx_2 = np.random.randint(10)\n",
    "        if idx_2 != idx_1: \n",
    "            break \n",
    "    output = np.zeros(10)\n",
    "    output[idx_1] = 1.0\n",
    "    output[idx_2] = 1.0\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "DIGITS = None\n",
    "USE_CUDA = True\n",
    "train_data = dl.load_mnist_data('train', digits=DIGITS, use_cuda=USE_CUDA)\n",
    "val_data = dl.load_mnist_data('val', digits=DIGITS, use_cuda=USE_CUDA)\n",
    "xentropyReg = train.XEntropyReg()\n",
    "l1_reg = train.LpWeightReg(scalar=1e-3, lp='l1')\n",
    "loss = train.LossFunctional(regularizers=[xentropyReg, l1_reg])\n",
    "train_params = train.TrainParameters(train_data, val_data, 50, loss_functional=loss, test_after_epoch=5)\n",
    "train_batch_1 = next(iter(train_data))[0].view(-1, 784)\n",
    "    \n",
    "def test_mnist(layer_sizes, k, radius):\n",
    "    results = []\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for i in range(k):\n",
    "        net = ReLUNet(layer_sizes)\n",
    "        train.training_loop(net, train_params, use_cuda=USE_CUDA)\n",
    "        cvec = random_cvec()\n",
    "        exp = Experiment(local_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True, num_threads=4) \n",
    "        results.append(exp.do_data_evals(train_batch_1, ball_factory, num_random=5))\n",
    "    return results\n",
    "mnist_test1= test_mnist([784, 20, 20, 20, 10], 1, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(mnist_test1[0], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(mnist_test1[0], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rl = ResultList(mnist_test1)\n",
    "\n",
    "def dim_scale(k, val, dim):\n",
    "    if k not in ['SeqLip', 'LipSDP']:\n",
    "        return val\n",
    "    else:\n",
    "        return math.sqrt(dim) *val\n",
    "        \n",
    "def errors(result):\n",
    "    vals = result.values()\n",
    "    true = vals['LipProblem']\n",
    "    return sorted([(k, dim_scale(k, v, 784) / true) for k, v in vals.items()], key=lambda p: p[1])\n",
    "outputs = []\n",
    "for result in rl.results:\n",
    "    outputs.append(errors(result))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_outs = {}\n",
    "for output in outputs:\n",
    "    for k,v in output:\n",
    "        if k not in dict_outs:\n",
    "            dict_outs[k] = []\n",
    "        dict_outs[k].append(v)\n",
    "avg_outs = {}\n",
    "for k, v in dict_outs.items():\n",
    "    avg_outs[k] = sum(v) / len(v)\n",
    "avg_outs    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
