{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Widget to see future trajectories GAN generates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Widget is designed for Fetch Push and Fetch Pick And Place environments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "%matplotlib notebook\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=4, linewidth=300, suppress=True)\n",
    "from ipywidgets import interact\n",
    "from ipywidgets.widgets import IntSlider\n",
    "import joblib\n",
    "\n",
    "from plot_utils import generate_trajectories, load_experiment\n",
    "from envs.fetch_push import FetchPush\n",
    "from envs.fetch_pick_and_place import FetchPickAndPlace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the experiment. Here we choose to load in the GANs trained on FetchPickAndPlace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/henry/anaconda3/envs/deepRL/lib/python3.6/site-packages/sklearn/base.py:318: UserWarning: Trying to unpickle estimator StandardScaler from version 0.22 when using version 0.22.2.post1. This might lead to breaking code or invalid results. Use at your own risk.\n",
      "  UserWarning)\n"
     ]
    }
   ],
   "source": [
    "#expt_name = \"FetchPush\"\n",
    "expt_name = \"FetchPickAndPlace\"\n",
    "#env = FetchPush(remove_gripper=True)\n",
    "env = FetchPickAndPlace()\n",
    "controller, planner = load_experiment(expt_name, param_name=\"final\")\n",
    "planning_args = planner.args\n",
    "traj_len = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use these along with the planner to gather a few trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trajectory 0 completed\n",
      "Trajectory 1 completed\n",
      "Trajectory 2 completed\n"
     ]
    }
   ],
   "source": [
    "num_trajs = 3\n",
    "\n",
    "def gather_trajs(num_trajs, traj_len=50, render=False):\n",
    "    trajectories = []\n",
    "    for i in range(num_trajs):\n",
    "        traj = {}\n",
    "        obs = env.reset()\n",
    "        if render:\n",
    "            env.render()\n",
    "        end_goal = obs[\"desired_goal\"]\n",
    "        traj[\"states\"] = [obs[\"observation\"]]\n",
    "        traj[\"end_goal\"] = end_goal\n",
    "        traj[\"sim_state\"] = [env.save_state()]\n",
    "        traj[\"actions\"] = []\n",
    "        for j in range(traj_len-1):\n",
    "            action, _ = planner.generate_next_action(obs[\"observation\"],\n",
    "                                                    end_goal,\n",
    "                                                    controller.imagination,\n",
    "                                                    env, **planning_args)\n",
    "            traj[\"actions\"].append(action)\n",
    "            obs,_,_,_ = env.step(action)\n",
    "            if render:\n",
    "                env.render()\n",
    "            traj[\"states\"].append(obs[\"observation\"])\n",
    "            traj[\"sim_state\"].append(env.save_state())\n",
    "        print(\"Trajectory {} completed\".format(i))\n",
    "        trajectories.append(traj)\n",
    "    return trajectories\n",
    "\n",
    "trajectories = gather_trajs(num_trajs=num_trajs, render=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectories(num_imagined_trajs, num_steps, traj_ind, timestep, plot_exact=False, end_points_only=False, \n",
    "                      object_only=False):\n",
    "    generate_trajectories(controller.imagination, env, num_imagined_trajs, object=True, plot_exact=plot_exact,\n",
    "                         end_points_only=end_points_only, object_only=object_only, \n",
    "                          start_state=trajectories[traj_ind][\"states\"][timestep],\n",
    "                         start_env_state=trajectories[traj_ind][\"sim_state\"][timestep], fixed_axes=True,\n",
    "                         end_goal=trajectories[traj_ind][\"end_goal\"], plot_model=False, num_steps=num_steps,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualising the Imagined Future Trajectories\n",
    "The widget plots the position of the gripper (SOLID RED CIRCLE), the position of the object (WHITE CIRCLE WITH RED BORDER) and the target goal (SOLID BLACK CIRCLE). It then plots a randomly generated set of imagined future trajectories from the trained ensemble of GANs. The blue trajectories are the imagined trajectories of the gripper, and the red trajectories are the imagined object trajectories. If we select the \"plot_exact\" checkbox then the actual trajectories in the simulator (based on the imagined actions) are shown (black for the gripper trajectories, green for the object trajectories).\n",
    "\n",
    "The first slider lets us choose the number of imagined trajectories to plot, and the second chooses how many future time steps the imagined trajectories cover. The traj_ind slider lets us choose amongst the number of trajectories that we have generated, and the timestep slider let's us move through the trajectory to choose the starting time step. Setting cts_upd to true means that the plot will update as you drag the slider, but may be slow and not perform well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "610acc4d9f244763b7d25d9470a2de1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=10, continuous_update=False, description='num_imagined_trajs', min=1), I…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<function __main__.plot_trajectories(num_imagined_trajs, num_steps, traj_ind, timestep, plot_exact=False, end_points_only=False, object_only=False)>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cts_upd = False\n",
    "\n",
    "interact(plot_trajectories, num_imagined_trajs=IntSlider(min=1, max=100, value=10, step=1, title=\"Number of imagined trajs\",continuous_update=cts_upd), \n",
    "         num_steps=IntSlider(min=1, max=traj_len, value=5, step=1, title=\"Imagined traj lengths\", continuous_update=cts_upd), \n",
    "         traj_ind=IntSlider(min=0, max=len(trajectories)-1, value=0, step=1, title=\"Traj index\",continuous_update=cts_upd),\n",
    "         timestep=IntSlider(min=0, max=traj_len-1, value=0, step=1, title=\"Timestep\", continuous_update=cts_upd), \n",
    "         plot_exact=False, object_only=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
