{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import cv2\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from torchray.utils import get_device\n",
    "from pathlib import Path\n",
    "\n",
    "from data.dataloader import VOCDataModule, COCODataModule\n",
    "from utils.image_utils import get_unnormalized_image\n",
    "from utils.helper import get_targets_from_annotations, get_filename_from_annotations, extract_masks\n",
    "from models.explainer_classifier import ExplainerClassifierModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_cam_on_image(img: np.ndarray,\n",
    "                      mask: np.ndarray,\n",
    "                      use_rgb: bool = False,\n",
    "                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:\n",
    "    \"\"\" This function overlays the cam mask on the image as an heatmap.\n",
    "    By default the heatmap is in BGR format.\n",
    "    :param img: The base image in RGB or BGR format.\n",
    "    :param mask: The cam mask.\n",
    "    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.\n",
    "    :param colormap: The OpenCV colormap to be used.\n",
    "    :returns: The default image with the cam overlay.\n",
    "    \"\"\"\n",
    "    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), colormap)\n",
    "    if use_rgb:\n",
    "        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n",
    "    heatmap = np.float32(heatmap) / 255\n",
    "\n",
    "    if np.max(img) > 1:\n",
    "        raise Exception(\n",
    "            \"The input image should np.float32 in the range [0, 1]\")\n",
    "\n",
    "    cam = heatmap + img\n",
    "    cam = cam / np.max(cam)\n",
    "    return np.uint8(255 * cam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_class_mask(explainer, image, class_id):\n",
    "    class_mask = explainer(image)[0][class_id].sigmoid()\n",
    "\n",
    "    return class_mask.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_topk_classes(explainer, image, k=5):\n",
    "    class_masks = explainer(image)[0].sigmoid()\n",
    "    class_mask_means = class_masks.mean(dim=(1,2))\n",
    "\n",
    "    values, topk_classes = class_mask_means.topk(k)\n",
    "    return values.cpu().numpy(), topk_classes.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_class_scores(explainer, classifier, image, class_id):\n",
    "    class_mask = explainer(image)[0][class_id].sigmoid()\n",
    "    masked_image = class_mask.unsqueeze(0).unsqueeze(0) * image\n",
    "\n",
    "    unmasked_logits = classifier(image)[0]\n",
    "    masked_logits = classifier(masked_image)[0]\n",
    "\n",
    "    unmasked_class_prob = unmasked_logits.sigmoid()[class_id]\n",
    "    masked_class_prob = masked_logits.sigmoid()[class_id]\n",
    "\n",
    "    return unmasked_class_prob, masked_class_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_dataloader(dataset, data_path):\n",
    "    if dataset == \"VOC\":\n",
    "        data_module = VOCDataModule(data_path, test_batch_size=1)\n",
    "    elif dataset == \"COCO\":\n",
    "        data_module = COCODataModule(data_path, test_batch_size=1)\n",
    "        \n",
    "    data_module.setup(stage=\"test\")\n",
    "    test_dataloader = data_module.test_dataloader()\n",
    "\n",
    "    return test_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"VOC\"\n",
    "num_classes = 20\n",
    "data_path = Path(\"../datasets/VOC2007/\")\n",
    "classifier_type = \"resnet50\"\n",
    "explainer_classifier_path = Path(\"../src/checkpoints/explainer_resnet50_voc.ckpt\")\n",
    "output_dir = Path(f\"./topk_attributions/{classifier_type}\")\n",
    "\n",
    "explainer_classifier = ExplainerClassifierModel.load_from_checkpoint(explainer_classifier_path, \n",
    "                                                                     num_classes=num_classes, \n",
    "                                                                     dataset=dataset, \n",
    "                                                                     classifier_type=classifier_type)\n",
    "                                                                     \n",
    "device = get_device()\n",
    "explainer = explainer_classifier.explainer.to(device)\n",
    "explainer.freeze()\n",
    "classifier = explainer_classifier.classifier.to(device)\n",
    "classifier.freeze()\n",
    "\n",
    "dataloader = get_test_dataloader(dataset, data_path)\n",
    "image_list = list(enumerate(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "random.shuffle(image_list)\n",
    "n_images = 8\n",
    "count = 0\n",
    "for i, item in image_list:\n",
    "    image, annotation = item\n",
    "    image = image.to(device)\n",
    "    filename = get_filename_from_annotations(annotation, dataset=dataset)[:-4]\n",
    "    targets = get_targets_from_annotations(annotation, dataset=dataset)\n",
    "    target_classes = [i for i, val in enumerate(targets[0]) if val == 1.0]\n",
    "    topk_values, topk_classes = get_topk_classes(explainer, image, k=5)\n",
    "\n",
    "    fig = plt.figure(figsize=(25, 5))\n",
    "    original_image = np.transpose(get_unnormalized_image(image).cpu().numpy().squeeze(), (1, 2, 0))\n",
    "    plt.imsave(output_dir / f\"{filename}_original.png\", original_image, format=\"png\")\n",
    "    plt.subplot(1, 7, 1)\n",
    "    plt.imshow(original_image)\n",
    "    plt.axis(\"off\")\n",
    "\n",
    "    segmentations = explainer(image)\n",
    "    aggregated_mask, _ = extract_masks(segmentations, targets)\n",
    "    aggregated_mask = aggregated_mask[0].cpu().numpy()\n",
    "    aggregated_attribution = show_cam_on_image(original_image, aggregated_mask)\n",
    "    plt.imsave(output_dir / f\"{filename}_aggregated.png\", aggregated_attribution, format=\"png\")\n",
    "    plt.subplot(1, 7, 2)\n",
    "    plt.imshow(aggregated_attribution)\n",
    "    plt.axis(\"off\")\n",
    "    for j, class_id in enumerate(topk_classes):\n",
    "        unmasked_class_prob, masked_class_prob = get_class_scores(explainer, classifier, image, class_id)\n",
    "        class_mask = get_class_mask(explainer, image, class_id)\n",
    "        attribution = show_cam_on_image(original_image, class_mask)\n",
    "        plt.imsave(output_dir / f\"{filename}_rank_{j}_class_{class_id+1}.png\", attribution, format=\"png\")\n",
    "        plt.subplot(1, 7, j+3)\n",
    "        plt.imshow(attribution, vmin=0, vmax=1)\n",
    "        class_title = f\"{class_id+1}**\" if class_id in target_classes else f\"{class_id+1}\"\n",
    "        plt.title(f\"{class_title}: CLS={unmasked_class_prob*100:.2f}, MASK={topk_values[j]*100:.2f}\")\n",
    "        plt.axis(\"off\")\n",
    "\n",
    "    count += 1\n",
    "    if count >= n_images:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "7ca055faf3ae29c3c66579db75b878dd7b19a8a2f3a26672c87b0ac03bac584e"
  },
  "kernelspec": {
   "display_name": "Python 3.8.6 ('pytorch_gpu')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
