{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Bi1XwcTQl1Bp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import foolbox as fb\n",
    "\n",
    "# This code is tested with Foolbox 3.0.0b, you might\n",
    "# have to install the latest master version from git w/\n",
    "# \n",
    "# pip3 install git+https://github.com/bethgelab/foolbox.git\n",
    "#\n",
    "assert int(fb.__version__.split('.')[0]) >= 3\n",
    "\n",
    "import resnet\n",
    "\n",
    "import logging\n",
    "logger = logging.getLogger('kwtalogger')\n",
    "logger.setLevel(logging.DEBUG)\n",
    "fh = logging.FileHandler('kwta_debug.log')\n",
    "fh.setLevel(logging.DEBUG)\n",
    "formatter = logging.Formatter('%(asctime)s - %(message)s')\n",
    "fh.setFormatter(formatter)\n",
    "logger.addHandler(fh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load pretrained weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'kwta_spresnet18_0.1_cifar_adv.pth'\n",
    "url = f'https://github.com/BLINDED_FOR_REVIEW/robustness_workshop/releases/download/v0.0.1/{filename}'\n",
    "\n",
    "if not os.path.isfile(filename):\n",
    "    print('Downloading pretrained weights.')\n",
    "    urllib.request.urlretrieve(url, filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "i9Ck9uRil1B2"
   },
   "source": [
    "### load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ozOMJhwql1B8"
   },
   "outputs": [],
   "source": [
    "norm_mean = 0\n",
    "norm_var = 1\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),\n",
    "])\n",
    "\n",
    "cifar_test = datasets.CIFAR10(\"./data\", train=False, download=True, transform=transform_test)\n",
    "test_loader = DataLoader(cifar_test, batch_size = 200, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BwxyNL0bl1CG"
   },
   "source": [
    "### load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LxtxTQb7l1CI"
   },
   "outputs": [],
   "source": [
    "gamma = 0.1\n",
    "epsilon = 0.031\n",
    "filepath = f'kwta_spresnet18_{gamma}_cifar_adv.pth'\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "model = resnet.SparseResNet18(sparsities=[gamma, gamma, gamma, gamma], sparse_func='vol').to(device)\n",
    "model.load_state_dict(torch.load(filepath))\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "O4Ho8WgsotFR"
   },
   "source": [
    "### clean accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2_gt-jT-l1CQ",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "acc = 0\n",
    "total_number = 0\n",
    "\n",
    "for images, labels in test_loader:\n",
    "    logits = model(images.to(device))\n",
    "    acc += np.sum(logits.detach().cpu().numpy().argmax(1) == labels.cpu().numpy())\n",
    "    total_number += images.shape[0]\n",
    "\n",
    "# the clean accuracy is much lower than what is reported in the paper\n",
    "# but the authors claimed that this checkpoint is more robust.\n",
    "print(f'Clean accuracy is {acc / total_number:.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### final attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# wrap model as Foolbox model\n",
    "fmodel = fb.models.PyTorchModel(model, bounds=(0, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "import eagerpy as ep\n",
    "\n",
    "def best_other_classes(logits: ep.Tensor, exclude: ep.Tensor) -> ep.Tensor:\n",
    "    other_logits = logits - ep.onehot_like(logits, exclude, value=ep.inf)\n",
    "    return other_logits.argmax(axis=-1)\n",
    "\n",
    "def loss_fn(x, classes):\n",
    "    logits = fmodel(x)\n",
    "\n",
    "    c_minimize = classes\n",
    "    c_maximize = best_other_classes(logits, classes)\n",
    "\n",
    "    N = len(x)\n",
    "    rows = range(N)\n",
    "    \n",
    "    logits_diffs = logits[rows, c_minimize] - logits[rows, c_maximize]\n",
    "    assert logits_diffs.shape == (N,)\n",
    "\n",
    "    return logits_diffs\n",
    "\n",
    "def es_gradient_estimator(x, y, samples, sigma, clip=False):\n",
    "    value = loss_fn(x, y)\n",
    "\n",
    "    gradient = ep.zeros_like(x)\n",
    "    for k in range(samples // 2):\n",
    "        noise = ep.normal(x, shape=x.shape)\n",
    "\n",
    "        pos_theta = x + sigma * noise\n",
    "        neg_theta = x - sigma * noise\n",
    "\n",
    "        if clip:\n",
    "            pos_theta = pos_theta.clip(*bounds)\n",
    "            neg_theta = neg_theta.clip(*bounds)\n",
    "\n",
    "        pos_loss = loss_fn(pos_theta, y)\n",
    "        neg_loss = loss_fn(neg_theta, y)\n",
    "\n",
    "        gradient += (pos_loss - neg_loss)[:, None, None, None] * noise\n",
    "\n",
    "    gradient /= 2 * sigma * 2 * samples\n",
    "\n",
    "    return gradient\n",
    "\n",
    "def gradient_estimator_pgd(images, labels):\n",
    "    ep_images = ep.astensor(images.to(device))\n",
    "    ep_labels = ep.astensor(labels.to(device))\n",
    "\n",
    "    deltas = ep.zeros_like(ep_images)\n",
    "\n",
    "    samples = 100\n",
    "    sigma = 8/255\n",
    "    lr = 0.01\n",
    "\n",
    "    adversarials = ep.zeros_like(ep_images)\n",
    "    mask = loss_fn(ep_images, ep_labels) >= 0\n",
    "\n",
    "    for it in range(100):\n",
    "        if it < 20:\n",
    "            samples = 100\n",
    "        elif it < 40:\n",
    "            samples = 1000\n",
    "        else:\n",
    "            samples = 20000\n",
    "\n",
    "        pert_images = (ep_images + deltas).clip(0, 1)\n",
    "        grads = es_gradient_estimator(pert_images[mask], ep_labels[mask], samples, sigma)\n",
    "\n",
    "        # update only subportion of deltas\n",
    "        _deltas = np.array(deltas.numpy())\n",
    "        _deltas[mask.numpy()] = (deltas[mask] - lr * grads.sign()).numpy()\n",
    "        deltas = ep.from_numpy(deltas, _deltas)\n",
    "\n",
    "        deltas = deltas.clip(-epsilon, epsilon)\n",
    "        pert_images = (ep_images + deltas).clip(0, 1)\n",
    "\n",
    "        new_logit_diffs = loss_fn(pert_images, ep_labels)\n",
    "        mask = new_logit_diffs >= 0\n",
    "\n",
    "        values = new_logit_diffs.numpy()\n",
    "        message = f'({it} / {mask.sum()}) {float(new_logit_diffs.mean().raw):.3f}: {np.array2string(values[mask.numpy()], precision=2, separator=\",\")}'\n",
    "        logger.debug(message)\n",
    "\n",
    "        if mask.sum() == 0:\n",
    "            break\n",
    "            \n",
    "    return adversarials, mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = 0\n",
    "total_images = 0\n",
    "\n",
    "for k, (images, labels) in enumerate(test_loader):\n",
    "    perturbations, correct = gradient_estimator_pgd(images, labels)\n",
    "    \n",
    "    acc += float(correct.sum().numpy())\n",
    "    total_images += images.shape[0]\n",
    "    print(f'({k}) model accuracy on perturbed images is {100 * acc / total_images:.1f}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "include_colab_link": true,
   "name": "kwta_attack.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
