{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "428c82b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#python3.7 \n",
    "# install dependencies with requirements.txt (included in zip)\n",
    "\n",
    "# to change to Mario level 8-3 set:\n",
    "# world = 8, level=3 in gen_color_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4533ab27",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import *\n",
    "from train_utils import *\n",
    "from pytorchtools import EarlyStopping\n",
    "\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9e6442d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available:\n",
    "    device = \"cuda:0\"\n",
    "    print('Using GPU')\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "    print('using CPU')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6120195",
   "metadata": {},
   "source": [
    "# Data Generation/Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd1466b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "814906be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get training data loaders\n",
    "if new_data:\n",
    "    expert_observations, color_observations, expert_actions, episode_schedule = gen_color_data(output_path=None,\n",
    "                                                                                               num_interactions=int(1e4),\n",
    "                                                                                               bad=False)#=int(30000/322))\n",
    "    np.savez_compressed(\n",
    "                        'mario_HOLDOUT.npz',\n",
    "                        expert_actions=expert_actions,#np.array(acts),\n",
    "                        color_observations=color_observations,\n",
    "                        expert_observations=expert_observations,#np.array(states),\n",
    "                        episode_schedule = episode_schedule#np.array(episode_schedule)\n",
    "                )\n",
    "else:\n",
    "    arrs = np.load('marioCROP.npz')\n",
    "    expert_observations = arrs['expert_observations']\n",
    "    color_observations = arrs['color_observations']\n",
    "    expert_actions = arrs['expert_actions']\n",
    "    episode_schedule = arrs['episode_schedule']\n",
    "    \n",
    "    expert_dataset = ExpertDataset(expert_observations, expert_actions)\n",
    "    episode_labels = torch.FloatTensor(episode_schedule[:,0]).to(device)\n",
    "    step_labels = torch.FloatTensor(episode_schedule[:,1]).to(device)\n",
    "    DatasetWithInDices = dataset_with_indices(ExpertDataset)\n",
    "    dset = DatasetWithInDices(expert_observations, expert_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f26270e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get evaluation data\n",
    "holdout_arrs = np.load('marioCROP_HOLDOUT.npz')\n",
    "holdout_expert_observations = holdout_arrs['expert_observations']\n",
    "holdout_color_observations = holdout_arrs['color_observations']\n",
    "holdout_expert_actions = holdout_arrs['expert_actions']\n",
    "holdout_episode_schedule = holdout_arrs['episode_schedule']\n",
    "\n",
    "holdout_expert_dataset = ExpertDataset(holdout_expert_observations, holdout_expert_actions)\n",
    "\n",
    "eval_loader = th.utils.data.DataLoader(\n",
    "    dataset=holdout_expert_dataset, batch_size=128, shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cbc2d13",
   "metadata": {},
   "source": [
    "# Encoder Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84d1a827",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset for encoder pretraining\n",
    "# RUN THIS FOR PRE-TRAINING\n",
    "\n",
    "train_size = int(0.8 * len(dset))\n",
    "test_size = int(0.2 * len(dset))\n",
    "\n",
    "train_expert_dataset, test_expert_dataset = random_split(\n",
    "        dset, [train_size, test_size], generator=torch.Generator().manual_seed(42))\n",
    "\n",
    "\n",
    "kwargs = {\"num_workers\": 8, \"pin_memory\": False}\n",
    "train_loader = th.utils.data.DataLoader(\n",
    "        dataset=train_expert_dataset, batch_size=64, shuffle=True, **kwargs\n",
    ")\n",
    "test_loader = th.utils.data.DataLoader(\n",
    "        dataset=test_expert_dataset, batch_size=64, shuffle=True, **kwargs,\n",
    ")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96b22656",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ae_outputs(encoder,decoder,n=10):\n",
    "    plt.figure(figsize=(16,4.5))\n",
    "    try:\n",
    "        imgs, targets = next(iter(test_loader))\n",
    "    except Exception as e:\n",
    "        imgs, targets, _ = next(iter(test_loader))\n",
    " \n",
    "    targets = targets.numpy()\n",
    "    \n",
    "    for i in range(n):\n",
    "      ax = plt.subplot(2,n,i+1)\n",
    "      img = imgs[i].unsqueeze(0).to(device)\n",
    "      encoder.eval()\n",
    "      decoder.eval()\n",
    "      with torch.no_grad():\n",
    "         rec_img  = decoder(encoder(img))\n",
    "      for k in range(4):\n",
    "         plt.imshow(img.cpu().squeeze().numpy()[k], cmap='gist_gray')\n",
    "      ax.get_xaxis().set_visible(False)\n",
    "      ax.get_yaxis().set_visible(False)  \n",
    "      if i == n//2:\n",
    "        ax.set_title('Original images')\n",
    "      ax = plt.subplot(2, n, i + 1 + n)\n",
    "      for k in range(4):\n",
    "          plt.imshow(rec_img.cpu().squeeze().numpy()[k], cmap='gist_gray')  \n",
    "      ax.get_xaxis().set_visible(False)\n",
    "      ax.get_yaxis().set_visible(False)  \n",
    "      if i == n//2:\n",
    "         ax.set_title('Reconstructed images')\n",
    "    plt.show()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "907e5c59",
   "metadata": {},
   "outputs": [],
   "source": [
    "#https://raw.githubusercontent.com/lyakaap/NetVLAD-pytorch/master/hard_triplet_loss.py\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class HardQuadrupletLoss(nn.Module):\n",
    "    \"\"\"Hard/Hardest Triplet Loss\n",
    "    (pytorch implementation of https://omoindrot.github.io/triplet-loss)\n",
    "\n",
    "    For each anchor, we get the hardest positive and hardest negative to form a triplet.\n",
    "    \"\"\"\n",
    "    def __init__(self, margin1=0.1, margin2=.05, hardest=False, squared=False, epsilon=10, tau=11):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            margin: margin for triplet loss\n",
    "            hardest: If true, loss is considered only hardest triplets.\n",
    "            squared: If true, output is the pairwise squared euclidean distance matrix.\n",
    "                If false, output is the pairwise euclidean distance matrix.\n",
    "        \"\"\"\n",
    "        super(HardQuadrupletLoss, self).__init__()\n",
    "        self.margin1 = margin1\n",
    "        self.margin2 = margin2\n",
    "        self.hardest = hardest\n",
    "        self.squared = squared\n",
    "        self.epsilon = epsilon\n",
    "        self.tau = tau\n",
    "\n",
    "    def forward(self, embeddings, labels, idx):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            labels: labels of the batch, of size (batch_size,)\n",
    "            embeddings: tensor of shape (batch_size, embed_dim)\n",
    "\n",
    "        Returns:\n",
    "            triplet_loss: scalar tensor containing the triplet loss\n",
    "        \"\"\"\n",
    "        pairwise_dist = _pairwise_distance(embeddings, squared=self.squared)\n",
    "\n",
    "        if self.hardest:\n",
    "            # Get the hardest positive pairs\n",
    "            mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()\n",
    "            valid_positive_dist = pairwise_dist * mask_anchor_positive\n",
    "            hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)\n",
    "\n",
    "            # Get the hardest negative1 pairs\n",
    "            mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()\n",
    "            max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)\n",
    "            anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (\n",
    "                    1.0 - mask_anchor_negative)\n",
    "            hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)\n",
    "            \n",
    "            # Get hardest negative 2 pairs\n",
    "\n",
    "            # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss\n",
    "            quad_loss = F.relu(hardest_positive_dist - hardest_negative_dist + self.margin1)\n",
    "            quad_loss += F.relu(hardest_positive_dist - hardest_negative2_dist + self.margin2)\n",
    "            quad_loss = torch.mean(quad_loss)\n",
    "        else:\n",
    "            anc_pos_dist = pairwise_dist.unsqueeze(dim=2)\n",
    "            anc_neg_dist = pairwise_dist.unsqueeze(dim=1)\n",
    "            anc_neg2_dist = pairwise_dist.unsqueeze(dim=0)\n",
    "\n",
    "            # Compute a 3D tensor of size (batch_size, batch_size, batch_size)\n",
    "            # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k\n",
    "            # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)\n",
    "            # and the 2nd (batch_size, 1, batch_size)\n",
    "            loss = F.relu(anc_pos_dist - anc_neg_dist + self.margin1)\n",
    "            loss += F.relu(anc_pos_dist - anc_neg2_dist + self.margin2)\n",
    "            \n",
    "            #print('\\ninit loss stats', torch.min(loss).item(), torch.max(loss).item())\n",
    "\n",
    "            mask = _get_quadruplet_mask(labels, idx).float()\n",
    "            quadruplet_loss = loss * mask\n",
    "            \n",
    "            #print('masked loss stats', torch.min(triplet_loss).item(), torch.max(triplet_loss).item())\n",
    "\n",
    "            # Remove negative losses (i.e. the easy triplets)\n",
    "            #quadruplet_loss = F.relu(quadruplet_loss)\n",
    "\n",
    "            # Count number of hard triplets (where triplet_loss > 0)\n",
    "            hard_quadruplets = torch.gt(quadruplet_loss, 1e-16).float()\n",
    "            num_hard_quadruplets = torch.sum(hard_quadruplets)\n",
    "\n",
    "            quadruplet_loss = torch.sum(quadruplet_loss) / (num_hard_quadruplets + 1e-16)\n",
    "\n",
    "        return quadruplet_loss\n",
    "\n",
    "\n",
    "\n",
    "def _pairwise_distance(x, squared=False, eps=1e-16):\n",
    "    # Compute the 2D matrix of distances between all the embeddings.\n",
    "\n",
    "    cor_mat = torch.matmul(x, x.t())\n",
    "    norm_mat = cor_mat.diag()\n",
    "    distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)\n",
    "    distances = F.relu(distances)\n",
    "\n",
    "    if not squared:\n",
    "        mask = torch.eq(distances, 0.0).float()\n",
    "        distances = distances + mask * eps\n",
    "        distances = torch.sqrt(distances)\n",
    "        distances = distances * (1.0 - mask)\n",
    "\n",
    "    return distances\n",
    "\n",
    "\n",
    "def _get_anchor_positive_triplet_mask(labels):\n",
    "    # Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.\n",
    "\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1\n",
    "\n",
    "    # Check if labels[i] == labels[j]\n",
    "    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)\n",
    "\n",
    "    mask = indices_not_equal * labels_equal\n",
    "\n",
    "    return mask\n",
    "\n",
    "\n",
    "def _get_anchor_negative_triplet_mask(labels):\n",
    "    # Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.\n",
    "\n",
    "    # Check if labels[i] != labels[k]\n",
    "    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)\n",
    "    mask = labels_equal ^ 1\n",
    "\n",
    "    return mask\n",
    "\n",
    "\n",
    "def _get_quadruplet_mask(\n",
    "        labels,\n",
    "        idx\n",
    "    ):\n",
    "    B = labels.size(0)\n",
    "\n",
    "    # Make sure that i != j != k != l\n",
    "    indices_equal = torch.eye(B, dtype=torch.bool).to(device)  # [B, B] \n",
    "    indices_not_equal = ~indices_equal  # [B, B] \n",
    "    i_not_equal_j = indices_not_equal.view(B, B, 1, 1)  # [B, B, 1, 1]\n",
    "    j_not_equal_k = indices_not_equal.view(1, B, B, 1)  # [B, 1, 1, B] \n",
    "    k_not_equal_l = indices_not_equal.view(1, 1, B, B)  # [1, 1, B, B] \n",
    "    distinct_indices = i_not_equal_j & j_not_equal_k & k_not_equal_l  # [B, B, B, B] \n",
    "\n",
    "    # Make sure that labels[i] == labels[j] \n",
    "    #            and labels[j] != labels[k] \n",
    "    #            and labels[k] != labels[l]\n",
    "    labels_equal = labels.view(1, B) == labels.view(B, 1)  # [B, B]\n",
    "    i_equal_j = labels_equal.view(B, B, 1, 1)  # [B, B, 1, 1]\n",
    "    j_equal_k = labels_equal.view(1, B, B, 1)  # [1, B, B, 1]\n",
    "    l_equal_i = labels_equal.view(B, 1, 1, B)  # [1, 1, B, B]\n",
    "    label_match = i_equal_j & ~j_equal_k #& ~l_equal_i\n",
    "    \n",
    "    eps = 15#self.epsilon\n",
    "    tau = 15#self.tau\n",
    "    \n",
    "    ep_labels = episode_labels[idx]\n",
    "    lst = step_labels[idx] \n",
    "    \n",
    "    ep_labels_equal = ep_labels.view(1, B) == ep_labels.view(B, 1)  # [B, B]\n",
    "    i_equal_j_ep = labels_equal.view(B, B, 1, 1)  # [B, B, 1, 1]\n",
    "    j_equal_k_ep = labels_equal.view(1, B, B, 1)  # [1, B, B, 1]\n",
    "    k_equal_l_ep = labels_equal.view(1, 1, B, B)  # [1, 1, B, B]\n",
    "    episode_match = i_equal_j_ep & j_equal_k_ep & k_equal_l_ep\n",
    "    \n",
    "    # uncomment this to keep quadruplets in the same episode\n",
    "    '''\n",
    "    within_eps_over_tau = torch.logical_and(torch.abs(lst[:, None, None, None] - lst[None, :, None, None]) <= eps,\n",
    "                                            torch.abs(lst[:,None,None,None]-lst[None,None,:,None]) <= eps,\n",
    "                                            torch.abs(lst[:, None, None,None] - lst[None, None,None,:]) >= tau)\n",
    "    '''\n",
    "    within_eps_over_tau = (torch.abs(lst[:, None, None, None] - lst[None, :, None, None]) <= eps) & (torch.abs(lst[:,None,None,None]-lst[None,None,:,None] <= eps)) & (torch.abs(lst[:, None, None,None] - lst[None, None,None,:]) >= tau)\n",
    "    #return within_eps_over_tau & episode_match & label_match & distinct_indices  # [B, B, B, B] \n",
    "    return within_eps_over_tau & label_match & distinct_indices\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7dbbc58",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Training function\n",
    "def train_siamese_epoch(vae, device, dataloader, optimizer, criterion,beta = 1., lbda = .1):\n",
    "    # Set train mode for both the encoder and the decoder\n",
    "    vae.train()\n",
    "    train_loss = 0.0\n",
    "    vae_loss_tot = 0.0\n",
    "    siam_loss_tot = 0.0\n",
    "    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)\n",
    "    for batch, (x, actions, idx) in enumerate(dataloader,1): \n",
    "        # Move tensor to the proper device\n",
    "        x = x.to(device)\n",
    "        actions = actions.to(device)\n",
    "        idx = idx.to(device)\n",
    "        \n",
    "        #forward pass\n",
    "        x_hat, z = vae(x)\n",
    "        \n",
    "        \n",
    "        # Evaluate VAE loss\n",
    "        vae_loss = ((x - x_hat)**2).sum() + beta * vae.encoder.kl\n",
    "\n",
    "        # Evaluate triplet/quadruplet loss\n",
    "        siam_loss = criterion(embeddings = z, labels = actions, idx = idx)\n",
    "        \n",
    "        loss = vae_loss + lbda * siam_loss \n",
    "        \n",
    "        # Backward pass\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        # Print batch loss\n",
    "        #print('\\t partial train loss (single batch): %f' % (loss.item()))\n",
    "        \n",
    "        train_loss+=loss.item()\n",
    "        vae_loss_tot += vae_loss.item()\n",
    "        siam_loss_tot += siam_loss.item()\n",
    "    \n",
    "    print('Avg VAE/SIAM loss: ', vae_loss_tot / len(dataloader.dataset), siam_loss_tot / len(dataloader.dataset))\n",
    "        \n",
    "    return train_loss / len(dataloader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b474a140",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Set the random seed for reproducible results\n",
    "torch.manual_seed(0)\n",
    "\n",
    "d = 7\n",
    "\n",
    "vae = VAE(z_dim = d,nc=4)\n",
    "lr = 1e-4\n",
    "\n",
    "optimizer = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-7)\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(f'Selected device: {device}')\n",
    "\n",
    "vae.to(device)\n",
    "\n",
    "criterion = HardQuadrupletLoss(margin1=2., margin2=2.5,squared=True,hardest=False,epsilon=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59a7a1c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100\n",
    "\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "   train_loss = train_siamese_epoch(vae,device,train_loader,optimizer, criterion,beta =1.5,lbda=1500)\n",
    "   #val_loss = test_epoch(vae,device,test_loader, beta = 1.5)\n",
    "   #print('\\n EPOCH {}/{} \\t train loss {:.3f} \\t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))\n",
    "   plot_ae_outputs(vae.encoder,vae.decoder,n=10)\n",
    "   print('epoch, train loss: ', epoch, train_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79e3f161",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(vae, 'mario_svae.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d17b87d",
   "metadata": {},
   "source": [
    "# t-SNE for Pre-trained Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2398810",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts = {0:'NOOP',1:'RIGHT',2:'RIGHT+A',3:'RIGHT+B',4:'RIGHT+A+B',5:'A',6:'LEFT'}\n",
    "encoded_samples = []\n",
    "for sample in tqdm(test_expert_dataset):\n",
    "    img = torch.FloatTensor(sample[0]).unsqueeze(0).to(device)\n",
    "    label = sample[1]\n",
    "    # Encode image\n",
    "    vae.eval()\n",
    "    with torch.no_grad():\n",
    "        encoded_img  = vae.encoder(img)\n",
    "    # Append to list\n",
    "    encoded_img = encoded_img.flatten().cpu().numpy()\n",
    "    encoded_sample = {f\"Enc. Variable {i}\": enc for i, enc in enumerate(encoded_img)}\n",
    "    encoded_sample['label'] = acts[label]\n",
    "    encoded_samples.append(encoded_sample)\n",
    "    \n",
    "encoded_samples = pd.DataFrame(encoded_samples)\n",
    "encoded_samples\n",
    "\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.express as px\n",
    "\n",
    "px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "073aff33",
   "metadata": {},
   "outputs": [],
   "source": [
    "#beta = 1.5\n",
    "tsne = TSNE(n_components=2)\n",
    "tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))\n",
    "\n",
    "fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20f7cce4",
   "metadata": {},
   "source": [
    "# Downstream Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbdcee11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset for downstream \n",
    "# RUN THIS FOR DOWNSTREAM TRAINING\n",
    "\n",
    "train_size = int(0.8 * len(expert_dataset))\n",
    "test_size = int(0.2 * len(expert_dataset))\n",
    "\n",
    "train_expert_dataset, test_expert_dataset = random_split(\n",
    "    expert_dataset, [train_size, test_size]\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "kwargs = {\"num_workers\": 8, \"pin_memory\": False}\n",
    "train_loader = th.utils.data.DataLoader(\n",
    "        dataset=train_expert_dataset, batch_size=32, shuffle=False, **kwargs\n",
    ")\n",
    "test_loader = th.utils.data.DataLoader(\n",
    "    dataset=test_expert_dataset, batch_size=32, shuffle=False, **kwargs\n",
    ")\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73029787",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_ds_epoch(net, \n",
    "          train_loader, \n",
    "          optimizer,\n",
    "          iso_coeff=10,\n",
    "          clst_coeff =0.,\n",
    "          sep_coeff = 0.,\n",
    "          rep_coeff = 0.,\n",
    "          validation = False,\n",
    "          resnetbc=False\n",
    "         ):\n",
    "    running_CE = 0.\n",
    "    running_iso = 0.\n",
    "    running_clst= 0.\n",
    "    running_sep = 0.\n",
    "    running_rep = 0.\n",
    "    \n",
    "    if not resnetbc:\n",
    "        I = torch.eye(net.isometry.weight.data.shape[0]).to(device)\n",
    "        max_dist = net.prototype_shape[1]\n",
    "    \n",
    "    criterion = nn.CrossEntropyLoss()   \n",
    "    \n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for (states, actions) in train_loader:\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "        \n",
    "        if not validation:\n",
    "            optimizer.zero_grad()\n",
    "        \n",
    "        if not resnetbc:\n",
    "            logits, min_distances = net(states.float())\n",
    "        else:\n",
    "            logits = net(states.float())\n",
    "        \n",
    "        _, predicted = torch.max(logits.data, 1)\n",
    "        total += actions.size(0)\n",
    "        correct += (predicted == actions).sum().item()\n",
    "            \n",
    "\n",
    "        if not resnetbc:\n",
    "            #cluster cost\n",
    "            # it's the mean of the min distances between encoding and prototypes of same class\n",
    "            # where min is taken over batch dimension\n",
    "            # torch.t(model.prototype_action_identity[:,target]) is distance matrix with entries\n",
    "            # for prototypes of wrong class zeroed out\n",
    "            # torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1\n",
    "            # gives you the min distances between encodings and correct protos over \n",
    "            # batch dim\n",
    "            prototypes_of_correct_class = torch.t(net.prototype_action_identity[:,actions]).to(device)\n",
    "            inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1)\n",
    "            clst_cost = torch.mean(max_dist - inverted_distances)\n",
    "\n",
    "            #separation cost\n",
    "            prototypes_of_wrong_class = 1 - prototypes_of_correct_class\n",
    "            inverted_distances_to_nontarget_prototypes, _ = torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1)\n",
    "            sep_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)\n",
    "\n",
    "            #avg clustering cost\n",
    "            #avg_sep = torch.sum(min_distances * prototypes_of_wrong_class,dim=1) / torch.sum(prototypes_of_wrong_class,dim=1)\n",
    "            #avg_sep = torch.mean(avg_sep)\n",
    "\n",
    "            # Rep term\n",
    "            rep = torch.sum(torch.min(min_distances, dim=0)[0])\n",
    "            \n",
    "            CE = criterion(logits, actions.long())\n",
    "\n",
    "            A = net.isometry.weight.data\n",
    "            iso_penalty = torch.linalg.matrix_norm(torch.mm(A.T, A) - I, ord = 'fro')**2 # 2 gives operator norm\n",
    "\n",
    "            loss = CE + iso_coeff * iso_penalty + clst_coeff * clst_cost - sep_coeff * sep_cost + rep_coeff * rep \n",
    "\n",
    "            running_CE += CE.item()\n",
    "            running_iso += iso_penalty.item()\n",
    "            running_clst += clst_cost.item()\n",
    "            running_sep += sep_cost.item()\n",
    "            running_rep += rep.item()\n",
    "        else:\n",
    "            loss = criterion(logits, actions.long())\n",
    "            running_CE += loss.item()\n",
    "\n",
    "        if not validation:\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "     \n",
    "    if not validation:\n",
    "        print('Train Acc: %d %%' % (100 * correct / total))\n",
    "    else:\n",
    "        print('Val Acc: %d %%' % (100 * correct / total))\n",
    "    return running_CE, running_iso, running_clst, running_sep, running_rep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6660f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_iso(net, \n",
    "          train_loader, \n",
    "          val_loader,\n",
    "          iso_coeff=10,\n",
    "          clst_coeff =0.,\n",
    "          sep_coeff = 0.,\n",
    "          rep_coeff = 0.,\n",
    "          push_interval=None,\n",
    "          push_epochs=100,\n",
    "          n_epochs=100,\n",
    "          main_lr = 1e-6,\n",
    "          pruned=False,\n",
    "          lr_sched=False,\n",
    "          resnetbc=False\n",
    "         ):\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    net_optimizer = optim.Adam(net.parameters(), lr=main_lr)\n",
    "    if push_interval is not None:\n",
    "        ll_optimizer = optim.Adam(net.last_layer.parameters(), lr=1e-5)\n",
    "    else:\n",
    "        ll_optimizer = None\n",
    "    if lr_sched:\n",
    "        if pruned:\n",
    "            sched = StepLR(ll_optimizer, 10, .8)\n",
    "        else:\n",
    "            sched = StepLR(net_optimizer, 10, .8)\n",
    "    \n",
    "    \n",
    "    early_stopping = EarlyStopping(patience=50, verbose=True)\n",
    "    \n",
    "    #wandb.init(project='mario')\n",
    "    #wandb.watch(net)\n",
    "    \n",
    "    for i in range(n_epochs):\n",
    "        net.train()\n",
    "        if pruned:\n",
    "            (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                   train_loader, \n",
    "                                                   ll_optimizer,\n",
    "                                                   iso_coeff,\n",
    "                                                   clst_coeff,\n",
    "                                                   sep_coeff,\n",
    "                                                   rep_coeff,\n",
    "                                                   validation=False,\n",
    "                                                   resnetbc=resnetbc)\n",
    "        else:\n",
    "            (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                   train_loader, \n",
    "                                                   net_optimizer,\n",
    "                                                   iso_coeff,\n",
    "                                                   clst_coeff,\n",
    "                                                   sep_coeff,\n",
    "                                                   rep_coeff,\n",
    "                                                   validation=False,\n",
    "                                                   resnetbc=resnetbc)         \n",
    "        #wandb.log({'CE':ce,'Iso':iso,'Clst':clst,'Sep':sep,'Rep':rep})\n",
    "        print('\\nEpoch', i, {'CE':ce,'Iso':iso,'Clst':clst,'Sep':sep,'Rep':rep})\n",
    "        \n",
    "        if push_interval is not None:\n",
    "            if i > 0 and i % push_interval == 0:\n",
    "                # push\n",
    "                prots = save_proto(net, train_loader, project=True, device=device)\n",
    "                \n",
    "                #freeze isometry\n",
    "                for param in net.isometry.parameters():\n",
    "                    param.requires_grad = False\n",
    "                #freeze protos\n",
    "                net.prototype_vectors.requires_grad = False\n",
    "                \n",
    "                for j in range(push_epochs):\n",
    "                    (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                           train_loader, \n",
    "                                                           ll_optimizer,\n",
    "                                                           iso_coeff,\n",
    "                                                           clst_coeff,\n",
    "                                                           sep_coeff,\n",
    "                                                           rep_coeff,\n",
    "                                                           validation=False,\n",
    "                                                           resnetbc=resnetbc)\n",
    "                    #wandb.log({'ll CE':ce,'ll Iso':iso,'ll Clst':clst,'ll Sep':sep,'ll Rep':rep})\n",
    "                    \n",
    "                # train isometry\n",
    "                for param in net.isometry.parameters():\n",
    "                    param.requires_grad = True\n",
    "                \n",
    "                # train protos\n",
    "                net.prototype_vectors.requires_grad = True\n",
    "                    \n",
    "        # validation step\n",
    "        net.eval()\n",
    "        (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                               val_loader, \n",
    "                                               net_optimizer,\n",
    "                                               iso_coeff,\n",
    "                                               clst_coeff,\n",
    "                                               sep_coeff,\n",
    "                                               rep_coeff,\n",
    "                                               validation=True,\n",
    "                                               resnetbc=resnetbc)\n",
    "        early_stopping(ce, net)\n",
    "        #wandb.log({'val CE':ce,'val Iso':iso,'val Clst':clst,'val Sep':sep,'val Rep':rep})\n",
    "        \n",
    "        if early_stopping.early_stop:\n",
    "            print('early Stopping')\n",
    "            break\n",
    "    net.load_state_dict(torch.load('checkpoint.pt'))\n",
    "    \n",
    "    \n",
    "    # final push\n",
    "    print('Doing final push!')\n",
    "    if not resnetbc:\n",
    "        prots = save_proto(net, train_loader, project=True, device=device)\n",
    "                \n",
    "        #freeze isometry\n",
    "        for param in net.isometry.parameters():\n",
    "            param.requires_grad = False\n",
    "        #freeze protos\n",
    "        net.prototype_vectors.requires_grad = False\n",
    "                \n",
    "        for j in range(push_epochs):\n",
    "            (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                      train_loader, \n",
    "                                                      ll_optimizer,\n",
    "                                                      iso_coeff,\n",
    "                                                      clst_coeff,\n",
    "                                                      sep_coeff,\n",
    "                                                      rep_coeff,\n",
    "                                                      validation=False,\n",
    "                                                      resnetbc=resnetbc)\n",
    "            #wandb.log({'ll CE':ce,'ll Iso':iso,'ll Clst':clst,'ll Sep':sep,'ll Rep':rep})\n",
    "\n",
    "    if lr_sched:\n",
    "        sched.step()\n",
    "    \n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c56d5d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(net, device, test_loader):\n",
    "    net.eval()\n",
    "\n",
    "    total = 0.\n",
    "    correct = 0.\n",
    "    for (states, actions) in test_loader:\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "\n",
    "        logits, _ = net(states)\n",
    "\n",
    "        _, predicted = torch.max(logits.data, 1)\n",
    "        total += actions.size(0)\n",
    "        correct += (predicted == actions).sum().item()\n",
    "\n",
    "    acc = 100 * (correct / total)\n",
    "    return acc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce14f332",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flip_fidelity(net,\n",
    "                   world=1,\n",
    "                   stage=1,\n",
    "                   action_type='simple',\n",
    "                   saved_path='trained_models',\n",
    "                   output_path='output',\n",
    "                   n_flip=int(10),\n",
    "                   num_interactions=int(10)):\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(123)\n",
    "    else:\n",
    "        torch.manual_seed(123)\n",
    "    if action_type == \"right\":\n",
    "        actions = RIGHT_ONLY\n",
    "    elif action_type == \"simple\":\n",
    "        actions = SIMPLE_MOVEMENT\n",
    "    else:\n",
    "        actions = COMPLEX_MOVEMENT\n",
    "    env = create_train_env(world, stage, actions,output_path=output_path)\n",
    "    model = PPO(env.observation_space.shape[0], len(actions))\n",
    "    if torch.cuda.is_available():\n",
    "        model.load_state_dict(torch.load(\"{}/ppo_super_mario_bros_{}_{}\".format(saved_path, world, stage)))\n",
    "        model.cuda()\n",
    "    else:\n",
    "        model.load_state_dict(torch.load(\"{}/ppo_super_mario_bros_{}_{}\".format(saved_path, world, stage),\n",
    "                                         map_location=lambda storage, loc: storage))\n",
    "    model.eval()\n",
    "\n",
    "    state_shape = env.observation_space.shape\n",
    "    action_shape = env.action_space.shape\n",
    "    \n",
    "    max_size = max(n_flip, num_interactions) + 1\n",
    "    expert_actions = np.empty((max_size,) + env.action_space.shape)\n",
    "    agent_actions = np.empty((max_size,) + env.action_space.shape)\n",
    "    \n",
    "    ep_number = 0\n",
    "    step_number = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    flip_correct = 0\n",
    "    flip_total = 0\n",
    "    while flip_total < n_flip and step_number < num_interactions:\n",
    "        print(step_number, num_interactions)\n",
    "        state = torch.from_numpy(env.reset())\n",
    "        while flip_total < n_flip and step_number < num_interactions:#True:\n",
    "            if torch.cuda.is_available():\n",
    "                state = state.cuda()\n",
    "            try:\n",
    "                logits, value = model(state)\n",
    "            except Exception as e:\n",
    "                logits = model(state)\n",
    "            policy = F.softmax(logits, dim=1)\n",
    "            \n",
    "            # expert action\n",
    "            if random.random() < .995:\n",
    "                action = torch.argmax(policy).item()\n",
    "            else:\n",
    "                action = np.random.choice(7,size=1,p=policy.detach().cpu().numpy()[0])[0]\n",
    "            \n",
    "            # crop for agent\n",
    "            state = torchvision.transforms.functional.crop(state, 15,0,84-15,84)\n",
    "            state = torchvision.transforms.functional.resize(state,(84,84))       \n",
    "                \n",
    "            # agent action\n",
    "            try:\n",
    "                l, _ = net(state)\n",
    "            except Exception as e:\n",
    "                l = net(state)\n",
    "            agent_policy = F.softmax(l, dim=1)\n",
    "            agent_action = torch.argmax(agent_policy).item()\n",
    "            \n",
    "            if action == agent_action:\n",
    "                correct += 1\n",
    "            total += 1\n",
    "            \n",
    "            \n",
    "            if step_number >= 1  and action != expert_actions[step_number-1]:\n",
    "                flip_total += 1\n",
    "                if agent_action == action and agent_actions[step_number-1] == expert_actions[step_number-1]:\n",
    "                    flip_correct += 1\n",
    "            \n",
    "            state, reward, done, info = env.step(action)\n",
    "            \n",
    "            # save data\n",
    "            #expert_observations[step_number] = state#.transpose(0,3,1,2) \n",
    "            expert_actions[step_number] = action\n",
    "            agent_actions[step_number] = agent_action\n",
    "            #episode_schedule[step_number] = np.array([ep_number, step_number])\n",
    "            \n",
    "            state = torch.from_numpy(state)\n",
    "            env.render()\n",
    "            \n",
    "            step_number += 1\n",
    "            \n",
    "            if info[\"flag_get\"]:\n",
    "                print(\"World {} stage {} completed\".format(world, stage))\n",
    "                ep_number += 1\n",
    "                break\n",
    "                \n",
    "    \n",
    "    total_fidelity = (correct / total) * 100.\n",
    "    flip_fidelity = (flip_correct / flip_total) * 100.\n",
    "    env.close()\n",
    "    print('got', flip_total, ' flip points')\n",
    "    \n",
    "    return total_fidelity, flip_fidelity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc6f48e1",
   "metadata": {},
   "source": [
    "# Train ResNet-BC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc52524e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train bbox\n",
    "bbox = models.resnet18(pretrained=False)\n",
    "bbox.conv1 = nn.Conv2d(4, 64, kernel_size=7,stride=2,padding=3,bias=False)\n",
    "bbox.fc = nn.Linear(512,7)\n",
    "bbox=bbox.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ddd810",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbox = train_iso(bbox, \n",
    "      train_loader, \n",
    "      test_loader, \n",
    "      iso_coeff=1e-8,#.01,\n",
    "      clst_coeff=4e-4,\n",
    "      sep_coeff=1e-4,\n",
    "      rep_coeff=1e-5,\n",
    "      n_epochs=2,#1000\n",
    "      main_lr=5e-6,\n",
    "      push_interval=None,\n",
    "      push_epochs=0,#20\n",
    "      resnetbc=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bb7f810",
   "metadata": {},
   "source": [
    "# Train ProtoX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baacd1f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make Prototype network\n",
    "net = ProtoIsoResNet(prototype_shape=(350,512*3*3),num_actions=7,tl=False, sim_method=0)# MY sim method\n",
    "net.isometry.weight.data.copy_(torch.eye(512*3*3))\n",
    "\n",
    "# Load encoder from pre-training (change enc40.pt to whatever the pre-trained encoder is) \n",
    "encoder_state = torch.load('mario_svae.pt').encoder.state_dict()\n",
    "#encoder_state = enc.state_dict()\n",
    "net.load_state_dict(encoder_state, strict=False)\n",
    "net = net.to(device)\n",
    "\n",
    "# Freeze encoder\n",
    "for parm in net.convunit.parameters():\n",
    "    parm.requries_grad = False\n",
    "\n",
    "net = net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "996cf493",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = train_iso(net, \n",
    "      train_loader, \n",
    "      test_loader, \n",
    "      iso_coeff=1e-8,#.01,\n",
    "      clst_coeff=4e-4,\n",
    "      sep_coeff=1e-4,\n",
    "      rep_coeff=1e-5,\n",
    "      n_epochs=2,#1000\n",
    "      main_lr=5e-6,\n",
    "      push_interval=25,\n",
    "      push_epochs=0)#20)\n",
    "torch.save(net, 'mario_ds.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c971dd05",
   "metadata": {},
   "source": [
    "# ProtoX t-SNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "313171e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_samples = []\n",
    "for sample in tqdm(test_expert_dataset):\n",
    "    img = torch.FloatTensor(sample[0]).unsqueeze(0).to(device)\n",
    "    label = sample[1]\n",
    "    # Encode image\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        encoded_img, _  = net.push_forward(img)\n",
    "    # Append to list\n",
    "    encoded_img = encoded_img.flatten().cpu().numpy()\n",
    "    encoded_sample = {f\"Enc. Variable {i}\": enc for i, enc in enumerate(encoded_img)}\n",
    "    encoded_sample['label'] = acts[label]\n",
    "    encoded_samples.append(encoded_sample)\n",
    "    \n",
    "encoded_samples = pd.DataFrame(encoded_samples)\n",
    "encoded_samples\n",
    "\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.express as px\n",
    "\n",
    "px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c44499",
   "metadata": {},
   "outputs": [],
   "source": [
    "#beta = 1.5\n",
    "tsne = TSNE(n_components=2)\n",
    "tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))\n",
    "\n",
    "fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42d25581",
   "metadata": {},
   "source": [
    "# Evaluate fidelity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd2ea058",
   "metadata": {},
   "outputs": [],
   "source": [
    "test(net,device,eval_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd0c1c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf, ff = flip_fidelity(bbox,output_path=None,n_flip=10000,num_interactions=30000)\n",
    "print(tf, ff)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70f1c8b2",
   "metadata": {},
   "source": [
    "# Make Pruned Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf96ac45",
   "metadata": {},
   "outputs": [],
   "source": [
    "eq = prot_equivs(net)\n",
    "unique_ix = list(eq)\n",
    "        \n",
    "pruned_net = merge_weights(net,device,eq,unique_ix,action_filter=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34ab90e1",
   "metadata": {},
   "source": [
    "# Visualization Fns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd10c322",
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_k_prots(net, oidx, action_filter=False, k=3, score_sort=False, sim_method=0, act=None):\n",
    "    enc, dist = net(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))\n",
    "    \n",
    "    if sim_method == 1:\n",
    "        sim = torch.log( (1 + dist) / (dist + 1e-10) )[0]#[pidx]\n",
    "    else:\n",
    "        sim = torch.exp(dist * -.05)[0]\n",
    "    try:\n",
    "        fcl = net.last_layer.weight.data.T\n",
    "    except Exception as e:\n",
    "        fcl = torch.exp(net.last_layer.log_weight.data.T)\n",
    "        \n",
    "    if act is None:\n",
    "        action = expert_actions[oidx].item()\n",
    "    else:\n",
    "        action = act\n",
    "    if not score_sort:\n",
    "        topk_idx = torch.topk(sim, k=k)\n",
    "    else:\n",
    "        scores = sim * net.last_layer.weight.data[int(action)]\n",
    "        topk_idx = torch.topk(scores,k=k)\n",
    "        \n",
    "    # TODO FIX NONES\n",
    "    idx = [topk_idx[1][j].item() for j in range(k)]\n",
    "    sims = [sim[topk_idx[1][j].item()] for j in range(k)]\n",
    "    fcs = [fcl[int(idx[j])][int(action)].item() for j in range(k)]\n",
    "        \n",
    "    return idx, sims, fcs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6afc81c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prot_rep(net, ix, push_loader,project=False,sim_method=0,max_samples=5000):\n",
    "    for pidx in [ix]:#tqdm(range(net.prototype_vectors.shape[0])):\n",
    "        best_prot = None\n",
    "        ct = 0\n",
    "        best_sim = -1*float('inf')\n",
    "        best_idx = -1\n",
    "        for batch, (x, actions) in enumerate(push_loader):\n",
    "            batch_size = x.shape[0]\n",
    "            \n",
    "            x = x.to(device)\n",
    "            enc, dist = net.push_forward(x)\n",
    "        \n",
    "            if sim_method == 1:\n",
    "                sim = torch.log( (1 + dist) / (dist + 1e-10) )\n",
    "            else:\n",
    "                sim = torch.exp(-.05 * dist)\n",
    "\n",
    "            sim = sim[:,pidx]\n",
    "                \n",
    "            bat_ix = torch.argmax(sim).item()\n",
    "            \n",
    "            data_ix = bat_ix + ct\n",
    "            \n",
    "            top_sim = sim[bat_ix].item()\n",
    "            \n",
    "            if top_sim > best_sim:\n",
    "                best_prot = enc[bat_ix]\n",
    "                best_idx = data_ix\n",
    "                best_sim = top_sim\n",
    "\n",
    "                \n",
    "            ct += batch_size\n",
    "            \n",
    "            if ct > max_samples:\n",
    "                break\n",
    "    print('prototype action is ',acts[expert_actions[data_ix]])\n",
    "    return color_observations[data_ix], best_prot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb38be27",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_color(net,\n",
    "                 oidx,\n",
    "                 push_loader,\n",
    "                 action_filter=False,\n",
    "                 k=3,\n",
    "                 fname=None,\n",
    "                 score_sort=False,\n",
    "                 sim_method=0):\n",
    "    net.eval()\n",
    "    acts = {0:'NOOP',1:'RIGHT',2:'RIGHT+A',3:'RIGHT+B',4:'RIGHT+A+B',5:'A',6:'LEFT'}\n",
    "\n",
    "    num_prototypes_per_action = net.num_prototypes // net.num_actions\n",
    "    \n",
    "    fig, axes = plt.subplots(1,1+k,figsize=(30,30))\n",
    "    \n",
    "    action = expert_actions[oidx].item()\n",
    "    \n",
    "    out, _ = net(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))\n",
    "\n",
    "    logit = out[0][int(action)].item()\n",
    "    # show input\n",
    "    axes[0].imshow(color_observations[oidx].astype(int))\n",
    "    if k > 1:\n",
    "        axes[0].set_title('Input w/ action: ' + acts[action],size=25)# + '\\nTotal points: '+str(logit),size=25)\n",
    "    else:\n",
    "        axes[0].set_title('Input w/ action: ' + acts[action] + '\\nat t='+str(oidx),size=40)# + '\\nTotal points: '+str(logit),size=25)\n",
    "    axes[0].set_xticklabels([])\n",
    "    axes[0].set_yticklabels([])\n",
    "    axes[0].set_xticks([])\n",
    "    axes[0].set_yticks([])\n",
    "    axes[0].axis('off')\n",
    "    \n",
    "    # show prots\n",
    "    top_k_ix, sims, fcs = top_k_prots(net, oidx, \n",
    "                                      action_filter=action_filter,\n",
    "                                      k=k,\n",
    "                                      sim_method=sim_method,\n",
    "                                      score_sort=score_sort)\n",
    "    for j in range(1,1+k):\n",
    "        im, _ = prot_rep(net, top_k_ix[j-1], push_loader)\n",
    "        axes[j].imshow(im.astype(np.uint8))\n",
    "        axes[j].set_xticklabels([])\n",
    "        axes[j].set_yticklabels([])\n",
    "        axes[j].set_xticks([])\n",
    "        axes[j].set_yticks([])\n",
    "        axes[j].axis('off')\n",
    "        if k > 1:\n",
    "            axes[j].set_title('Sim score: {:.2f} \\n'.format(sims[j-1])+acts[action]+' score: {:.2f}\\nPoints: {:.2f}'.format(fcs[j-1],sims[j-1]*fcs[j-1]),size=25)\n",
    "        else:\n",
    "             axes[j].set_title('Most Similar Prototype'.format(sims[j-1], fcs[j-1],sims[j-1]*fcs[j-1]),size=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc8b2d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "def color_ball(net, oidx,min_sim, n_samples=1000,alpha=.2,Af=True):        \n",
    "    fig, axes = plt.subplots(1,2)\n",
    "    axes[0].imshow(color_observations[oidx].astype(int))\n",
    "    \n",
    "    if Af:\n",
    "        enc1, _ = net.push_forward(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device)) \n",
    "    else:\n",
    "        enc1 = net.encoder(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))\n",
    "        \n",
    "    best_list = []\n",
    "    for i in range(n_samples):\n",
    "        if Af:\n",
    "            enc2, _ = net.push_forward(torch.FloatTensor(expert_observations[i]).unsqueeze(0).to(device)) \n",
    "        else:\n",
    "            enc2 = net.encoder(torch.FloatTensor(expert_observations[i]).unsqueeze(0).to(device))\n",
    "            \n",
    "        d = torch.norm(enc1 - enc2)\n",
    "        sim = torch.exp(-.05 * d)\n",
    "        \n",
    "        if sim > min_sim:\n",
    "            best_list.append(i)\n",
    "            \n",
    "    print('Found ', len(best_list))\n",
    "    for i, ix in enumerate(best_list):\n",
    "        axes[1].imshow(color_observations[ix].astype(int),alpha=alpha)\n",
    "    for i in [0,1]:        \n",
    "        axes[i].set_xticklabels([])\n",
    "        axes[i].set_yticklabels([])\n",
    "        axes[i].set_xticks([])\n",
    "        axes[i].set_yticks([])\n",
    "        axes[i].axis('off')\n",
    "    axes[0].set_title('Input')\n",
    "    axes[1].set_title('Top 30\\n Most Similar States')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd078430",
   "metadata": {},
   "source": [
    "# Compare similar states between Af and f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2947a988",
   "metadata": {},
   "outputs": [],
   "source": [
    "Af = torch.load('mario_ds.pt')\n",
    "f = torch.load('mario_svae.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92eff702",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_ball(f, oidx=137, n_samples=1000, min_sim=.95,alpha=.4,Af=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edae1fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_ball(net, oidx=137, n_samples=1000, min_sim=.45,alpha=.4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b43cff0",
   "metadata": {},
   "source": [
    "# Explain step 32 with pruned net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e125a19d",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_color(pruned_net.to(device),32, train_loader, action_filter=False, k=3,fname=None,sim_method=0,score_sort=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aba0189e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef455a95",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
