{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jordanm/.virtualenvs/myvenv/lib/python3.7/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "In this example, we build a binary MNIST classifier and then run GeoCert on several test points.\n",
    "'''\n",
    "\n",
    "# =====================\n",
    "# Imports\n",
    "# =====================\n",
    "%load_ext line_profiler\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../mister_ed') # library for adversarial examples\n",
    "\n",
    "import geocert_oop as geo\n",
    "from plnn import PLNN\n",
    "import _polytope_ as _poly_\n",
    "from _polytope_ import Polytope, Face\n",
    "import utilities as utils\n",
    "import os\n",
    "import time \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "import adversarial_perturbations as ap \n",
    "import prebuilt_loss_functions as plf\n",
    "import loss_functions as lf \n",
    "import adversarial_attacks as aa\n",
    "import utils.pytorch_utils as me_utils\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import mnist.mnist_loader as  ml \n",
    "mnist_trainset = ml.load_mnist_data('train')\n",
    "mnist_valset = ml.load_mnist_data('val')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code to select only a subset of digits from MNIST datasets\n",
    "def select_digits(dataset, digits, mb_size=128, remap_label=True):\n",
    "    if remap_label and len(digits) <= 2:\n",
    "        def label_map(labels):\n",
    "            return (labels == digits[0]).unsqueeze(1)\n",
    "    elif remap_label and len(digits) > 2:\n",
    "        raise NotImplementedError(\"Only handling between 2 types of digits for now\")\n",
    "    else:\n",
    "        label_map = lambda x: x\n",
    "            \n",
    "    sel_data, sel_labels = [], [] \n",
    "    for data, labels in dataset:\n",
    "        mask = labels == -1\n",
    "        for digit in digits:\n",
    "            mask += (labels == digit)\n",
    "        #mask = (labels == 7) + (labels == 1)\n",
    "        masked_data = data.masked_select(mask.view(-1, 1, 1, 1).expand(data.shape)).view(-1, 1, 28, 28)\n",
    "        masked_labels = labels.masked_select(mask)\n",
    "        sel_data.append(masked_data)\n",
    "        sel_labels.append(masked_labels)\n",
    "        \n",
    "    # Then concatenate and resplit into minibatches of size 128\n",
    "    cat_data = torch.cat(sel_data)\n",
    "    cat_labels = torch.cat(sel_labels)\n",
    "    full_dataset = [] \n",
    "    num_mb = int(len(cat_labels) / mb_size + 1)\n",
    "    for i in range(num_mb):\n",
    "        full_dataset.append((cat_data[mb_size * i: mb_size * (i + 1)], \n",
    "                             label_map(cat_labels[mb_size * i: mb_size * (i + 1)]).squeeze().long()))\n",
    "    return full_dataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a dataset containing only 1's and 7's\n",
    "MINIBATCH_SIZE = 128\n",
    "os_trainset = select_digits(mnist_trainset, [1, 7], mb_size=MINIBATCH_SIZE)\n",
    "os_valset = select_digits(mnist_valset, [1, 7], mb_size=MINIBATCH_SIZE)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3",
   "language": "python",
   "name": "py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
