{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from hyperbox import Hyperbox\n",
    "from interval_analysis import HBoxIA\n",
    "from relu_nets import ReLUNet\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip\n",
    "from neural_nets import train\n",
    "from neural_nets import data_loaders as dl\n",
    "from experiment import Experiment, InstanceGroup, Result\n",
    "from utilities import Factory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Sanity check: \n",
    "    - we're finding that the MIP solution yields an 'x' point at which the NN is not differentiable\n",
    "    - this causes problems when comparing against the pytorch-computed gradient because it's not clear \n",
    "      what pytorch is doing at this point \n",
    "    - proposed solution is to, for a given NN, instance, and MIP optimal point, find a DIFFERENTIABLE \n",
    "      x with the same optimum. If the NN is in general position, such an x is guaranteed to exist within any \n",
    "      neighborhood of x*.\n",
    "      \n",
    "      \n",
    "In this notebook we'll \n",
    "1) demonstrate an example where this is a problem\n",
    "2) And then we'll demonstrate the solution\n",
    "\"\"\"\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00 | Accuracy: 52.00\n",
      "Epoch 100 | Accuracy: 65.60\n",
      "Epoch 200 | Accuracy: 77.60\n",
      "Epoch 300 | Accuracy: 80.00\n",
      "Epoch 400 | Accuracy: 81.20\n",
      "Epoch 500 | Accuracy: 83.20\n",
      "Epoch 600 | Accuracy: 85.20\n",
      "Epoch 700 | Accuracy: 86.00\n",
      "Epoch 800 | Accuracy: 87.60\n",
      "Epoch 900 | Accuracy: 87.20\n"
     ]
    }
   ],
   "source": [
    "# 1) Build a dataset and train a neural network\n",
    "dataset_params = dl.RandomKParameters(250, 20, radius=0.02, dimension=2)\n",
    "dataset = dl.RandomBinaryDataset(dataset_params, random_seed=420)\n",
    "train_data, val_data = dataset.split_train_val(1.0)\n",
    "train_params = train.TrainParameters(train_data, train_data, 1000, test_after_epoch=100)\n",
    "test_net = ReLUNet(layer_sizes=[2, 20, 20, 2])\n",
    "train.training_loop(test_net, train_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LipMIP Result: \n",
       "\tValue 76.921\n",
       "\tRuntime 0.154"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2) Create and solve a lipMIP problem\n",
    "hbox1 = Hyperbox.build_linf_ball(np.array([0.5 for _ in range(test_net.layer_sizes[0])]), 0.2)\n",
    "c_vec = np.array([1.0, -1.0])\n",
    "prob = LipProblem(test_net, hbox1, c_vec, num_threads=4)\n",
    "prob.compute_max_lipschitz()\n",
    "result = prob.result\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(72.8327)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 3) Examine what we think the true solution should be:\n",
    "rand_lb = RandomLB(test_net, c_vec, hbox1, 'linf')\n",
    "rand_lb.compute(num_points=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BEST X:  [0.49633214296137573, 0.3152697391985641]\n",
      "PYTORCH GRAD [12.660749 22.725334]\n",
      "MIP GRAD [12.660749308223947, 22.725335458943402]\n"
     ]
    }
   ],
   "source": [
    "# 4) Best x and what MIP + Pytorch have to say about it \n",
    "best_x = result.best_x\n",
    "print(\"BEST X: \", list(best_x))\n",
    "print(\"PYTORCH GRAD\", test_net.get_grad_at_point(best_x, c_vec).numpy())\n",
    "print(\"MIP GRAD\", result.squire.get_grad_at_point(best_x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fc_1_pre[0] +1\n",
      "fc_1_pre[1] +1\n",
      "fc_1_pre[2] -1\n",
      "fc_1_pre[3] +1\n",
      "fc_1_pre[4] -1\n",
      "fc_1_pre[5] +1\n",
      "fc_1_pre[6] -1\n",
      "fc_1_pre[7] +1\n",
      "fc_1_pre[8] -1\n",
      "fc_1_pre[9] -1\n",
      "fc_1_pre[10] -1\n",
      "fc_1_pre[11] +1\n",
      "fc_1_pre[12] +1\n",
      "fc_1_pre[13] ??\n",
      "fc_1_pre[14] +1\n",
      "fc_1_pre[15] +1\n",
      "fc_1_pre[16] -1\n",
      "fc_1_pre[17] +1\n",
      "fc_1_pre[18] +1\n",
      "fc_1_pre[19] -1\n",
      "fc_2_pre[0] +1\n",
      "fc_2_pre[1] ??\n",
      "fc_2_pre[2] +1\n",
      "fc_2_pre[3] +1\n",
      "fc_2_pre[4] -1\n",
      "fc_2_pre[5] +1\n",
      "fc_2_pre[6] +1\n",
      "fc_2_pre[7] +1\n",
      "fc_2_pre[8] -1\n",
      "fc_2_pre[9] +1\n",
      "fc_2_pre[10] +1\n",
      "fc_2_pre[11] +1\n",
      "fc_2_pre[12] +1\n",
      "fc_2_pre[13] +1\n",
      "fc_2_pre[14] -1\n",
      "fc_2_pre[15] -1\n",
      "fc_2_pre[16] +1\n",
      "fc_2_pre[17] -1\n",
      "fc_2_pre[18] +1\n",
      "fc_2_pre[19] -1\n"
     ]
    }
   ],
   "source": [
    "# 5) Verify that the function is indeed nondifferentiable at that point:\n",
    "model = result.squire.model\n",
    "import re\n",
    "def sign(x):\n",
    "    if x < 0:\n",
    "        return '-1'\n",
    "    elif x > 0:\n",
    "        return '+1'\n",
    "    else:\n",
    "        return '??'\n",
    "fc_pre = r'fc_\\d+_pre\\[\\d+\\]'\n",
    "for var in model.getVars():\n",
    "    if re.match(fc_pre, var.varName):\n",
    "        print(var.varName, sign(var.X))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Now consider sign configs\n",
    "preacts = test_net(torch.tensor([0.0, 0.0]), return_preacts=True)[:2]\n",
    "signs = [_.squeeze().detach().numpy() > 0 for _ in preacts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimize a model with 63 rows, 43 columns and 747 nonzeros\n",
      "Coefficient statistics:\n",
      "  Matrix range     [3e-03, 2e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [3e-01, 7e-01]\n",
      "  RHS range        [2e-03, 1e+00]\n",
      "Presolve removed 11 rows and 28 columns\n",
      "Presolve time: 0.01s\n",
      "\n",
      "Solved in 0 iterations and 0.01 seconds\n",
      "Infeasible or unbounded model\n"
     ]
    }
   ],
   "source": [
    "xout = test_net.find_feasible_from_signs(signs, hbox1)\n",
    "#print(\"PYTORCH GRAD\", test_net.get_grad_at_point(xout, c_vec).numpy())\n",
    "#print(\"MIP GRAD\", result.squire.get_grad_at_point(xout))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Do it LIVE!\n",
    "import utilities as utils \n",
    "import gurobipy as gb\n",
    "with utils.silent():\n",
    "    model = gb.Model()\n",
    "input_key = 'input'\n",
    "input_namer = utils.build_var_namer(input_key)\n",
    "input_vars = [model.addVar(lb=-2, ub=2, name=input_namer(i))for i in range(test_net.layer_sizes[0])]\n",
    "slack_var = model.addVar(lb=0, ub=1.0, name='slack')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FUCK\n",
    "# do this for a super simple network \n",
    "\n",
    "test_net = ReLUNet([2, 10, 20, 4])\n",
    "preacts = [_.squeeze().detach().numpy() for _ in \n",
    "           test_net(torch.tensor([0.0, 0.0]), True)[:2]]\n",
    "print(preacts)\n",
    "signs = [p > 0 for p in preacts]\n",
    "signs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_net.get_sign_configs(torch.tensor([[0.0, 0.0]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = test_net.find_feasible_from_signs(signs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array([v.X for v in inp])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[_.X for _ in mout.getVars()[:2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with utils.silent():\n",
    "    model = gb.Model()\n",
    "input_namer = utils.build_var_namer('input')\n",
    "input_vars = [model.addVar(lb=-gb.GRB.INFINITY, ub=gb.GRB.INFINITY, name=input_namer(i))for i in range(test_net.layer_sizes[0])]\n",
    "slack_var = model.addVar(lb=0, ub=1.0, name='slack')\n",
    "\n",
    "weight = test_net.fcs[0].weight.detach().numpy()\n",
    "bias = test_net.fcs[0].bias.detach().numpy()\n",
    "\n",
    "def add_layer(linear, signs, model, input_vars, slack_var):\n",
    "    weight = linear.weight.detach().numpy() \n",
    "    bias = linear.bias.detach().numpy()\n",
    "    out_vars = []\n",
    "    for i, row in enumerate(weight):\n",
    "        if signs[i] == True:\n",
    "            var = model.addVar(lb=-gb.GRB.INFINITY, ub=gb.GRB.INFINITY)\n",
    "            model.addConstr(var == gb.LinExpr(row, input_vars) + bias[i])\n",
    "            model.addConstr(gb.LinExpr(row, input_vars) + bias[i] - slack_var >= 0)\n",
    "        else:\n",
    "            var = model.addVar(lb=0.0, ub=0.0)\n",
    "            model.addConstr(gb.LinExpr(row, input_vars) + bias[i] + slack_var <= 0)\n",
    "        out_vars.append(var)\n",
    "    return out_vars\n",
    "\n",
    "out_vars = test_net._add_layer_to_gurobi_model(0, model, input_vars, slack_var, signs[0])\n",
    "out_vars = test_net._add_layer_to_gurobi_model(1, model, out_vars, slack_var, signs[1])\n",
    "\n",
    "#out_vars = add_layer(test_net.fcs[0], signs[0], model, input_vars, slack_var)\n",
    "#out_vars2 = add_layer(test_net.fcs[1], signs[1], model, out_vars, slack_var)\n",
    "'''\n",
    "out_vars = []\n",
    "for i in range(10):\n",
    "    if signs[0][i] == True:\n",
    "        var = model.addVar(lb=-100, ub=100)\n",
    "        model.addConstr(gb.LinExpr(weight[i], input_vars) + bias[i] - slack_var >=0)\n",
    "    else:\n",
    "        var = model.addVar(lb=0.0, ub=0.0)\n",
    "        model.addConstr(gb.LinExpr(weight[i], input_vars) + bias[i] + slack_var <= 0)\n",
    "    out_vars.append(var)\n",
    "model.update()\n",
    "\n",
    "weight = test_net.fcs[1].weight.detach().numpy()\n",
    "bias = test_net.fcs[1].bias.detach().numpy()\n",
    "for i in range(20):\n",
    "    if signs[1][i] == True:\n",
    "        var = model.addVar(lb=-100, ub=100)\n",
    "        model.addConstr(gb.LinExpr(weight[i], out_vars) + bias[i] - slack_var >=0)\n",
    "    else:\n",
    "        var = model.addVar(lb=0.0, ub=0.0)\n",
    "        model.addConstr(gb.LinExpr(weight[i], out_vars) + bias[i] + slack_var <= 0)\n",
    "'''\n",
    "model.update()\n",
    "\n",
    "model.setObjective(slack_var, gb.GRB.MAXIMIZE)\n",
    "model.optimize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bestx = [_.X for _ in model.getVars()[:2]]\n",
    "print(test_net.get_sign_configs(torch.tensor([bestx])))\n",
    "print(signs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]],\n",
       "       dtype=torch.int8)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#### Polytope from signs\n",
    "torch.diag(torch.tensor(signs[0]).type(torch.int8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_polytope_config(self, signs, comparison_form_flag=False,\n",
    "                            uncertain_constraints=None, as_tensor=False):\n",
    "    \n",
    "    lambdas = [torch.diag(config) for config in configs]\n",
    "    js = [torch.diag(-2 * config + 1) for config in configs]\n",
    "\n",
    "    # Compute Z_k = W_k * x + b_k for each layer\n",
    "    wks = [self.fcs[0].weight]\n",
    "    bks = [self.fcs[0].bias]\n",
    "    for (i, fc) in enumerate(self.fcs[1:]):\n",
    "        current_wk = wks[-1]\n",
    "        current_bk = bks[-1]\n",
    "        current_lambda = lambdas[i]\n",
    "        precompute = fc.weight.matmul(current_lambda)\n",
    "        wks.append(precompute.matmul(current_wk))\n",
    "        bks.append(precompute.matmul(current_bk) + fc.bias)\n",
    "\n",
    "    a_stack = []\n",
    "    b_stack = []\n",
    "    for j, wk, bk in zip(js, wks, bks):\n",
    "        a_stack.append(j.matmul(wk))\n",
    "        b_stack.append(-j.matmul(bk))\n",
    "    if as_tensor:\n",
    "        return {'a_stack': a_stack,\n",
    "                'b_stack': b_stack,\n",
    "                'total_a': wks[-1],\n",
    "                'total_b': bks[-1]}\n",
    "\n",
    "    polytope_A = utils.as_numpy(torch.cat(a_stack, dim=0))\n",
    "    polytope_b = utils.as_numpy(torch.cat(b_stack, dim=0))\n",
    "\n",
    "    if(comparison_form_flag):\n",
    "        polytope_A, polytope_b = utils.comparison_form(polytope_A, polytope_b)\n",
    "\n",
    "\n",
    "    return {'poly_a': polytope_A,\n",
    "            'poly_b': polytope_b,\n",
    "            'configs': configs,\n",
    "            'total_a': wks[-1],\n",
    "            'total_b': bks[-1]\n",
    "            }\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
