{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference with Discrete Latent Variables\n",
    "\n",
    "This tutorial describes Pyro's enumeration strategy for discrete latent variable models.\n",
    "This tutorial assumes the reader is already familiar with the [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html).\n",
    "\n",
    "#### Summary \n",
    "\n",
    "- Pyro implements automatic enumeration over discrete latent variables.\n",
    "- This strategy can be used alone or inside SVI (via [TraceEnum_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.TraceEnum_ELBO)), HMC, or NUTS.\n",
    "- The standalone [infer_discrete](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.discrete.infer_discrete) can generate samples or MAP estimates.\n",
    "- Annotate a sample site `infer={\"enumerate\": \"parallel\"}` to trigger enumeration.\n",
    "- If a sample site determines downstream structure, instead use `{\"enumerate\": \"sequential\"}`.\n",
    "- Write your models to allow arbitrarily deep batching on the left, e.g. use broadcasting.\n",
    "- Inference cost is exponential in treewidth, so try to write models with narrow treewidth.\n",
    "- If you have trouble, ask for help on [forum.pyro.ai](https://forum.pyro.ai)!\n",
    "\n",
    "#### Table of contents\n",
    "\n",
    "- [Overview](#Overview)\n",
    "- [Mechanics of enumeration](#Mechanics-of-enumeration)\n",
    "  - [Multiple latent variables](#Multiple-latent-variables)\n",
    "  - [Examining discrete latent states](#Examining-discrete-latent-states)\n",
    "  - [Indexing with enumerated variables](#Indexing-with-enumerated-variables)\n",
    "- [Plates and enumeration](#Plates-and-enumeration)\n",
    "  - [Dependencies among plates](#Dependencies-among-plates)\n",
    "- [Time series example](#Time-series-example)\n",
    "  - [How to enumerate more than 25 variables](#How-to-enumerate-more-than-25-variables)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "from torch.distributions import constraints\n",
    "from pyro import poutine\n",
    "from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete\n",
    "from pyro.infer.autoguide import AutoDiagonalNormal\n",
    "from pyro.ops.indexing import Vindex\n",
    "\n",
    "smoke_test = ('CI' in os.environ)\n",
    "assert pyro.__version__.startswith('0.5.0')\n",
    "pyro.enable_validation()\n",
    "pyro.set_rng_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview <a class=\"anchor\" id=\"Overview\"></a>\n",
    "\n",
    "Pyro's enumeration strategy encompasses popular algorithms including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms. Aside from enumeration, Pyro implements a number of inference strategies including variational inference ([SVI](http://docs.pyro.ai/en/dev/inference_algos.html)) and monte carlo ([HMC](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.HMC) and [NUTS](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.NUTS)). Enumeration can be used either as a stand-alone strategy via [infer_discrete](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.discrete.infer_discrete), or as a component of other strategies. Thus enumeration allows Pyro to marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mechanics of enumeration  <a class=\"anchor\" id=\"Mechanics-of-enumeration\"></a>\n",
    "\n",
    "The core idea of enumeration is to interpret discrete [pyro.sample](http://docs.pyro.ai/en/dev/primitives.html#pyro.sample) statements as full enumeration rather than random sampling. Other inference algorithms can then sum out the enumerated values. For example a sample statement might return a tensor of scalar shape under the standard \"sample\" interpretation (we'll illustrate with trivial model and guide):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "guide z = 4\n",
      "model z = 4\n"
     ]
    }
   ],
   "source": [
    "def model():\n",
    "    z = pyro.sample(\"z\", dist.Categorical(torch.ones(5)))\n",
    "    print('model z = {}'.format(z))\n",
    "\n",
    "def guide():\n",
    "    z = pyro.sample(\"z\", dist.Categorical(torch.ones(5)))\n",
    "    print('guide z = {}'.format(z))\n",
    "\n",
    "elbo = Trace_ELBO()\n",
    "elbo.loss(model, guide);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However under the enumeration interpretation, the same sample site will return a fully enumerated set of values, based on its distribution's [.enumerate_support()](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.enumerate_support) method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "guide z = tensor([0, 1, 2, 3, 4])\n",
      "model z = tensor([0, 1, 2, 3, 4])\n"
     ]
    }
   ],
   "source": [
    "elbo = TraceEnum_ELBO(max_plate_nesting=0)\n",
    "elbo.loss(model, config_enumerate(guide, \"parallel\"));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that we've used \"parallel\" enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of `z`. To support dynamic program structure, you can instead use \"sequential\" enumeration, which runs the entire model,guide pair once per sample value, but requires running the model multiple times."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "guide z = 4\n",
      "model z = 4\n",
      "guide z = 3\n",
      "model z = 3\n",
      "guide z = 2\n",
      "model z = 2\n",
      "guide z = 1\n",
      "model z = 1\n",
      "guide z = 0\n",
      "model z = 0\n"
     ]
    }
   ],
   "source": [
    "elbo = TraceEnum_ELBO(max_plate_nesting=0)\n",
    "elbo.loss(model, config_enumerate(guide, \"sequential\"));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Parallel enumeration is cheaper but more complex than sequential enumeration, so we'll focus the rest of this tutorial on the parallel variant. Note that both forms can be interleaved.\n",
    "\n",
    "### Multiple latent variables <a class=\"anchor\" id=\"Multiple-latent-variables\"></a>\n",
    "\n",
    "We just saw that a single discrete sample site can be enumerated via nonstandard interpretation. A model with a single discrete latent variable is a mixture model. Models with multiple discrete latent variables can be more complex, including HMMs, CRFs, DBNs, and other structured models. In models with multiple discrete latent variables, Pyro enumerates each variable in a different tensor dimension (counting from the right; see [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html)). This allows Pyro to determine the dependency graph among variables and then perform cheap exact inference using variable elimination algorithms.\n",
    "\n",
    "To understand enumeration dimension allocation, consider the following model, where here we collapse variables out of the model, rather than enumerate them in the guide."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model x.shape = torch.Size([3])\n",
      "model y.shape = torch.Size([3, 1])\n",
      "model z.shape = torch.Size([3, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "@config_enumerate\n",
    "def model():\n",
    "    p = pyro.param(\"p\", torch.randn(3, 3).exp(), constraint=constraints.simplex)\n",
    "    x = pyro.sample(\"x\", dist.Categorical(p[0]))\n",
    "    y = pyro.sample(\"y\", dist.Categorical(p[x]))\n",
    "    z = pyro.sample(\"z\", dist.Categorical(p[y]))\n",
    "    print('model x.shape = {}'.format(x.shape))\n",
    "    print('model y.shape = {}'.format(y.shape))\n",
    "    print('model z.shape = {}'.format(z.shape))\n",
    "    return x, y, z\n",
    "    \n",
    "def guide():\n",
    "    pass\n",
    "\n",
    "pyro.clear_param_store()\n",
    "elbo = TraceEnum_ELBO(max_plate_nesting=0)\n",
    "elbo.loss(model, guide);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Examining discrete latent states <a class=\"anchor\" id=\"Examining-discrete-latent-states\"></a>\n",
    "\n",
    "While enumeration in SVI allows fast learning of parameters like `p` above, it does not give access to predicted values of the discrete latent variables like `x,y,z` above. We can access these using a standalone [infer_discrete](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.discrete.infer_discrete) handler. In this case the guide was trivial, so we can simply wrap the model in `infer_discrete`. We need to pass a `first_available_dim` argument to tell `infer_discrete` which dimensions are available for enumeration; this is related to the `max_plate_nesting` arg of `TraceEnum_ELBO` via\n",
    "```\n",
    "first_available_dim = -1 - max_plate_nesting\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model x.shape = torch.Size([3])\n",
      "model y.shape = torch.Size([3, 1])\n",
      "model z.shape = torch.Size([3, 1, 1])\n",
      "model x.shape = torch.Size([])\n",
      "model y.shape = torch.Size([])\n",
      "model z.shape = torch.Size([])\n",
      "x = 0\n",
      "y = 2\n",
      "z = 1\n"
     ]
    }
   ],
   "source": [
    "serving_model = infer_discrete(model, first_available_dim=-1)\n",
    "x, y, z = serving_model()  # takes the same args as model(), here no args\n",
    "print(\"x = {}\".format(x))\n",
    "print(\"y = {}\".format(y))\n",
    "print(\"z = {}\".format(z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that under the hood `infer_discrete` runs the model twice: first in forward-filter mode where sites are enumerated, then in replay-backward-sample model where sites are sampled. `infer_discrete` can also perform MAP inference by passing `temperature=0`. Note that while `infer_discrete` produces correct posterior samples, it does not currently produce correct logprobs, and should not be used in other gradient-based inference algorthms."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Indexing with enumerated variables\n",
    "\n",
    "It can be tricky to use [advanced indexing](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html) to select an element of a tensor using one or more enumerated variables. For example, suppose a plated random variable `z` depends on two different random variables:\n",
    "```py\n",
    "p = pyro.param(\"p\", torch.randn(5, 4, 3, 2).exp(),\n",
    "               constraint=constraints.simplex)\n",
    "x = pyro.sample(\"x\", dist.Categorical(torch.ones(4)))\n",
    "y = pyro.sample(\"y\", dist.Categorical(torch.ones(3)))\n",
    "with pyro.plate(\"z_plate\", 5):\n",
    "    p_xy = p[..., x, y, :]  # Not compatible with enumeration!\n",
    "    z = pyro.sample(\"z\", dist.Categorical(p_xy)\n",
    "```\n",
    "Due to advanced indexing semantics, the expression `p[..., x, y, :]` will work correctly without enumeration, but is incorrect when `x` or `y` is enumerated. Pyro provides a simple way to index correctly, but first let's see how to correctly index using PyTorch's advanced indexing without Pyro:\n",
    "```py\n",
    "# Compatible with enumeration, but not recommended:\n",
    "p_xy = p[torch.arange(5, device=p.device).reshape(5, 1),\n",
    "         x.unsqueeze(-1),\n",
    "         y.unsqueeze(-1),\n",
    "         torch.arange(2, device=p.device)]\n",
    "```\n",
    "Pyro provides a helper [Vindex()[]](http://docs.pyro.ai/en/dev/ops.html#pyro.ops.indexing.Vindex) to use enumeration-compatible advanced indexing semantics rather than PyTorch/NumPy semantics. `Vindex()[]` makes the `.__getitem__()` operator broadcast like other familiar operators `+`, `*` etc. Using `Vindex()[]` we can write the same expression as if `x` and `y` were numbers (i.e. not enumerated):\n",
    "```py\n",
    "# Recommended syntax compatible with enumeration:\n",
    "p_xy = Vindex(p)[..., x, y, :]\n",
    "```\n",
    "Here is a complete example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   p.shape = torch.Size([5, 4, 3, 2])\n",
      "   x.shape = torch.Size([4, 1])\n",
      "   y.shape = torch.Size([3, 1, 1])\n",
      "p_xy.shape = torch.Size([3, 4, 5, 2])\n",
      "   z.shape = torch.Size([2, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "@config_enumerate\n",
    "def model():\n",
    "    p = pyro.param(\"p\", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)\n",
    "    x = pyro.sample(\"x\", dist.Categorical(torch.ones(4)))\n",
    "    y = pyro.sample(\"y\", dist.Categorical(torch.ones(3)))\n",
    "    with pyro.plate(\"z_plate\", 5):\n",
    "        p_xy = Vindex(p)[..., x, y, :]\n",
    "        z = pyro.sample(\"z\", dist.Categorical(p_xy))\n",
    "    print('   p.shape = {}'.format(p.shape))\n",
    "    print('   x.shape = {}'.format(x.shape))\n",
    "    print('   y.shape = {}'.format(y.shape))\n",
    "    print('p_xy.shape = {}'.format(p_xy.shape))\n",
    "    print('   z.shape = {}'.format(z.shape))\n",
    "    return x, y, z\n",
    "    \n",
    "def guide():\n",
    "    pass\n",
    "\n",
    "pyro.clear_param_store()\n",
    "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n",
    "elbo.loss(model, guide);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plates and enumeration <a class=\"anchor\" id=\"Plates-and-enumeration\"></a>\n",
    "\n",
    "Pyro [plates](http://docs.pyro.ai/en/dev/primitives.html#pyro.plate) express conditional independence among random variables. Pyro's enumeration strategy can take advantage of plates to reduce the high cost (exponential in the size of the plate) of enumerating a cartesian product down to a low cost (linear in the size of the plate) of enumerating conditionally independent random variables in lock-step. This is especially important for e.g. minibatched data.\n",
    "\n",
    "To illustrate, consider a gaussian mixture model with shared variance and different mean."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running model with 10 data points\n",
      "x.shape = torch.Size([10])\n",
      "dist.Normal(loc[x], scale).batch_shape = torch.Size([10])\n",
      "Running model with 10 data points\n",
      "x.shape = torch.Size([3, 1])\n",
      "dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])\n"
     ]
    }
   ],
   "source": [
    "@config_enumerate\n",
    "def model(data, num_components=3):\n",
    "    print('Running model with {} data points'.format(len(data)))\n",
    "    p = pyro.sample(\"p\", dist.Dirichlet(0.5 * torch.ones(3)))\n",
    "    scale = pyro.sample(\"scale\", dist.LogNormal(0, num_components))\n",
    "    with pyro.plate(\"components\", num_components):\n",
    "        loc = pyro.sample(\"loc\", dist.Normal(0, 10))\n",
    "    with pyro.plate(\"data\", len(data)):\n",
    "        x = pyro.sample(\"x\", dist.Categorical(p))\n",
    "        print(\"x.shape = {}\".format(x.shape))\n",
    "        pyro.sample(\"obs\", dist.Normal(loc[x], scale), obs=data)\n",
    "        print(\"dist.Normal(loc[x], scale).batch_shape = {}\".format(\n",
    "            dist.Normal(loc[x], scale).batch_shape))\n",
    "        \n",
    "guide = AutoDiagonalNormal(poutine.block(model, hide=[\"x\", \"data\"]))\n",
    "\n",
    "data = torch.randn(10)\n",
    "        \n",
    "pyro.clear_param_store()\n",
    "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n",
    "elbo.loss(model, guide, data);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observe that the model is run twice, first by the `AutoDiagonalNormal` to trace sample sites, and second by `elbo` to compute loss. In the first run, `x` has the standard interpretation of one sample per datum, hence shape `(10,)`. In the second run enumeration can use the same three values `(3,1)` for all data points, and relies on broadcasting for any dependent sample or observe sites that depend on data. For example, in the `pyro.sample(\"obs\",...)` statement, the distribution has shape `(3,1)`, the data has shape`(10,)`, and the broadcasted log probability tensor has shape `(3,10)`.\n",
    "\n",
    "For a more in-depth treatment of enumeration in mixture models, see the [Gaussian Mixture Model Tutorial](http://pyro.ai/examples/gmm.html) and the [HMM Example](http://pyro.ai/examples/hmm.html).\n",
    "\n",
    "### Dependencies among plates <a class=\"anchor\" id=\"Dependencies-among-plates\"></a>\n",
    "\n",
    "The computational savings of enumerating in vectorized plates comes with restrictions on the dependency structure of models. These restrictions are in addition to the usual restrictions of conditional independence. The enumeration restrictions are checked by `TraceEnum_ELBO` and will result in an error if violated (however the usual conditional independence restriction cannot be generally verified by Pyro). For completeness we list all three restrictions:\n",
    "\n",
    "#### Restriction 1: conditional independence\n",
    "Variables within a plate may not depend on each other (along the plate dimension). This applies to any variable, whether or not it is enumerated. This applies to both sequential plates and vectorized plates. For example the following model is invalid:\n",
    "```py\n",
    "def invalid_model():\n",
    "    x = 0\n",
    "    for i in pyro.plate(\"invalid\", 10):\n",
    "        x = pyro.sample(\"x_{}\".format(i), dist.Normal(x, 1.))\n",
    "```\n",
    "\n",
    "#### Restriction 2: no downstream coupling\n",
    "No variable outside of a vectorized plate can depend on an enumerated variable inside of that plate. This would violate Pyro's exponential speedup assumption. For example the following model is invalid:\n",
    "```py\n",
    "@config_enumerate\n",
    "def invalid_model(data):\n",
    "    with pyro.plate(\"plate\", 10):  # <--- invalid vectorized plate\n",
    "        x = pyro.sample(\"x\", dist.Bernoulli(0.5))\n",
    "    assert x.shape == (10,)\n",
    "    pyro.sample(\"obs\", dist.Normal(x.sum(), 1.), data)\n",
    "```\n",
    " To work around this restriction, you can convert the vectorized plate to a sequential plate:\n",
    "```py\n",
    "@config_enumerate\n",
    "def valid_model(data):\n",
    "    x = []\n",
    "    for i in pyro.plate(\"plate\", 10):  # <--- valid sequential plate\n",
    "        x.append(pyro.sample(\"x_{}\".format(i), dist.Bernoulli(0.5)))\n",
    "    assert len(x) == 10\n",
    "    pyro.sample(\"obs\", dist.Normal(sum(x), 1.), data)\n",
    "```\n",
    "\n",
    "#### Restriction 3: single path leaving each plate\n",
    "The final restriction is subtle, but is required to enable Pyro's exponential speedup\n",
    "\n",
    "> For any enumerated variable `x`, the set of all enumerated variables on which `x` depends must be linearly orderable in their vectorized plate nesting.\n",
    "\n",
    "This requirement only applies when there are at least two plates and at least three variables in different plate contexts. The simplest counterexample is a Boltzmann machine\n",
    "```py\n",
    "@config_enumerate\n",
    "def invalid_model(data):\n",
    "    plate_1 = pyro.plate(\"plate_1\", 10, dim=-1)  # vectorized\n",
    "    plate_2 = pyro.plate(\"plate_2\", 10, dim=-2)  # vectorized\n",
    "    with plate_1:\n",
    "        x = pyro.sample(\"y\", dist.Bernoulli(0.5))\n",
    "    with plate_2:\n",
    "        y = pyro.sample(\"x\", dist.Bernoulli(0.5))\n",
    "    with plate_1, plate2:\n",
    "        z = pyro.sample(\"z\", dist.Bernoulli((1. + x + y) / 4.))\n",
    "        ...\n",
    "```\n",
    "Here we see that the variable `z` depends on variable `x` (which is in `plate_1` but not `plate_2`) and depends on variable `y` (which is in `plate_2` but not `plate_1`). This model is invalid because there is no way to linearly order `x` and `y` such that one's plate nesting is less than the other.\n",
    "\n",
    "To work around this restriction, you can convert one of the plates to a sequential plate:\n",
    "```py\n",
    "@config_enumerate\n",
    "def valid_model(data):\n",
    "    plate_1 = pyro.plate(\"plate_1\", 10, dim=-1)  # vectorized\n",
    "    plate_2 = pyro.plate(\"plate_2\", 10)          # sequential\n",
    "    with plate_1:\n",
    "        x = pyro.sample(\"y\", dist.Bernoulli(0.5))\n",
    "    for i in plate_2:\n",
    "        y = pyro.sample(\"x_{}\".format(i), dist.Bernoulli(0.5))\n",
    "        with plate_1:\n",
    "            z = pyro.sample(\"z_{}\".format(i), dist.Bernoulli((1. + x + y) / 4.))\n",
    "            ...\n",
    "```\n",
    "but beware that this increases the computational complexity, which may be exponential in the size of the sequential plate."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Time series example  <a class=\"anchor\" id=\"Time-series-example\"></a>\n",
    "\n",
    "Consider a discrete HMM with latent states $x_t$ and observations $y_t$. Suppose we want to learn the transition and emission probabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dim = 4\n",
    "num_steps = 10\n",
    "data = dist.Categorical(torch.ones(num_steps, data_dim)).sample()\n",
    "\n",
    "def hmm_model(data, data_dim, hidden_dim=10):\n",
    "    print('Running for {} time steps'.format(len(data)))\n",
    "    # Sample global matrices wrt a Jeffreys prior.\n",
    "    with pyro.plate(\"hidden_state\", hidden_dim):\n",
    "        transition = pyro.sample(\"transition\", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))\n",
    "        emission = pyro.sample(\"emission\", dist.Dirichlet(0.5 * torch.ones(data_dim)))\n",
    "\n",
    "    x = 0  # initial state\n",
    "    for t, y in enumerate(data):\n",
    "        x = pyro.sample(\"x_{}\".format(t), dist.Categorical(transition[x]),\n",
    "                        infer={\"enumerate\": \"parallel\"})\n",
    "        pyro.sample(\"y_{}\".format(t), dist.Categorical(emission[x]), obs=y)\n",
    "        print(\"x_{}.shape = {}\".format(t, x.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can learn the global parameters using SVI with an autoguide."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for 10 time steps\n",
      "x_0.shape = torch.Size([])\n",
      "x_1.shape = torch.Size([])\n",
      "x_2.shape = torch.Size([])\n",
      "x_3.shape = torch.Size([])\n",
      "x_4.shape = torch.Size([])\n",
      "x_5.shape = torch.Size([])\n",
      "x_6.shape = torch.Size([])\n",
      "x_7.shape = torch.Size([])\n",
      "x_8.shape = torch.Size([])\n",
      "x_9.shape = torch.Size([])\n",
      "Running for 10 time steps\n",
      "x_0.shape = torch.Size([10, 1])\n",
      "x_1.shape = torch.Size([10, 1, 1])\n",
      "x_2.shape = torch.Size([10, 1, 1, 1])\n",
      "x_3.shape = torch.Size([10, 1, 1, 1, 1])\n",
      "x_4.shape = torch.Size([10, 1, 1, 1, 1, 1])\n",
      "x_5.shape = torch.Size([10, 1, 1, 1, 1, 1, 1])\n",
      "x_6.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1])\n",
      "x_7.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1])\n",
      "x_8.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n",
      "x_9.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "hmm_guide = AutoDiagonalNormal(poutine.block(hmm_model, expose=[\"transition\", \"emission\"]))\n",
    "\n",
    "pyro.clear_param_store()\n",
    "elbo = TraceEnum_ELBO(max_plate_nesting=1)\n",
    "elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the model was run twice here: first it was run without enumeration by `AutoDiagonalNormal`, so that the autoguide can record all sample sites; then second it is run by `TraceEnum_ELBO` with enumeration enabled. We see in the first run that samples have the standard interpretation, whereas in the second run samples have the enumeration interpretation.\n",
    "\n",
    "For more complex examples, including minibatching and multiple plates, see the [HMM tutorial](https://github.com/uber/pyro/blob/dev/examples/hmm.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How to enumerate more than 25 variables <a class=\"anchor\" id=\"How-to-enumerate-more-than-25-variables\"></a>\n",
    "\n",
    "PyTorch tensors have a dimension limit of 25 in CUDA and 64 in CPU. By default Pyro enumerates each sample site in a new dimension. If you need more sample sites, you can annotate your model with  [pyro.markov](http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.markov) to tell Pyro when it is safe to recycle tensor dimensions. Let's see how that works with the HMM model from above. The only change we need is to annotate the for loop with `pyro.markov`, informing Pyro that the variables in each step of the loop depend only on variables outside of the loop and variables at this step and the previous step of the loop:\n",
    "```diff\n",
    "- for t, y in enumerate(data):\n",
    "+ for t, y in pyro.markov(enumerate(data)):\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_0.shape = torch.Size([10, 1])\n",
      "x_1.shape = torch.Size([10, 1, 1])\n",
      "x_2.shape = torch.Size([10, 1])\n",
      "x_3.shape = torch.Size([10, 1, 1])\n",
      "x_4.shape = torch.Size([10, 1])\n",
      "x_5.shape = torch.Size([10, 1, 1])\n",
      "x_6.shape = torch.Size([10, 1])\n",
      "x_7.shape = torch.Size([10, 1, 1])\n",
      "x_8.shape = torch.Size([10, 1])\n",
      "x_9.shape = torch.Size([10, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "def hmm_model(data, data_dim, hidden_dim=10):\n",
    "    with pyro.plate(\"hidden_state\", hidden_dim):\n",
    "        transition = pyro.sample(\"transition\", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))\n",
    "        emission = pyro.sample(\"emission\", dist.Dirichlet(0.5 * torch.ones(data_dim)))\n",
    "\n",
    "    x = 0  # initial state\n",
    "    for t, y in pyro.markov(enumerate(data)):\n",
    "        x = pyro.sample(\"x_{}\".format(t), dist.Categorical(transition[x]),\n",
    "                        infer={\"enumerate\": \"parallel\"})\n",
    "        pyro.sample(\"y_{}\".format(t), dist.Categorical(emission[x]), obs=y)\n",
    "        print(\"x_{}.shape = {}\".format(t, x.shape))\n",
    "\n",
    "# We'll reuse the same guide and elbo.\n",
    "elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that this model now only needs three tensor dimensions: one for the plate, one for even states, and one for odd states. For more complex examples, see the Dynamic Bayes Net model in the [HMM example](https://github.com/uber/pyro/blob/dev/examples/hmm.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
