{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.append('..')\n",
    "from other_methods.naive_methods import NaiveUB\n",
    "from other_methods.fast_lip import FastLip\n",
    "from other_methods.lip_sdp import LipSDPCustom\n",
    "from relu_nets import ReLUNet\n",
    "import torch\n",
    "import utilities as utils\n",
    "import numpy as np\n",
    "import math\n",
    "from hyperbox import Hyperbox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_vector = torch.Tensor([1.0, -1.0])\n",
    "unit_box = Hyperbox.build_unit_hypercube(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('net.1.weight', tensor([[-0.2420, -0.4277, -0.2591,  0.2209],\n",
       "                      [-0.1439,  0.4846,  0.4228,  0.3916],\n",
       "                      [ 0.0603,  0.4963,  0.3303,  0.4624],\n",
       "                      [ 0.4159,  0.2385, -0.4251, -0.0779],\n",
       "                      [ 0.1009, -0.3171, -0.2157,  0.4901],\n",
       "                      [ 0.4927,  0.4081,  0.3245,  0.3202],\n",
       "                      [ 0.4527, -0.0518,  0.2519, -0.1986],\n",
       "                      [-0.1238, -0.0092,  0.4427,  0.3163]])),\n",
       "             ('net.1.bias',\n",
       "              tensor([-0.2179, -0.0707,  0.2108, -0.2958, -0.0026, -0.1606, -0.0699, -0.1206])),\n",
       "             ('net.3.weight',\n",
       "              tensor([[-0.1803,  0.1849,  0.3457,  0.2782,  0.2728,  0.2723, -0.2501,  0.1121],\n",
       "                      [ 0.2284,  0.0413, -0.1938,  0.0264,  0.3304,  0.0769,  0.2050,  0.1452]])),\n",
       "             ('net.3.bias', tensor([-0.2676, -0.2789]))])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_net.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "fast_lip = LipSDPCustom(test_net, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2, 8) torch.Size([2])\n",
      "SNTHS\n",
      "[array([[-0.24197292, -0.4276663 , -0.25911695,  0.22092235],\n",
      "       [-0.14391285,  0.48464197,  0.42284632,  0.39162278],\n",
      "       [ 0.06028676,  0.496297  ,  0.33026093,  0.46235538],\n",
      "       [ 0.415949  ,  0.23847353, -0.42511708, -0.07792896],\n",
      "       [ 0.10090196, -0.31707644, -0.21569586,  0.49009258],\n",
      "       [ 0.4926731 ,  0.40812284,  0.32453418,  0.32015443],\n",
      "       [ 0.45269096, -0.05183452,  0.25190532, -0.19857115],\n",
      "       [-0.12381262, -0.00924289,  0.44274247,  0.31627023]],\n",
      "      dtype=float32), array([-0.40870488,  0.14368993,  0.5394963 ,  0.2518722 , -0.05757976,\n",
      "        0.19538134, -0.455074  , -0.03309274], dtype=float32)]\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "could not broadcast input array from shape (8,4) into shape (8)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-be5b6092b925>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfast_lip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/grad/lipMIP/other_methods/lip_sdp.py\u001b[0m in \u001b[0;36mcompute\u001b[0;34m(self, attr_name)\u001b[0m\n\u001b[1;32m     53\u001b[0m                 \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextract_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_vector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m                 \u001b[0mtmp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtempfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNamedTemporaryFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'w+b'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m                 \u001b[0msavemat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtmp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m                 \u001b[0;31m# Build matlab stuff\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/configs/env/lib/python3.6/site-packages/scipy-1.3.0-py3.6-linux-x86_64.egg/scipy/io/matlab/mio.py\u001b[0m in \u001b[0;36msavemat\u001b[0;34m(file_name, mdict, appendmat, format, long_field_names, do_compression, oned_as)\u001b[0m\n\u001b[1;32m    277\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    278\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Format should be '4' or '5'\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 279\u001b[0;31m         \u001b[0mMW\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mput_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/configs/env/lib/python3.6/site-packages/scipy-1.3.0-py3.6-linux-x86_64.egg/scipy/io/matlab/mio5.py\u001b[0m in \u001b[0;36mput_variables\u001b[0;34m(self, mdict, write_header)\u001b[0m\n\u001b[1;32m    847\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile_stream\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    848\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# not compressing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 849\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_matrix_writer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_top\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0masbytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_global\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/configs/env/lib/python3.6/site-packages/scipy-1.3.0-py3.6-linux-x86_64.egg/scipy/io/matlab/mio5.py\u001b[0m in \u001b[0;36mwrite_top\u001b[0;34m(self, arr, name, is_global)\u001b[0m\n\u001b[1;32m    588\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_var_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    589\u001b[0m         \u001b[0;31m# write the header and data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 590\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    591\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    592\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/configs/env/lib/python3.6/site-packages/scipy-1.3.0-py3.6-linux-x86_64.egg/scipy/io/matlab/mio5.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, arr)\u001b[0m\n\u001b[1;32m    606\u001b[0m             \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    607\u001b[0m         \u001b[0;31m# Try to convert things that aren't arrays\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 608\u001b[0;31m         \u001b[0mnarr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_writeable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    609\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mnarr\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    610\u001b[0m             raise TypeError('Could not convert %s (type %s) to array'\n",
      "\u001b[0;32m~/configs/env/lib/python3.6/site-packages/scipy-1.3.0-py3.6-linux-x86_64.egg/scipy/io/matlab/mio5.py\u001b[0m in \u001b[0;36mto_writeable\u001b[0;34m(source)\u001b[0m\n\u001b[1;32m    448\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mEmptyStructMarker\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    449\u001b[0m     \u001b[0;31m# Next try and convert to an array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 450\u001b[0;31m     \u001b[0mnarr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msource\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    451\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mnarr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mobject\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobject_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    452\u001b[0m        \u001b[0mnarr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnarr\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0msource\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/configs/env/lib/python3.6/site-packages/numpy/core/numeric.py\u001b[0m in \u001b[0;36masanyarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m    589\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    590\u001b[0m     \"\"\"\n\u001b[0;32m--> 591\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    592\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    593\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: could not broadcast input array from shape (8,4) into shape (8)"
     ]
    }
   ],
   "source": [
    "fast_lip.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_net.random_max_grad(unit_box, c_vector, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_out = LipSDPCustom.extract_weights(test_net, c_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_out[0].astype(np.object)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(weight_list, dtype=np.object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
