{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tensor shapes in Pyro\n",
    "\n",
    "This tutorial introduces Pyro's organization of tensor dimensions. Before starting, you should familiarize yourself with [PyTorch broadcasting semantics](http://pytorch.org/docs/master/notes/broadcasting.html).\n",
    "\n",
    "#### Summary:\n",
    "- While you are learning or debugging, set `pyro.enable_validation(True)`.\n",
    "- Tensors broadcast by aligning on the right: `torch.ones(3,4,5) + torch.ones(5)`.\n",
    "- Distribution `.sample().shape == batch_shape + event_shape`.\n",
    "- Distribution `.log_prob(x).shape == batch_shape` (but not `event_shape`!).\n",
    "- Use `.expand()` to draw a batch of samples, or rely on `plate` to expand automatically.\n",
    "- Use `my_dist.to_event(1)` to declare a dimension as dependent.\n",
    "- Use `with pyro.plate('name', size):` to declare a dimension as conditionally independent.\n",
    "- All dimensions must be declared either dependent or conditionally independent.\n",
    "- Try to support batching on the left. This lets Pyro auto-parallelize.\n",
    "  - use negative indices like `x.sum(-1)` rather than `x.sum(2)`\n",
    "  - use ellipsis notation like `pixel = image[..., i, j]`\n",
    "  - use [Vindex](http://docs.pyro.ai/en/dev/ops.html#pyro.ops.indexing.Vindex) if `i,j` are enumerated, `pixel = Vindex(image)[..., i, j]`\n",
    "- When debugging, examine all shapes in a trace using [Trace.format_shapes()](http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.Trace.format_shapes).\n",
    "  \n",
    "#### Table of Contents\n",
    "- [Distribution shapes](#Distributions-shapes:-batch_shape-and-event_shape)\n",
    "  - [Examples](#Examples)\n",
    "  - [Reshaping distributions](#Reshaping-distributions)\n",
    "  - [It is always safe to assume dependence](#It-is-always-safe-to-assume-dependence)\n",
    "- [Declaring independence with plate](#Declaring-independent-dims-with-plate)\n",
    "- [Subsampling inside plate](#Subsampling-tensors-inside-a-plate)\n",
    "- [Broadcasting to allow Parallel Enumeration](#Broadcasting-to-allow-parallel-enumeration)\n",
    "  - [Writing parallelizable code](#Writing-parallelizable-code)\n",
    "  - [Automatic broadcasting inside pyro.plate](#Automatic-broadcasting-inside-pyro-plate)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pyro\n",
    "from torch.distributions import constraints\n",
    "from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal\n",
    "from pyro.distributions.util import broadcast_shape\n",
    "from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate\n",
    "import pyro.poutine as poutine\n",
    "from pyro.optim import Adam\n",
    "\n",
    "smoke_test = ('CI' in os.environ)\n",
    "assert pyro.__version__.startswith('0.5.0')\n",
    "pyro.enable_validation(True)    # <---- This is always a good idea!\n",
    "\n",
    "# We'll ue this helper to check our models are correct.\n",
    "def test_model(model, guide, loss):\n",
    "    pyro.clear_param_store()\n",
    "    loss.loss(model, guide)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distributions shapes: `batch_shape` and `event_shape` <a class=\"anchor\" id=\"Distributions-shapes:-batch_shape-and-event_shape\"></a>\n",
    "\n",
    "PyTorch `Tensor`s have a single `.shape` attribute, but `Distribution`s have two shape attributions with special meaning: `.batch_shape` and `.event_shape`. These two combine to define the total shape of a sample\n",
    "```py\n",
    "x = d.sample()\n",
    "assert x.shape == d.batch_shape + d.event_shape\n",
    "```\n",
    "Indices over `.batch_shape` denote conditionally independent random variables, whereas indices over `.event_shape` denote dependent random variables (ie one draw from a distribution). Because the dependent random variables define probability together, the `.log_prob()` method only produces a single number for each event of shape `.event_shape`. Thus the total shape of `.log_prob()` is `.batch_shape`:\n",
    "```py\n",
    "assert d.log_prob(x).shape == d.batch_shape\n",
    "```\n",
    "Note that the `Distribution.sample()` method also takes a `sample_shape` parameter that indexes over independent identically distributed (iid) random varables, so that\n",
    "```py\n",
    "x2 = d.sample(sample_shape)\n",
    "assert x2.shape == sample_shape + batch_shape + event_shape\n",
    "```\n",
    "In summary\n",
    "```\n",
    "      |      iid     | independent | dependent\n",
    "------+--------------+-------------+------------\n",
    "shape = sample_shape + batch_shape + event_shape\n",
    "```\n",
    "For example univariate distributions have empty event shape (because each number is an independent event). Distributions over vectors like `MultivariateNormal` have `len(event_shape) == 1`. Distributions over matrices like `InverseWishart` have `len(event_shape) == 2`.\n",
    "\n",
    "### Examples <a class=\"anchor\" id=\"Examples\"></a>\n",
    "\n",
    "The simplest distribution shape is a single univariate distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = Bernoulli(0.5)\n",
    "assert d.batch_shape == ()\n",
    "assert d.event_shape == ()\n",
    "x = d.sample()\n",
    "assert x.shape == ()\n",
    "assert d.log_prob(x).shape == ()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Distributions can be batched by passing in batched parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = Bernoulli(0.5 * torch.ones(3,4))\n",
    "assert d.batch_shape == (3, 4)\n",
    "assert d.event_shape == ()\n",
    "x = d.sample()\n",
    "assert x.shape == (3, 4)\n",
    "assert d.log_prob(x).shape == (3, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another way to batch distributions is via the `.expand()` method. This only works if \n",
    "parameters are identical along the leftmost dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])\n",
    "assert d.batch_shape == (3, 4)\n",
    "assert d.event_shape == ()\n",
    "x = d.sample()\n",
    "assert x.shape == (3, 4)\n",
    "assert d.log_prob(x).shape == (3, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Multivariate distributions have nonempty `.event_shape`. For these distributions, the shapes of `.sample()` and `.log_prob(x)` differ:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))\n",
    "assert d.batch_shape == ()\n",
    "assert d.event_shape == (3,)\n",
    "x = d.sample()\n",
    "assert x.shape == (3,)            # == batch_shape + event_shape\n",
    "assert d.log_prob(x).shape == ()  # == batch_shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reshaping distributions <a class=\"anchor\" id=\"Reshaping-distributions\"></a>\n",
    "\n",
    "In Pyro you can treat a univariate distribution as multivariate by calling the [.to_event(n)](http://docs.pyro.ai/en/dev/distributions.html#pyro.distributions.torch_distribution.TorchDistributionMixin.to_event) property where `n` is the number of batch dimensions (from the right) to declare as *dependent*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)\n",
    "assert d.batch_shape == (3,)\n",
    "assert d.event_shape == (4,)\n",
    "x = d.sample()\n",
    "assert x.shape == (3, 4)\n",
    "assert d.log_prob(x).shape == (3,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While you work with Pyro programs, keep in mind that samples have shape `batch_shape + event_shape`, whereas `.log_prob(x)` values have shape `batch_shape`. You'll need to ensure that `batch_shape` is carefully controlled by either trimming it down with `.to_event(n)` or by declaring dimensions as independent via `pyro.plate`.\n",
    "\n",
    "### It is always safe to assume dependence <a class=\"anchor\" id=\"It-is-always-safe-to-assume-dependence\"></a>\n",
    "\n",
    "Often in Pyro we'll declare some dimensions as dependent even though they are in fact independent, e.g.\n",
    "```py\n",
    "x = pyro.sample(\"x\", dist.Normal(0, 1).expand([10]).to_event(1))\n",
    "assert x.shape == (10,)\n",
    "```\n",
    "This is useful for two reasons: First it allows us to easily swap in a `MultivariateNormal` distribution later. Second it simplifies the code a bit since we don't need a `plate` (see below) as in\n",
    "```py\n",
    "with pyro.plate(\"x_plate\", 10):\n",
    "    x = pyro.sample(\"x\", dist.Normal(0, 1))  # .expand([10]) is automatic\n",
    "    assert x.shape == (10,)\n",
    "```\n",
    "The difference between these two versions is that the second version with `plate` informs Pyro that it can make use of conditional independence information when estimating gradients, whereas in the first version Pyro must assume they are dependent (even though the normals are in fact conditionally independent). This is analogous to d-separation in graphical models: it is always safe to add edges and assume variables *may* be dependent (i.e. to widen the model class), but it is unsafe to assume independence when variables are actually dependent (i.e. narrowing the model class so the true model lies outside of the class, as in mean field). In practice Pyro's SVI inference algorithm uses reparameterized gradient estimators for `Normal` distributions so both gradient estimators have the same performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Declaring independent dims with `plate` <a class=\"anchor\" id=\"Declaring-independent-dims-with-plate\"></a>\n",
    "\n",
    "Pyro models can use the context manager [pyro.plate](http://docs.pyro.ai/en/dev/primitives.html#pyro.plate) to declare that certain batch dimensions are independent. Inference algorithms can then take advantage of this independence to e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.\n",
    "\n",
    "The simplest way to declare a dimension as independent is to declare the rightmost batch dimension as independent via a simple\n",
    "```py\n",
    "with pyro.plate(\"my_plate\"):\n",
    "    # within this context, batch dimension -1 is independent\n",
    "```\n",
    "We recommend always providing an optional size argument to aid in debugging shapes\n",
    "```py\n",
    "with pyro.plate(\"my_plate\", len(my_data)):\n",
    "    # within this context, batch dimension -1 is independent\n",
    "```\n",
    "Starting with Pyro 0.2 you can additionally nest `plates`, e.g. if you have per-pixel independence:\n",
    "```py\n",
    "with pyro.plate(\"x_axis\", 320):\n",
    "    # within this context, batch dimension -1 is independent\n",
    "    with pyro.plate(\"y_axis\", 200):\n",
    "        # within this context, batch dimensions -2 and -1 are independent\n",
    "```\n",
    "Note that we always count from the right by using negative indices like -2, -1.\n",
    "\n",
    "Finally if you want to mix and match `plate`s for e.g. noise that depends only on `x`, some noise that depends only on `y`, and some noise that depends on both, you can declare multiple `plates` and use them as reusable context managers. In this case Pyro cannot automatically allocate a dimension, so you need to provide a `dim` argument (again counting from the right):\n",
    "```py\n",
    "x_axis = pyro.plate(\"x_axis\", 3, dim=-2)\n",
    "y_axis = pyro.plate(\"y_axis\", 2, dim=-3)\n",
    "with x_axis:\n",
    "    # within this context, batch dimension -2 is independent\n",
    "with y_axis:\n",
    "    # within this context, batch dimension -3 is independent\n",
    "with x_axis, y_axis:\n",
    "    # within this context, batch dimensions -3 and -2 are independent\n",
    "```\n",
    "Let's take a closer look at batch sizes within `plate`s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model1():\n",
    "    a = pyro.sample(\"a\", Normal(0, 1))\n",
    "    b = pyro.sample(\"b\", Normal(torch.zeros(2), 1).to_event(1))\n",
    "    with pyro.plate(\"c_plate\", 2):\n",
    "        c = pyro.sample(\"c\", Normal(torch.zeros(2), 1))\n",
    "    with pyro.plate(\"d_plate\", 3):\n",
    "        d = pyro.sample(\"d\", Normal(torch.zeros(3,4,5), 1).to_event(2))\n",
    "    assert a.shape == ()       # batch_shape == ()     event_shape == ()\n",
    "    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)\n",
    "    assert c.shape == (2,)     # batch_shape == (2,)   event_sahpe == ()\n",
    "    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5) \n",
    "\n",
    "    x_axis = pyro.plate(\"x_axis\", 3, dim=-2)\n",
    "    y_axis = pyro.plate(\"y_axis\", 2, dim=-3)\n",
    "    with x_axis:\n",
    "        x = pyro.sample(\"x\", Normal(0, 1))\n",
    "    with y_axis:\n",
    "        y = pyro.sample(\"y\", Normal(0, 1))\n",
    "    with x_axis, y_axis:\n",
    "        xy = pyro.sample(\"xy\", Normal(0, 1))\n",
    "        z = pyro.sample(\"z\", Normal(0, 1).expand([5]).to_event(1))\n",
    "    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()\n",
    "    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()\n",
    "    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()\n",
    "    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)\n",
    "    \n",
    "test_model(model1, model1, Trace_ELBO())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is helpful to visualize the `.shape`s of each sample site by aligning them at the boundary between `batch_shape` and `event_shape`: dimensions to the right will be summed out in `.log_prob()` and dimensions to the left will remain. \n",
    "```\n",
    "batch dims | event dims\n",
    "-----------+-----------\n",
    "           |        a = sample(\"a\", Normal(0, 1))\n",
    "           |2       b = sample(\"b\", Normal(zeros(2), 1)\n",
    "           |                        .to_event(1))\n",
    "           |        with plate(\"c\", 2):\n",
    "          2|            c = sample(\"c\", Normal(zeros(2), 1))\n",
    "           |        with plate(\"d\", 3):\n",
    "          3|4 5         d = sample(\"d\", Normal(zeros(3,4,5), 1)\n",
    "           |                       .to_event(2))\n",
    "           |\n",
    "           |        x_axis = plate(\"x\", 3, dim=-2)\n",
    "           |        y_axis = plate(\"y\", 2, dim=-3)\n",
    "           |        with x_axis:\n",
    "        3 1|            x = sample(\"x\", Normal(0, 1))\n",
    "           |        with y_axis:\n",
    "      2 1 1|            y = sample(\"y\", Normal(0, 1))\n",
    "           |        with x_axis, y_axis:\n",
    "      2 3 1|            xy = sample(\"xy\", Normal(0, 1))\n",
    "      2 3 1|5           z = sample(\"z\", Normal(0, 1).expand([5])\n",
    "           |                       .to_event(1))\n",
    "```\n",
    "To examine the shapes of sample sites in a program automatically, you can trace the program and use the [Trace.format_shapes()](http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.Trace.format_shapes) method, which prints three shapes for each sample site: the distribution shape (both `site[\"fn\"].batch_shape` and `site[\"fn\"].event_shape`), the value shape (`site[\"value\"].shape`), and if log probability has been computed also the `log_prob` shape (`site[\"log_prob\"].shape`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trace Shapes:            \n",
      " Param Sites:            \n",
      "Sample Sites:            \n",
      "       a dist       |    \n",
      "        value       |    \n",
      "     log_prob       |    \n",
      "       b dist       | 2  \n",
      "        value       | 2  \n",
      "     log_prob       |    \n",
      " c_plate dist       |    \n",
      "        value     2 |    \n",
      "     log_prob       |    \n",
      "       c dist     2 |    \n",
      "        value     2 |    \n",
      "     log_prob     2 |    \n",
      " d_plate dist       |    \n",
      "        value     3 |    \n",
      "     log_prob       |    \n",
      "       d dist     3 | 4 5\n",
      "        value     3 | 4 5\n",
      "     log_prob     3 |    \n",
      "  x_axis dist       |    \n",
      "        value     3 |    \n",
      "     log_prob       |    \n",
      "  y_axis dist       |    \n",
      "        value     2 |    \n",
      "     log_prob       |    \n",
      "       x dist   3 1 |    \n",
      "        value   3 1 |    \n",
      "     log_prob   3 1 |    \n",
      "       y dist 2 1 1 |    \n",
      "        value 2 1 1 |    \n",
      "     log_prob 2 1 1 |    \n",
      "      xy dist 2 3 1 |    \n",
      "        value 2 3 1 |    \n",
      "     log_prob 2 3 1 |    \n",
      "       z dist 2 3 1 | 5  \n",
      "        value 2 3 1 | 5  \n",
      "     log_prob 2 3 1 |    \n"
     ]
    }
   ],
   "source": [
    "trace = poutine.trace(model1).get_trace()\n",
    "trace.compute_log_prob()  # optional, but allows printing of log_prob shapes\n",
    "print(trace.format_shapes())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Subsampling tensors inside a `plate` <a class=\"anchor\" id=\"Subsampling-tensors-inside-a-plate\"></a>\n",
    "\n",
    "One of the main uses of [plate](http://docs.pyro.ai/en/dev/primitives.html#pyro.plate) is to subsample data. This is possible within a `plate` because data are conditionally independent, so the expected value of the loss on, say, half the data should be half the expected loss on the full data.\n",
    "\n",
    "To subsample data, you need to inform Pyro of both the original data size and the subsample size; Pyro will then choose a random subset of data and yield the set of indices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.arange(100.)\n",
    "\n",
    "def model2():\n",
    "    mean = pyro.param(\"mean\", torch.zeros(len(data)))\n",
    "    with pyro.plate(\"data\", len(data), subsample_size=10) as ind:\n",
    "        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.\n",
    "        batch = data[ind]        # Select a minibatch of data.\n",
    "        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.\n",
    "        # Do stuff with batch:\n",
    "        x = pyro.sample(\"x\", Normal(mean_batch, 1), obs=batch)\n",
    "        assert len(x) == 10\n",
    "        \n",
    "test_model(model2, guide=lambda: None, loss=Trace_ELBO())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Broadcasting to allow parallel enumeration <a class=\"anchor\" id=\"Broadcasting-to-allow-parallel-enumeration\"></a>\n",
    "\n",
    "Pyro 0.2 introduces the ability to enumerate discrete latent variables in parallel. This can significantly reduce the variance of gradient estimators when learning a posterior via [SVI](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.svi.SVI).\n",
    "\n",
    "To use parallel enumeration, Pyro needs to allocate tensor dimension that it can use for enumeration. To avoid conflicting with other dimensions that we want to use for `plate`s, we need to declare a budget of the maximum number of tensor dimensions we'll use. This budget is called `max_plate_nesting` and is an argument to [SVI](http://docs.pyro.ai/en/dev/inference_algos.html) (the argument is simply passed through to [TraceEnum_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.TraceEnum_ELBO)). Usually Pyro can determine this budget on its own (it runs the `(model,guide)` pair once and record what happens), but in case of dynamic model structure you may need to declare `max_plate_nesting` manually.\n",
    "\n",
    "To understand `max_plate_nesting` and how Pyro allocates dimensions for enumeration, let's revisit `model1()` from above. This time we'll map out three types of dimensions:\n",
    "enumeration dimensions on the left (Pyro takes control of these), batch dimensions in the middle, and event dimensions on the right."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "      max_plate_nesting = 3\n",
    "           |<--->|\n",
    "enumeration|batch|event\n",
    "-----------+-----+-----\n",
    "           |. . .|      a = sample(\"a\", Normal(0, 1))\n",
    "           |. . .|2     b = sample(\"b\", Normal(zeros(2), 1)\n",
    "           |     |                      .to_event(1))\n",
    "           |     |      with plate(\"c\", 2):\n",
    "           |. . 2|          c = sample(\"c\", Normal(zeros(2), 1))\n",
    "           |     |      with plate(\"d\", 3):\n",
    "           |. . 3|4 5       d = sample(\"d\", Normal(zeros(3,4,5), 1)\n",
    "           |     |                     .to_event(2))\n",
    "           |     |\n",
    "           |     |      x_axis = plate(\"x\", 3, dim=-2)\n",
    "           |     |      y_axis = plate(\"y\", 2, dim=-3)\n",
    "           |     |      with x_axis:\n",
    "           |. 3 1|          x = sample(\"x\", Normal(0, 1))\n",
    "           |     |      with y_axis:\n",
    "           |2 1 1|          y = sample(\"y\", Normal(0, 1))\n",
    "           |     |      with x_axis, y_axis:\n",
    "           |2 3 1|          xy = sample(\"xy\", Normal(0, 1))\n",
    "           |2 3 1|5         z = sample(\"z\", Normal(0, 1).expand([5]))\n",
    "           |     |                     .to_event(1))\n",
    "```\n",
    "Note that it is safe to overprovision `max_plate_nesting=4` but we cannot underprovision `max_plate_nesting=2` (or Pyro will error). Let's see how this works in practice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "@config_enumerate\n",
    "def model3():\n",
    "    p = pyro.param(\"p\", torch.arange(6.) / 6)\n",
    "    locs = pyro.param(\"locs\", torch.tensor([-1., 1.]))\n",
    "\n",
    "    a = pyro.sample(\"a\", Categorical(torch.ones(6) / 6))\n",
    "    b = pyro.sample(\"b\", Bernoulli(p[a]))  # Note this depends on a.\n",
    "    with pyro.plate(\"c_plate\", 4):\n",
    "        c = pyro.sample(\"c\", Bernoulli(0.3))\n",
    "        with pyro.plate(\"d_plate\", 5):\n",
    "            d = pyro.sample(\"d\", Bernoulli(0.4))\n",
    "            e_loc = locs[d.long()].unsqueeze(-1)\n",
    "            e_scale = torch.arange(1., 8.)\n",
    "            e = pyro.sample(\"e\", Normal(e_loc, e_scale)\n",
    "                            .to_event(1))  # Note this depends on d.\n",
    "\n",
    "    #                   enumerated|batch|event dims\n",
    "    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.\n",
    "    assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.\n",
    "    assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.\n",
    "    assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.\n",
    "    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.\n",
    "\n",
    "    assert e_loc.shape   == (2, 1, 1, 1, 1, 1, 1,)\n",
    "    assert e_scale.shape == (                  7,)\n",
    "            \n",
    "test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a closer look at those dimensions. First note that Pyro allocates enumeration dims starting from the right at `max_plate_nesting`: Pyro allocates dim -3 to enumerate `a`, then dim -4 to enumerate `b`, then dim -5 to enumerate `c`, and finally dim -6 to enumerate `d`. Next note that samples only have extent (size > 1) in the new enumeration dimension. This helps keep tensors small and computation cheap. (Note that the `log_prob` shape will be broadcast up to contain both enumeratin shape and batch shape, so e.g. `trace.nodes['d']['log_prob'].shape == (2, 1, 1, 1, 5, 4)`.)\n",
    "\n",
    "We can draw a similar map of the tensor dimensions:\n",
    "```\n",
    "     max_plate_nesting = 2\n",
    "            |<->|\n",
    "enumeration batch event\n",
    "------------|---|-----\n",
    "           6|1 1|     a = pyro.sample(\"a\", Categorical(torch.ones(6) / 6))\n",
    "         2 1|1 1|     b = pyro.sample(\"b\", Bernoulli(p[a]))\n",
    "            |   |     with pyro.plate(\"c_plate\", 4):\n",
    "       2 1 1|1 1|         c = pyro.sample(\"c\", Bernoulli(0.3))\n",
    "            |   |         with pyro.plate(\"d_plate\", 5):\n",
    "     2 1 1 1|1 1|             d = pyro.sample(\"d\", Bernoulli(0.4))\n",
    "     2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)\n",
    "            |   |7            e_scale = torch.arange(1., 8.)\n",
    "     2 1 1 1|5 4|7            e = pyro.sample(\"e\", Normal(e_loc, e_scale)\n",
    "            |   |                             .to_event(1))\n",
    "```\n",
    "To automatically examine this model with enumeration semantics, we can create an enumerated trace and then use [Trace.format_shapes()](http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.Trace.format_shapes):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trace Shapes:                \n",
      " Param Sites:                \n",
      "            p             6  \n",
      "         locs             2  \n",
      "Sample Sites:                \n",
      "       a dist             |  \n",
      "        value       6 1 1 |  \n",
      "     log_prob       6 1 1 |  \n",
      "       b dist       6 1 1 |  \n",
      "        value     2 1 1 1 |  \n",
      "     log_prob     2 6 1 1 |  \n",
      " c_plate dist             |  \n",
      "        value           4 |  \n",
      "     log_prob             |  \n",
      "       c dist           4 |  \n",
      "        value   2 1 1 1 1 |  \n",
      "     log_prob   2 1 1 1 4 |  \n",
      " d_plate dist             |  \n",
      "        value           5 |  \n",
      "     log_prob             |  \n",
      "       d dist         5 4 |  \n",
      "        value 2 1 1 1 1 1 |  \n",
      "     log_prob 2 1 1 1 5 4 |  \n",
      "       e dist 2 1 1 1 5 4 | 7\n",
      "        value 2 1 1 1 5 4 | 7\n",
      "     log_prob 2 1 1 1 5 4 |  \n"
     ]
    }
   ],
   "source": [
    "trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()\n",
    "trace.compute_log_prob()  # optional, but allows printing of log_prob shapes\n",
    "print(trace.format_shapes())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Writing parallelizable code <a class=\"anchor\" id=\"Writing-parallelizable-code\"></a>\n",
    "\n",
    "It can be tricky to write Pyro models that correctly handle parallelized sample sites. Two tricks help: [broadcasting](http://pytorch.org/docs/master/notes/broadcasting.html) and [ellipsis slicing](http://python-reference.readthedocs.io/en/dev/docs/brackets/ellipsis.html). Let's look at a contrived model to see how these work in practice. Our aim is to write a model that works both with and without enumeration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "width = 8\n",
    "height = 10\n",
    "sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])\n",
    "enumerated = None  # set to either True or False below\n",
    "\n",
    "def fun(observe):\n",
    "    p_x = pyro.param(\"p_x\", torch.tensor(0.1), constraint=constraints.unit_interval)\n",
    "    p_y = pyro.param(\"p_y\", torch.tensor(0.1), constraint=constraints.unit_interval)\n",
    "    x_axis = pyro.plate('x_axis', width, dim=-2)\n",
    "    y_axis = pyro.plate('y_axis', height, dim=-1)\n",
    "\n",
    "    # Note that the shapes of these sites depend on whether Pyro is enumerating.\n",
    "    with x_axis:\n",
    "        x_active = pyro.sample(\"x_active\", Bernoulli(p_x))\n",
    "    with y_axis:\n",
    "        y_active = pyro.sample(\"y_active\", Bernoulli(p_y))\n",
    "    if enumerated:\n",
    "        assert x_active.shape  == (2, 1, 1)\n",
    "        assert y_active.shape  == (2, 1, 1, 1)\n",
    "    else:\n",
    "        assert x_active.shape  == (width, 1)\n",
    "        assert y_active.shape  == (height,)\n",
    "\n",
    "    # The first trick is to broadcast. This works with or without enumeration.\n",
    "    p = 0.1 + 0.5 * x_active * y_active\n",
    "    if enumerated:\n",
    "        assert p.shape == (2, 2, 1, 1)\n",
    "    else:\n",
    "        assert p.shape == (width, height)\n",
    "    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))\n",
    "\n",
    "    # The second trick is to index using ellipsis slicing.\n",
    "    # This allows Pyro to add arbitrary dimensions on the left.\n",
    "    for x, y in sparse_pixels:\n",
    "        dense_pixels[..., x, y] = 1\n",
    "    if enumerated:\n",
    "        assert dense_pixels.shape == (2, 2, width, height)\n",
    "    else:\n",
    "        assert dense_pixels.shape == (width, height)\n",
    "\n",
    "    with x_axis, y_axis:    \n",
    "        if observe:\n",
    "            pyro.sample(\"pixels\", Bernoulli(p), obs=dense_pixels)\n",
    "\n",
    "def model4():\n",
    "    fun(observe=True)\n",
    "\n",
    "def guide4():\n",
    "    fun(observe=False)\n",
    "\n",
    "# Test without enumeration.\n",
    "enumerated = False\n",
    "test_model(model4, guide4, Trace_ELBO())\n",
    "\n",
    "# Test with enumeration.\n",
    "enumerated = True\n",
    "test_model(model4, config_enumerate(guide4, \"parallel\"),\n",
    "           TraceEnum_ELBO(max_plate_nesting=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Automatic broadcasting inside pyro.plate<a class=\"anchor\" id=\"Automatic-broadcasting-inside-pyro-plate\"></a>\n",
    "\n",
    "Note that in all our model/guide specifications, we have relied on [pyro.plate](http://docs.pyro.ai/en/dev/primitives.html#pyro.plate) to automatically expand sample shapes to satisfy the constraints on batch shape enforced by `pyro.sample` statements. However this broadcasting is equivalent to hand-annotated `.expand()` statements.\n",
    "\n",
    "We will demonstrate this using `model4` from the [previous section](#Writing-parallelizable-code). Note the following changes to the code from earlier:\n",
    "\n",
    " - For the purpose of this example, we will only consider \"parallel\" enumeration, but broadcasting should work as expected without enumeration or with \"sequential\" enumeration.\n",
    " - We have separated out the sampling function which returns the tensors corresponding to the active pixels. Modularizing the model code into components is a common practice, and helps with maintainability of large models.\n",
    " - We would also like to use the `pyro.plate` construct to parallelize the ELBO estimator over [num_particles](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.elbo.ELBO). This is done by wrapping the contents of model/guide inside an outermost `pyro.plate` context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_particles = 100  # Number of samples for the ELBO estimator\n",
    "width = 8\n",
    "height = 10\n",
    "sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])\n",
    "\n",
    "def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):\n",
    "    with x_axis:\n",
    "        x_active = pyro.sample(\"x_active\", Bernoulli(p_x).expand([num_particles, width, 1]))\n",
    "    with y_axis:\n",
    "        y_active = pyro.sample(\"y_active\", Bernoulli(p_y).expand([num_particles, 1, height]))\n",
    "    return x_active, y_active\n",
    "\n",
    "def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):\n",
    "    with x_axis:\n",
    "        x_active = pyro.sample(\"x_active\", Bernoulli(p_x))\n",
    "    with y_axis:\n",
    "        y_active = pyro.sample(\"y_active\", Bernoulli(p_y))\n",
    "    return x_active, y_active \n",
    "\n",
    "def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):\n",
    "    with x_axis:\n",
    "        x_active = pyro.sample(\"x_active\", Bernoulli(p_x).expand([width, 1]))\n",
    "    with y_axis:\n",
    "        y_active = pyro.sample(\"y_active\", Bernoulli(p_y).expand([height]))\n",
    "    return x_active, y_active \n",
    "\n",
    "def fun(observe, sample_fn):\n",
    "    p_x = pyro.param(\"p_x\", torch.tensor(0.1), constraint=constraints.unit_interval)\n",
    "    p_y = pyro.param(\"p_y\", torch.tensor(0.1), constraint=constraints.unit_interval)\n",
    "    x_axis = pyro.plate('x_axis', width, dim=-2)\n",
    "    y_axis = pyro.plate('y_axis', height, dim=-1)\n",
    "\n",
    "    with pyro.plate(\"num_particles\", 100, dim=-3):\n",
    "        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)\n",
    "        # Indices corresponding to \"parallel\" enumeration are appended \n",
    "        # to the left of the \"num_particles\" plate dim.\n",
    "        assert x_active.shape  == (2, 1, 1, 1)\n",
    "        assert y_active.shape  == (2, 1, 1, 1, 1)\n",
    "        p = 0.1 + 0.5 * x_active * y_active\n",
    "        assert p.shape == (2, 2, 1, 1, 1)\n",
    "\n",
    "        dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))\n",
    "        for x, y in sparse_pixels:\n",
    "            dense_pixels[..., x, y] = 1\n",
    "        assert dense_pixels.shape == (2, 2, 1, width, height)\n",
    "\n",
    "        with x_axis, y_axis:    \n",
    "            if observe:\n",
    "                pyro.sample(\"pixels\", Bernoulli(p), obs=dense_pixels)\n",
    "\n",
    "def test_model_with_sample_fn(sample_fn):\n",
    "    def model():\n",
    "        fun(observe=True, sample_fn=sample_fn)\n",
    "\n",
    "    @config_enumerate\n",
    "    def guide():\n",
    "        fun(observe=False, sample_fn=sample_fn)\n",
    "\n",
    "    test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=3))\n",
    "\n",
    "test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)\n",
    "test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)\n",
    "test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first sampling function, we had to do some manual book-keeping and expand the `Bernoulli` distribution's batch shape to account for the conditionally independent dimensions added by the `pyro.plate` contexts. In particular, note how `sample_pixel_locations` needs knowledge of `num_particles`, `width` and `height` and is accessing these variables from the global scope, which is not ideal. \n",
    "\n",
    " - The second argument to `pyro.plate`, i.e. the optional `size` argument needs to be provided for implicit broadasting, so that it can infer the batch shape requirement for each of the sample sites. \n",
    " - The existing `batch_shape` of the sample site must be broadcastable with the size of the `pyro.plate` contexts. In our particular example, `Bernoulli(p_x)` has an empty batch shape which is universally broadcastable.\n",
    "\n",
    "Note how simple it is to achieve parallelization via tensorized operations using `pyro.plate`! `pyro.plate` also helps in code modularization because model components can be written agnostic of the `plate` contexts in which they may subsequently get embedded in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
