{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "51ac1b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "#python3.7 \n",
    "# install dependencies with requirements.txt (included in zip)\n",
    "\n",
    "# To do experiments for Pong, or any other ALE game\n",
    "# just change env_id\n",
    "# e.g. for Pong, env_id = \"PongNoFrameskip-v4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb77c483",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import *\n",
    "from train_utils import *\n",
    "from pytorchtools import EarlyStopping\n",
    "\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from dagger import *\n",
    "#from training import *\n",
    "import cv2\n",
    "from visualization import *\n",
    "import pandas as pd\n",
    "\n",
    "# import EarlyStopping\n",
    "from pytorchtools import EarlyStopping\n",
    "\n",
    "import sys, os\n",
    "from copy import deepcopy\n",
    "\n",
    "if torch.cuda.is_available:\n",
    "        device = \"cuda:0\"\n",
    "        print('Using GPU')\n",
    "else:\n",
    "        device = \"cpu\"\n",
    "        print('using CPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f96514d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f05f32a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title get_data\n",
    "# get dataset of state-action pairs from expert\n",
    "from skimage.transform import resize\n",
    "from PIL import Image\n",
    "def gen_color_data(num_interactions=int(6e4), env_id=\"PongNoFrameskip-v4\", preprocess=False):\n",
    "    env, ppo_expert = get_env_and_model(env_id)\n",
    "\n",
    "    if env_id == 'CartPole-v1':\n",
    "        img = env.render(mode='rgb_array') \n",
    "    \n",
    "    state_shape = env.observation_space.shape\n",
    "    action_shape = env.action_space.shape\n",
    "\n",
    "    print('state shape: ', state_shape)\n",
    "    print('action shape: ', action_shape)\n",
    "    \n",
    "    atari_games = ['PongNoFrameskip-v4',\n",
    "                   'EnduroNoFrameskip-v4',\n",
    "                   'breakout'\n",
    "                   ]\n",
    "\n",
    "    \n",
    "    #gather data\n",
    "    if isinstance(env.action_space, gym.spaces.Box):\n",
    "      expert_observations = np.empty((num_interactions,) + env.observation_space.shape)\n",
    "      #expert_observations = np.empty((num_interactions, 4,84,84))\n",
    "      expert_actions = np.empty((num_interactions,) + (env.action_space.shape[0],))\n",
    "\n",
    "    else:\n",
    "      #expert_observations = np.empty((num_interactions,) + env.observation_space.shape)\n",
    "      expert_observations = np.empty((num_interactions, 4,84,84))\n",
    "      expert_actions = np.empty((num_interactions,) + env.action_space.shape)\n",
    "\n",
    "    episode_schedule = np.empty((num_interactions, 2))\n",
    "    color_observations = np.empty((num_interactions,84,84,3))\n",
    "      \n",
    "    obs = env.reset()\n",
    "\n",
    "    ep_number = 0\n",
    "    \n",
    "    for i in tqdm(range(num_interactions)):\n",
    "        action, _ = ppo_expert.predict(obs, deterministic=True)\n",
    "        if preprocess:\n",
    "            obs = crop_pong(obs)[0]\n",
    "            obs = np.expand_dims(resize(obs, (84,84,4)),0)\n",
    "\n",
    "        expert_observations[i]= obs.transpose(0,3,1,2)\n",
    "        frame = env.render(mode='rgb_array')\n",
    "        im = Image.fromarray(frame)\n",
    "        im = im.resize(size=(84,84), resample=Image.BICUBIC, reducing_gap=3.0)\n",
    "        color_observations[i] = np.array(im)\n",
    "        \n",
    "        expert_actions[i] = action\n",
    "\n",
    "        episode_schedule[i] = np.array([ep_number, i])\n",
    "        \n",
    "        obs, reward, done, info = env.step(action)\n",
    "        if done:\n",
    "            ep_number = ep_number + 1\n",
    "            obs = env.reset()\n",
    "\n",
    "    env.close()\n",
    "\n",
    "    return expert_observations, color_observations, expert_actions, episode_schedule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16cd7049",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataset_with_indices(cls):\n",
    "    \"\"\"\n",
    "    Modifies the given Dataset class to return a tuple data, target, index\n",
    "    instead of just data, target.\n",
    "    \"\"\"\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        data, target = cls.__getitem__(self, index)\n",
    "        return data, target, index\n",
    "\n",
    "    return type(cls.__name__, (cls,), {\n",
    "        '__getitem__': __getitem__,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e7a918e",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "599f552f",
   "metadata": {},
   "outputs": [],
   "source": [
    "if new_data:\n",
    "    expert_observations, color_observations, expert_actions, episode_schedule = gen_color_data(num_interactions=int(1e4), env_id=\"SeaquestNoFrameskip-v4\", preprocess=False)\n",
    "    np.savez_compressed(\n",
    "                        'seaquest_holdout.npz',\n",
    "                        expert_actions=expert_actions,#np.array(acts),\n",
    "                        color_observations=color_observations,\n",
    "                        expert_observations=expert_observations,#np.array(states),\n",
    "                        episode_schedule = episode_schedule#np.array(episode_schedule)\n",
    "                )\n",
    "else:\n",
    "    arrs = np.load('seaqest.npz')\n",
    "    expert_observations = arrs['expert_observations']\n",
    "    color_observations = arrs['color_observations']\n",
    "    expert_actions = arrs['expert_actions']\n",
    "    episode_schedule = arrs['episode_schedule']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57fc8084",
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_dataset = ExpertDataset(expert_observations, expert_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac8e604",
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_labels = torch.FloatTensor(episode_schedule[:,0]).to(device)\n",
    "step_labels = torch.FloatTensor(episode_schedule[:,1]).to(device)\n",
    "DatasetWithInDices = dataset_with_indices(ExpertDataset)\n",
    "dset = DatasetWithInDices(expert_observations, expert_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "826bb45a",
   "metadata": {},
   "outputs": [],
   "source": [
    "holdout_arrs = np.load('seaquest_holdout.npz')\n",
    "holdout_expert_observations = holdout_arrs['expert_observations']\n",
    "holdout_color_observations = holdout_arrs['color_observations']\n",
    "holdout_expert_actions = holdout_arrs['expert_actions']\n",
    "holdout_episode_schedule = holdout_arrs['episode_schedule']\n",
    "\n",
    "holdout_expert_dataset = ExpertDataset(holdout_expert_observations, holdout_expert_actions)\n",
    "eval_loader = th.utils.data.DataLoader(\n",
    "    dataset=holdout_expert_dataset, batch_size=128, shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1656f4c7",
   "metadata": {},
   "source": [
    "# Encoder Pretraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef0114d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset for encoder pretraining\n",
    "# RUN THIS FOR PRE-TRAINING\n",
    "\n",
    "train_size = int(0.8 * len(dset))\n",
    "test_size = int(0.2 * len(dset))\n",
    "\n",
    "\n",
    "train_expert_dataset, test_expert_dataset = random_split(\n",
    "        dset, [train_size, test_size], generator=torch.Generator().manual_seed(42))\n",
    "\n",
    "\n",
    "kwargs = {\"num_workers\": 8, \"pin_memory\": False}\n",
    "train_loader = th.utils.data.DataLoader(\n",
    "        dataset=train_expert_dataset, batch_size=64, shuffle=True, **kwargs\n",
    ")\n",
    "\n",
    "test_loader = th.utils.data.DataLoader(\n",
    "        dataset=test_expert_dataset, batch_size=64, shuffle=True, **kwargs,\n",
    ")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5a576b",
   "metadata": {},
   "outputs": [],
   "source": [
    "push_loader = th.utils.data.DataLoader(\n",
    "    dataset=expert_dataset, batch_size=128, shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0390f880",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ae_outputs(encoder,decoder,n=10):\n",
    "    plt.figure(figsize=(16,4.5))\n",
    "    try:\n",
    "        imgs, targets = next(iter(test_loader))\n",
    "    except Exception as e:\n",
    "        imgs, targets, _ = next(iter(test_loader))\n",
    "    \n",
    "    targets = targets.numpy()\n",
    "    for i in range(n):\n",
    "      ax = plt.subplot(2,n,i+1)\n",
    "      img = imgs[i].unsqueeze(0).to(device)\n",
    "      encoder.eval()\n",
    "      decoder.eval()\n",
    "      with torch.no_grad():\n",
    "         rec_img  = decoder(encoder(img))\n",
    "      for k in range(4):\n",
    "         plt.imshow(img.cpu().squeeze().numpy()[k], cmap='gist_gray')\n",
    "      ax.get_xaxis().set_visible(False)\n",
    "      ax.get_yaxis().set_visible(False)  \n",
    "      if i == n//2:\n",
    "        ax.set_title('Original images')\n",
    "      ax = plt.subplot(2, n, i + 1 + n)\n",
    "      for k in range(4):\n",
    "          plt.imshow(rec_img.cpu().squeeze().numpy()[k], cmap='gist_gray')  \n",
    "      ax.get_xaxis().set_visible(False)\n",
    "      ax.get_yaxis().set_visible(False)  \n",
    "      if i == n//2:\n",
    "         ax.set_title('Reconstructed images')\n",
    "    plt.show()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed76857e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#https://raw.githubusercontent.com/lyakaap/NetVLAD-pytorch/master/hard_triplet_loss.py\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class HardQuadrupletLoss(nn.Module):\n",
    "    \"\"\"Hard/Hardest Triplet Loss\n",
    "    (pytorch implementation of https://omoindrot.github.io/triplet-loss)\n",
    "\n",
    "    For each anchor, we get the hardest positive and hardest negative to form a triplet.\n",
    "    \"\"\"\n",
    "    def __init__(self, margin1=0.1, margin2=.05, hardest=False, squared=False, epsilon=10, tau=11):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            margin: margin for triplet loss\n",
    "            hardest: If true, loss is considered only hardest triplets.\n",
    "            squared: If true, output is the pairwise squared euclidean distance matrix.\n",
    "                If false, output is the pairwise euclidean distance matrix.\n",
    "        \"\"\"\n",
    "        super(HardQuadrupletLoss, self).__init__()\n",
    "        self.margin1 = margin1\n",
    "        self.margin2 = margin2\n",
    "        self.hardest = hardest\n",
    "        self.squared = squared\n",
    "        self.epsilon = epsilon\n",
    "        self.tau = tau\n",
    "\n",
    "    def forward(self, embeddings, labels, idx):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            labels: labels of the batch, of size (batch_size,)\n",
    "            embeddings: tensor of shape (batch_size, embed_dim)\n",
    "\n",
    "        Returns:\n",
    "            triplet_loss: scalar tensor containing the triplet loss\n",
    "        \"\"\"\n",
    "        pairwise_dist = _pairwise_distance(embeddings, squared=self.squared)\n",
    "\n",
    "        if self.hardest:\n",
    "            # Get the hardest positive pairs\n",
    "            mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()\n",
    "            valid_positive_dist = pairwise_dist * mask_anchor_positive\n",
    "            hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)\n",
    "\n",
    "            # Get the hardest negative1 pairs\n",
    "            mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()\n",
    "            max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)\n",
    "            anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (\n",
    "                    1.0 - mask_anchor_negative)\n",
    "            hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)\n",
    "            \n",
    "            # Get hardest negative 2 pairs\n",
    "\n",
    "            # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss\n",
    "            quad_loss = F.relu(hardest_positive_dist - hardest_negative_dist + self.margin1)\n",
    "            quad_loss += F.relu(hardest_positive_dist - hardest_negative2_dist + self.margin2)\n",
    "            quad_loss = torch.mean(quad_loss)\n",
    "        else:\n",
    "            anc_pos_dist = pairwise_dist.unsqueeze(dim=2)\n",
    "            anc_neg_dist = pairwise_dist.unsqueeze(dim=1)\n",
    "            anc_neg2_dist = pairwise_dist.unsqueeze(dim=0)\n",
    "\n",
    "            # Compute a 3D tensor of size (batch_size, batch_size, batch_size)\n",
    "            # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k\n",
    "            # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)\n",
    "            # and the 2nd (batch_size, 1, batch_size)\n",
    "            loss = F.relu(anc_pos_dist - anc_neg_dist + self.margin1)\n",
    "            loss += F.relu(anc_pos_dist - anc_neg2_dist + self.margin2)\n",
    "            \n",
    "            #print('\\ninit loss stats', torch.min(loss).item(), torch.max(loss).item())\n",
    "\n",
    "            mask = _get_quadruplet_mask(labels, idx).float()\n",
    "            quadruplet_loss = loss * mask\n",
    "            \n",
    "            #print('masked loss stats', torch.min(triplet_loss).item(), torch.max(triplet_loss).item())\n",
    "\n",
    "            # Remove negative losses (i.e. the easy triplets)\n",
    "            #quadruplet_loss = F.relu(quadruplet_loss)\n",
    "\n",
    "            # Count number of hard triplets (where triplet_loss > 0)\n",
    "            hard_quadruplets = torch.gt(quadruplet_loss, 1e-16).float()\n",
    "            num_hard_quadruplets = torch.sum(hard_quadruplets)\n",
    "\n",
    "            quadruplet_loss = torch.sum(quadruplet_loss) / (num_hard_quadruplets + 1e-16)\n",
    "\n",
    "        return quadruplet_loss\n",
    "\n",
    "\n",
    "\n",
    "def _pairwise_distance(x, squared=False, eps=1e-16):\n",
    "    # Compute the 2D matrix of distances between all the embeddings.\n",
    "\n",
    "    cor_mat = torch.matmul(x, x.t())\n",
    "    norm_mat = cor_mat.diag()\n",
    "    distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)\n",
    "    distances = F.relu(distances)\n",
    "\n",
    "    if not squared:\n",
    "        mask = torch.eq(distances, 0.0).float()\n",
    "        distances = distances + mask * eps\n",
    "        distances = torch.sqrt(distances)\n",
    "        distances = distances * (1.0 - mask)\n",
    "\n",
    "    return distances\n",
    "\n",
    "\n",
    "def _get_anchor_positive_triplet_mask(labels):\n",
    "    # Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.\n",
    "\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1\n",
    "\n",
    "    # Check if labels[i] == labels[j]\n",
    "    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)\n",
    "\n",
    "    mask = indices_not_equal * labels_equal\n",
    "\n",
    "    return mask\n",
    "\n",
    "\n",
    "def _get_anchor_negative_triplet_mask(labels):\n",
    "    # Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.\n",
    "\n",
    "    # Check if labels[i] != labels[k]\n",
    "    labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)\n",
    "    mask = labels_equal ^ 1\n",
    "\n",
    "    return mask\n",
    "\n",
    "\n",
    "def _get_quadruplet_mask(\n",
    "        labels,\n",
    "        idx\n",
    "    ):\n",
    "    B = labels.size(0)\n",
    "\n",
    "    # Make sure that i != j != k != l\n",
    "    indices_equal = torch.eye(B, dtype=torch.bool).to(device)  # [B, B] \n",
    "    indices_not_equal = ~indices_equal  # [B, B] \n",
    "    i_not_equal_j = indices_not_equal.view(B, B, 1, 1)  # [B, B, 1, 1]\n",
    "    j_not_equal_k = indices_not_equal.view(1, B, B, 1)  # [B, 1, 1, B] \n",
    "    k_not_equal_l = indices_not_equal.view(1, 1, B, B)  # [1, 1, B, B] \n",
    "    distinct_indices = i_not_equal_j & j_not_equal_k & k_not_equal_l  # [B, B, B, B] \n",
    "\n",
    "    # Make sure that labels[i] == labels[j] \n",
    "    #            and labels[j] != labels[k] \n",
    "    #            and labels[k] != labels[l]\n",
    "    labels_equal = labels.view(1, B) == labels.view(B, 1)  # [B, B]\n",
    "    i_equal_j = labels_equal.view(B, B, 1, 1)  # [B, B, 1, 1]\n",
    "    j_equal_k = labels_equal.view(1, B, B, 1)  # [1, B, B, 1]\n",
    "    l_equal_i = labels_equal.view(B, 1, 1, B)  # [1, 1, B, B]\n",
    "    label_match = i_equal_j & ~j_equal_k #& ~l_equal_i\n",
    "    \n",
    "    eps = 15#self.epsilon\n",
    "    tau = 15#self.tau\n",
    "    \n",
    "    ep_labels = episode_labels[idx]\n",
    "    lst = step_labels[idx] \n",
    "    \n",
    "    ep_labels_equal = ep_labels.view(1, B) == ep_labels.view(B, 1)  # [B, B]\n",
    "    i_equal_j_ep = labels_equal.view(B, B, 1, 1)  # [B, B, 1, 1]\n",
    "    j_equal_k_ep = labels_equal.view(1, B, B, 1)  # [1, B, B, 1]\n",
    "    k_equal_l_ep = labels_equal.view(1, 1, B, B)  # [1, 1, B, B]\n",
    "    episode_match = i_equal_j_ep & j_equal_k_ep & k_equal_l_ep\n",
    "    \n",
    "    # uncomment this to keep quadruplets in the same episode\n",
    "    '''\n",
    "    within_eps_over_tau = torch.logical_and(torch.abs(lst[:, None, None, None] - lst[None, :, None, None]) <= eps,\n",
    "                                            torch.abs(lst[:,None,None,None]-lst[None,None,:,None]) <= eps,\n",
    "                                            torch.abs(lst[:, None, None,None] - lst[None, None,None,:]) >= tau)\n",
    "    '''\n",
    "    within_eps_over_tau = (torch.abs(lst[:, None, None, None] - lst[None, :, None, None]) <= eps) & (torch.abs(lst[:,None,None,None]-lst[None,None,:,None] <= eps)) & (torch.abs(lst[:, None, None,None] - lst[None, None,None,:]) >= tau)\n",
    "    #return within_eps_over_tau & episode_match & label_match & distinct_indices  # [B, B, B, B] \n",
    "    return within_eps_over_tau & label_match & distinct_indices\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7abf9236",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Training function\n",
    "def train_siamese_epoch(vae, device, dataloader, optimizer, criterion,beta = 1., lbda = .1):\n",
    "    # Set train mode for both the encoder and the decoder\n",
    "    vae.train()\n",
    "    train_loss = 0.0\n",
    "    vae_loss_tot = 0.0\n",
    "    siam_loss_tot = 0.0\n",
    "    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)\n",
    "    for batch, (x, actions, idx) in enumerate(dataloader,1): \n",
    "        # Move tensor to the proper device\n",
    "        x = x.to(device)\n",
    "        actions = actions.to(device)\n",
    "        idx = idx.to(device)\n",
    "        \n",
    "        #forward pass\n",
    "        x_hat, z = vae(x)\n",
    "        \n",
    "        \n",
    "        # Evaluate VAE loss\n",
    "        vae_loss = ((x - x_hat)**2).sum() + beta * vae.encoder.kl\n",
    "\n",
    "        # Evaluate triplet/quadruplet loss\n",
    "        siam_loss = criterion(embeddings = z, labels = actions, idx = idx)\n",
    "        \n",
    "        loss = vae_loss + lbda * siam_loss \n",
    "        \n",
    "        # Backward pass\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        # Print batch loss\n",
    "        #print('\\t partial train loss (single batch): %f' % (loss.item()))\n",
    "        \n",
    "        train_loss+=loss.item()\n",
    "        vae_loss_tot += vae_loss.item()\n",
    "        siam_loss_tot += siam_loss.item()\n",
    "    \n",
    "    print('Avg VAE/SIAM loss: ', vae_loss_tot / len(dataloader.dataset), siam_loss_tot / len(dataloader.dataset))\n",
    "        \n",
    "    return train_loss / len(dataloader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f717215",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Set the random seed for reproducible results\n",
    "torch.manual_seed(0)\n",
    "\n",
    "d = 7\n",
    "\n",
    "vae = VAE(z_dim = d,nc=4)\n",
    "lr = 1e-4\n",
    "\n",
    "optimizer = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-7)\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(f'Selected device: {device}')\n",
    "\n",
    "vae.to(device)\n",
    "\n",
    "criterion = HardQuadrupletLoss(margin1=2., margin2=2.5,squared=True,hardest=False,epsilon=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a22fd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 1\n",
    "\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "   train_loss = train_siamese_epoch(vae,device,train_loader,optimizer, criterion,beta =1.5,lbda=1000)\n",
    "   #val_loss = test_epoch(vae,device,test_loader, beta = 1.5)\n",
    "   #print('\\n EPOCH {}/{} \\t train loss {:.3f} \\t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))\n",
    "   plot_ae_outputs(vae.encoder,vae.decoder,n=10)\n",
    "   print('epoch, train loss: ', epoch, train_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ed64ade",
   "metadata": {},
   "outputs": [],
   "source": [
    "        acts = {0:'NOOP',\n",
    "                1:'FIRE',\n",
    "                2:'UP',\n",
    "                3:'LEFT',\n",
    "                4:'RIGHT',\n",
    "                5:'DOWN',\n",
    "                6:'LEFT-FIRE',\n",
    "                7:'RIGHT-FIRE',\n",
    "                8:'UP-LEFT',\n",
    "                9:'UP-RIGHT',\n",
    "                10:'UP-FIRE',\n",
    "                11:'DOWN-LEFT',\n",
    "                12:'DOWN-RIGHT',\n",
    "                13:'DOWN-FIRE',\n",
    "                14:'UP-LEFT-FIRE',\n",
    "                15:'UP-RIGHT-FIRE',\n",
    "                16:'DOWN-LEFT-FIRE',\n",
    "                17:'DOWN-RIGHT-FIRE'\n",
    "                }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce33ded2",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(vae, 'seaquest_svae.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32bb0f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = torch.load('seaquest_svae.pt')\n",
    "encoded_samples = []\n",
    "for sample in tqdm(test_expert_dataset):\n",
    "    img = torch.FloatTensor(sample[0]).unsqueeze(0).to(device)\n",
    "    label = sample[1]\n",
    "    # Encode image\n",
    "    vae.eval()\n",
    "    with torch.no_grad():\n",
    "        encoded_img  = vae.encoder(img)\n",
    "    # Append to list\n",
    "    encoded_img = encoded_img.flatten().cpu().numpy()\n",
    "    encoded_sample = {f\"Enc. Variable {i}\": enc for i, enc in enumerate(encoded_img)}\n",
    "    encoded_sample['label'] = acts[label]\n",
    "    encoded_samples.append(encoded_sample)\n",
    "    \n",
    "encoded_samples = pd.DataFrame(encoded_samples)\n",
    "encoded_samples\n",
    "\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.express as px\n",
    "\n",
    "px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5640949",
   "metadata": {},
   "outputs": [],
   "source": [
    "#beta = 1.5\n",
    "tsne = TSNE(n_components=2)\n",
    "tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))\n",
    "\n",
    "fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb179813",
   "metadata": {},
   "source": [
    "## DOWNSTREAM TRAINING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ea223fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset for downstream \n",
    "# RUN THIS FOR DOWNSTREAM TRAINING\n",
    "\n",
    "train_size = int(0.8 * len(expert_dataset))\n",
    "#push_size = int(0.1 * len(expert_dataset))\n",
    "test_size = int(0.2 * len(expert_dataset))\n",
    "\n",
    "train_expert_dataset, test_expert_dataset = random_split(\n",
    "    expert_dataset, [train_size, test_size]\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "kwargs = {\"num_workers\": 8, \"pin_memory\": False}\n",
    "train_loader = th.utils.data.DataLoader(\n",
    "        dataset=train_expert_dataset, batch_size=32, shuffle=False, **kwargs\n",
    ")\n",
    "test_loader = th.utils.data.DataLoader(\n",
    "    dataset=test_expert_dataset, batch_size=32, shuffle=False, **kwargs\n",
    ")\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b15aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_ds_epoch(net, \n",
    "          train_loader, \n",
    "          optimizer,\n",
    "          iso_coeff=10,\n",
    "          clst_coeff =0.,\n",
    "          sep_coeff = 0.,\n",
    "          rep_coeff = 0.,\n",
    "          validation = False,\n",
    "          resnetbc=False\n",
    "         ):\n",
    "    running_CE = 0.\n",
    "    running_iso = 0.\n",
    "    running_clst= 0.\n",
    "    running_sep = 0.\n",
    "    running_rep = 0.\n",
    "    \n",
    "    if not resnetbc:\n",
    "        I = torch.eye(net.isometry.weight.data.shape[0]).to(device)\n",
    "        max_dist = net.prototype_shape[1]\n",
    "    \n",
    "    criterion = nn.CrossEntropyLoss()   \n",
    "    \n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for (states, actions) in train_loader:\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "        \n",
    "        if not validation:\n",
    "            optimizer.zero_grad()\n",
    "        \n",
    "        if not resnetbc:\n",
    "            logits, min_distances = net(states.float())\n",
    "        else:\n",
    "            logits = net(states.float())\n",
    "        \n",
    "        _, predicted = torch.max(logits.data, 1)\n",
    "        total += actions.size(0)\n",
    "        correct += (predicted == actions).sum().item()\n",
    "            \n",
    "\n",
    "        if not resnetbc:\n",
    "            #cluster cost\n",
    "            # it's the mean of the min distances between encoding and prototypes of same class\n",
    "            # where min is taken over batch dimension\n",
    "            # torch.t(model.prototype_action_identity[:,target]) is distance matrix with entries\n",
    "            # for prototypes of wrong class zeroed out\n",
    "            # torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1\n",
    "            # gives you the min distances between encodings and correct protos over \n",
    "            # batch dim\n",
    "            prototypes_of_correct_class = torch.t(net.prototype_action_identity[:,actions]).to(device)\n",
    "            inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1)\n",
    "            clst_cost = torch.mean(max_dist - inverted_distances)\n",
    "\n",
    "            #separation cost\n",
    "            prototypes_of_wrong_class = 1 - prototypes_of_correct_class\n",
    "            inverted_distances_to_nontarget_prototypes, _ = torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1)\n",
    "            sep_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)\n",
    "\n",
    "            #avg clustering cost\n",
    "            #avg_sep = torch.sum(min_distances * prototypes_of_wrong_class,dim=1) / torch.sum(prototypes_of_wrong_class,dim=1)\n",
    "            #avg_sep = torch.mean(avg_sep)\n",
    "\n",
    "            # Rep term\n",
    "            rep = torch.sum(torch.min(min_distances, dim=0)[0])\n",
    "            \n",
    "            CE = criterion(logits, actions.long())\n",
    "\n",
    "            A = net.isometry.weight.data\n",
    "            iso_penalty = torch.linalg.matrix_norm(torch.mm(A.T, A) - I, ord = 'fro')**2 # 2 gives operator norm\n",
    "\n",
    "            loss = CE + iso_coeff * iso_penalty + clst_coeff * clst_cost - sep_coeff * sep_cost + rep_coeff * rep \n",
    "\n",
    "            running_CE += CE.item()\n",
    "            running_iso += iso_penalty.item()\n",
    "            running_clst += clst_cost.item()\n",
    "            running_sep += sep_cost.item()\n",
    "            running_rep += rep.item()\n",
    "        else:\n",
    "            loss = criterion(logits, actions.long())\n",
    "            running_CE += loss.item()\n",
    "\n",
    "        if not validation:\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "     \n",
    "    if not validation:\n",
    "        print('Train Acc: %d %%' % (100 * correct / total))\n",
    "    else:\n",
    "        print('Val Acc: %d %%' % (100 * correct / total))\n",
    "    return running_CE, running_iso, running_clst, running_sep, running_rep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cad18703",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_iso(net, \n",
    "          train_loader, \n",
    "          val_loader,\n",
    "          iso_coeff=10,\n",
    "          clst_coeff =0.,\n",
    "          sep_coeff = 0.,\n",
    "          rep_coeff = 0.,\n",
    "          push_interval=None,\n",
    "          push_epochs=100,\n",
    "          n_epochs=100,\n",
    "          main_lr = 1e-6,\n",
    "          pruned=False,\n",
    "          lr_sched=False,\n",
    "          resnetbc=False\n",
    "         ):\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    net_optimizer = optim.Adam(net.parameters(), lr=main_lr)\n",
    "    if push_interval is not None:\n",
    "        ll_optimizer = optim.Adam(net.last_layer.parameters(), lr=1e-5)\n",
    "    else:\n",
    "        ll_optimizer = None\n",
    "    if lr_sched:\n",
    "        if pruned:\n",
    "            sched = StepLR(ll_optimizer, 10, .8)\n",
    "        else:\n",
    "            sched = StepLR(net_optimizer, 10, .8)\n",
    "    \n",
    "    \n",
    "    early_stopping = EarlyStopping(patience=50, verbose=True)\n",
    "    \n",
    "    #wandb.init(project='mario')\n",
    "    #wandb.watch(net)\n",
    "    \n",
    "    for i in range(n_epochs):\n",
    "        net.train()\n",
    "        if pruned:\n",
    "            (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                   train_loader, \n",
    "                                                   ll_optimizer,\n",
    "                                                   iso_coeff,\n",
    "                                                   clst_coeff,\n",
    "                                                   sep_coeff,\n",
    "                                                   rep_coeff,\n",
    "                                                   validation=False,\n",
    "                                                   resnetbc=resnetbc)\n",
    "        else:\n",
    "            (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                   train_loader, \n",
    "                                                   net_optimizer,\n",
    "                                                   iso_coeff,\n",
    "                                                   clst_coeff,\n",
    "                                                   sep_coeff,\n",
    "                                                   rep_coeff,\n",
    "                                                   validation=False,\n",
    "                                                   resnetbc=resnetbc)         \n",
    "        #wandb.log({'CE':ce,'Iso':iso,'Clst':clst,'Sep':sep,'Rep':rep})\n",
    "        print('\\nEpoch', i, {'CE':ce,'Iso':iso,'Clst':clst,'Sep':sep,'Rep':rep})\n",
    "        \n",
    "        if push_interval is not None:\n",
    "            if i > 0 and i % push_interval == 0:\n",
    "                # push\n",
    "                prots = save_proto(net, train_loader, project=True, device=device)\n",
    "                \n",
    "                #freeze isometry\n",
    "                for param in net.isometry.parameters():\n",
    "                    param.requires_grad = False\n",
    "                #freeze protos\n",
    "                net.prototype_vectors.requires_grad = False\n",
    "                \n",
    "                for j in range(push_epochs):\n",
    "                    (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                           train_loader, \n",
    "                                                           ll_optimizer,\n",
    "                                                           iso_coeff,\n",
    "                                                           clst_coeff,\n",
    "                                                           sep_coeff,\n",
    "                                                           rep_coeff,\n",
    "                                                           validation=False,\n",
    "                                                           resnetbc=resnetbc)\n",
    "                    #wandb.log({'ll CE':ce,'ll Iso':iso,'ll Clst':clst,'ll Sep':sep,'ll Rep':rep})\n",
    "                    \n",
    "                # train isometry\n",
    "                for param in net.isometry.parameters():\n",
    "                    param.requires_grad = True\n",
    "                \n",
    "                # train protos\n",
    "                net.prototype_vectors.requires_grad = True\n",
    "                    \n",
    "        # validation step\n",
    "        net.eval()\n",
    "        (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                               val_loader, \n",
    "                                               net_optimizer,\n",
    "                                               iso_coeff,\n",
    "                                               clst_coeff,\n",
    "                                               sep_coeff,\n",
    "                                               rep_coeff,\n",
    "                                               validation=True,\n",
    "                                               resnetbc=resnetbc)\n",
    "        early_stopping(ce, net)\n",
    "        #wandb.log({'val CE':ce,'val Iso':iso,'val Clst':clst,'val Sep':sep,'val Rep':rep})\n",
    "        \n",
    "        if early_stopping.early_stop:\n",
    "            print('early Stopping')\n",
    "            break\n",
    "    net.load_state_dict(torch.load('checkpoint.pt'))\n",
    "    \n",
    "    \n",
    "    # final push\n",
    "    print('Doing final push!')\n",
    "    if not resnetbc:\n",
    "        prots = save_proto(net, train_loader, project=True, device=device)\n",
    "                \n",
    "        #freeze isometry\n",
    "        for param in net.isometry.parameters():\n",
    "            param.requires_grad = False\n",
    "        #freeze protos\n",
    "        net.prototype_vectors.requires_grad = False\n",
    "                \n",
    "        for j in range(push_epochs):\n",
    "            (ce,iso,clst,sep,rep) = train_ds_epoch(net, \n",
    "                                                      train_loader, \n",
    "                                                      ll_optimizer,\n",
    "                                                      iso_coeff,\n",
    "                                                      clst_coeff,\n",
    "                                                      sep_coeff,\n",
    "                                                      rep_coeff,\n",
    "                                                      validation=False,\n",
    "                                                      resnetbc=resnetbc)\n",
    "            #wandb.log({'ll CE':ce,'ll Iso':iso,'ll Clst':clst,'ll Sep':sep,'ll Rep':rep})\n",
    "\n",
    "    if lr_sched:\n",
    "        sched.step()\n",
    "    \n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab8e83d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(net, device, test_loader):\n",
    "    net.eval()\n",
    "\n",
    "    total = 0.\n",
    "    correct = 0.\n",
    "    for (states, actions) in test_loader:\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "\n",
    "        logits, _ = net(states)\n",
    "\n",
    "        _, predicted = torch.max(logits.data, 1)\n",
    "        total += actions.size(0)\n",
    "        correct += (predicted == actions).sum().item()\n",
    "\n",
    "    acc = 100 * (correct / total)\n",
    "    return acc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "688186e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title get_data\n",
    "# get dataset of state-action pairs from expert\n",
    "from skimage.transform import resize\n",
    "from PIL import Image\n",
    "def flip_fidelity(net,num_interactions=int(6e4), n_flip=10000,env_id=\"SeaquestNoFrameskip-v4\", preprocess=False,save=False):\n",
    "    env, ppo_expert = get_env_and_model(env_id)\n",
    "    \n",
    "    state_shape = env.observation_space.shape\n",
    "    action_shape = env.action_space.shape\n",
    "\n",
    "    print('state shape: ', state_shape)\n",
    "    print('action shape: ', action_shape)\n",
    "    \n",
    "\n",
    "    expert_actions = np.empty((num_interactions,) + env.action_space.shape)\n",
    "    agent_actions = np.empty((num_interactions,) + env.action_space.shape)\n",
    "\n",
    "    episode_schedule = np.empty((num_interactions, 2))\n",
    "    \n",
    "    # save flip points\n",
    "    flip_observations = np.empty((n_flip, 4,84,84))\n",
    "    flip_actions = np.empty((n_flip,) + env.action_space.shape)\n",
    "        \n",
    "    obs = env.reset()\n",
    "\n",
    "    ep_number = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    flip_total = 0\n",
    "    flip_correct = 0\n",
    "    for i in tqdm(range(num_interactions)):\n",
    "        step_number = i\n",
    "        action, _ = ppo_expert.predict(obs, deterministic=True)\n",
    "        #PREPROCESS AFTER EXPERT IS DONE!!!!!!\n",
    "        if preprocess:\n",
    "            obs = crop_pong(obs)[0]\n",
    "            obs = np.expand_dims(resize(obs, (84,84,4)),0)\n",
    "\n",
    "        expert_observations[i]= obs.transpose(0,3,1,2)\n",
    "        frame = env.render(mode='rgb_array')#.astype(int)\n",
    "        im = Image.fromarray(frame)\n",
    "        im = im.resize(size=(84,84), resample=Image.BICUBIC, reducing_gap=3.0)\n",
    "        color_observations[i] = np.array(im)\n",
    "        \n",
    "        expert_actions[i] = action\n",
    "\n",
    "        episode_schedule[i] = np.array([ep_number, i])\n",
    "        \n",
    "\n",
    "        try:\n",
    "            l, _ = net(torch.FloatTensor(obs).permute(0,3,1,2).to(device))\n",
    "        except Exception as e:\n",
    "            l = net(torch.FloatTensor(obs).permute(0,3,1,2).to(device))\n",
    "        agent_policy = F.softmax(l, dim=1)\n",
    "        agent_action = torch.argmax(agent_policy).item()\n",
    "        \n",
    "        if action == agent_action:\n",
    "            correct += 1\n",
    "        total += 1\n",
    "            \n",
    "        agent_actions[i] = agent_action\n",
    "            \n",
    "        if step_number >= 1  and action != expert_actions[step_number-1]:\n",
    "            flip_total += 1\n",
    "            if agent_action == action and agent_actions[step_number-1] == expert_actions[step_number-1]:\n",
    "                flip_correct += 1\n",
    "            flip_observations[flip_total - 1] = obs\n",
    "            flip_actions = action\n",
    "        \n",
    "        obs, reward, done, info = env.step(action)\n",
    "        if done:\n",
    "            ep_number = ep_number + 1\n",
    "            obs = env.reset()\n",
    "            \n",
    "        if flip_total >= n_flip:\n",
    "            break\n",
    "\n",
    "    env.close()\n",
    "\n",
    "    return correct/total, flip_correct/flip_total"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37d60591",
   "metadata": {},
   "source": [
    "# ResNet BC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad5a06a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train bbox\n",
    "bbox = models.resnet18(pretrained=False)\n",
    "bbox.conv1 = nn.Conv2d(4, 64, kernel_size=7,stride=2,padding=3,bias=False)\n",
    "bbox.fc = nn.Linear(512,18)\n",
    "bbox=bbox.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36bb79aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "bbox = train_iso(bbox, \n",
    "      train_loader, \n",
    "      test_loader, \n",
    "      iso_coeff=1e-8,#.01,\n",
    "      clst_coeff=4e-4,\n",
    "      sep_coeff=1e-4,\n",
    "      ev_coeff=1e-5,\n",
    "      div_coeff=0.,#1e-8,\n",
    "      n_epochs=2,\n",
    "      main_lr=5e-6,\n",
    "      push_interval=None,\n",
    "      push_epochs=1,\n",
    "      resnetbc=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05e73ab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf, ff = flip_fidelity(bbox,n_flip=10000,num_interactions=30000, env_id=\"SeaquestNoFrameskip-v4\", preprocess=False)\n",
    "print(tf, ff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b56d65e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(bbox,'resnetbc_seaquest.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f297639b",
   "metadata": {},
   "source": [
    "# ProtoX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5eff1b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make Prototype network\n",
    "net = ProtoIsoResNet(prototype_shape=(360,512*3*3),num_actions=18,tl=False, sim_method=0)# MY sim method\n",
    "net.isometry.weight.data.copy_(torch.eye(512*3*3))\n",
    "\n",
    "# Load encoder from pre-training (change enc40.pt to whatever the pre-trained encoder is) \n",
    "encoder_state = torch.load('seaquest_svae.pt').encoder.state_dict()\n",
    "#encoder_state = enc.state_dict()\n",
    "net.load_state_dict(encoder_state, strict=False)\n",
    "net = net.to(device)\n",
    "\n",
    "# Freeze encoder\n",
    "for parm in net.convunit.parameters():\n",
    "    parm.requries_grad = False\n",
    "\n",
    "net = net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42730ff4",
   "metadata": {},
   "outputs": [],
   "source": [
    "I = torch.eye(net.isometry.weight.data.shape[0]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "597ccbc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = train_iso(net, \n",
    "      train_loader, \n",
    "      test_loader, \n",
    "      iso_coeff=1e-8,#.01,\n",
    "      clst_coeff=4e-4,\n",
    "      sep_coeff=1e-4,\n",
    "      rep_coeff=1e-5,\n",
    "      n_epochs=1,\n",
    "      main_lr=5e-6,\n",
    "      push_interval=25,\n",
    "      push_epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d87a3b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net, 'seaquest_svae_ds.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59f9ee84",
   "metadata": {},
   "source": [
    "# TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe9f46aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = torch.load('seaquest_svae_ds.pt')\n",
    "encoded_samples = []\n",
    "for sample in tqdm(test_expert_dataset):\n",
    "    img = torch.FloatTensor(sample[0]).unsqueeze(0).to(device)\n",
    "    label = sample[1]\n",
    "    # Encode image\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        encoded_img, _  = net.push_forward(img)\n",
    "    # Append to list\n",
    "    encoded_img = encoded_img.flatten().cpu().numpy()\n",
    "    encoded_sample = {f\"Enc. Variable {i}\": enc for i, enc in enumerate(encoded_img)}\n",
    "    encoded_sample['label'] = acts[label]\n",
    "    encoded_samples.append(encoded_sample)\n",
    "    \n",
    "encoded_samples = pd.DataFrame(encoded_samples)\n",
    "encoded_samples\n",
    "\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.express as px\n",
    "\n",
    "px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b05095",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#beta = 1.5\n",
    "tsne = TSNE(n_components=2)\n",
    "tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))\n",
    "\n",
    "fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "977d8a45",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.unique(net.prototype_vectors, dim=0).shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7004a5d9",
   "metadata": {},
   "source": [
    "# Fidelity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9309cb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "test(net,device,eval_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c94d9db",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf, ff = flip_fidelity(net,n_flip=10000,num_interactions=30000, env_id=\"SeaquestNoFrameskip-v4\", preprocess=False)\n",
    "print(tf, ff)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42c36d74",
   "metadata": {},
   "source": [
    "# VISUALIZATION FNs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639430b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_k_prots(net, oidx, action_filter=False, k=3, score_sort=False, sim_method=0, act=None):\n",
    "    enc, dist = net(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))\n",
    "    \n",
    "    if sim_method == 1:\n",
    "        sim = torch.log( (1 + dist) / (dist + 1e-10) )[0]#[pidx]\n",
    "    else:\n",
    "        sim = torch.exp(dist * -.05)[0]\n",
    "    try:\n",
    "        fcl = net.last_layer.weight.data.T\n",
    "    except Exception as e:\n",
    "        fcl = torch.exp(net.last_layer.log_weight.data.T)\n",
    "        \n",
    "    if act is None:\n",
    "        action = expert_actions[oidx].item()\n",
    "    else:\n",
    "        action = act\n",
    "    if not score_sort:\n",
    "        topk_idx = torch.topk(sim, k=k)\n",
    "    else:\n",
    "        scores = sim * net.last_layer.weight.data[int(action)]\n",
    "        topk_idx = torch.topk(scores,k=k)\n",
    "        \n",
    "    # TODO FIX NONES\n",
    "    idx = [topk_idx[1][j].item() for j in range(k)]\n",
    "    sims = [sim[topk_idx[1][j].item()] for j in range(k)]\n",
    "    fcs = [fcl[int(idx[j])][int(action)].item() for j in range(k)]\n",
    "        \n",
    "    return idx, sims, fcs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce846cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prot_rep(net, ix, push_loader,project=False,sim_method=0,max_samples=5000):\n",
    "    for pidx in [ix]:#tqdm(range(net.prototype_vectors.shape[0])):\n",
    "        best_prot = None\n",
    "        ct = 0\n",
    "        best_sim = -1*float('inf')\n",
    "        best_idx = -1\n",
    "        for batch, (x, actions) in enumerate(push_loader):\n",
    "            batch_size = x.shape[0]\n",
    "            \n",
    "            x = x.to(device)\n",
    "            enc, dist = net.push_forward(x)\n",
    "        \n",
    "            if sim_method == 1:\n",
    "                sim = torch.log( (1 + dist) / (dist + 1e-10) )#[0][pidx]\n",
    "            else:\n",
    "                sim = torch.exp(-.05 * dist)\n",
    "\n",
    "            sim = sim[:,pidx]\n",
    "                \n",
    "            bat_ix = torch.argmax(sim).item()\n",
    "            \n",
    "            data_ix = bat_ix + ct\n",
    "            \n",
    "            top_sim = sim[bat_ix].item()\n",
    "            \n",
    "            if top_sim > best_sim:\n",
    "                best_prot = enc[bat_ix]\n",
    "                best_idx = data_ix\n",
    "                best_sim = top_sim\n",
    "\n",
    "                \n",
    "            ct += batch_size\n",
    "            \n",
    "            if ct > max_samples:\n",
    "                break\n",
    "    print('prototype action is ',acts[expert_actions[data_ix]])\n",
    "    return color_observations[data_ix], best_prot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33f53d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_color(net,\n",
    "                 oidx,\n",
    "                 push_loader,\n",
    "                 action_filter=False,\n",
    "                 k=3,\n",
    "                 fname=None,\n",
    "                 score_sort=False,\n",
    "                 sim_method=0):\n",
    "    net.eval()\n",
    "    acts = {0:'NOOP',1:'RIGHT',2:'RIGHT+A',3:'RIGHT+B',4:'RIGHT+A+B',5:'A',6:'LEFT'}\n",
    "\n",
    "    num_prototypes_per_action = net.num_prototypes // net.num_actions\n",
    "    \n",
    "    fig, axes = plt.subplots(1,1+k,figsize=(30,30))\n",
    "    \n",
    "    action = expert_actions[oidx].item()\n",
    "    \n",
    "    out, _ = net(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))\n",
    "\n",
    "    logit = out[0][int(action)].item()\n",
    "    # show input\n",
    "    axes[0].imshow(color_observations[oidx].astype(int))\n",
    "    if k > 1:\n",
    "        axes[0].set_title('Input w/ action: ' + acts[action],size=25)# + '\\nTotal points: '+str(logit),size=25)\n",
    "    else:\n",
    "        axes[0].set_title('Input w/ action: ' + acts[action] + '\\nat t='+str(oidx),size=40)# + '\\nTotal points: '+str(logit),size=25)\n",
    "    axes[0].set_xticklabels([])\n",
    "    axes[0].set_yticklabels([])\n",
    "    axes[0].set_xticks([])\n",
    "    axes[0].set_yticks([])\n",
    "    axes[0].axis('off')\n",
    "    \n",
    "    # show prots\n",
    "    top_k_ix, sims, fcs = top_k_prots(net, oidx, \n",
    "                                      action_filter=action_filter,\n",
    "                                      k=k,\n",
    "                                      sim_method=sim_method,\n",
    "                                      score_sort=score_sort)\n",
    "    for j in range(1,1+k):\n",
    "        im, _ = prot_rep(net, top_k_ix[j-1], push_loader)\n",
    "        axes[j].imshow(im.astype(np.uint8))\n",
    "        axes[j].set_xticklabels([])\n",
    "        axes[j].set_yticklabels([])\n",
    "        axes[j].set_xticks([])\n",
    "        axes[j].set_yticks([])\n",
    "        axes[j].axis('off')\n",
    "        if k > 1:\n",
    "            axes[j].set_title('Sim score: {:.2f} \\n'.format(sims[j-1])+acts[action]+' score: {:.2f}\\nPoints: {:.2f}'.format(fcs[j-1],sims[j-1]*fcs[j-1]),size=25)\n",
    "        else:\n",
    "             axes[j].set_title('Most Similar Prototype'.format(sims[j-1], fcs[j-1],sims[j-1]*fcs[j-1]),size=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "581d16a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def color_ball(net, oidx,min_sim, n_samples=1000,alpha=.2,Af=True):        \n",
    "    fig, axes = plt.subplots(1,2)\n",
    "    axes[0].imshow(color_observations[oidx].astype(int))\n",
    "    \n",
    "    if Af:\n",
    "        enc1, _ = net.push_forward(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device)) \n",
    "    else:\n",
    "        enc1 = net.encoder(torch.FloatTensor(expert_observations[oidx]).unsqueeze(0).to(device))\n",
    "        \n",
    "    best_list = []\n",
    "    for i in range(n_samples):\n",
    "        if Af:\n",
    "            enc2, _ = net.push_forward(torch.FloatTensor(expert_observations[i]).unsqueeze(0).to(device)) \n",
    "        else:\n",
    "            enc2 = net.encoder(torch.FloatTensor(expert_observations[i]).unsqueeze(0).to(device))\n",
    "            \n",
    "        d = torch.norm(enc1 - enc2)\n",
    "        sim = torch.exp(-.05 * d)\n",
    "        \n",
    "        if sim > min_sim:\n",
    "            best_list.append(i)\n",
    "            \n",
    "    print('Found ', len(best_list))\n",
    "    for i, ix in enumerate(best_list):\n",
    "        axes[1].imshow(color_observations[ix].astype(int),alpha=alpha)\n",
    "    for i in [0,1]:        \n",
    "        axes[i].set_xticklabels([])\n",
    "        axes[i].set_yticklabels([])\n",
    "        axes[i].set_xticks([])\n",
    "        axes[i].set_yticks([])\n",
    "        axes[i].axis('off')\n",
    "    axes[0].set_title('Input')\n",
    "    axes[1].set_title('Top 30\\n Most Similar States')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7753b912",
   "metadata": {},
   "source": [
    "# Compare Similar states between Af and f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bded941",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = torch.load('seaquest_svae.pt')\n",
    "Af = torch.load('seaquest_svae_ds.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "991298f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_ball(f,200,.943,n_samples=5000,Af=False,alpha=.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1780af2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_ball(Af,200,.507,n_samples=5000,Af=True,alpha=.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "428df437",
   "metadata": {},
   "source": [
    "# make pruned net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f92ca5f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "eq = prot_equivs(net)\n",
    "unique_ix = list(eq)\n",
    "        \n",
    "pruned_net = merge_weights(net,device,eq,unique_ix,action_filter=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36fb4c32",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_net.num_prototypes = pruned_net.prototype_vectors.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b9ca027",
   "metadata": {},
   "source": [
    "# Explain step 32 using pruned network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a8286f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get prototype representations without projecting into net\n",
    "pruned_prots = save_proto(pruned_net, train_loader, project=False, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66d8e4d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "net.last_layer.weight.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4956f036",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_net.last_layer.weight.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8435e420",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_color(pruned_net.to(device),32, train_loader, action_filter=False, k=3,fname=None,sim_method=0,score_sort=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "475b103a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42f2efca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
