{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "In this example, we build a binary MNIST classifier and then run GeoCert on several test points.\n",
    "'''\n",
    "\n",
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "%load_ext line_profiler\n",
    "import sys\n",
    "import pickle \n",
    "sys.path.append('..')\n",
    "sys.path.append('../mister_ed') # library for adversarial examples\n",
    "\n",
    "import geocert_oop as geo\n",
    "from plnn import PLNN\n",
    "import _polytope_ as _poly_\n",
    "from _polytope_ import Polytope, Face\n",
    "import utilities as utils\n",
    "import os\n",
    "import time \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "import adversarial_perturbations as ap \n",
    "import prebuilt_loss_functions as plf\n",
    "import loss_functions as lf \n",
    "import adversarial_attacks as aa\n",
    "import utils.pytorch_utils as me_utils\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import mnist.mnist_loader as  ml \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code to select only a subset of digits from MNIST datasets\n",
    "def select_digits(dataset, digits, mb_size=128, remap_label=True):\n",
    "    if remap_label and len(digits) <= 2:\n",
    "        def label_map(labels):\n",
    "            return (labels == digits[0]).unsqueeze(1)\n",
    "    elif remap_label and len(digits) > 2:\n",
    "        raise NotImplementedError(\"Only handling between 2 types of digits for now\")\n",
    "    else:\n",
    "        label_map = lambda x: x\n",
    "            \n",
    "    sel_data, sel_labels = [], [] \n",
    "    for data, labels in dataset:\n",
    "        mask = labels == -1\n",
    "        for digit in digits:\n",
    "            mask += (labels == digit)\n",
    "        #mask = (labels == 7) + (labels == 1)\n",
    "        masked_data = data.masked_select(mask.view(-1, 1, 1, 1).expand(data.shape)).view(-1, 1, 28, 28)\n",
    "        masked_labels = labels.masked_select(mask)\n",
    "        sel_data.append(masked_data)\n",
    "        sel_labels.append(masked_labels)\n",
    "        \n",
    "    # Then concatenate and resplit into minibatches of size 128\n",
    "    cat_data = torch.cat(sel_data)\n",
    "    cat_labels = torch.cat(sel_labels)\n",
    "    full_dataset = [] \n",
    "    num_mb = int(len(cat_labels) / mb_size + 1)\n",
    "    for i in range(num_mb):\n",
    "        full_dataset.append((cat_data[mb_size * i: mb_size * (i + 1)], \n",
    "                             label_map(cat_labels[mb_size * i: mb_size * (i + 1)]).squeeze().long()))\n",
    "    return full_dataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a dataset containing only 1's and 7's\n",
    "MINIBATCH_SIZE = 128\n",
    "try: \n",
    "    os_trainset = pickle.load(open('os_trainset.pkl', 'rb'))\n",
    "except:\n",
    "    mnist_trainset = ml.load_mnist_data('train')\n",
    "    os_trainset = select_digits(mnist_trainset, [1, 7], mb_size=MINIBATCH_SIZE)\n",
    "    with open('os_trainset.pkl', 'rb') as f:\n",
    "        pickle.dump(os_trainset, f)     \n",
    "\n",
    "        \n",
    "try: \n",
    "    os_valset = pickle.load(open('os_valset.pkl', 'rb'))\n",
    "except:\n",
    "    mnist_valset = ml.load_mnist_data('val')\n",
    "    os_valset = select_digits(mnist_valset, [1, 7], mb_size=MINIBATCH_SIZE)\n",
    "    with open('os_valset.pkl', 'rb') as f:\n",
    "        pickle.dump(os_valset, f)           \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define functions to train and evaluate a network \n",
    "\n",
    "def l1_loss(net):\n",
    "    return sum([_.norm(p=1) for _ in net.parameters() if _.dim() > 1])\n",
    "\n",
    "def train(net, trainset, num_epochs):\n",
    "    opt = optim.Adam(net.parameters(), lr=1e-3)\n",
    "    for epoch in range(num_epochs):\n",
    "        err_acc = 0\n",
    "        err_count = 0\n",
    "        for data, labels in trainset:\n",
    "            output = net(Variable(data.view(-1, 784)))\n",
    "            l = nn.CrossEntropyLoss()(output, Variable(labels)).view([1])\n",
    "            l1_scale = torch.Tensor([2e-3])\n",
    "            l += l1_scale * l1_loss(net).view([1])\n",
    "            \n",
    "            err_acc += (output.max(1)[1].data != labels).float().mean() \n",
    "            err_count += 1\n",
    "            opt.zero_grad() \n",
    "            (l).backward() \n",
    "            opt.step() \n",
    "        print(\"(%02d) error:\" % epoch, err_acc / err_count)\n",
    "            \n",
    "        \n",
    "def test_acc(net, valset):\n",
    "    err_acc = 0 \n",
    "    err_count = 0 \n",
    "    for data, labels in valset:\n",
    "        n = data.shape[0]\n",
    "        output = net(Variable(data.view(-1, 784)))\n",
    "        err_acc += (output.max(1)[1].data != labels).float().mean() * n\n",
    "        err_count += n\n",
    "        \n",
    "    print(\"Accuracy of: %.03f\" % (1 - (err_acc / err_count).item()))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (1): Linear(in_features=784, out_features=15, bias=True)\n",
      "  (2): ReLU()\n",
      "  (3): Linear(in_features=15, out_features=50, bias=True)\n",
      "  (4): ReLU()\n",
      "  (5): Linear(in_features=50, out_features=10, bias=True)\n",
      "  (6): ReLU()\n",
      "  (7): Linear(in_features=10, out_features=2, bias=True)\n",
      ")\n",
      "(00) error: tensor(0.2947)\n",
      "(01) error: tensor(0.0129)\n",
      "(02) error: tensor(0.0102)\n",
      "(03) error: tensor(0.0087)\n",
      "(04) error: tensor(0.0082)\n",
      "(05) error: tensor(0.0072)\n",
      "(06) error: tensor(0.0071)\n",
      "(07) error: tensor(0.0067)\n",
      "(08) error: tensor(0.0066)\n",
      "(09) error: tensor(0.0062)\n",
      "Accuracy of: 0.991\n"
     ]
    }
   ],
   "source": [
    "# Define the network architecture.\n",
    "MNIST_DIM = 784\n",
    "network = PLNN([MNIST_DIM, 15, 50, 10, 2])\n",
    "net = network.net\n",
    "\n",
    "# Train and evaluate the network \n",
    "train(net, os_trainset, 10)\n",
    "test_acc(net, os_valset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 28, 28])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x10f7fc588>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAADCpJREFUeJzt3WuoXfWZx/Hvo6YipmAyxXCwQWPRwaJg5SDCiFQ6liiFWBBpkCEj0vRFA1PoixEHGV+IyNA0FJRCiqHJUG2EtpgXMlMNI87AUIziNWdSnZLSXExSvNSAWC/PvDgrnaOes/bOvq19fL4fOJy113/ttR6X+eW/1v6v7H9kJpLqOaPrAiR1w/BLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrqrEkeLCJ8nFAas8yMfrYbquePiPURcSAiXouIO4fZl6TJikGf7Y+IM4HfAjcAh4BngI2Zub/lPfb80phNoue/GngtM3+XmX8Gfg5sGGJ/kiZomPBfAPxhwetDzbqPiYjNEbEvIvYNcSxJIzb2D/wyczuwHbzsl6bJMD3/YWDtgtdfbNZJWgaGCf8zwCURsS4iPgd8C9gzmrIkjdvAl/2Z+UFEbAH+HTgT2JGZr4ysMkljNfBQ30AH855fGruJPOQjafky/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qaiBp+gGiIiDwDvAh8AHmTk7iqIkjd9Q4W9cn5l/HMF+JE2Ql/1SUcOGP4FfR8SzEbF5FAVJmoxhL/uvzczDEXE+8ERE/E9mPr1wg+YvBf9ikKZMZOZodhRxD3AyM3/Qss1oDiZpSZkZ/Ww38GV/RJwbEZ8/tQx8HXh50P1JmqxhLvvXAL+KiFP7eTgz/20kVUkau5Fd9vd1MC/7F3XVVVe1tl922WUTquT0nXfeea3tDzzwwJJtW7ZsaX3vgw8+OFBN1Y39sl/S8mb4paIMv1SU4ZeKMvxSUYZfKsqhvj5t3bp1ybZLL710qH33Gspbt27dUPufVidPnmxt37y5/anw3bt3j7KczwyH+iS1MvxSUYZfKsrwS0UZfqkowy8VZfilokbx7b0lXHPNNQO1ATz88MOt7W+//fZANS13K1eubG2/+OKLJ1RJTfb8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4/x9OnLkyJJt27Zta33v3Xff3dp+/vnnt7bPzMy0tnep13/b+vXrB973qlWrWttXrFjR2v7+++8PfOwK7Pmlogy/VJThl4oy/FJRhl8qyvBLRRl+qaie39sfETuAbwDHM/PyZt1qYDdwEXAQuDUz3+x5sGX8vf1aXK9nEA4dOjTwvp966qnW9ltuuaW1/c03e/6R/Ewa5ff2/xT45JMadwJ7M/MSYG/zWtIy0jP8mfk08MYnVm8AdjbLO4GbR1yXpDEb9J5/TWYebZZfB9aMqB5JEzL0s/2ZmW338hGxGWifdE3SxA3a8x+LiBmA5vfxpTbMzO2ZOZuZswMeS9IYDBr+PcCmZnkT8NhoypE0KT3DHxGPAP8N/HVEHIqIO4D7gRsi4lXgb5vXkpaRnvf8mblxiaavjbgW6WNWr17d2n7WWX4dxTB8wk8qyvBLRRl+qSjDLxVl+KWiDL9UlGMlGsptt902tn0/+uijre0nTpwY27ErsOeXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIc59dQbrzxxq5L0IDs+aWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMf51ZkjR460tj/++OMTqqQme36pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKqrnOH9E7AC+ARzPzMubdfcA3wZOfXH6XZnpoOxn0O23397aft111w2877feequ1/YUXXhh43+qtn57/p8D6RdZvy8wrmx+DLy0zPcOfmU8Db0ygFkkTNMw9/5aIeDEidkTEqpFVJGkiBg3/j4EvAVcCR4GtS20YEZsjYl9E7BvwWJLGYKDwZ+axzPwwMz8CfgJc3bLt9syczczZQYuUNHoDhT8iZha8/Cbw8mjKkTQp/Qz1PQJ8FfhCRBwC/hn4akRcCSRwEPjOGGuUNAY9w5+ZGxdZ/dAYalEHIqK1/eyzz25tP+MMnxNbrvw/JxVl+KWiDL9UlOGXijL8UlGGXyoqMnNyB4uY3MHUl3POOae1/eTJk2M79v79+1vbr7jiirEd+7MsM9vHbxv2/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlFN0qzP33Xdf1yWUZs8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0U5zq/OzM3NdV1Cafb8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1RUz3H+iFgL7ALWAAlsz8wfRcRqYDdwEXAQuDUz3xxfqRqHe++9d6z737Vr15JtBw4cGOux1a6fnv8D4PuZ+WXgGuC7EfFl4E5gb2ZeAuxtXktaJnqGPzOPZuZzzfI7wBxwAbAB2NlsthO4eVxFShq907rnj4iLgK8AvwHWZObRpul15m8LJC0TfT/bHxErgV8A38vMP0X8/3RgmZlLzcMXEZuBzcMWKmm0+ur5I2IF88H/WWb+sll9LCJmmvYZ4Phi783M7Zk5m5mzoyhY0mj0DH/Md/EPAXOZ+cMFTXuATc3yJuCx0ZcnaVz6uez/G+DvgJci4vlm3V3A/cCjEXEH8Hvg1vGUqHG68MILx7r/EydOLNn27rvvjvXYatcz/Jn5X8BS831/bbTlSJoUn/CTijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxXV93Rd0mLee++91va5ubkJVaLTZc8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0VFZrZvELEW2AWsARLYnpk/ioh7gG8DpyZgvyszH++xr/aDaeKuv/761vYnn3yytf3IkSOt7WvXrj3tmjSczIx+tuvnIZ8PgO9n5nMR8Xng2Yh4omnblpk/GLRISd3pGf7MPAocbZbfiYg54IJxFyZpvE7rnj8iLgK+AvymWbUlIl6MiB0RsWqJ92yOiH0RsW+oSiWNVN/hj4iVwC+A72Xmn4AfA18CrmT+ymDrYu/LzO2ZOZuZsyOoV9KI9BX+iFjBfPB/lpm/BMjMY5n5YWZ+BPwEuHp8ZUoatZ7hj4gAHgLmMvOHC9bPLNjsm8DLoy9P0rj0M9R3LfCfwEvAR83qu4CNzF/yJ3AQ+E7z4WDbvhzqk8as36G+nuEfJcMvjV+/4fcJP6kowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGTnqL7j8DvF7z+QrNuGk1rbdNaF1jboEZZ24X9bjjRf8//qYNH7JvW7/ab1tqmtS6wtkF1VZuX/VJRhl8qquvwb+/4+G2mtbZprQusbVCd1NbpPb+k7nTd80vqSCfhj4j1EXEgIl6LiDu7qGEpEXEwIl6KiOe7nmKsmQbteES8vGDd6oh4IiJebX4vOk1aR7XdExGHm3P3fETc1FFtayPiPyJif0S8EhH/0Kzv9Ny11NXJeZv4ZX9EnAn8FrgBOAQ8A2zMzP0TLWQJEXEQmM3MzseEI+I64CSwKzMvb9b9C/BGZt7f/MW5KjP/cUpquwc42fXMzc2EMjMLZ5YGbgb+ng7PXUtdt9LBeeui578aeC0zf5eZfwZ+DmzooI6pl5lPA298YvUGYGezvJP5PzwTt0RtUyEzj2bmc83yO8CpmaU7PXctdXWii/BfAPxhwetDTNeU3wn8OiKejYjNXReziDULZkZ6HVjTZTGL6Dlz8yR9YmbpqTl3g8x4PWp+4Pdp12bmVcCNwHeby9uplPP3bNM0XNPXzM2TssjM0n/R5bkbdMbrUesi/IeBtQtef7FZNxUy83Dz+zjwK6Zv9uFjpyZJbX4f77iev5immZsXm1maKTh30zTjdRfhfwa4JCLWRcTngG8Bezqo41Mi4tzmgxgi4lzg60zf7MN7gE3N8ibgsQ5r+Zhpmbl5qZml6fjcTd2M15k58R/gJuY/8f9f4J+6qGGJui4GXmh+Xum6NuAR5i8D32f+s5E7gL8C9gKvAk8Cq6eotn9lfjbnF5kP2kxHtV3L/CX9i8Dzzc9NXZ+7lro6OW8+4ScV5Qd+UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeK+j/KdNy0Sdc/gwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# With a trained net, we can pick some examples and run GeoCert on them \n",
    "EXAMPLE_NUMBER = 70\n",
    "data, labels = next(iter(os_valset)) # select a minibatch of the validation set\n",
    "example = data[EXAMPLE_NUMBER]\n",
    "print(example.shape)\n",
    "plt.gray()\n",
    "plt.imshow(example.squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Builds an object used to to hold algorithm parameters\n",
    "geocert = geo.IncrementalGeoCert(network, verbose=True, config_fxn='serial',\n",
    "                                 config_fxn_kwargs={'use_clarkson': True})\n",
    "\n",
    "# Runs the algorithm\n",
    "start = time.time()\n",
    "lp_dist, cw_bound, cw_example, adver_example, boundary = geocert.min_dist(example.view(1, -1), lp_norm='l_2', \n",
    "                                                                          compute_upper_bound=True)\n",
    "end = time.time() \n",
    "\n",
    "# Prints outputs\n",
    "print(\"Found an adversarial example at dist\", lp_dist, \" in %.02f seconds \" % (end - start))\n",
    "if cw_bound is not None:\n",
    "    print(\"Carlini-Wagner attack found example at distance\", cw_bound)\n",
    "original_logits = network(example)\n",
    "adver_logits = network(torch.Tensor(adver_example).view(1, 28, 28))\n",
    "print(\"Original output was \", original_logits.data.cpu().numpy())\n",
    "print(\"Adversarial output was \", adver_logits.data.cpu().numpy())\n",
    "\n",
    "# Display the adversarial examples \n",
    "to_displays = [(example.cpu().numpy().reshape((28, 28)), 'Original'), \n",
    "               (adver_example.reshape((28, 28)), 'GeoCert')]\n",
    "if cw_example is not None: \n",
    "    to_displays.append((cw_example.reshape((28, 28)), 'Carlini-Wagner L2'))\n",
    "f, axarr = plt.subplots(1, len(to_displays), figsize=(12, 12))\n",
    "for i in range(len(to_displays)):\n",
    "    axarr[i].imshow(to_displays[i][0])\n",
    "    axarr[i].set_title(to_displays[i][1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "geocert = geo.IncrementalGeoCert(network, verbose=True, config_fxn='serial',\n",
    "                                 config_fxn_kwargs={'use_clarkson': True})\n",
    "initial_poly = Polytope.from_polytope_dict(network.compute_polytope(example, True))\n",
    "seen_dict = geocert.seen_to_polytope_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#geocert._update_step(initial_poly, None)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NOT USING CLARKSON\n",
      "SHAPE (76,)\n",
      "SHAPE (76,)\n",
      "SHAPE (76,)\n",
      "SHAPE (76,)\n",
      "SHAPE (76,)\n",
      "SHAPE (76,)\n",
      "SHAPE (74,)\n",
      "SHAPE (74,)\n",
      "SHAPE (74,)\n",
      "SHAPE (74,)\n",
      "SHAPE (73,)\n",
      "SHAPE (73,)\n",
      "SHAPE (73,)\n",
      "SHAPE (71,)\n",
      "SHAPE (70,)\n",
      "SHAPE (68,)\n",
      "SHAPE (68,)\n",
      "SHAPE (67,)\n",
      "SHAPE (67,)\n",
      "SHAPE (67,)\n",
      "SHAPE (66,)\n",
      "SHAPE (66,)\n",
      "SHAPE (65,)\n",
      "SHAPE (64,)\n",
      "SHAPE (59,)\n",
      "SHAPE (56,)\n",
      "SHAPE (56,)\n",
      "SHAPE (56,)\n",
      "SHAPE (56,)\n",
      "SHAPE (56,)\n",
      "SHAPE (56,)\n",
      "SHAPE (56,)\n",
      "SHAPE (55,)\n",
      "SHAPE (55,)\n",
      "SHAPE (52,)\n",
      "SHAPE (51,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "SHAPE (50,)\n",
      "CONFIG [9] is weird\n",
      "CONFIG [15] is weird\n",
      "CONFIG [18] is weird\n",
      "CONFIG [20] is weird\n",
      "CONFIG [23] is weird\n",
      "CONFIG [31] is weird\n",
      "CONFIG [33] is weird\n",
      "CONFIG [35] is weird\n",
      "CONFIG [41] is weird\n",
      "CONFIG [45] is weird\n",
      "CONFIG [47] is weird\n",
      "CONFIG [48] is weird\n",
      "CONFIG [49] is weird\n",
      "CONFIG [50] is weird\n",
      "CONFIG [51] is weird\n",
      "CONFIG [53] is weird\n",
      "CONFIG [54] is weird\n",
      "CONFIG [65] is weird\n",
      "CONFIG [67] is weird\n",
      "{'infeasible': 26}\n"
     ]
    }
   ],
   "source": [
    "facets, rejects = initial_poly.generate_facets_configs(seen_dict, network, use_clarkson=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True True\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'status': 'optimal',\n",
       " 'x': <784x1 matrix, tc='d'>,\n",
       " 's': <75x1 matrix, tc='d'>,\n",
       " 'y': <1x1 matrix, tc='d'>,\n",
       " 'z': <75x1 matrix, tc='d'>,\n",
       " 'primal objective': 0.0,\n",
       " 'dual objective': -0.0,\n",
       " 'gap': 0.0,\n",
       " 'relative gap': None,\n",
       " 'primal infeasibility': 2.7137676122440157e-13,\n",
       " 'dual infeasibility': 0.0,\n",
       " 'primal slack': -3.04858385829813e-13,\n",
       " 'dual slack': -0.0,\n",
       " 'residual as primal infeasibility certificate': None,\n",
       " 'residual as dual infeasibility certificate': None}"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weird_facet = [facet for facet in facets if facet.tight_list[0] == 9][0]\n",
    "print(weird_facet.is_feasible, weird_facet.is_facet)\n",
    "\n",
    "weird_facet.config\n",
    "\n",
    "int_point = (np.matmul(weird_facet.poly_a, weird_facet.interior).squeeze() - weird_facet.poly_b)\n",
    "idxs = [i for i, el in enumerate(int_point ) if el >= 0]\n",
    "\n",
    "\n",
    "from cvxopt import matrix, solvers\n",
    "#solvers.options['show_progress'] = False\n",
    "c = np.zeros(weird_facet.poly_a.shape[1])\n",
    "\n",
    "A_upper = np.eye(784)\n",
    "b_upper = np.ones(784)\n",
    "A_lower = - np.eye(784)\n",
    "b_lower = np.zeros(784)\n",
    "\n",
    "full_a = np.vstack((weird_facet.poly_a, A_upper, A_lower))\n",
    "full_b = np.vstack((weird_facet.poly_b, b_upper, b_lower))\n",
    "def test_fxn():\n",
    "    cvxopt_out = solvers.lp(matrix(c), matrix(weird_facet.poly_a),\n",
    "                        matrix(weird_facet.poly_b),\n",
    "                        A=matrix(weird_facet.a_eq), b=matrix(weird_facet.b_eq),\n",
    "                        solver='glpk')\n",
    "    return cvxopt_out\n",
    "    \n",
    "    \n",
    "def test_fxn2(): \n",
    "    cvxopt_out = solvers.lp(matrix(c), matrix(weird_facet.poly_a),\n",
    "                        matrix(weird_facet.poly_b),\n",
    "                        A=matrix(weird_facet.a_eq), b=matrix(weird_facet.b_eq),\n",
    "                        solver='glpk')    \n",
    "test_fxn()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17.5 ms ± 1.18 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 100 -r 5 test_fxn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "facets[7].is_feasible = None \n",
    "\n",
    "facets[7].check_facet()\n",
    "print(\"TIGHT CONSTRAINT\", facets[7].tight_list)\n",
    "weird_constraints = [i for i, el  in enumerate(((np.matmul(facets[7].poly_a, facets[7].interior)).squeeze()- (facets[7].poly_b)) >= 0 )\n",
    "     if el]\n",
    "print(weird_constraints)\n",
    "print(facets[7].poly_a[3][:10], facets[7].poly_b[0] )\n",
    "print(facets[7].poly_a[6][:10], facets[7].poly_b[3])\n",
    "\n",
    "((np.matmul(facets[7].poly_a, facets[7].interior)).squeeze() - facets[7].poly_b) >= 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lo = torch.Tensor(facets[0].interior.T + facets[0].a_eq * 1e-4)\n",
    "hi = torch.Tensor(facets[0].interior.T - facets[0].a_eq * 1e-4)\n",
    "lo_pr, lo_conf = network.relu_config(lo)\n",
    "hi_pr, hi_conf = network.relu_config(hi)\n",
    "\n",
    "multi_pr = []\n",
    "for el in zip(lo_pr, hi_pr):\n",
    "    multi_pr.append(el[0] * el[1])\n",
    "\n",
    "print([_ < 0 for _ in multi_pr])\n",
    "utils.flatten_config(facets[0].config) in [utils.flatten_config(hi_conf), utils.flatten_config(lo_conf)]\n",
    "#lo_conf, hi_conf\n",
    "#facets[0].interior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_to_config_coord(config, index):\n",
    "    \"\"\" Given an index of the flattened array, returns the 2d index of where\n",
    "        this corresponds to the configs\n",
    "    \"\"\"\n",
    "    config_shapes = [_.numel() for _ in config]\n",
    "    assert index < sum(config_shapes)\n",
    "\n",
    "    for i, config_len in enumerate(config_shapes):\n",
    "        if index > config_len - 1:\n",
    "            index -= config_len\n",
    "        else:\n",
    "            return (i, index)\n",
    "\n",
    "        \n",
    "config = initial_poly.config\n",
    "for i in range(100):\n",
    "    print(i, index_to_config_coord(config, i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = 'alpha'\n",
    "b = 'alzxa'\n",
    "sum(1 for el in zip(a,b) if el[0]!= el[1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
    "            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
    "            0, 0] ,[0, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n",
    "[len(el) for el in layers]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len([0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_to_config_coord(config, index):\n",
    "    \"\"\" Given an index of the flattened array, returns the 2d index of where\n",
    "        this corresponds to the configs\n",
    "    \"\"\"\n",
    "    config_shapes = [_.numel() for _ in config]\n",
    "    assert index < sum(config_shapes)\n",
    "\n",
    "    for i, config_len in enumerate(config_shapes):\n",
    "        if index >= config_len:\n",
    "            index -= config_len\n",
    "        else:\n",
    "            return (i, index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_poly.config[0][3] = int(1 - initial_poly.config[0][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_poly.config[0][3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('os_trainset.pkl', 'wb') as f:\n",
    "    pickle.dump(os_trainset, f)\n",
    "    \n",
    "with open('os_valset.pkl', 'wb') as g:\n",
    "    pickle.dump(os_valset, g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3",
   "language": "python",
   "name": "py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
