{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using the PyTorch JIT Compiler with Pyro\n",
    "\n",
    "This tutorial shows how to use the PyTorch [jit compiler](https://pytorch.org/docs/master/jit.html) in Pyro models.\n",
    "\n",
    "#### Summary:\n",
    "- You can use compiled functions in Pyro models.\n",
    "- You cannot use pyro primitives inside compiled functions.\n",
    "- If your model has static structure, you can use a `Jit*` version of an `ELBO` algorithm, e.g.\n",
    "  ```diff\n",
    "  - Trace_ELBO()\n",
    "  + JitTrace_ELBO()\n",
    "  ```\n",
    "- The [HMC](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.HMC) and [NUTS](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.NUTS) classes accept `jit_compile=True` kwarg.\n",
    "- Models should input all tensors as `*args` and all non-tensors as `**kwargs`.\n",
    "- Each different value of `**kwargs` triggers a separate compilation.\n",
    "- Use `**kwargs` to specify all variation in structure (e.g. time series length).\n",
    "- To ignore jit warnings in safe code blocks, use `with pyro.util.ignore_jit_warnings():`.\n",
    "- To ignore all jit warnings in `HMC` or `NUTS`, pass `ignore_jit_warnings=True`.\n",
    "\n",
    "#### Table of contents\n",
    "- [Introduction](#Introduction)\n",
    "- [A simple model](#A-simple-model)\n",
    "- [Varying structure](#Varying-structure)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "from torch.distributions import constraints\n",
    "from pyro import poutine\n",
    "from pyro.distributions.util import broadcast_shape\n",
    "from pyro.infer import Trace_ELBO, JitTrace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO, SVI\n",
    "from pyro.infer.mcmc import MCMC, NUTS\n",
    "from pyro.infer.autoguide import AutoDiagonalNormal\n",
    "from pyro.optim import Adam\n",
    "\n",
    "smoke_test = ('CI' in os.environ)\n",
    "assert pyro.__version__.startswith('0.5.0')\n",
    "pyro.enable_validation(True)    # <---- This is always a good idea!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Introduction\n",
    "\n",
    "PyTorch 1.0 includes a [jit compiler](https://pytorch.org/docs/master/jit.html) to speed up models. You can think of compilation as a \"static mode\", whereas PyTorch usually operates in \"eager mode\".\n",
    "\n",
    "Pyro supports the jit compiler in two ways. First you can use compiled functions inside Pyro models (but those functions cannot contain Pyro primitives). Second, you can use Pyro's jit inference algorithms to compile entire inference steps; in static models this can reduce the Python overhead of Pyro models and speed up inference.\n",
    "\n",
    "The rest of this tutorial focuses on Pyro's jitted inference algorithms: [JitTrace_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.trace_elbo.JitTrace_ELBO), [JitTraceGraph_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.tracegraph_elbo.JitTraceGraph_ELBO), [JitTraceEnum_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.JitTraceEnum_ELBO), [JitMeanField_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.trace_mean_field_elbo.JitTraceMeanField_ELBO), [HMC(jit_compile=True)](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.HMC), and [NUTS(jit_compile=True)](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.NUTS). For further reading, see the [examples/](https://github.com/uber/pyro/tree/dev/examples) directory, where most examples include a `--jit` option to run in compiled mode.\n",
    "\n",
    "## A simple model\n",
    "\n",
    "Let's start with a simple Gaussian model and an [autoguide](http://docs.pyro.ai/en/dev/infer.autoguide.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    loc = pyro.sample(\"loc\", dist.Normal(0., 10.))\n",
    "    scale = pyro.sample(\"scale\", dist.LogNormal(0., 3.))\n",
    "    with pyro.plate(\"data\", data.size(0)):\n",
    "        pyro.sample(\"obs\", dist.Normal(loc, scale), obs=data)\n",
    "\n",
    "guide = AutoDiagonalNormal(model)\n",
    "\n",
    "data = dist.Normal(0.5, 2.).sample((100,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First let's run as usual with an SVI object and `Trace_ELBO`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.71 s, sys: 31.4 ms, total: 2.74 s\n",
      "Wall time: 2.76 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "pyro.clear_param_store()\n",
    "elbo = Trace_ELBO()\n",
    "svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)\n",
    "for i in range(2 if smoke_test else 1000):\n",
    "    svi.step(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next to run with a jit compiled inference, we simply replace\n",
    "```diff\n",
    "- elbo = Trace_ELBO()\n",
    "+ elbo = JitTrace_ELBO()\n",
    "```\n",
    "Also note that the `AutoDiagonalNormal` guide behaves a little differently on its first invocation (it runs the model to produce a prototype trace), and we don't want to record this warmup behavior when compiling. Thus we call the `guide(data)` once to initialize, then run the compiled SVI,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.1 s, sys: 30.4 ms, total: 1.13 s\n",
      "Wall time: 1.16 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "pyro.clear_param_store()\n",
    "\n",
    "guide(data)  # Do any lazy initialization before compiling.\n",
    "\n",
    "elbo = JitTrace_ELBO()\n",
    "svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)\n",
    "for i in range(2 if smoke_test else 1000):\n",
    "    svi.step(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that we have a more than 2x speedup for this small model.\n",
    "\n",
    "Let us now use the same model, but we will instead use MCMC to generate samples from the model's posterior. We will use the No-U-Turn(NUTS) sampler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3dc474967f304f56a22df195ce1ed06f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Warmup', style=ProgressStyle(description_width='initial')), H…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "CPU times: user 4.61 s, sys: 101 ms, total: 4.71 s\n",
      "Wall time: 4.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "nuts_kernel = NUTS(model)\n",
    "pyro.set_rng_seed(1)\n",
    "mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can compile the potential energy computation in NUTS using the `jit_compile=True` argument to the NUTS kernel. We also silence JIT warnings due to the presence of tensor constants in the model by using `ignore_jit_warnings=True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "134dac15856941dfaf427a9e6089f7e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Warmup', style=ProgressStyle(description_width='initial')), H…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "CPU times: user 2.04 s, sys: 74.1 ms, total: 2.11 s\n",
      "Wall time: 2.09 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)\n",
    "pyro.set_rng_seed(1)\n",
    "mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We notice a significant increase in sampling throughput when JIT compilation is enabled."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Varying structure\n",
    "\n",
    "Time series models often run on datasets of multiple time series with different lengths. To accomodate varying structure like this, Pyro requires models to separate all model inputs into tensors and non-tensors.$^\\dagger$\n",
    "\n",
    "- Non-tensor inputs should be passed as `**kwargs` to the model and guide. These can determine model structure, so that a model is compiled for each value of the passed `**kwargs`.\n",
    "- Tensor inputs should be passed as `*args`. These must not determine model structure. However `len(args)` may determine model structure (as is used e.g. in semisupervised models).\n",
    "\n",
    "To illustrate this with a time series model, we will pass in a sequence of observations as a tensor `arg` and the sequence length as a non-tensor `kwarg`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(sequence, num_sequences, length, state_dim=16):\n",
    "    # This is a Gaussian HMM model.\n",
    "    with pyro.plate(\"states\", state_dim):\n",
    "        trans = pyro.sample(\"trans\", dist.Dirichlet(0.5 * torch.ones(state_dim)))\n",
    "        emit_loc = pyro.sample(\"emit_loc\", dist.Normal(0., 10.))\n",
    "    emit_scale = pyro.sample(\"emit_scale\", dist.LogNormal(0., 3.))\n",
    "\n",
    "    # We're doing manual data subsampling, so we need to scale to actual data size.\n",
    "    with poutine.scale(scale=num_sequences):\n",
    "        # We'll use enumeration inference over the hidden x.\n",
    "        x = 0\n",
    "        for t in pyro.markov(range(length)):\n",
    "            x = pyro.sample(\"x_{}\".format(t), dist.Categorical(trans[x]),\n",
    "                            infer={\"enumerate\": \"parallel\"})\n",
    "            pyro.sample(\"y_{}\".format(t), dist.Normal(emit_loc[x], emit_scale),\n",
    "                        obs=sequence[t])\n",
    "\n",
    "guide = AutoDiagonalNormal(poutine.block(model, expose=[\"trans\", \"emit_scale\", \"emit_loc\"]))\n",
    "\n",
    "# This is fake data of different lengths.\n",
    "lengths = [24] * 50 + [48] * 20 + [72] * 5\n",
    "sequences = [torch.randn(length) for length in lengths]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets' run SVI as usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 52.4 s, sys: 270 ms, total: 52.7 s\n",
      "Wall time: 52.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "pyro.clear_param_store()\n",
    "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n",
    "svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)\n",
    "for i in range(1 if smoke_test else 10):\n",
    "    for sequence in sequences:\n",
    "        svi.step(sequence,                                            # tensor args\n",
    "                 num_sequences=len(sequences), length=len(sequence))  # non-tensor args"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again we'll simply swap in a `Jit*` implementation\n",
    "```diff\n",
    "- elbo = TraceEnum_ELBO(max_plate_nesting=1)\n",
    "+ elbo = JitTraceEnum_ELBO(max_plate_nesting=1)\n",
    "```\n",
    "Note that we are manually specifying the `max_plate_nesting` arg. Usually Pyro can figure this out automatically by running the model once on the first invocation; however to avoid this extra work when we run the compiler on the first step, we pass this in manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 21.9 s, sys: 201 ms, total: 22.1 s\n",
      "Wall time: 22.2 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "pyro.clear_param_store()\n",
    "\n",
    "# Do any lazy initialization before compiling.\n",
    "guide(sequences[0], num_sequences=len(sequences), length=len(sequences[0]))\n",
    "\n",
    "elbo = JitTraceEnum_ELBO(max_plate_nesting=1)\n",
    "svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)\n",
    "for i in range(1 if smoke_test else 10):\n",
    "    for sequence in sequences:\n",
    "        svi.step(sequence,                                            # tensor args\n",
    "                 num_sequences=len(sequences), length=len(sequence))  # non-tensor args"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again we see more than 2x speedup. Note that since there were three different sequence lengths, compilation was triggered three times.\n",
    "\n",
    "$^\\dagger$ Note this section is only valid for SVI, and HMC/NUTS assume fixed model arguments."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
