{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5eeje4O8fviH",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Model-Based Reinforcement Learning\n",
    "\n",
    "## Principle\n",
    "We consider the optimal control problem of an MDP with a **known** reward function $R$ and subject to **unknown deterministic** dynamics $s_{t+1} = f(s_t, a_t)$:\n",
    "\n",
    "$$\\max_{(a_0,a_1,\\dotsc)} \\sum_{t=0}^\\infty \\gamma^t R(s_t,a_t)$$\n",
    "\n",
    "In **model-based reinforcement learning**, this problem is solved in **two steps**:\n",
    "1. **Model learning**:\n",
    "We learn a model of the dynamics $f_\\theta \\simeq f$ through regression on interaction data.\n",
    "2. **Planning**:\n",
    "We leverage the dynamics model $f_\\theta$ to compute the optimal trajectory $$\\max_{(a_0,a_1,\\dotsc)} \\sum_{t=0}^\\infty \\gamma^t R(\\hat{s}_t,a_t)$$ following the learnt dynamics $\\hat{s}_{t+1} = f_\\theta(\\hat{s}_t, a_t)$.\n",
    "\n",
    "(We can easily extend to unknown rewards and stochastic dynamics, but we consider the simpler case in this notebook for ease of presentation)\n",
    "\n",
    "\n",
    "## Motivation\n",
    "\n",
    "### Sparse rewards\n",
    "* In model-free reinforcement learning, we only obtain a reinforcement signal when encountering rewards. In environment with **sparse rewards**, the chance of obtaining a reward randomly is **negligible**, which prevents any learning.\n",
    "* However, even in the **absence of rewards** we still receive a **stream of state transition data**. We can exploit this data to learn about the task at hand.\n",
    "\n",
    "### Complexity of the policy/value vs dynamics:\n",
    "Is it easier to decide which action is best, or to predict what is going to happen?\n",
    "* Some problems can have **complex dynamics** but a **simple optimal policy or value function**. For instance, consider the problem of learning to swim. Predicting the movement requires understanding fluid dynamics and vortices while the optimal policy simply consists in moving the limbs in sync.\n",
    "* Conversely, other problems can have **simple dynamics** but **complex policies/value functions**. Think of the game of Go, its rules are simplistic (placing a stone merely changes the board state at this location) but the corresponding optimal policy is very complicated.\n",
    "\n",
    "Intuitively, model-free RL should be applied to the first category of problems and model-based RL to the second category.\n",
    "\n",
    "### Inductive bias\n",
    "Oftentimes, real-world problems exhibit a particular **structure**: for instance, any problem involving motion of physical objects will be **continuous**. It can also be **smooth**, **invariant** to translations, etc. This knowledge can then be incorporated in machine learning models to foster efficient learning. In contrast, there can often be **discontinuities** in the policy decisions or value function: e.g. think of a collision vs near-collision state.\n",
    "\n",
    "###  Sample efficiency\n",
    "Overall, it is generally recognized that model-based approaches tend to **learn faster** than model-free techniques (see e.g. [[Sutton, 1990]](http://papersdb.cs.ualberta.ca/~papersdb/uploaded_files/paper_p160-sutton.pdf.stjohn)).\n",
    "\n",
    "### Interpretability\n",
    "In real-world applications, we may want to know **how a policy will behave before actually executing it**, for instance for **safety-check** purposes. However, model-free reinforcement learning only recommends which action to take at current time without being able to predict its consequences. In order to obtain the trajectory, we have no choice but executing the policy. In stark contrast, model-based methods a more interpretable in the sense that we can probe the policy for its intended (and predicted) trajectory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2-oVNY_KTw6R"
   },
   "source": [
    "## Our challenge: Automated Parking System\n",
    "\n",
    "We consider the **parking-v0** task of the [highway-env](https://github.com/eleurent/highway-env) environment. It is a **goal-conditioned continuous control** task where an agent **drives a car** by controlling the gaz pedal and steering angle and must **park in a given location** with the appropriate heading.\n",
    "\n",
    "This MDP has several properties wich justifies using model-based methods:\n",
    "* The policy/value is highly dependent on the goal which adds a significant level of complexity to a model-free learning process, whereas the dynamics are completely independent of the goal and hence can be simpler to learn.\n",
    "* In the context of an industrial application, we can reasonably expect for safety concerns that the planned trajectory is required to be known in advance, before execution.\n",
    "\n",
    "###  Warming up\n",
    "We start with a few useful installs and imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bzMSuJEOfviP",
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'load_results'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-d0599839efc1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;31m# Environment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mhighway_env\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/gym/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbenchmarks\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbenchmark_spec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmake\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscoreboard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mupload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     32\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwrappers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/gym/scoreboard/api.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtarfile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtempfile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbenchmark_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmonitoring\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscoreboard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclient\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mresource\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutil\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/gym/monitoring/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmonitoring\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstats_recorder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mStatsRecorder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmonitoring\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvideo_recorder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVideoRecorder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrappers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmonitoring\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_results\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdetect_training_manifests\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload_env_info_from_manifests\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_open_monitors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'load_results'"
     ]
    }
   ],
   "source": [
    "\"\"\" Remove \" > /dev/null 2>&1\" to see what is going on under the hood\"\"\"\n",
    "\n",
    "# # Install environment and visualization dependencies \n",
    "# !pip install git+https://github.com/eleurent/highway-env#egg=highway-env  > /dev/null 2>&1\n",
    "# !pip install gym pyvirtualdisplay > /dev/null 2>&1\n",
    "# !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n",
    "\n",
    "# Environment\n",
    "import gym\n",
    "import highway_env\n",
    "\n",
    "# Models and computation\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from collections import namedtuple\n",
    "# torch.set_default_tensor_type(\"torch.cuda.FloatTensor\")\n",
    "\n",
    "# Visualization\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "from tqdm import tnrange\n",
    "from IPython import display as ipythondisplay\n",
    "from pyvirtualdisplay import Display\n",
    "# from gym.wrappers import Monitor\n",
    "import base64\n",
    "\n",
    "# IO\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "n2Bu_Pqop0E7"
   },
   "source": [
    "We also define a simple helper function for visualization of episodes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "so7yH4ucyB-3"
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'Display' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-aab452e620ac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdisplay\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvisible\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1400\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m900\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mshow_video\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mhtml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'Display' is not defined"
     ]
    }
   ],
   "source": [
    "display = Display(visible=0, size=(1400, 900))\n",
    "display.start()\n",
    "\n",
    "def show_video():\n",
    "    html = []\n",
    "    for mp4 in Path(\"video\").glob(\"*.mp4\"):\n",
    "        video_b64 = base64.b64encode(mp4.read_bytes())\n",
    "        html.append('''<video alt=\"{}\" autoplay \n",
    "                      loop controls style=\"height: 400px;\">\n",
    "                      <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
    "                 </video>'''.format(mp4, video_b64.decode('ascii')))\n",
    "    ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nFtBY6JSqPFa"
   },
   "source": [
    "### Let's try it!\n",
    "\n",
    "Make the environment, and run an episode with random actions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jKZt9Cb1rJ6n"
   },
   "outputs": [],
   "source": [
    "env = gym.make(\"parking-v0\")\n",
    "env = Monitor(env, './video', force=True, video_callable=lambda episode: True)\n",
    "env.reset()\n",
    "done = False\n",
    "while not done:\n",
    "    action = env.action_space.sample()\n",
    "    obs, reward, done, info = env.step(action)\n",
    "env.close()\n",
    "show_video()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ewG5f_essAS0"
   },
   "source": [
    "The environment is a `GoalEnv`, which means the agent receives a dictionary containing both the current `observation` and the `desired_goal` that conditions its policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XIC98mGhr7v6"
   },
   "outputs": [],
   "source": [
    "print(\"Observation format:\", obs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "voagCILztJ3J"
   },
   "source": [
    "There is also an `achieved_goal` that won't be useful here (it only serves when the state and goal spaces are different, as a projection from the observation to the goal space).\n",
    "\n",
    "Alright! We are now ready to apply the model-based reinforcement learning paradigm."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I2PuVAvyfvib",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Experience collection\n",
    "First, we randomly interact with the environment to produce a batch of experiences \n",
    "\n",
    "$$D = \\{s_t, a_t, s_{t+1}\\}_{t\\in[1,N]}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tvUYSL7sfvie",
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "Transition = namedtuple('Transition', ['state', 'action', 'next_state'])\n",
    "\n",
    "def collect_interaction_data(env, size=1000, action_repeat=2):\n",
    "    data, done = [], True\n",
    "    for _ in tnrange(size, desc=\"Collect interaction data\"):\n",
    "        action = env.action_space.sample()\n",
    "        for _ in range(action_repeat):\n",
    "            previous_obs = env.reset() if done else obs\n",
    "            obs, reward, done, info = env.step(action)\n",
    "            data.append(Transition(torch.Tensor(previous_obs[\"observation\"]),\n",
    "                                   torch.Tensor(action),\n",
    "                                   torch.Tensor(obs[\"observation\"])))\n",
    "    return data\n",
    "\n",
    "data = collect_interaction_data(env)\n",
    "print(\"Sample transition:\", data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4Th1JezEfvir"
   },
   "source": [
    "## Build a dynamics model\n",
    "\n",
    "We now design a model to represent the system dynamics. We choose  a **structured model** inspired from *Linear Time-Invariant (LTI) systems* \n",
    "\n",
    "$$\\dot{x} = f_\\theta(x, u) = A_\\theta(x, u)x + B_\\theta(x, u)u$$\n",
    "\n",
    "where the $(x, u)$ notation comes from the Control Theory community and stands for the state and action $(s,a)$. Intuitively, we learn at each point $(x_t, u_t)$ the **linearization** of the true dynamics $f$ with respect to $(x, u)$.\n",
    "\n",
    "We parametrize $A_\\theta$ and $B_\\theta$ as two fully-connected networks with one hidden layer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "F7Gl2kKJfviu",
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class DynamicsModel(nn.Module):\n",
    "    STATE_X = 0\n",
    "    STATE_Y = 1\n",
    "\n",
    "    def __init__(self, state_size, action_size, hidden_size, dt):\n",
    "        super().__init__()\n",
    "        self.state_size, self.action_size, self.dt = state_size, action_size, dt\n",
    "        A_size, B_size = state_size * state_size, state_size * action_size\n",
    "        self.A1 = nn.Linear(state_size + action_size, hidden_size)\n",
    "        self.A2 = nn.Linear(hidden_size, A_size)\n",
    "        self.B1 = nn.Linear(state_size + action_size, hidden_size)\n",
    "        self.B2 = nn.Linear(hidden_size, B_size)\n",
    "\n",
    "    def forward(self, x, u):\n",
    "        \"\"\"\n",
    "            Predict x_{t+1} = f(x_t, u_t)\n",
    "        :param x: a batch of states\n",
    "        :param u: a batch of actions\n",
    "        \"\"\"\n",
    "        xu = torch.cat((x, u), -1)\n",
    "        xu[:, self.STATE_X:self.STATE_Y+1] = 0  # Remove dependency in (x,y)\n",
    "        A = self.A2(F.relu(self.A1(xu)))\n",
    "        A = torch.reshape(A, (x.shape[0], self.state_size, self.state_size))\n",
    "        B = self.B2(F.relu(self.B1(xu)))\n",
    "        B = torch.reshape(B, (x.shape[0], self.state_size, self.action_size))\n",
    "        dx = A @ x.unsqueeze(-1) + B @ u.unsqueeze(-1)\n",
    "        return x + dx.squeeze()*self.dt\n",
    "\n",
    "\n",
    "dynamics = DynamicsModel(state_size=env.observation_space.spaces[\"observation\"].shape[0],\n",
    "                         action_size=env.action_space.shape[0],\n",
    "                         hidden_size=64,\n",
    "                         dt=1/env.unwrapped.config[\"policy_frequency\"])\n",
    "print(\"Forward initial model on a sample transition:\", dynamics(data[0].state.unsqueeze(0),\n",
    "                                                                data[0].action.unsqueeze(0)).detach())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FFsgc7gffvi0"
   },
   "source": [
    "## Fit the model on data\n",
    "We can now train our model $f_\\theta$ in a supervised fashion to minimize an MSE loss $L^2(f_\\theta; D)$ over our experience batch $D$ by stochastic gradient descent:\n",
    "\n",
    "$$L^2(f_\\theta; D) = \\frac{1}{|D|}\\sum_{s_t,a_t,s_{t+1}\\in D}||s_{t+1}- f_\\theta(s_t, a_t)||^2$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NwCDLD1wfvi2",
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(dynamics.parameters(), lr=0.01)\n",
    "\n",
    "# Split dataset into training and validation\n",
    "train_ratio = 0.7\n",
    "train_data, validation_data = data[:int(train_ratio * len(data))], \\\n",
    "                              data[int(train_ratio * len(data)):]\n",
    "\n",
    "def compute_loss(model, data_t, loss_func = torch.nn.MSELoss()):\n",
    "    states, actions, next_states = data_t\n",
    "    predictions = model(states, actions)\n",
    "    return loss_func(predictions, next_states)\n",
    "\n",
    "def transpose_batch(batch):\n",
    "    return Transition(*map(torch.stack, zip(*batch)))\n",
    "\n",
    "def train(model, train_data, validation_data, epochs=1500):\n",
    "    train_data_t = transpose_batch(train_data)\n",
    "    validation_data_t = transpose_batch(validation_data)\n",
    "    losses = np.full((epochs, 2), np.nan)\n",
    "    for epoch in tnrange(epochs, desc=\"Train dynamics\"):\n",
    "        # Compute loss gradient and step optimizer\n",
    "        loss = compute_loss(model, train_data_t)\n",
    "        validation_loss = compute_loss(model, validation_data_t)\n",
    "        losses[epoch] = [loss.detach().numpy(), validation_loss.detach().numpy()]\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # Plot losses\n",
    "    plt.plot(losses)\n",
    "    plt.yscale(\"log\")\n",
    "    plt.xlabel(\"epochs\")\n",
    "    plt.ylabel(\"loss\")\n",
    "    plt.legend([\"training\", \"validation\"])\n",
    "    plt.show()\n",
    "\n",
    "train(dynamics, data, validation_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NXBODCuYfvi_",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Visualize trained dynamics\n",
    "\n",
    "In order to qualitatively evaluate our model, we can choose some values of steering angle *(right, center, left)* and acceleration *(slow, fast)* in order to predict and visualize the corresponding trajectories from an initial state.  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SMPA55bCfvjB",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def predict_trajectory(state, actions, model, action_repeat=1):\n",
    "    states = []\n",
    "    for action in actions:\n",
    "        for _ in range(action_repeat):\n",
    "            state = model(state, action)\n",
    "            states.append(state)\n",
    "    return torch.stack(states, dim=0)\n",
    "\n",
    "def plot_trajectory(states, color):\n",
    "    scales = np.array(highway_env.envs.parking_env.ParkingEnv.DEFAULT_CONFIG[\"observation\"][\"scales\"])\n",
    "    states = np.clip(states.squeeze(1).detach().numpy() * scales, -100, 100)\n",
    "    plt.plot(states[:, 0], states[:, 1], color=color, marker='.')\n",
    "    plt.arrow(states[-1,0], states[-1,1], states[-1,4]*1, states[-1,5]*1, color=color)\n",
    "\n",
    "def visualize_trajectories(model, state, horizon=15):\n",
    "    plt.cla()\n",
    "    # Draw a car\n",
    "    plt.plot(state.numpy()[0]+2.5*np.array([-1, -1, 1, 1, -1]),\n",
    "             state.numpy()[1]+1.0*np.array([-1, 1, 1, -1, -1]), 'k')\n",
    "    # Draw trajectories\n",
    "    state = state.unsqueeze(0)\n",
    "    colors = iter(plt.get_cmap(\"tab20\").colors)\n",
    "    # Generate commands\n",
    "    for steering in np.linspace(-0.5, 0.5, 3):\n",
    "        for acceleration in np.linspace(0.8, 0.4, 2):\n",
    "            actions = torch.Tensor([acceleration, steering]).view(1,1,-1)\n",
    "            # Predict trajectories\n",
    "            states = predict_trajectory(state, actions, model, action_repeat=horizon)\n",
    "            plot_trajectory(states, color=next(colors))\n",
    "    plt.axis(\"equal\")\n",
    "    plt.show()\n",
    "    \n",
    "visualize_trajectories(dynamics, state=torch.Tensor([0, 0, 0, 0, 1, 0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DOa0j1_muNXi"
   },
   "source": [
    "## Reward model\n",
    "We assume that the reward $R(s,a)$ is known (chosen by the system designer), and takes the form of a **weighted L1-norm** between the state and the goal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cRubRv9buNXj"
   },
   "outputs": [],
   "source": [
    "def reward_model(states, goal, gamma=None):\n",
    "    \"\"\"\n",
    "        The reward is a weighted L1-norm between the state and a goal\n",
    "    :param Tensor states: a batch of states. shape: [batch_size, state_size].\n",
    "    :param Tensor goal: a goal state. shape: [state_size].\n",
    "    :param float gamma: a discount factor\n",
    "    \"\"\"\n",
    "    goal = goal.expand(states.shape)\n",
    "    reward_weigths = torch.Tensor(env.unwrapped.REWARD_WEIGHTS)\n",
    "    rewards = -torch.pow(torch.norm((states-goal)*reward_weigths, p=1, dim=-1), 0.5)\n",
    "    if gamma:\n",
    "        time = torch.arange(rewards.shape[0], dtype=torch.float).unsqueeze(-1).expand(rewards.shape)\n",
    "        rewards *= torch.pow(gamma, time)\n",
    "    return rewards\n",
    "\n",
    "obs = env.reset()\n",
    "print(\"Reward of a sample transition:\", reward_model(torch.Tensor(obs[\"observation\"]).unsqueeze(0),\n",
    "                                                     torch.Tensor(obs[\"desired_goal\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Q5D6W4p7fvjI"
   },
   "source": [
    "## Leverage dynamics model for planning\n",
    "\n",
    "We now use the learnt dynamics model $f_\\theta$ for planning.\n",
    "In order to solve the optimal control problem, we use a sampling-based optimization algorithm: the **Cross-Entropy Method** (`CEM`). It is an optimization algorithm applicable to problems that are both **combinatorial** and **continuous**, which is our case: find the best performing sequence of actions.\n",
    "\n",
    "This method approximates the optimal importance sampling estimator by repeating two phases:\n",
    "1. **Draw samples** from a probability distribution. We use Gaussian distributions over sequences of actions.\n",
    "2. Minimize the **cross-entropy** between this distribution and a **target distribution** to produce a better sample in the next iteration. We define this target distribution by selecting the top-k performing sampled sequences.\n",
    "\n",
    "![Credits to Olivier Sigaud](https://github.com/yfletberliac/rlss2019-hands-on/blob/master/imgs/cem.png?raw=1)\n",
    "\n",
    "Note that as we have a local linear dynamics model, we could instead choose an `Iterative LQR` planner which would be more efficient. We prefer `CEM` in this educational setting for its simplicity and generality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bzPKYg23fvjL",
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def cem_planner(state, goal, action_size, horizon=5, population=100, selection=10, iterations=5):\n",
    "    state = state.expand(population, -1)\n",
    "    action_mean = torch.zeros(horizon, 1, action_size)\n",
    "    action_std = torch.ones(horizon, 1, action_size)\n",
    "    for _ in range(iterations):\n",
    "        # 1. Draw sample sequences of actions from a normal distribution\n",
    "        actions = action_mean + action_std * torch.randn(horizon, population, action_size)\n",
    "        actions = torch.clamp(actions, min=env.action_space.low.min(), max=env.action_space.high.max())\n",
    "        states = predict_trajectory(state, actions, dynamics, action_repeat=5)\n",
    "        # 2. Fit the distribution to the top-k performing sequences\n",
    "        returns = reward_model(states, goal).sum(dim=0)\n",
    "        _, best = returns.topk(selection, largest=True, sorted=False)\n",
    "        best_actions = actions[:, best, :]\n",
    "        action_mean, action_std = best_actions.mean(dim=1, keepdim=True), best_actions.std(dim=1, unbiased=False, keepdim=True)\n",
    "    return action_mean[0].squeeze(dim=0)\n",
    "  \n",
    "  \n",
    "# Run the planner on a sample transition\n",
    "action = cem_planner(torch.Tensor(obs[\"observation\"]),\n",
    "                     torch.Tensor(obs[\"desired_goal\"]),\n",
    "                     env.action_space.shape[0])\n",
    "print(\"Planned action:\", action)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_8L6vEPWyea7"
   },
   "source": [
    "## Visualize a few episodes\n",
    "\n",
    "En voiture, Simone !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xOcOP7Of18T2"
   },
   "outputs": [],
   "source": [
    "env = gym.make(\"parking-v0\")\n",
    "env = Monitor(env, './video', force=True, video_callable=lambda episode: True)\n",
    "for episode in tnrange(3, desc=\"Test episodes\"):\n",
    "    obs, done = env.reset(), False\n",
    "    while not done:\n",
    "        action = cem_planner(torch.Tensor(obs[\"observation\"]),\n",
    "                             torch.Tensor(obs[\"desired_goal\"]),\n",
    "                             env.action_space.shape[0])\n",
    "        obs, reward, done, info = env.step(action.numpy())\n",
    "env.close()\n",
    "show_video()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "psBBQIv4fvjT",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Limits\n",
    "\n",
    "### Model bias\n",
    "\n",
    "In model-based reinforcement learning, we replace our original optimal control problem by another problem: optimizing our learnt approximate MDP. When settling for this approximate MDP to plan with, we introduce a **bias** that can only **decrease the true performance** of the corresponding planned policy. This is called the problem of model bias.\n",
    "\n",
    "In some MDPs, even slight model errors lead to a dramatic drop in performance, as illustrated in the beginning of the following video:\n",
    "\n",
    "[![Approximate Robust Control of Uncertain Dynamical Systems](https://img.youtube.com/vi/8khqd3BJo0A/0.jpg)](https://www.youtube.com/watch?v=8khqd3BJo0A)\n",
    "\n",
    "The question of how to address model bias belongs to the field of **Safe Reinforcement Learning**. \n",
    "\n",
    "### [L'appel du vide](https://www.urbandictionary.com/define.php?term=L%27appel%20du%20vide)\n",
    "\n",
    "The model will be accurate only on some region of the state space that was explored and covered in $D$.\n",
    "Outside of $D$, the model may diverge and **hallucinate** important rewards.\n",
    "This effect is problematic when the model is used by a planning algorithm, as the latter will try to **exploit** these hallucinated high rewards and will steer the agent towards **unknown** (and thus dangerous) **regions** where the model is erroneously optimistic.\n",
    "\n",
    "### Computational cost of planning\n",
    "\n",
    "At test time, the planning step typically requires **sampling a lot of trajectories** to find a near-optimal candidate, wich may turn out to be very costly. This may be prohibitive in a high-frequency real-time setting. The **model-free** methods which directly recommend the best action are **much more efficient** in that regard."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "parking_model_based.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
