{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from hyperbox import Hyperbox\n",
    "from interval_analysis import HBoxIA\n",
    "from relu_nets import ReLUNet\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip\n",
    "from neural_nets import train\n",
    "from neural_nets import data_loaders as dl\n",
    "from experiment import Experiment, InstanceGroup, Result\n",
    "from utilities import Factory\n",
    "import utilities as utils \n",
    "import gurobipy as gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Sanity check: \n",
    "    - we're finding that the MIP solution yields an 'x' point at which the NN is not differentiable\n",
    "    - this causes problems when comparing against the pytorch-computed gradient because it's not clear \n",
    "      what pytorch is doing at this point \n",
    "    - proposed solution is to, for a given NN, instance, and MIP optimal point, find a DIFFERENTIABLE \n",
    "      x with the same optimum. If the NN is in general position, such an x is guaranteed to exist within any \n",
    "      neighborhood of x*.\n",
    "      \n",
    "      \n",
    "In this notebook we'll \n",
    "1) demonstrate an example where this is a problem\n",
    "2) And then we'll demonstrate the solution\n",
    "\"\"\"\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00 | Accuracy: 64.00\n",
      "Epoch 100 | Accuracy: 79.60\n",
      "Epoch 200 | Accuracy: 81.60\n",
      "Epoch 300 | Accuracy: 82.40\n",
      "Epoch 400 | Accuracy: 88.00\n",
      "Epoch 500 | Accuracy: 91.20\n",
      "Epoch 600 | Accuracy: 93.20\n",
      "Epoch 700 | Accuracy: 95.20\n",
      "Epoch 800 | Accuracy: 95.60\n",
      "Epoch 900 | Accuracy: 96.40\n"
     ]
    }
   ],
   "source": [
    "# 1) Build a dataset and train a neural network\n",
    "dataset_params = dl.RandomKParameters(250, 20, radius=0.02, dimension=4)\n",
    "dataset = dl.RandomBinaryDataset(dataset_params, random_seed=420)\n",
    "train_data, val_data = dataset.split_train_val(1.0)\n",
    "train_params = train.TrainParameters(train_data, train_data, 1000, test_after_epoch=100)\n",
    "test_net = ReLUNet(layer_sizes=[4, 20, 20, 2])\n",
    "train.training_loop(test_net, train_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LipMIP Result: \n",
       "\tValue 161.822\n",
       "\tRuntime 0.186"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2) Create and solve a lipMIP problem\n",
    "hbox1 = Hyperbox.build_linf_ball(np.array([0.5 for _ in range(test_net.layer_sizes[0])]), 0.2)\n",
    "c_vec = np.array([1.0, -1.0])\n",
    "prob = LipProblem(test_net, hbox1, c_vec, num_threads=4)\n",
    "prob.compute_max_lipschitz()\n",
    "result = prob.result\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(154.5907)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 3) Examine what we think the true solution should be:\n",
    "rand_lb = RandomLB(test_net, c_vec, hbox1, 'linf')\n",
    "rand_lb.compute(num_points=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BEST X:  [0.3127116012600085, 0.3, 0.565522962521815, 0.36548028386861103]\n",
      "PYTORCH GRAD [-27.473738  39.043407   9.800853 -15.490733]\n",
      "MIP GRAD [-27.245864479162954, 33.86189576790028, 22.693260697662772, -14.63638939533962]\n"
     ]
    }
   ],
   "source": [
    "# 4) Best x and what MIP + Pytorch have to say about it \n",
    "best_x = result.best_x\n",
    "print(\"BEST X: \", list(best_x))\n",
    "print(\"PYTORCH GRAD\", test_net.get_grad_at_point(best_x, c_vec).numpy())\n",
    "print(\"MIP GRAD\", result.squire.get_grad_at_point(best_x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fc_1_pre[0] -1\n",
      "fc_1_pre[1] -1\n",
      "fc_1_pre[2] -1\n",
      "fc_1_pre[3] +1\n",
      "fc_1_pre[4] -1\n",
      "fc_1_pre[5] ??\n",
      "fc_1_pre[6] +1\n",
      "fc_1_pre[7] +1\n",
      "fc_1_pre[8] +1\n",
      "fc_1_pre[9] -1\n",
      "fc_1_pre[10] +1\n",
      "fc_1_pre[11] +1\n",
      "fc_1_pre[12] -1\n",
      "fc_1_pre[13] -1\n",
      "fc_1_pre[14] +1\n",
      "fc_1_pre[15] +1\n",
      "fc_1_pre[16] +1\n",
      "fc_1_pre[17] +1\n",
      "fc_1_pre[18] +1\n",
      "fc_1_pre[19] -1\n",
      "fc_2_pre[0] +1\n",
      "fc_2_pre[1] -1\n",
      "fc_2_pre[2] +1\n",
      "fc_2_pre[3] +1\n",
      "fc_2_pre[4] +1\n",
      "fc_2_pre[5] +1\n",
      "fc_2_pre[6] +1\n",
      "fc_2_pre[7] +1\n",
      "fc_2_pre[8] -1\n",
      "fc_2_pre[9] ??\n",
      "fc_2_pre[10] +1\n",
      "fc_2_pre[11] +1\n",
      "fc_2_pre[12] -1\n",
      "fc_2_pre[13] -1\n",
      "fc_2_pre[14] -1\n",
      "fc_2_pre[15] -1\n",
      "fc_2_pre[16] +1\n",
      "fc_2_pre[17] +1\n",
      "fc_2_pre[18] -1\n",
      "fc_2_pre[19] -1\n"
     ]
    }
   ],
   "source": [
    "# 5) Verify that the function is indeed nondifferentiable at that point:\n",
    "model = result.squire.model\n",
    "import re\n",
    "def sign(x):\n",
    "    if x < 0:\n",
    "        return '-1'\n",
    "    elif x > 0:\n",
    "        return '+1'\n",
    "    else:\n",
    "        return '??'\n",
    "fc_pre = r'fc_\\d+_pre\\[\\d+\\]'\n",
    "for var in model.getVars():\n",
    "    if re.match(fc_pre, var.varName):\n",
    "        print(var.varName, sign(var.X))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# And the polytope from the sign config\n",
    "signs = result.squire.get_sign_configs()\n",
    "sign_poly = test_net.polytope_from_signs(signs)\n",
    "best_x = result.best_x\n",
    "best_x\n",
    "sign_poly.contains(best_x, tolerance=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimize a model with 44 rows, 5 columns and 204 nonzeros\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-02, 5e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e+00, 1e+00]\n",
      "  RHS range        [2e-02, 1e+00]\n",
      "Presolve time: 0.01s\n",
      "Presolved: 5 rows, 45 columns, 205 nonzeros\n",
      "\n",
      "Iteration    Objective       Primal Inf.    Dual Inf.      Time\n",
      "       0    1.0000000e+00   0.000000e+00   3.939915e+01      0s\n",
      "      11    2.0307286e-02   0.000000e+00   0.000000e+00      0s\n",
      "\n",
      "Solved in 11 iterations and 0.01 seconds\n",
      "Optimal objective  2.030728623e-02\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([0.15274361, 0.36230745, 0.56706253, 0.54873953])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_out = sign_poly.intersects_hbox(hbox1)\n",
    "model_out\n",
    "\n",
    "\n",
    "#model_out.getVars()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-27.4737,  39.0434,   9.8009, -15.4907])\n",
      "tensor(161.8217)\n"
     ]
    }
   ],
   "source": [
    "print(test_net.get_grad_at_point(best_x, c_vec))\n",
    "print(test_net.get_grad_at_point(model_out, c_vec).norm(p=1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
