{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torchvision.transforms as T\n",
    "from tqdm.notebook import tqdm\n",
    "from PIL import Image\n",
    "from pathlib import Path\n",
    "from torchray.utils import get_device\n",
    "from IPython.display import clear_output, display, Latex\n",
    "\n",
    "from models.classifier import VGG16ClassifierModel, Resnet50ClassifierModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'VOC'\n",
    "num_classes = 20\n",
    "data_path = Path('../datasets/VOC2007/VOCdevkit/VOC2007/JPEGImages/')\n",
    "classifier_type = 'vgg16'\n",
    "classifier_checkpoint = Path('../src/checkpoints/pretrained_classifiers/vgg16_voc.ckpt')\n",
    "\n",
    "masks_base_path = Path('../src/evaluation/masks/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['explainer', 'grad_cam', 'rise', 'rt_saliency']\n",
    "\n",
    "mask_objects = ['bottle', 'car', 'cat', 'dog', 'person']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_dict = {'aeroplane' : 0, 'bicycle' : 1, 'bird' : 2, 'boat' : 3, 'bottle' : 4, 'bus' : 5, 'car' : 6, \n",
    "                'cat' : 7, 'chair' : 8, 'cow' : 9, 'diningtable' : 10, 'dog' : 11, 'horse' : 12, 'motorbike' : 13, 'person' : 14, \n",
    "                'pottedplant' : 15, 'sheep' : 16, 'sofa' : 17, 'train' : 18, 'tvmonitor' : 19}\n",
    "\n",
    "inv_target_dict = {value: key for key, value in target_dict.items()}\n",
    "\n",
    "mask_classes = [target_dict[obj] for obj in mask_objects]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_count_dict = {'aeroplane': 205, 'bicycle': 250, 'bird': 289, 'boat': 176, 'bottle': 240, 'bus': 183, 'car': 775, \n",
    "                    'cat': 332, 'chair': 545, 'cow': 127, 'diningtable': 247, 'dog': 433, 'horse': 279, 'motorbike': 233, \n",
    "                    'person': 2097, 'pottedplant': 254, 'sheep': 98, 'sofa': 355, 'train': 259, 'tvmonitor': 255}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "if classifier_type == 'vgg16':\n",
    "    model = VGG16ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset)\n",
    "elif classifier_type == 'resnet50':\n",
    "    model = Resnet50ClassifierModel.load_from_checkpoint(classifier_checkpoint, num_classes=num_classes, dataset=dataset)\n",
    "else:\n",
    "    raise Exception(\"Unknown classifier type \" + classifier_type)\n",
    "device = get_device()\n",
    "model.to(device)\n",
    "model.eval()\n",
    "for param in model.parameters():\n",
    "    param.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_results(model, image, mask, class_id):\n",
    "    thresholds = np.arange(0.1, 1.0, 0.1)\n",
    "    outputs = []\n",
    "    for threshold in thresholds:\n",
    "        thresh_mask = (mask > threshold).float()\n",
    "        masked_image = thresh_mask * image\n",
    "        output_probs = torch.nn.Softmax(dim=1)(model(masked_image))\n",
    "        outputs.append(output_probs[0][class_id].cpu().numpy())\n",
    "\n",
    "    return np.mean(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    results = np.load('class_scores_results.npz', allow_pickle=True)['results'].item()\n",
    "except:\n",
    "    results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating classifier on unmasked person images\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8b44f2ca2dc42b99900aad8dcbdb305",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2097.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transformer = T.Compose([ T.Resize(size=(224,224)),\n",
    "                          T.ToTensor(), \n",
    "                          T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])\n",
    "\n",
    "\n",
    "for target_class in mask_classes:\n",
    "    #Only need filenames from the directory, and not masks, therefore we can just take any method here\n",
    "    masks_dir = (masks_base_path / '{}_{}_{}'.format(dataset, classifier_type, \"explainer\") \n",
    "                                 / 'class_masks'\n",
    "                                 / 'target_class_{}'.format(target_class)\n",
    "                                 / 'masks_for_class_{}'.format(target_class))\n",
    "\n",
    "    clear_output(wait=True)\n",
    "    print(\"Evaluating classifier on unmasked {} images\".format(inv_target_dict[target_class]))\n",
    "\n",
    "    all_scores = []\n",
    "    for filename in tqdm(masks_dir.glob('*.npz'), total=class_count_dict[inv_target_dict[target_class]]):\n",
    "        jpeg_filename = os.path.splitext(filename.name)[0] + '.jpg'\n",
    "        image = Image.open(data_path / jpeg_filename).convert(\"RGB\")\n",
    "        image = transformer(image).unsqueeze(0)\n",
    "        image = image.to(device)\n",
    "\n",
    "        output_probs = torch.nn.Softmax(dim=1)(model(image))[0]\n",
    "        target_prob = output_probs[target_class].cpu().numpy()\n",
    "        all_scores.append(target_prob)\n",
    "\n",
    "    for method in methods:\n",
    "        if method not in results:\n",
    "            results[method] = {}\n",
    "        if inv_target_dict[target_class] not in results[method]:\n",
    "            results[method][inv_target_dict[target_class]] = {}\n",
    "        \n",
    "        results[method][inv_target_dict[target_class]]['no_mask'] = {}\n",
    "        results[method][inv_target_dict[target_class]]['no_mask']['mean_probs'] = np.mean(all_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating method rt_saliency for person images with person masks\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "855412798f77461081c23e0e2cb59985",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2097.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transformer = T.Compose([ T.Resize(size=(224,224)),\n",
    "                          T.ToTensor(), \n",
    "                          T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])\n",
    "\n",
    "for method in methods:\n",
    "    if method not in results:\n",
    "        results[method] = {}\n",
    "    for target_class in mask_classes:\n",
    "        if target_class not in results[method]:\n",
    "            results[method][inv_target_dict[target_class]] = {}\n",
    "        for mask_class in mask_classes:\n",
    "            results[method][inv_target_dict[target_class]][inv_target_dict[mask_class]] = {}\n",
    "\n",
    "            masks_dir = (masks_base_path / '{}_{}_{}'.format(dataset, classifier_type, method)\n",
    "                                         / 'class_masks'\n",
    "                                         / 'target_class_{}'.format(target_class)\n",
    "                                         / 'masks_for_class_{}'.format(mask_class))\n",
    "\n",
    "            clear_output(wait=True)\n",
    "            print(\"Evaluating method {} for {} images with {} masks\".format(method, \n",
    "                                                                            inv_target_dict[target_class],\n",
    "                                                                            inv_target_dict[mask_class]))\n",
    "\n",
    "            all_scores = []\n",
    "            for filename in tqdm(masks_dir.glob('*.npz'), total=class_count_dict[inv_target_dict[target_class]]):\n",
    "                jpeg_filename = os.path.splitext(filename.name)[0] + '.jpg'\n",
    "                image = Image.open(data_path / jpeg_filename).convert(\"RGB\")\n",
    "                image = transformer(image).unsqueeze(0)\n",
    "                image = image.to(device)\n",
    "\n",
    "                mask = np.load(filename)['arr_0']\n",
    "                mask = torch.tensor(np.reshape(mask, [1,1, *mask.shape]), device=device)\n",
    "                score = compute_results(model=model, image=image, mask=mask, class_id=target_class)\n",
    "                all_scores.append(score)\n",
    "\n",
    "            results[method][inv_target_dict[target_class]][inv_target_dict[mask_class]]['mean_probs'] = np.mean(all_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('class_scores_results.npz', results=results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "explainer\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_c512d_row0_col0, #T_c512d_row1_col0, #T_c512d_row2_col0, #T_c512d_row3_col0, #T_c512d_row4_col5 {\n",
       "  background-color: #081d58;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c512d_row0_col1 {\n",
       "  background-color: #1f82b9;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c512d_row0_col2, #T_c512d_row0_col3, #T_c512d_row0_col4, #T_c512d_row0_col5, #T_c512d_row3_col1, #T_c512d_row3_col2, #T_c512d_row3_col5 {\n",
       "  background-color: #ffffd9;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row1_col1, #T_c512d_row1_col3, #T_c512d_row1_col4, #T_c512d_row4_col4 {\n",
       "  background-color: #fcfed1;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row1_col2 {\n",
       "  background-color: #22328f;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c512d_row1_col5 {\n",
       "  background-color: #fafdcf;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row2_col1, #T_c512d_row2_col2, #T_c512d_row3_col3, #T_c512d_row4_col2 {\n",
       "  background-color: #fdfed4;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row2_col3 {\n",
       "  background-color: #24479d;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c512d_row2_col4 {\n",
       "  background-color: #fcfed3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row2_col5 {\n",
       "  background-color: #feffd6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row3_col4 {\n",
       "  background-color: #091e5a;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c512d_row4_col0 {\n",
       "  background-color: #0d2163;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c512d_row4_col1 {\n",
       "  background-color: #f7fcc6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c512d_row4_col3 {\n",
       "  background-color: #fdfed5;\n",
       "  color: #000000;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_c512d\">\n",
       "  <caption>explainer</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_c512d_level0_col0\" class=\"col_heading level0 col0\" >none</th>\n",
       "      <th id=\"T_c512d_level0_col1\" class=\"col_heading level0 col1\" >bottle</th>\n",
       "      <th id=\"T_c512d_level0_col2\" class=\"col_heading level0 col2\" >car</th>\n",
       "      <th id=\"T_c512d_level0_col3\" class=\"col_heading level0 col3\" >cat</th>\n",
       "      <th id=\"T_c512d_level0_col4\" class=\"col_heading level0 col4\" >dog</th>\n",
       "      <th id=\"T_c512d_level0_col5\" class=\"col_heading level0 col5\" >person</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_c512d_level0_row0\" class=\"row_heading level0 row0\" >bottle</th>\n",
       "      <td id=\"T_c512d_row0_col0\" class=\"data row0 col0\" >20.57</td>\n",
       "      <td id=\"T_c512d_row0_col1\" class=\"data row0 col1\" >15.29</td>\n",
       "      <td id=\"T_c512d_row0_col2\" class=\"data row0 col2\" >4.91</td>\n",
       "      <td id=\"T_c512d_row0_col3\" class=\"data row0 col3\" >4.82</td>\n",
       "      <td id=\"T_c512d_row0_col4\" class=\"data row0 col4\" >4.82</td>\n",
       "      <td id=\"T_c512d_row0_col5\" class=\"data row0 col5\" >2.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c512d_level0_row1\" class=\"row_heading level0 row1\" >car</th>\n",
       "      <td id=\"T_c512d_row1_col0\" class=\"data row1 col0\" >70.52</td>\n",
       "      <td id=\"T_c512d_row1_col1\" class=\"data row1 col1\" >6.71</td>\n",
       "      <td id=\"T_c512d_row1_col2\" class=\"data row1 col2\" >62.92</td>\n",
       "      <td id=\"T_c512d_row1_col3\" class=\"data row1 col3\" >6.67</td>\n",
       "      <td id=\"T_c512d_row1_col4\" class=\"data row1 col4\" >6.66</td>\n",
       "      <td id=\"T_c512d_row1_col5\" class=\"data row1 col5\" >7.11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c512d_level0_row2\" class=\"row_heading level0 row2\" >cat</th>\n",
       "      <td id=\"T_c512d_row2_col0\" class=\"data row2 col0\" >72.78</td>\n",
       "      <td id=\"T_c512d_row2_col1\" class=\"data row2 col1\" >6.09</td>\n",
       "      <td id=\"T_c512d_row2_col2\" class=\"data row2 col2\" >6.18</td>\n",
       "      <td id=\"T_c512d_row2_col3\" class=\"data row2 col3\" >60.55</td>\n",
       "      <td id=\"T_c512d_row2_col4\" class=\"data row2 col4\" >6.33</td>\n",
       "      <td id=\"T_c512d_row2_col5\" class=\"data row2 col5\" >5.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c512d_level0_row3\" class=\"row_heading level0 row3\" >dog</th>\n",
       "      <td id=\"T_c512d_row3_col0\" class=\"data row3 col0\" >58.00</td>\n",
       "      <td id=\"T_c512d_row3_col1\" class=\"data row3 col1\" >4.67</td>\n",
       "      <td id=\"T_c512d_row3_col2\" class=\"data row3 col2\" >4.70</td>\n",
       "      <td id=\"T_c512d_row3_col3\" class=\"data row3 col3\" >5.93</td>\n",
       "      <td id=\"T_c512d_row3_col4\" class=\"data row3 col4\" >57.60</td>\n",
       "      <td id=\"T_c512d_row3_col5\" class=\"data row3 col5\" >4.58</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c512d_level0_row4\" class=\"row_heading level0 row4\" >person</th>\n",
       "      <td id=\"T_c512d_row4_col0\" class=\"data row4 col0\" >69.37</td>\n",
       "      <td id=\"T_c512d_row4_col1\" class=\"data row4 col1\" >8.90</td>\n",
       "      <td id=\"T_c512d_row4_col2\" class=\"data row4 col2\" >6.18</td>\n",
       "      <td id=\"T_c512d_row4_col3\" class=\"data row4 col3\" >5.83</td>\n",
       "      <td id=\"T_c512d_row4_col4\" class=\"data row4 col4\" >6.57</td>\n",
       "      <td id=\"T_c512d_row4_col5\" class=\"data row4 col5\" >71.01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f9bbdf1d730>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}\n",
      "\\caption{explainer}\n",
      "\\begin{tabular}{lrrrrrr}\n",
      " & none & bottle & car & cat & dog & person \\\\\n",
      "bottle & \\background-color#081d58 \\color#f1f1f1 20.57 & \\background-color#1f82b9 \\color#f1f1f1 15.29 & \\background-color#ffffd9 \\color#000000 4.91 & \\background-color#ffffd9 \\color#000000 4.82 & \\background-color#ffffd9 \\color#000000 4.82 & \\background-color#ffffd9 \\color#000000 2.72 \\\\\n",
      "car & \\background-color#081d58 \\color#f1f1f1 70.52 & \\background-color#fcfed1 \\color#000000 6.71 & \\background-color#22328f \\color#f1f1f1 62.92 & \\background-color#fcfed1 \\color#000000 6.67 & \\background-color#fcfed1 \\color#000000 6.66 & \\background-color#fafdcf \\color#000000 7.11 \\\\\n",
      "cat & \\background-color#081d58 \\color#f1f1f1 72.78 & \\background-color#fdfed4 \\color#000000 6.09 & \\background-color#fdfed4 \\color#000000 6.18 & \\background-color#24479d \\color#f1f1f1 60.55 & \\background-color#fcfed3 \\color#000000 6.33 & \\background-color#feffd6 \\color#000000 5.74 \\\\\n",
      "dog & \\background-color#081d58 \\color#f1f1f1 58.00 & \\background-color#ffffd9 \\color#000000 4.67 & \\background-color#ffffd9 \\color#000000 4.70 & \\background-color#fdfed4 \\color#000000 5.93 & \\background-color#091e5a \\color#f1f1f1 57.60 & \\background-color#ffffd9 \\color#000000 4.58 \\\\\n",
      "person & \\background-color#0d2163 \\color#f1f1f1 69.37 & \\background-color#f7fcc6 \\color#000000 8.90 & \\background-color#fdfed4 \\color#000000 6.18 & \\background-color#fdfed5 \\color#000000 5.83 & \\background-color#fcfed1 \\color#000000 6.57 & \\background-color#081d58 \\color#f1f1f1 71.01 \\\\\n",
      "\\end{tabular}\n",
      "\\end{table}\n",
      "\n",
      "\n",
      "\n",
      "grad_cam\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_cc2a9_row0_col0, #T_cc2a9_row1_col0, #T_cc2a9_row2_col0, #T_cc2a9_row3_col0, #T_cc2a9_row4_col0 {\n",
       "  background-color: #081d58;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cc2a9_row0_col1 {\n",
       "  background-color: #216daf;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cc2a9_row0_col2, #T_cc2a9_row2_col5, #T_cc2a9_row4_col3 {\n",
       "  background-color: #fdfed5;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row0_col3, #T_cc2a9_row0_col4, #T_cc2a9_row0_col5, #T_cc2a9_row3_col1, #T_cc2a9_row3_col2, #T_cc2a9_row3_col5 {\n",
       "  background-color: #ffffd9;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row1_col1 {\n",
       "  background-color: #fbfdd0;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row1_col2 {\n",
       "  background-color: #2076b3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cc2a9_row1_col3 {\n",
       "  background-color: #fcfed1;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row1_col4, #T_cc2a9_row2_col2 {\n",
       "  background-color: #fcfed3;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row1_col5, #T_cc2a9_row2_col4, #T_cc2a9_row3_col3 {\n",
       "  background-color: #f7fcc6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row2_col1 {\n",
       "  background-color: #fdfed4;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row2_col3 {\n",
       "  background-color: #2ca1c2;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cc2a9_row3_col4 {\n",
       "  background-color: #2073b2;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_cc2a9_row4_col1 {\n",
       "  background-color: #f0f9b7;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row4_col2 {\n",
       "  background-color: #f8fcca;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row4_col4 {\n",
       "  background-color: #f7fcc7;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_cc2a9_row4_col5 {\n",
       "  background-color: #234da0;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_cc2a9\">\n",
       "  <caption>grad_cam</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_cc2a9_level0_col0\" class=\"col_heading level0 col0\" >none</th>\n",
       "      <th id=\"T_cc2a9_level0_col1\" class=\"col_heading level0 col1\" >bottle</th>\n",
       "      <th id=\"T_cc2a9_level0_col2\" class=\"col_heading level0 col2\" >car</th>\n",
       "      <th id=\"T_cc2a9_level0_col3\" class=\"col_heading level0 col3\" >cat</th>\n",
       "      <th id=\"T_cc2a9_level0_col4\" class=\"col_heading level0 col4\" >dog</th>\n",
       "      <th id=\"T_cc2a9_level0_col5\" class=\"col_heading level0 col5\" >person</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_cc2a9_level0_row0\" class=\"row_heading level0 row0\" >bottle</th>\n",
       "      <td id=\"T_cc2a9_row0_col0\" class=\"data row0 col0\" >20.57</td>\n",
       "      <td id=\"T_cc2a9_row0_col1\" class=\"data row0 col1\" >16.10</td>\n",
       "      <td id=\"T_cc2a9_row0_col2\" class=\"data row0 col2\" >5.24</td>\n",
       "      <td id=\"T_cc2a9_row0_col3\" class=\"data row0 col3\" >4.85</td>\n",
       "      <td id=\"T_cc2a9_row0_col4\" class=\"data row0 col4\" >4.76</td>\n",
       "      <td id=\"T_cc2a9_row0_col5\" class=\"data row0 col5\" >3.61</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cc2a9_level0_row1\" class=\"row_heading level0 row1\" >car</th>\n",
       "      <td id=\"T_cc2a9_row1_col0\" class=\"data row1 col0\" >70.52</td>\n",
       "      <td id=\"T_cc2a9_row1_col1\" class=\"data row1 col1\" >6.93</td>\n",
       "      <td id=\"T_cc2a9_row1_col2\" class=\"data row1 col2\" >50.19</td>\n",
       "      <td id=\"T_cc2a9_row1_col3\" class=\"data row1 col3\" >6.65</td>\n",
       "      <td id=\"T_cc2a9_row1_col4\" class=\"data row1 col4\" >6.53</td>\n",
       "      <td id=\"T_cc2a9_row1_col5\" class=\"data row1 col5\" >8.92</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cc2a9_level0_row2\" class=\"row_heading level0 row2\" >cat</th>\n",
       "      <td id=\"T_cc2a9_row2_col0\" class=\"data row2 col0\" >72.78</td>\n",
       "      <td id=\"T_cc2a9_row2_col1\" class=\"data row2 col1\" >6.17</td>\n",
       "      <td id=\"T_cc2a9_row2_col2\" class=\"data row2 col2\" >6.39</td>\n",
       "      <td id=\"T_cc2a9_row2_col3\" class=\"data row2 col3\" >43.74</td>\n",
       "      <td id=\"T_cc2a9_row2_col4\" class=\"data row2 col4\" >8.99</td>\n",
       "      <td id=\"T_cc2a9_row2_col5\" class=\"data row2 col5\" >5.85</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cc2a9_level0_row3\" class=\"row_heading level0 row3\" >dog</th>\n",
       "      <td id=\"T_cc2a9_row3_col0\" class=\"data row3 col0\" >58.00</td>\n",
       "      <td id=\"T_cc2a9_row3_col1\" class=\"data row3 col1\" >4.46</td>\n",
       "      <td id=\"T_cc2a9_row3_col2\" class=\"data row3 col2\" >4.39</td>\n",
       "      <td id=\"T_cc2a9_row3_col3\" class=\"data row3 col3\" >8.23</td>\n",
       "      <td id=\"T_cc2a9_row3_col4\" class=\"data row3 col4\" >41.88</td>\n",
       "      <td id=\"T_cc2a9_row3_col5\" class=\"data row3 col5\" >4.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_cc2a9_level0_row4\" class=\"row_heading level0 row4\" >person</th>\n",
       "      <td id=\"T_cc2a9_row4_col0\" class=\"data row4 col0\" >69.37</td>\n",
       "      <td id=\"T_cc2a9_row4_col1\" class=\"data row4 col1\" >11.81</td>\n",
       "      <td id=\"T_cc2a9_row4_col2\" class=\"data row4 col2\" >8.18</td>\n",
       "      <td id=\"T_cc2a9_row4_col3\" class=\"data row4 col3\" >5.90</td>\n",
       "      <td id=\"T_cc2a9_row4_col4\" class=\"data row4 col4\" >8.75</td>\n",
       "      <td id=\"T_cc2a9_row4_col5\" class=\"data row4 col5\" >56.38</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f9bbdf1d730>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}\n",
      "\\caption{grad_cam}\n",
      "\\begin{tabular}{lrrrrrr}\n",
      " & none & bottle & car & cat & dog & person \\\\\n",
      "bottle & \\background-color#081d58 \\color#f1f1f1 20.57 & \\background-color#216daf \\color#f1f1f1 16.10 & \\background-color#fdfed5 \\color#000000 5.24 & \\background-color#ffffd9 \\color#000000 4.85 & \\background-color#ffffd9 \\color#000000 4.76 & \\background-color#ffffd9 \\color#000000 3.61 \\\\\n",
      "car & \\background-color#081d58 \\color#f1f1f1 70.52 & \\background-color#fbfdd0 \\color#000000 6.93 & \\background-color#2076b3 \\color#f1f1f1 50.19 & \\background-color#fcfed1 \\color#000000 6.65 & \\background-color#fcfed3 \\color#000000 6.53 & \\background-color#f7fcc6 \\color#000000 8.92 \\\\\n",
      "cat & \\background-color#081d58 \\color#f1f1f1 72.78 & \\background-color#fdfed4 \\color#000000 6.17 & \\background-color#fcfed3 \\color#000000 6.39 & \\background-color#2ca1c2 \\color#f1f1f1 43.74 & \\background-color#f7fcc6 \\color#000000 8.99 & \\background-color#fdfed5 \\color#000000 5.85 \\\\\n",
      "dog & \\background-color#081d58 \\color#f1f1f1 58.00 & \\background-color#ffffd9 \\color#000000 4.46 & \\background-color#ffffd9 \\color#000000 4.39 & \\background-color#f7fcc6 \\color#000000 8.23 & \\background-color#2073b2 \\color#f1f1f1 41.88 & \\background-color#ffffd9 \\color#000000 4.38 \\\\\n",
      "person & \\background-color#081d58 \\color#f1f1f1 69.37 & \\background-color#f0f9b7 \\color#000000 11.81 & \\background-color#f8fcca \\color#000000 8.18 & \\background-color#fdfed5 \\color#000000 5.90 & \\background-color#f7fcc7 \\color#000000 8.75 & \\background-color#234da0 \\color#f1f1f1 56.38 \\\\\n",
      "\\end{tabular}\n",
      "\\end{table}\n",
      "\n",
      "\n",
      "\n",
      "rise\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_8fbb0_row0_col0, #T_8fbb0_row1_col0, #T_8fbb0_row2_col0, #T_8fbb0_row3_col0, #T_8fbb0_row4_col0 {\n",
       "  background-color: #081d58;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row0_col1 {\n",
       "  background-color: #243392;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row0_col2 {\n",
       "  background-color: #7cccbb;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row0_col3 {\n",
       "  background-color: #59bfc0;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row0_col4 {\n",
       "  background-color: #85cfba;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row0_col5 {\n",
       "  background-color: #a9ddb7;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row1_col1 {\n",
       "  background-color: #7ecdbb;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row1_col2 {\n",
       "  background-color: #2351a2;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row1_col3 {\n",
       "  background-color: #7acbbc;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row1_col4 {\n",
       "  background-color: #5bc0c0;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row1_col5 {\n",
       "  background-color: #69c5be;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row2_col1 {\n",
       "  background-color: #48b9c3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row2_col2 {\n",
       "  background-color: #3fb4c4;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row2_col3, #T_8fbb0_row4_col3 {\n",
       "  background-color: #1e83ba;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row2_col4 {\n",
       "  background-color: #34a9c3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row2_col5 {\n",
       "  background-color: #55bec1;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row3_col1 {\n",
       "  background-color: #b9e4b5;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row3_col2 {\n",
       "  background-color: #b0e0b6;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row3_col3 {\n",
       "  background-color: #8cd2ba;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row3_col4 {\n",
       "  background-color: #225ca7;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row3_col5 {\n",
       "  background-color: #95d5b9;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_8fbb0_row4_col1 {\n",
       "  background-color: #234fa1;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row4_col2 {\n",
       "  background-color: #225da8;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row4_col4 {\n",
       "  background-color: #2355a4;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_8fbb0_row4_col5 {\n",
       "  background-color: #1d2e83;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_8fbb0\">\n",
       "  <caption>rise</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_8fbb0_level0_col0\" class=\"col_heading level0 col0\" >none</th>\n",
       "      <th id=\"T_8fbb0_level0_col1\" class=\"col_heading level0 col1\" >bottle</th>\n",
       "      <th id=\"T_8fbb0_level0_col2\" class=\"col_heading level0 col2\" >car</th>\n",
       "      <th id=\"T_8fbb0_level0_col3\" class=\"col_heading level0 col3\" >cat</th>\n",
       "      <th id=\"T_8fbb0_level0_col4\" class=\"col_heading level0 col4\" >dog</th>\n",
       "      <th id=\"T_8fbb0_level0_col5\" class=\"col_heading level0 col5\" >person</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_8fbb0_level0_row0\" class=\"row_heading level0 row0\" >bottle</th>\n",
       "      <td id=\"T_8fbb0_row0_col0\" class=\"data row0 col0\" >20.57</td>\n",
       "      <td id=\"T_8fbb0_row0_col1\" class=\"data row0 col1\" >18.67</td>\n",
       "      <td id=\"T_8fbb0_row0_col2\" class=\"data row0 col2\" >10.94</td>\n",
       "      <td id=\"T_8fbb0_row0_col3\" class=\"data row0 col3\" >12.05</td>\n",
       "      <td id=\"T_8fbb0_row0_col4\" class=\"data row0 col4\" >10.70</td>\n",
       "      <td id=\"T_8fbb0_row0_col5\" class=\"data row0 col5\" >9.73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_8fbb0_level0_row1\" class=\"row_heading level0 row1\" >car</th>\n",
       "      <td id=\"T_8fbb0_row1_col0\" class=\"data row1 col0\" >70.52</td>\n",
       "      <td id=\"T_8fbb0_row1_col1\" class=\"data row1 col1\" >29.75</td>\n",
       "      <td id=\"T_8fbb0_row1_col2\" class=\"data row1 col2\" >56.66</td>\n",
       "      <td id=\"T_8fbb0_row1_col3\" class=\"data row1 col3\" >30.27</td>\n",
       "      <td id=\"T_8fbb0_row1_col4\" class=\"data row1 col4\" >34.43</td>\n",
       "      <td id=\"T_8fbb0_row1_col5\" class=\"data row1 col5\" >32.42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_8fbb0_level0_row2\" class=\"row_heading level0 row2\" >cat</th>\n",
       "      <td id=\"T_8fbb0_row2_col0\" class=\"data row2 col0\" >72.78</td>\n",
       "      <td id=\"T_8fbb0_row2_col1\" class=\"data row2 col1\" >37.89</td>\n",
       "      <td id=\"T_8fbb0_row2_col2\" class=\"data row2 col2\" >39.28</td>\n",
       "      <td id=\"T_8fbb0_row2_col3\" class=\"data row2 col3\" >49.61</td>\n",
       "      <td id=\"T_8fbb0_row2_col4\" class=\"data row2 col4\" >41.89</td>\n",
       "      <td id=\"T_8fbb0_row2_col5\" class=\"data row2 col5\" >36.02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_8fbb0_level0_row3\" class=\"row_heading level0 row3\" >dog</th>\n",
       "      <td id=\"T_8fbb0_row3_col0\" class=\"data row3 col0\" >58.00</td>\n",
       "      <td id=\"T_8fbb0_row3_col1\" class=\"data row3 col1\" >19.64</td>\n",
       "      <td id=\"T_8fbb0_row3_col2\" class=\"data row3 col2\" >20.34</td>\n",
       "      <td id=\"T_8fbb0_row3_col3\" class=\"data row3 col3\" >23.72</td>\n",
       "      <td id=\"T_8fbb0_row3_col4\" class=\"data row3 col4\" >45.12</td>\n",
       "      <td id=\"T_8fbb0_row3_col5\" class=\"data row3 col5\" >22.83</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_8fbb0_level0_row4\" class=\"row_heading level0 row4\" >person</th>\n",
       "      <td id=\"T_8fbb0_row4_col0\" class=\"data row4 col0\" >69.37</td>\n",
       "      <td id=\"T_8fbb0_row4_col1\" class=\"data row4 col1\" >56.06</td>\n",
       "      <td id=\"T_8fbb0_row4_col2\" class=\"data row4 col2\" >53.39</td>\n",
       "      <td id=\"T_8fbb0_row4_col3\" class=\"data row4 col3\" >47.33</td>\n",
       "      <td id=\"T_8fbb0_row4_col4\" class=\"data row4 col4\" >54.81</td>\n",
       "      <td id=\"T_8fbb0_row4_col5\" class=\"data row4 col5\" >63.52</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f9bbdf1d730>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}\n",
      "\\caption{rise}\n",
      "\\begin{tabular}{lrrrrrr}\n",
      " & none & bottle & car & cat & dog & person \\\\\n",
      "bottle & \\background-color#081d58 \\color#f1f1f1 20.57 & \\background-color#243392 \\color#f1f1f1 18.67 & \\background-color#7cccbb \\color#000000 10.94 & \\background-color#59bfc0 \\color#000000 12.05 & \\background-color#85cfba \\color#000000 10.70 & \\background-color#a9ddb7 \\color#000000 9.73 \\\\\n",
      "car & \\background-color#081d58 \\color#f1f1f1 70.52 & \\background-color#7ecdbb \\color#000000 29.75 & \\background-color#2351a2 \\color#f1f1f1 56.66 & \\background-color#7acbbc \\color#000000 30.27 & \\background-color#5bc0c0 \\color#000000 34.43 & \\background-color#69c5be \\color#000000 32.42 \\\\\n",
      "cat & \\background-color#081d58 \\color#f1f1f1 72.78 & \\background-color#48b9c3 \\color#f1f1f1 37.89 & \\background-color#3fb4c4 \\color#f1f1f1 39.28 & \\background-color#1e83ba \\color#f1f1f1 49.61 & \\background-color#34a9c3 \\color#f1f1f1 41.89 & \\background-color#55bec1 \\color#000000 36.02 \\\\\n",
      "dog & \\background-color#081d58 \\color#f1f1f1 58.00 & \\background-color#b9e4b5 \\color#000000 19.64 & \\background-color#b0e0b6 \\color#000000 20.34 & \\background-color#8cd2ba \\color#000000 23.72 & \\background-color#225ca7 \\color#f1f1f1 45.12 & \\background-color#95d5b9 \\color#000000 22.83 \\\\\n",
      "person & \\background-color#081d58 \\color#f1f1f1 69.37 & \\background-color#234fa1 \\color#f1f1f1 56.06 & \\background-color#225da8 \\color#f1f1f1 53.39 & \\background-color#1e83ba \\color#f1f1f1 47.33 & \\background-color#2355a4 \\color#f1f1f1 54.81 & \\background-color#1d2e83 \\color#f1f1f1 63.52 \\\\\n",
      "\\end{tabular}\n",
      "\\end{table}\n",
      "\n",
      "\n",
      "\n",
      "rt_saliency\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_c8f0e_row0_col0, #T_c8f0e_row1_col0, #T_c8f0e_row2_col0, #T_c8f0e_row2_col3, #T_c8f0e_row3_col4, #T_c8f0e_row4_col5 {\n",
       "  background-color: #081d58;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row0_col1 {\n",
       "  background-color: #21308b;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row0_col2 {\n",
       "  background-color: #75c9bd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c8f0e_row0_col3 {\n",
       "  background-color: #36abc3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row0_col4 {\n",
       "  background-color: #35aac3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row0_col5 {\n",
       "  background-color: #2351a2;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row1_col1 {\n",
       "  background-color: #32a6c2;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row1_col2 {\n",
       "  background-color: #142670;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row1_col3 {\n",
       "  background-color: #34a9c3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row1_col4 {\n",
       "  background-color: #6fc7bd;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c8f0e_row1_col5 {\n",
       "  background-color: #225da8;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row2_col1 {\n",
       "  background-color: #2397c1;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row2_col2 {\n",
       "  background-color: #53bdc1;\n",
       "  color: #000000;\n",
       "}\n",
       "#T_c8f0e_row2_col4 {\n",
       "  background-color: #234da0;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row2_col5 {\n",
       "  background-color: #2168ad;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row3_col0 {\n",
       "  background-color: #152772;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row3_col1 {\n",
       "  background-color: #1e88bc;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row3_col2 {\n",
       "  background-color: #39adc3;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row3_col3 {\n",
       "  background-color: #172976;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row3_col5 {\n",
       "  background-color: #12256d;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row4_col0 {\n",
       "  background-color: #162874;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row4_col1 {\n",
       "  background-color: #253595;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row4_col2 {\n",
       "  background-color: #216aad;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row4_col3 {\n",
       "  background-color: #2fa4c2;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "#T_c8f0e_row4_col4 {\n",
       "  background-color: #24449c;\n",
       "  color: #f1f1f1;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_c8f0e\">\n",
       "  <caption>rt_saliency</caption>\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_c8f0e_level0_col0\" class=\"col_heading level0 col0\" >none</th>\n",
       "      <th id=\"T_c8f0e_level0_col1\" class=\"col_heading level0 col1\" >bottle</th>\n",
       "      <th id=\"T_c8f0e_level0_col2\" class=\"col_heading level0 col2\" >car</th>\n",
       "      <th id=\"T_c8f0e_level0_col3\" class=\"col_heading level0 col3\" >cat</th>\n",
       "      <th id=\"T_c8f0e_level0_col4\" class=\"col_heading level0 col4\" >dog</th>\n",
       "      <th id=\"T_c8f0e_level0_col5\" class=\"col_heading level0 col5\" >person</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_c8f0e_level0_row0\" class=\"row_heading level0 row0\" >bottle</th>\n",
       "      <td id=\"T_c8f0e_row0_col0\" class=\"data row0 col0\" >20.57</td>\n",
       "      <td id=\"T_c8f0e_row0_col1\" class=\"data row0 col1\" >18.92</td>\n",
       "      <td id=\"T_c8f0e_row0_col2\" class=\"data row0 col2\" >11.19</td>\n",
       "      <td id=\"T_c8f0e_row0_col3\" class=\"data row0 col3\" >13.37</td>\n",
       "      <td id=\"T_c8f0e_row0_col4\" class=\"data row0 col4\" >13.44</td>\n",
       "      <td id=\"T_c8f0e_row0_col5\" class=\"data row0 col5\" >17.24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c8f0e_level0_row1\" class=\"row_heading level0 row1\" >car</th>\n",
       "      <td id=\"T_c8f0e_row1_col0\" class=\"data row1 col0\" >70.52</td>\n",
       "      <td id=\"T_c8f0e_row1_col1\" class=\"data row1 col1\" >41.25</td>\n",
       "      <td id=\"T_c8f0e_row1_col2\" class=\"data row1 col2\" >67.02</td>\n",
       "      <td id=\"T_c8f0e_row1_col3\" class=\"data row1 col3\" >40.69</td>\n",
       "      <td id=\"T_c8f0e_row1_col4\" class=\"data row1 col4\" >31.87</td>\n",
       "      <td id=\"T_c8f0e_row1_col5\" class=\"data row1 col5\" >54.31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c8f0e_level0_row2\" class=\"row_heading level0 row2\" >cat</th>\n",
       "      <td id=\"T_c8f0e_row2_col0\" class=\"data row2 col0\" >72.78</td>\n",
       "      <td id=\"T_c8f0e_row2_col1\" class=\"data row2 col1\" >45.84</td>\n",
       "      <td id=\"T_c8f0e_row2_col2\" class=\"data row2 col2\" >36.27</td>\n",
       "      <td id=\"T_c8f0e_row2_col3\" class=\"data row2 col3\" >72.62</td>\n",
       "      <td id=\"T_c8f0e_row2_col4\" class=\"data row2 col4\" >59.21</td>\n",
       "      <td id=\"T_c8f0e_row2_col5\" class=\"data row2 col5\" >54.16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c8f0e_level0_row3\" class=\"row_heading level0 row3\" >dog</th>\n",
       "      <td id=\"T_c8f0e_row3_col0\" class=\"data row3 col0\" >58.00</td>\n",
       "      <td id=\"T_c8f0e_row3_col1\" class=\"data row3 col1\" >41.37</td>\n",
       "      <td id=\"T_c8f0e_row3_col2\" class=\"data row3 col2\" >34.76</td>\n",
       "      <td id=\"T_c8f0e_row3_col3\" class=\"data row3 col3\" >57.59</td>\n",
       "      <td id=\"T_c8f0e_row3_col4\" class=\"data row3 col4\" >61.22</td>\n",
       "      <td id=\"T_c8f0e_row3_col5\" class=\"data row3 col5\" >58.76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_c8f0e_level0_row4\" class=\"row_heading level0 row4\" >person</th>\n",
       "      <td id=\"T_c8f0e_row4_col0\" class=\"data row4 col0\" >69.37</td>\n",
       "      <td id=\"T_c8f0e_row4_col1\" class=\"data row4 col1\" >64.71</td>\n",
       "      <td id=\"T_c8f0e_row4_col2\" class=\"data row4 col2\" >54.41</td>\n",
       "      <td id=\"T_c8f0e_row4_col3\" class=\"data row4 col3\" >43.42</td>\n",
       "      <td id=\"T_c8f0e_row4_col4\" class=\"data row4 col4\" >61.55</td>\n",
       "      <td id=\"T_c8f0e_row4_col5\" class=\"data row4 col5\" >73.59</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f9bbdf1d730>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table}\n",
      "\\caption{rt_saliency}\n",
      "\\begin{tabular}{lrrrrrr}\n",
      " & none & bottle & car & cat & dog & person \\\\\n",
      "bottle & \\background-color#081d58 \\color#f1f1f1 20.57 & \\background-color#21308b \\color#f1f1f1 18.92 & \\background-color#75c9bd \\color#000000 11.19 & \\background-color#36abc3 \\color#f1f1f1 13.37 & \\background-color#35aac3 \\color#f1f1f1 13.44 & \\background-color#2351a2 \\color#f1f1f1 17.24 \\\\\n",
      "car & \\background-color#081d58 \\color#f1f1f1 70.52 & \\background-color#32a6c2 \\color#f1f1f1 41.25 & \\background-color#142670 \\color#f1f1f1 67.02 & \\background-color#34a9c3 \\color#f1f1f1 40.69 & \\background-color#6fc7bd \\color#000000 31.87 & \\background-color#225da8 \\color#f1f1f1 54.31 \\\\\n",
      "cat & \\background-color#081d58 \\color#f1f1f1 72.78 & \\background-color#2397c1 \\color#f1f1f1 45.84 & \\background-color#53bdc1 \\color#000000 36.27 & \\background-color#081d58 \\color#f1f1f1 72.62 & \\background-color#234da0 \\color#f1f1f1 59.21 & \\background-color#2168ad \\color#f1f1f1 54.16 \\\\\n",
      "dog & \\background-color#152772 \\color#f1f1f1 58.00 & \\background-color#1e88bc \\color#f1f1f1 41.37 & \\background-color#39adc3 \\color#f1f1f1 34.76 & \\background-color#172976 \\color#f1f1f1 57.59 & \\background-color#081d58 \\color#f1f1f1 61.22 & \\background-color#12256d \\color#f1f1f1 58.76 \\\\\n",
      "person & \\background-color#162874 \\color#f1f1f1 69.37 & \\background-color#253595 \\color#f1f1f1 64.71 & \\background-color#216aad \\color#f1f1f1 54.41 & \\background-color#2fa4c2 \\color#f1f1f1 43.42 & \\background-color#24449c \\color#f1f1f1 61.55 & \\background-color#081d58 \\color#f1f1f1 73.59 \\\\\n",
      "\\end{tabular}\n",
      "\\end{table}\n",
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "indices = {\"no_mask\": 0, \"bottle\": 1, \"car\": 2, \"cat\": 3, \"dog\": 4, \"person\": 5}\n",
    "\n",
    "for i, method in enumerate(results):\n",
    "    num_mask_classes = len(results[method])\n",
    "    confusion_matrix = np.zeros(shape=(num_mask_classes, num_mask_classes+1))\n",
    "    for j, target_class in enumerate(results[method]):\n",
    "        for mask_class in results[method][target_class]:\n",
    "            index = indices[mask_class]\n",
    "            confusion_matrix[j, index] = \"{:.2f}\".format(results[method][target_class][mask_class]['mean_probs'] * 100.0)\n",
    "    \n",
    "    print(method)\n",
    "    df = pd.DataFrame(confusion_matrix, index=mask_objects, columns=[\"none\"] + mask_objects)\n",
    "    df = df.style.set_caption(method)\n",
    "    df.format(precision=2)\n",
    "    df.background_gradient(axis=1, vmin=5, low=0.2, cmap=\"YlGnBu\")\n",
    "    latex_table = Latex(df.to_latex()).data\n",
    "    display(df)\n",
    "    print(latex_table)\n",
    "    print('\\n')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "865fad0a43638678bdec2ad29cb1e9256b61e71a582e8f3afbf427c5ca7324ae"
  },
  "kernelspec": {
   "display_name": "pytorch_gpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
