{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple debugging of lipMIP\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import lipMIP as lm\n",
    "import gurobipy as gb\n",
    "from relu_nets import ReLUNet\n",
    "from hyperbox import Hyperbox\n",
    "from pre_activation_bounds import PreactivationBounds\n",
    "import numpy as np\n",
    "from pprint import pprint\n",
    "import torch\n",
    "import utilities as utils\n",
    "import random "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(363)\n",
    "network = ReLUNet([2, 4, 3])\n",
    "hbox = Hyperbox.build_linf_ball(np.zeros(2), 1.0)\n",
    "c_vector = np.array([1.0, -1.0, 0.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Academic license - for non-commercial use only\n",
      "-0.7759191940738912\n"
     ]
    }
   ],
   "source": [
    "output = lm.compute_max_lipschitz(network, hbox, 'l_inf', c_vector, verbose=False)\n",
    "squire = output['squire']\n",
    "print(-1 * output['lipschitz'])\n",
    "model = output['model']\n",
    "output\n",
    "MAXPOINT=output['best_x']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "get_grad_at_point() missing 1 required positional argument: 'c_vector'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-4d1b9de9f4a1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_grad_at_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhbox\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: get_grad_at_point() missing 1 required positional argument: 'c_vector'"
     ]
    }
   ],
   "source": [
    "network.get_grad_at_point(hbox.random_point, c_vect)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network.random_max_grad(hbox, c_vector, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_max(network, domain, c_vector, num_random=100):\n",
    "    max_norm = -1\n",
    "    max_point = None\n",
    "    max_grad = None\n",
    "    for i in range(num_random):\n",
    "        random_input = domain.random_point(tensor_or_np='tensor')\n",
    "        grad = network.get_grad_at_point(random_input, torch.tensor(c_vector).type(network.dtype))\n",
    "        grad_norm = grad.norm(p=1)\n",
    "        if grad_norm > max_norm:\n",
    "            max_norm = grad_norm \n",
    "            max_point = random_input\n",
    "            max_grad = grad \n",
    "    return {'norm': max_norm, 'point': max_point, 'grad': max_grad}\n",
    "collect_max(network, hbox, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network.get_grad_at_point(torch.Tensor([0.079484679, 0.967835081]), torch.Tensor(c_vector)).norm(p=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def verify():\n",
    "    rand_seed = random.randint(1,1000)\n",
    "    print(\"RAND SEED\", rand_seed)\n",
    "    torch.random.manual_seed(rand_seed)\n",
    "    network = ReLUNet([2, 4, 3])\n",
    "    hbox = Hyperbox.build_linf_ball(np.zeros(2), 1.0)\n",
    "    c_vector = np.array([1.0, -1.0, 0.0])\n",
    "    output = lm.compute_max_lipschitz(network, hbox, 'l_inf', c_vector, verbose=False)\n",
    "    random_max = collect_max(network, hbox, c_vector, num_random=1000)\n",
    "    \n",
    "    mip_out = output['lipschitz']\n",
    "    rand_out = random_max['norm'].item()\n",
    "    return (abs(mip_out - rand_out))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1000):\n",
    "    assert(verify() < 1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collect_max(network, hbox, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squire_clone = squire.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sq_grad = squire_clone.get_grad_at_point(MAXPOINT)\n",
    "sq_grad\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sq_grad, sum([abs(_) for _ in sq_grad])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_vars = squire.get_vars('gradient')\n",
    "grad_bounds = squire.get_backprop_bounds(0, two_col=True)\n",
    "print([grad_bounds[i][0] <= [_.X for _ in grad_vars][i] <= grad_bounds[i][1] for i in range(2)])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check constraint \n",
    "def check_constr_set(i):\n",
    "    input_var = grad_vars[i].X\n",
    "    output_var = abs(input_var)\n",
    "    grad_lo, grad_hi = grad_bounds[i]\n",
    "    c1 = grad_lo <= output_var <= "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_grad = network.get_grad_at_point(torch.Tensor([0.5, 0.5]), torch.Tensor([1.0, -1.0, 0.0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.as_numpy(sq_grad) -  utils.as_numpy(out_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.Tensor([0.5, 0.5])\n",
    "x = x.clone().detach().requires_grad_(True)\n",
    "x = x.type(network.dtype)\n",
    "y = network.forward(x).mv(torch.Tensor([1.0, -1.0, 0.0]))\n",
    "y.backward()\n",
    "x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.remove(model.getConstrByName('aoesu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gradient functional \n",
    "def grad_x(x, c_vector=c_vector):\n",
    "    x = torch.tensor(x, requires_grad=True, dtype=torch.float32)\n",
    "    output = network(x).squeeze().dot(torch.tensor(c_vector, dtype=torch.float32))\n",
    "    output.backward()\n",
    "    return x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_random_norms(num=100):\n",
    "    max_norm = -1\n",
    "    max_norm_point = None\n",
    "    for i in range(num):\n",
    "        random_point = np.random.random(2) * 2 - 1\n",
    "        norm = torch.norm(grad_x(random_point), 1)\n",
    "        if norm > max_norm:\n",
    "            max_norm = norm \n",
    "            max_norm_point = random_point\n",
    "    return max_norm, max_norm_point, grad_x(max_norm_point)\n",
    "\n",
    "test_random_norms(100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
