{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "parking_her.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5eeje4O8fviH",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# Parking with Hindsight Experience Replay\n",
        "\n",
        "##  Warming up\n",
        "We start with a few useful installs and imports:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "bzMSuJEOfviP",
        "pycharm": {
          "is_executing": false,
          "name": "#%%\n"
        },
        "colab": {}
      },
      "source": [
        "# Install environment and agent\n",
        "!pip install git+https://github.com/eleurent/highway-env#egg=highway-env  > /dev/null 2>&1\n",
        "!pip install stable-baselines==2.6.0 > /dev/null 2>&1\n",
        "\n",
        "# Environment\n",
        "import gym\n",
        "import highway_env\n",
        "\n",
        "# Agent\n",
        "from stable_baselines import HER, SAC"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "_wACJRDjqP-f",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%% \n"
        },
        "id": "Y5TOvonYqP-g",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "env = gym.make(\"parking-ActionRepeat-v0\")\n",
        "model = HER('MlpPolicy', env, SAC, n_sampled_goal=4,\n",
        "            goal_selection_strategy='future',\n",
        "            verbose=1, buffer_size=int(1e6),\n",
        "            learning_rate=1e-3,\n",
        "            gamma=0.9, batch_size=256,\n",
        "            policy_kwargs=dict(layers=[256, 256, 256]))\n",
        "model.learn(int(2e4))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "n2Bu_Pqop0E7"
      },
      "source": [
        "## Visualize a few episodes\n",
        "\n",
        "We first define a simple helper function for visualization of episodes:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "so7yH4ucyB-3",
        "colab": {}
      },
      "source": [
        "!pip install gym pyvirtualdisplay > /dev/null 2>&1\n",
        "!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n",
        "\n",
        "from IPython import display as ipythondisplay\n",
        "from pyvirtualdisplay import Display\n",
        "from gym.wrappers import Monitor\n",
        "from pathlib import Path\n",
        "import base64\n",
        "from tqdm import tnrange\n",
        "\n",
        "display = Display(visible=0, size=(1400, 900))\n",
        "display.start()\n",
        "\n",
        "def show_video():\n",
        "    html = []\n",
        "    for mp4 in Path(\"video\").glob(\"*.mp4\"):\n",
        "        video_b64 = base64.b64encode(mp4.read_bytes())\n",
        "        html.append('''<video alt=\"{}\" autoplay \n",
        "                      loop controls style=\"height: 400px;\">\n",
        "                      <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
        "                 </video>'''.format(mp4, video_b64.decode('ascii')))\n",
        "    ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_8L6vEPWyea7"
      },
      "source": [
        "\n",
        "Test the policy"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "xOcOP7Of18T2",
        "pycharm": {
          "name": "#%%\n"
        },
        "colab": {}
      },
      "source": [
        "env = gym.make(\"parking-ActionRepeat-v0\")\n",
        "env = Monitor(env, './video', force=True, video_callable=lambda episode: True)\n",
        "for episode in tnrange(3, desc=\"Test episodes\"):\n",
        "    obs, done = env.reset(), False\n",
        "    env.unwrapped.automatic_rendering_callback = env.video_recorder.capture_frame\n",
        "    while not done:\n",
        "        action, _ = model.predict(obs)\n",
        "        obs, reward, done, info = env.step(action)\n",
        "env.close()\n",
        "show_video()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}