{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple test to do lipMIP for a very easy MNIST network \n",
    "import sys \n",
    "sys.path.append('..')\n",
    "import lipMIP as lm\n",
    "import neural_nets.data_loaders as dl \n",
    "import neural_nets.train as train \n",
    "from relu_nets import ReLUNet\n",
    "import utilities as utils\n",
    "from hyperbox import Hyperbox\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First build and train a simple MNIST network \n",
    "oneseven_train = dl.load_mnist_data('train', digits=[1, 7])\n",
    "oneseven_val = dl.load_mnist_data('val', digits=[1, 7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully loaded: 2layer_2class_mnist.pkl ReLUNet(\n",
      "  (net): Sequential(\n",
      "    (1): Linear(in_features=784, out_features=20, bias=True)\n",
      "    (2): ReLU()\n",
      "    (3): Linear(in_features=20, out_features=2, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# And do training loop: \n",
    "saved_name = '2layer_2class_mnist.pkl'\n",
    "trained_net = train.train_cacher(saved_name, network=network, trainset=oneseven_train, valset=oneseven_val, \n",
    "                                 num_epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAEEklEQVR4nO3dv0nsWRiA4Z3l+jcQTIzEzBbMLcAejE3EGqxBxAJswA4MDAwswgYUBkQRdDbZZMH57fXOeOd193lCP+ZwkpcPPIyOJpPJH0DPn4u+APAxcUKUOCFKnBAlToj6MTQcjUZ+lQtfbDKZjD76uc0JUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlToj6segLMF9XV1eD84ODg8H50dHR1Nn5+fkv3YlfY3NClDghSpwQJU6IEidEiROixAlR3jm/meXl5cH5+vr64Pz9/X1wfnx8PHV2eXk5+NnxeDw453NsTogSJ0SJE6LECVHihChxQpSnlG9mdXV1cL6xsTHT+bu7u1Nna2trg5/1lDJfNidEiROixAlR4oQocUKUOCFKnBDlnfOb+bevfL29vc10/unp6dTZ4+PjTGfzOTYnRIkTosQJUeKEKHFClDghSpwQ5Z3zm9nZ2Rmc7+3tfdn5r6+vM53N59icECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKU73N+MycnJ4u+Ar+JzQlR4oQocUKUOCFKnBAlTojylPLNLC0tfen5T09PX3o+P8/mhChxQpQ4IUqcECVOiBInRIkTorxz8g9nZ2eLvgJ/szkhSpwQJU6IEidEiROixAlR4oQo75wxm5ubg/P9/f2Zzn94eBicPz8/z3Q+82NzQpQ4IUqcECVOiBInRIkTosQJUd45Y1ZWVgbn29vbM51/c3MzOL+/v5/pfObH5oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0T5F4D/M+PxeNFX4CfZnBAlTogSJ0SJE6LECVHihChxQpR3zpitra0vPf/i4uJLz2d+bE6IEidEiROixAlR4oQocUKUp5SYw8PDRV+BCJsTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTojyfc7/mLu7u5nmdNicECVOiBInRIkTosQJUeKEKHFClHfOmNvb25k+f319PTh/eXmZ6Xx+H5sTosQJUeKEKHFClDghSpwQJU6IGk0mk+nD0Wj6EJiLyWQy+ujnNidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghavBPYwKLY3NClDghSpwQJU6IEidEiROi/gIhk1ijxGjR/AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# And now try a small local lipschitz-y thing on a single example \n",
    "EXAMPLE_NUM = 9\n",
    "example = next(iter(oneseven_val))[0][EXAMPLE_NUM:EXAMPLE_NUM+1]\n",
    "utils.display_images(example, figsize=(4, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimize a model with 1630 rows, 2434 columns and 33854 nonzeros\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-06, 1e+00]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-08, 3e+00]\n",
      "  RHS range        [5e-03, 1e-01]\n",
      "\n",
      "Concurrent LP optimizer: dual simplex and barrier\n",
      "Showing barrier log only...\n",
      "\n",
      "Presolve removed 1630 rows and 2434 columns\n",
      "Presolve time: 0.03s\n",
      "Presolve: All rows and columns removed\n",
      "Iteration    Objective       Primal Inf.    Dual Inf.      Time\n",
      "       0    2.4315533e+02   0.000000e+00   0.000000e+00      0s\n",
      "\n",
      "Solved with dual simplex\n",
      "Solved in 0 iterations and 0.04 seconds\n",
      "Optimal objective  2.431553324e+02\n"
     ]
    }
   ],
   "source": [
    "# LipMIP stuff\n",
    "small_rad = Hyperbox.build_linf_ball(example, 0.01, global_lo=0.0, global_hi=1.0)\n",
    "lipmip_out = lm.compute_max_lipschitz(trained_net, small_rad, 'l_inf', torch.Tensor([1.0, -1.0]), verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
