{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GTiLRzJ_rYA-"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import torch\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import foolbox as fb\n",
    "import eagerpy as ep\n",
    "\n",
    "import transforms\n",
    "import resnet_3layer as resnet\n",
    "\n",
    "# This code is tested with Foolbox 3.0.0b, you might\n",
    "# have to install the latest master version from git w/\n",
    "# \n",
    "# pip3 install git+https://github.com/bethgelab/foolbox.git\n",
    "#\n",
    "assert int(fb.__version__.split('.')[0]) >= 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5TDsX2H2rYBN"
   },
   "outputs": [],
   "source": [
    "num_sample_MIOL = 15\n",
    "lamdaOL = 0.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'mixup_model_IAT.ckpt'\n",
    "url = f'https://github.com/BLINDED_FOR_REVIEW/robustness_workshop/releases/download/v0.0.1/{filename}'\n",
    "\n",
    "if not os.path.isfile(filename):\n",
    "    print('Downloading pretrained weights.')\n",
    "    urllib.request.urlretrieve(url, filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GR5LcewzrYBc"
   },
   "source": [
    "### Load backbone model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Do8LodSorYBf"
   },
   "outputs": [],
   "source": [
    "CLASSIFIER = resnet.model_dict['resnet50']\n",
    "classifier = CLASSIFIER(num_classes=10)\n",
    "\n",
    "device = torch.device(\"cuda:0\")\n",
    "classifier = classifier.to(device)\n",
    "\n",
    "classifier.load_state_dict(torch.load('mixup_model_IAT.ckpt'))\n",
    "classifier.eval();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "onppDLrirYBo"
   },
   "source": [
    "### Construct image pools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Cd7GM1_BrYBs"
   },
   "outputs": [],
   "source": [
    "def onehot(ind):\n",
    "    vector = np.zeros([10])\n",
    "    vector[ind] = 1\n",
    "    return vector.astype(np.float32)\n",
    "\n",
    "train_trans, test_trans = transforms.cifar_transform()\n",
    "trainset = torchvision.datasets.CIFAR10(root='~/cifar/',\n",
    "                                        train=False,\n",
    "                                        download=True,\n",
    "                                        transform=train_trans,\n",
    "                                        target_transform=onehot)\n",
    "testset = torchvision.datasets.CIFAR10(root='~/cifar/',\n",
    "                                       train=False,\n",
    "                                       download=True,\n",
    "                                       transform=test_trans,\n",
    "                                       target_transform=onehot)\n",
    "\n",
    "# we reduce the testset for this workshop\n",
    "testset.data = testset.data\n",
    "\n",
    "dataloader_train = torch.utils.data.DataLoader(\n",
    "    trainset,\n",
    "    batch_size=1,\n",
    "    shuffle=True,\n",
    "    num_workers=2)\n",
    "\n",
    "dataloader_test = torch.utils.data.DataLoader(\n",
    "    testset,\n",
    "    batch_size=10,\n",
    "    shuffle=False,\n",
    "    num_workers=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3sBHEl2DrYB1",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "num_pool = 10000\n",
    "mixup_pool_OL = {}\n",
    "\n",
    "for i in range(10):\n",
    "    mixup_pool_OL.update({i: []})\n",
    "\n",
    "for i, data_batch in tqdm(enumerate(dataloader_train), total=num_pool):\n",
    "    img_batch, label_batch = data_batch\n",
    "    img_batch = img_batch.to(device)\n",
    "    _, label_ind = torch.max(label_batch.data, 1)\n",
    "    mixup_pool_OL[label_ind.numpy()[0]].append(img_batch)\n",
    "    if i >= (num_pool - 1):\n",
    "        break\n",
    "\n",
    "print('Finish constructing mixup_pool_OL')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "W79o71oSrYB8"
   },
   "source": [
    "### Construct surrogate models that wrap OL within model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     5
    ],
    "colab": {},
    "colab_type": "code",
    "id": "ySaBEdFmrYB_"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "soft_max = nn.Softmax(dim=-1)\n",
    "\n",
    "class CombinedModel(nn.Module):\n",
    "    def __init__(self, classifier):\n",
    "        super(CombinedModel, self).__init__()\n",
    "        self.classifier = classifier\n",
    "\n",
    "    def forward(self, img_batch):\n",
    "        pred_cle_mixup_all_OL = 0 # torch.Tensor([0.]*10)\n",
    "        \n",
    "        # forward pass without PL/OL\n",
    "        pred_cle = self.classifier(img_batch)\n",
    "        cle_con, predicted_cle = torch.max(soft_max(pred_cle.data), 1)\n",
    "        predicted_cle = predicted_cle.cpu().numpy()\n",
    "            \n",
    "        # perform MI-OL\n",
    "        for k in range(num_sample_MIOL):\n",
    "            mixup_img_batch = np.empty(img_batch.shape, dtype=np.float32)\n",
    "            \n",
    "            for b in range(img_batch.shape[0]):\n",
    "                # CLEAN\n",
    "                xs_cle_label = np.random.randint(10)\n",
    "                while xs_cle_label == predicted_cle[b]:\n",
    "                    xs_cle_label = np.random.randint(10)\n",
    "                xs_cle_index = np.random.randint(len(mixup_pool_OL[xs_cle_label]))\n",
    "                mixup_img_cle = (1 - lamdaOL) * mixup_pool_OL[xs_cle_label][xs_cle_index][0]\n",
    "                mixup_img_batch[b] = mixup_img_cle.cpu().detach().numpy()\n",
    "\n",
    "            mixup_img_batch = ep.from_numpy(ep.astensor(img_batch), mixup_img_batch).raw + lamdaOL * img_batch\n",
    "            pred_cle_mixup = classifier(mixup_img_batch)\n",
    "            pred_cle_mixup_all_OL = pred_cle_mixup_all_OL + soft_max(pred_cle_mixup)\n",
    "\n",
    "        pred_cle_mixup_all_OL = pred_cle_mixup_all_OL / num_sample_MIOL\n",
    "\n",
    "        return pred_cle_mixup_all_OL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0m0rRzwHrYCM"
   },
   "outputs": [],
   "source": [
    "combined_classifier = CombinedModel(classifier)\n",
    "combined_classifier.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bizzUyAtrYCT"
   },
   "outputs": [],
   "source": [
    "iAT_model = fb.models.PyTorchModel(classifier, bounds=(-1, 1), device=device)\n",
    "iAT_OL_model = fb.models.PyTorchModel(combined_classifier, bounds=(-1, 1), device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = 0\n",
    "total_samples = 0\n",
    "epsilon = 8 / 255\n",
    "\n",
    "attack = fb.attacks.LinfPGD(steps=50, abs_stepsize=1 / 255)\n",
    "from tqdm import tqdm\n",
    "\n",
    "for images, labels in tqdm(dataloader_test):\n",
    "    images = images.to(device)\n",
    "    labels = labels.argmax(1).to(device)\n",
    "    N = len(images)\n",
    "    \n",
    "    # PGD returns three values: (1) the raw adversarial images as returned by the\n",
    "    # attack, (2) the raw adversarials clipped to the valid epsilon region and\n",
    "    # (3) a boolean tensor indicating which perturbations are actually adversarial\n",
    "    adv, adv_clipped, adv_mask = attack(iAT_OL_model, images, labels, epsilons=2 * epsilon)\n",
    "\n",
    "    acc += fb.utils.accuracy(iAT_OL_model, adv_clipped, labels) * N\n",
    "    total_samples += N\n",
    "    \n",
    "print()\n",
    "print(f'Model accuracy on adversarial examples: {acc / total_samples:.3f}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "include_colab_link": true,
   "name": "mixup_attack.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
