{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1, debug the hbox pushing works\n",
    "from relu_nets import ReLUNet\n",
    "from hyperbox import Hyperbox, BooleanHyperbox\n",
    "from interval_analysis import HBoxIA\n",
    "from lipMIP import LipProblem, GurobiSquire\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 16\n",
    "test_net = ReLUNet(layer_sizes=[n, 8, 16, 32, 4], bias=True) #LRLRL \n",
    "c_vector = np.array([1., -1., 0.5, 1.0])\n",
    "input_domain = Hyperbox.build_linf_ball(np.array([0.5] * n), 0.1)\n",
    "backprop_domain = Hyperbox.from_vector(c_vector)\n",
    "ia_obj = HBoxIA(test_net, input_domain, backprop_domain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ia_obj.compute_forward()\n",
    "ia_obj.compute_backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LipProblem(test_net, input_domain, c_vector, verbose=True, mip_gap=2.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0  0 -1 -1  1  0  0 -1  0 -1  1 -1  0  1 -1  1  0  0  0  1  0 -1  1 -1\n",
      "  0 -1  0  1  0  0  1  1]\n",
      "[-1  0  0 -1  1  1 -1  1  0  0 -1 -1  0  0  0  1]\n",
      "[ 1 -1  1  0 -1  0  0  1]\n",
      "Changed value of parameter Threads to 4\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Optimize a model with 420 rows, 320 columns and 2552 nonzeros\n",
      "Variable types: 280 continuous, 40 integer (40 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [2e-04, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-08, 1e+00]\n",
      "  RHS range        [1e-08, 7e-01]\n",
      "Presolve removed 180 rows and 150 columns\n",
      "Presolve time: 0.00s\n",
      "Presolved: 240 rows, 170 columns, 1152 nonzeros\n",
      "Variable types: 130 continuous, 40 integer (40 binary)\n",
      "\n",
      "Root relaxation: objective 3.868013e+00, 189 iterations, 0.01 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0    3.86801    0   27          -    3.86801      -     -    0s\n",
      "     0     0    3.08414    0   34          -    3.08414      -     -    0s\n",
      "     0     0    3.02465    0   37          -    3.02465      -     -    0s\n",
      "     0     0    3.02448    0   36          -    3.02448      -     -    0s\n",
      "H    0     0                       0.6410401    3.02448   372%     -    0s\n",
      "     0     0    2.98053    0   36    0.64104    2.98053   365%     -    0s\n",
      "     0     0    2.84942    0   29    0.64104    2.84942   345%     -    0s\n",
      "     0     0    2.79881    0   31    0.64104    2.79881   337%     -    0s\n",
      "     0     0    2.79669    0   31    0.64104    2.79669   336%     -    0s\n",
      "     0     0    2.77533    0   32    0.64104    2.77533   333%     -    0s\n",
      "     0     0    2.77524    0   32    0.64104    2.77524   333%     -    0s\n",
      "     0     0    2.77368    0   32    0.64104    2.77368   333%     -    0s\n",
      "     0     0    2.77368    0   32    0.64104    2.77368   333%     -    0s\n",
      "     0     2    2.77368    0   32    0.64104    2.77368   333%     -    0s\n",
      "*  296   119              25       0.6847479    2.04700   199%  22.8    0s\n",
      "*  300   119              25       0.6882123    2.04700   197%  22.6    0s\n",
      "*  589   197              25       0.6889956    1.69745   146%  21.3    0s\n",
      "* 3433     5              24       0.7100243    0.76141  7.24%  17.7    0s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 19\n",
      "  Cover: 6\n",
      "  Implied bound: 14\n",
      "  MIR: 56\n",
      "  Flow cover: 14\n",
      "  Inf proof: 3\n",
      "\n",
      "Explored 3456 nodes (61700 simplex iterations) in 0.73 seconds\n",
      "Thread count was 4 (of 8 available processors)\n",
      "\n",
      "Solution count 5: 0.710024 0.688996 0.688212 ... 0.64104\n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 7.100242915523e-01, best bound 7.100242915523e-01, gap 0.0000%\n"
     ]
    }
   ],
   "source": [
    "out = prob.compute_max_lipschitz()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.array([0.4, 0.5])\n",
    "print(out.squire.get_grad_at_point(x))\n",
    "test_net.get_grad_at_point(x, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_net.fcs[-1].weight.detach().numpy().T @ c_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ia_obj."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_net(torch.Tensor([0.4, 0.4]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### for pt in input_domain.random_point(num_points=1000):\n",
    "    mip_out = out.squire.get_var_at_point(pt, 'logits')\n",
    "    nn_out = test_net(torch.Tensor(pt)).squeeze().detach().numpy()\n",
    "    \n",
    "    logit_diff = max(abs(mip_out - nn_out))\n",
    "    \n",
    "    mip_grad = out.squire.get_grad_at_point(pt)\n",
    "    nn_grad = test_net.get_grad_at_point(pt, c_vector).detach().numpy()\n",
    "    grad_diff = max(abs(mip_grad - nn_grad))\n",
    "    print(max([logit_diff, grad_diff]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.Tensor([0.4, 0.4])\n",
    "test_net.fcs[0].weight @ x + test_net.fcs[0].bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ia_obj.gradient_range.as_twocol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_net.get_grad_at_point(torch.Tensor([0.0, 0.0]), c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for point in input_domain.random_point(num_points=1000):\n",
    "    point = torch.Tensor(point)\n",
    "    point_out = test_net(point).squeeze().detach().numpy()\n",
    "    mip_out = out.squire.get_var_at_point(point.numpy(), 'logits')\n",
    "    print(abs(point_out - mip_out).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gurobipy as gb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gb.LinExpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pre = HBoxIA(test_net, input_domain, backprop_domain)\n",
    "test_pre.compute_forward()\n",
    "test_pre.compute_backward()\n",
    "test_pre.get_backward_box(1, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_net.num_relus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.squire.pre_bounds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
