{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Poutine: A Guide to Programming with Effect Handlers in Pyro\n",
    "\n",
    "**Note to readers**: This tutorial is a guide to the API details of Pyro's effect handling library, [Poutine](http://docs.pyro.ai/en/dev/poutine.html). We recommend readers first orient themselves with the simplified [minipyro.py](https://github.com/uber/pyro/blob/dev/pyro/contrib/minipyro.py) which contains a minimal, readable implementation of Pyro's runtime and the effect handler abstraction described here. Pyro's effect handler library is more general than minipyro's but also contains more layers of indirection; it helps to read them side-by-side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "import pyro.poutine as poutine\n",
    "\n",
    "from pyro.poutine.runtime import effectful\n",
    "\n",
    "pyro.set_rng_seed(101)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Inference in probabilistic programming involves manipulating or transforming probabilistic programs written as generative models. For example, nearly all approximate inference algorithms require computing the unnormalized joint probability of values of latent and observed variables under a generative model.\n",
    "\n",
    "Consider the following example model from the [introductory inference tutorial](http://pyro.ai/examples/intro_part_ii.html):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scale(guess):\n",
    "    weight = pyro.sample(\"weight\", dist.Normal(guess, 1.0))\n",
    "    return pyro.sample(\"measurement\", dist.Normal(weight, 0.75))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model defines a joint probability distribution over `\"weight\"` and `\"measurement\"`:\n",
    "\n",
    "$${\\sf weight} \\, | \\, {\\sf guess} \\sim \\cal {\\sf Normal}({\\sf guess}, 1) $$\n",
    "$${\\sf measurement} \\, | \\, {\\sf guess}, {\\sf weight} \\sim {\\sf Normal}({\\sf weight}, 0.75)$$\n",
    "\n",
    "If we had access to the inputs and outputs of each `pyro.sample` site, we could compute their log-joint:\n",
    "```python\n",
    "logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()\n",
    "```\n",
    "However, the way we wrote `scale` above does not seem to expose these intermediate distribution objects, and rewriting it to return them would be intrusive and would violate the separation of concerns between models and inference algorithms that a probabilistic programming language like Pyro is designed to enforce.\n",
    "\n",
    "To resolve this conflict and facilitate inference algorithm development, Pyro exposes [Poutine](http://docs.pyro.ai/en/dev/poutine.html), a library of *effect handlers*, or composable building blocks for examining and modifying the behavior of Pyro programs. Most of Pyro's internals are implemented on top of Poutine."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A first look at Poutine: Pyro's library of algorithmic building blocks\n",
    "\n",
    "Effect handlers, a common abstraction in the programming languages community, give *nonstandard interpretations* or *side effects* to the behavior of particular statements in a programming language, like `pyro.sample` or `pyro.param`. For background reading on effect handlers in programming language research, see the optional \"References\" section at the end of this tutorial. \n",
    "\n",
    "Rather than reviewing more definitions, let's look at a first example that addresses the problem above: we can compose two existing effect handlers, `poutine.condition` (which sets output values of `pyro.sample` statements) and `poutine.trace` (which records the inputs, distributions, and outputs of `pyro.sample` statements), to concisely define a new effect handler that computes the log-joint:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-3.0203)\n"
     ]
    }
   ],
   "source": [
    "def make_log_joint(model):\n",
    "    def _log_joint(cond_data, *args, **kwargs):\n",
    "        conditioned_model = poutine.condition(model, data=cond_data)\n",
    "        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)\n",
    "        return trace.log_prob_sum()\n",
    "    return _log_joint\n",
    "\n",
    "scale_log_joint = make_log_joint(scale)\n",
    "print(scale_log_joint({\"measurement\": 9.5, \"weight\": 8.23}, 8.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That snippet is short, but still somewhat opaque - `poutine.condition`, `poutine.trace`, and `trace.log_prob_sum` are all black boxes.  Let's remove a layer of boilerplate from `poutine.condition` and `poutine.trace` and explicitly implement what `trace.log_prob_sum` is doing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-3.0203)\n"
     ]
    }
   ],
   "source": [
    "from pyro.poutine.trace_messenger import TraceMessenger\n",
    "from pyro.poutine.condition_messenger import ConditionMessenger\n",
    "\n",
    "def make_log_joint_2(model):\n",
    "    def _log_joint(cond_data, *args, **kwargs):\n",
    "        with TraceMessenger() as tracer:\n",
    "            with ConditionMessenger(data=cond_data):\n",
    "                model(*args, **kwargs)\n",
    "        \n",
    "        trace = tracer.trace\n",
    "        logp = 0.\n",
    "        for name, node in trace.nodes.items():\n",
    "            if node[\"type\"] == \"sample\":\n",
    "                if node[\"is_observed\"]:\n",
    "                    assert node[\"value\"] is cond_data[name]\n",
    "                logp = logp + node[\"fn\"].log_prob(node[\"value\"]).sum()\n",
    "        return logp\n",
    "    return _log_joint\n",
    "\n",
    "scale_log_joint = make_log_joint_2(scale)\n",
    "print(scale_log_joint({\"measurement\": 9.5, \"weight\": 8.23}, 8.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This makes things a little more clear: we can now see that `poutine.trace` and `poutine.condition` are wrappers for context managers that presumably communicate with the model through something inside `pyro.sample`. We can also see that `poutine.trace`  produces a data structure (a [Trace](http://docs.pyro.ai/en/dev/poutine.html#trace)) containing a dictionary whose keys are `sample` site names and values are dictionaries containing the distribution (`\"fn\"`) and output (`\"value\"`) at each site, and that the output values at each site are exactly the values specified in `data`.\n",
    "\n",
    "Finally, `TraceMessenger` and `ConditionMessenger` are Pyro effect handlers, or `Messenger`s: stateful context manager objects that are placed on a global stack and send messages (hence the name) up and down the stack at each effectful operation, like a `pyro.sample` call.  A `Messenger` is placed at the bottom of the stack when its `__enter__` method is called, i.e. when it is used in a \"with\" statement.\n",
    "\n",
    "We'll look at this process in more detail later in this tutorial.  For a simplified implementation in only a few lines of code, see [pyro.contrib.minipyro](https://github.com/uber/pyro/blob/dev/pyro/contrib/minipyro.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementing new effect handlers with the `Messenger` API\n",
    "\n",
    "Although it's easiest to build new effect handlers by composing the existing ones in `pyro.poutine`, implementing a new effect as a `pyro.poutine.messenger.Messenger` subclass is actually fairly straightforward. Before diving into the API, let's look at another example: a version of our log-joint computation that performs the sum while the model is executing. We'll then review what each part of the example is actually doing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-3.0203)\n",
      "tensor(-3.0203)\n"
     ]
    }
   ],
   "source": [
    "class LogJointMessenger(poutine.messenger.Messenger):\n",
    "    \n",
    "    def __init__(self, cond_data):\n",
    "        self.data = cond_data\n",
    "    \n",
    "    # __call__ is syntactic sugar for using Messengers as higher-order functions.\n",
    "    # Messenger already defines __call__, but we re-define it here\n",
    "    # for exposition and to change the return value:\n",
    "    def __call__(self, fn):\n",
    "        def _fn(*args, **kwargs):\n",
    "            with self:\n",
    "                fn(*args, **kwargs)\n",
    "                return self.logp.clone()\n",
    "        return _fn\n",
    "    \n",
    "    def __enter__(self):\n",
    "        self.logp = torch.tensor(0.)\n",
    "        # All Messenger subclasses must call the base Messenger.__enter__()\n",
    "        # in their __enter__ methods\n",
    "        return super(LogJointMessenger, self).__enter__()\n",
    "    \n",
    "    # __exit__ takes the same arguments in all Python context managers\n",
    "    def __exit__(self, exc_type, exc_value, traceback):\n",
    "        self.logp = torch.tensor(0.)\n",
    "        # All Messenger subclasses must call the base Messenger.__exit__ method\n",
    "        # in their __exit__ methods.\n",
    "        return super(LogJointMessenger, self).__exit__(exc_type, exc_value, traceback)\n",
    "    \n",
    "    # _pyro_sample will be called once per pyro.sample site.\n",
    "    # It takes a dictionary msg containing the name, distribution,\n",
    "    # observation or sample value, and other metadata from the sample site.\n",
    "    def _pyro_sample(self, msg):\n",
    "        assert msg[\"name\"] in self.data\n",
    "        msg[\"value\"] = self.data[msg[\"name\"]]\n",
    "        # Since we've observed a value for this site, we set the \"is_observed\" flag to True\n",
    "        # This tells any other Messengers not to overwrite msg[\"value\"] with a sample.\n",
    "        msg[\"is_observed\"] = True\n",
    "        self.logp = self.logp + (msg[\"scale\"] * msg[\"fn\"].log_prob(msg[\"value\"])).sum()\n",
    "\n",
    "with LogJointMessenger(cond_data={\"measurement\": 9.5, \"weight\": 8.23}) as m:\n",
    "    scale(8.5)\n",
    "    print(m.logp.clone())\n",
    "    \n",
    "scale_log_joint = LogJointMessenger(cond_data={\"measurement\": 9.5, \"weight\": 8.23})(scale)\n",
    "print(scale_log_joint(8.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A convenient bit of boilerplate that allows the use of `LogJointMessenger` as a context manager, decorator, or higher-order function is the following.  Most of the existing effect handlers in `pyro.poutine`, including `poutine.trace` and `poutine.condition` which we used earlier, are `Messenger`s wrapped this way in `pyro.poutine.handlers`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-3.0203)\n"
     ]
    }
   ],
   "source": [
    "def log_joint(model=None, cond_data=None):\n",
    "    msngr = LogJointMessenger(cond_data=cond_data)\n",
    "    return msngr(model) if model is not None else msngr\n",
    "\n",
    "scale_log_joint = log_joint(scale, cond_data={\"measurement\": 9.5, \"weight\": 8.23})\n",
    "print(scale_log_joint(8.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The `Messenger` API in more detail\n",
    "\n",
    "Our `LogJointMessenger` implementation has three important methods: `__enter__`, `__exit__`, and `_pyro_sample`. \n",
    "\n",
    "`__enter__` and `__exit__` are special methods needed by any Python context manager. When implementing new `Messenger` classes, if we override `__enter__` and `__exit__`, we always need to call the base `Messenger`'s `__enter__` and `__exit__` methods for the new `Messenger` to be applied correctly.\n",
    "\n",
    "The last method `LogJointMessenger._pyro_sample`, is called once at each sample site. It reads and modifies a *message*, which is a dictionary containing the sample site's name, distribution, sampled or observed value, and other metadata. We'll examine the contents of a message in more detail in the next section.\n",
    "\n",
    "Instead of `_pyro_sample`, a generic `Messenger` actually contains two methods that are called once per operation where side effects are performed:\n",
    "1. `_process_message` modifies a message and sends the result to the `Messenger` just above on the stack\n",
    "2. `_postprocess_message` modifies a message and sends the result to the next `Messenger` down on the stack. It is always called after all active `Messenger`s have had their `_process_message` method applied to the message.\n",
    "\n",
    "Although custom `Messenger`s can override `_process_message` and `_postprocess_message`, it's convenient to avoid requiring all effect handlers to be aware of all possible effectful operation types. For this reason, by default `Messenger._process_message` will use `msg[\"type\"]` to dispatch to a corresponding method `Messenger._pyro_<type>`, e.g. `Messenger._pyro_sample` as in `LogJointMessenger`.  Just as exception handling code ignores unhandled exception types, this allows `Messenger`s to simply forward operations they don't know how to handle up to the next `Messenger` in the stack:\n",
    "```python\n",
    "class Messenger(object):\n",
    "    ...\n",
    "    def _process_message(self, msg):\n",
    "        method_name = \"_pyro_{}\".format(msg[\"type\"])  # e.g. _pyro_sample when msg[\"type\"] == \"sample\"\n",
    "        if hasattr(self, method_name):\n",
    "            getattr(self, method_name)(msg)\n",
    "    ...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interlude: the global `Messenger` stack\n",
    "\n",
    "See [pyro.contrib.minipyro](https://github.com/uber/pyro/blob/dev/pyro/contrib/minipyro.py) for an end-to-end implementation of the mechanism in this section.\n",
    "\n",
    "The order in which `Messenger`s are applied to an operation like a `pyro.sample` statement is determined by the order in which their `__enter__` methods are called.  `Messenger.__enter__` appends a `Messenger` to the end (the bottom) of the global handler stack:\n",
    "```python\n",
    "class Messenger(object):\n",
    "    ...\n",
    "    # __enter__ pushes a Messenger onto the stack\n",
    "    def __enter__(self):\n",
    "        ...\n",
    "        _PYRO_STACK.append(self)\n",
    "        ...\n",
    "    \n",
    "    # __exit__ removes a Messenger from the stack\n",
    "    def __exit__(self, ...):\n",
    "        ...\n",
    "        assert _PYRO_STACK[-1] is self\n",
    "        _PYRO_STACK.pop()\n",
    "        ...\n",
    "```\n",
    "\n",
    "`pyro.poutine.runtime.apply_stack` then traverses the stack twice at each operation, first from bottom to top to apply each `_process_message` and then from top to bottom to apply each `_postprocess_message`:\n",
    "```python\n",
    "def apply_stack(msg):  # simplified\n",
    "    for handler in reversed(_PYRO_STACK):\n",
    "        handler._process_message(msg)\n",
    "    ...\n",
    "    default_process_message(msg)\n",
    "    ...\n",
    "    for handler in _PYRO_STACK:\n",
    "        handler._postprocess_message(msg) \n",
    "    ...\n",
    "    return msg\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Returning to the `LogJointMessenger` example\n",
    "\n",
    "The second method `_postprocess_message` is necessary because some effects can only be applied after all other effect handlers have had a chance to update the message once. In the case of `LogJointMessenger`, other effects, like enumeration, may modify a sample site's value or distribution (`msg[\"value\"]` or `msg[\"fn\"]`), so we move the log-probability computation to a new method, `_pyro_post_sample`, which is called by `_postprocess_message` (via a dispatch mechanism like the one used by `_process_message`) at each `sample` site after all active handlers' `_pyro_sample` methods have been applied:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-3.0203)\n"
     ]
    }
   ],
   "source": [
    "class LogJointMessenger2(poutine.messenger.Messenger):\n",
    "    \n",
    "    def __init__(self, cond_data):\n",
    "        self.data = cond_data\n",
    "    \n",
    "    def __call__(self, fn):\n",
    "        def _fn(*args, **kwargs):\n",
    "            with self:\n",
    "                fn(*args, **kwargs)\n",
    "                return self.logp.clone()\n",
    "        return _fn\n",
    "    \n",
    "    def __enter__(self):\n",
    "        self.logp = torch.tensor(0.)\n",
    "        return super(LogJointMessenger2, self).__enter__()\n",
    "    \n",
    "    def __exit__(self, exc_type, exc_value, traceback):\n",
    "        self.logp = torch.tensor(0.)\n",
    "        return super(LogJointMessenger2, self).__exit__(exc_type, exc_value, traceback)\n",
    "\n",
    "    def _pyro_sample(self, msg):\n",
    "        if msg[\"name\"] in self.data:\n",
    "            msg[\"value\"] = self.data[msg[\"name\"]]\n",
    "            msg[\"done\"] = True\n",
    "            \n",
    "    def _pyro_post_sample(self, msg):\n",
    "        assert msg[\"done\"]  # the \"done\" flag asserts that no more modifications to value and fn will be performed.\n",
    "        self.logp = self.logp + (msg[\"scale\"] * msg[\"fn\"].log_prob(msg[\"value\"])).sum()\n",
    "\n",
    "\n",
    "with LogJointMessenger2(cond_data={\"measurement\": 9.5, \"weight\": 8.23}) as m:\n",
    "    scale(8.5)\n",
    "    print(m.logp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inside the messages sent by `Messenger`s\n",
    "\n",
    "As the previous two examples mentioned, the actual messages sent up and down the stack are dictionaries with a particular set of keys. Consider the following sample statement:\n",
    "```python\n",
    "pyro.sample(\"x\", dist.Bernoulli(0.5), infer={\"enumerate\": \"parallel\"}, obs=None)\n",
    "```\n",
    "This sample statement is converted into an initial message before any effects are applied, and each effect handler's `_process_message` and `_postprocess_message` may update fields in place or add new fields.  We write out the full initial message here for completeness:\n",
    "```python\n",
    "msg = {\n",
    "    # The following fields contain the name, inputs, function, and output of a site.\n",
    "    # These are generally the only fields you'll need to think about.\n",
    "    \"name\": \"x\",\n",
    "    \"fn\": dist.Bernoulli(0.5),\n",
    "    \"value\": None,  # msg[\"value\"] will eventually contain the value returned by pyro.sample\n",
    "    \"is_observed\": False,  # because obs=None by default; only used by sample sites\n",
    "    \"args\": (),  # positional arguments passed to \"fn\" when it is called; usually empty for sample sites\n",
    "    \"kwargs\": {},  # keyword arguments passed to \"fn\" when it is called; usually empty for sample sites\n",
    "    # This field typically contains metadata needed or stored by a particular inference algorithm\n",
    "    \"infer\": {\"enumerate\": \"parallel\"},\n",
    "    # The remaining fields are generally only used by Pyro's internals,\n",
    "    # or for implementing more advanced effects beyond the scope of this tutorial\n",
    "    \"type\": \"sample\",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample\n",
    "    \"done\": False,\n",
    "    \"stop\": False,\n",
    "    \"scale\": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob\n",
    "    \"mask\": None,\n",
    "    \"continuation\": None,\n",
    "    \"cond_indep_stack\": (),  # Will contain metadata from each pyro.plate enclosing this sample site.\n",
    "}\n",
    "```\n",
    "Note that when we use `poutine.trace` or `TraceMessenger` as in our first two versions of `make_log_joint`, the contents of `msg` are exactly the information stored in the trace for each sample and param site."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementing inference algorithms with existing effect handlers: examples\n",
    "\n",
    "It turns out that many inference operations, like our first version of `make_log_joint` above, have strikingly short implementations in terms of existing effect handlers in `pyro.poutine`. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example: Variational inference with a Monte Carlo ELBO\n",
    "\n",
    "For example, here is an implementation of variational inference with a Monte Carlo ELBO that uses `poutine.trace`, `poutine.condition`, and `poutine.replay`.  This is very similar to the simple ELBO in [pyro.contrib.minipyro](https://github.com/uber/pyro/blob/dev/pyro/contrib/minipyro.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def monte_carlo_elbo(model, guide, batch, *args, **kwargs):\n",
    "    # assuming batch is a dictionary, we use poutine.condition to fix values of observed variables\n",
    "    conditioned_model = poutine.condition(model, data=batch)\n",
    "    \n",
    "    # we'll approximate the expectation in the ELBO with a single sample:\n",
    "    # first, we run the guide forward unmodified and record values and distributions\n",
    "    # at each sample site using poutine.trace\n",
    "    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)\n",
    "    \n",
    "    # we use poutine.replay to set the values of latent variables in the model\n",
    "    # to the values sampled above by our guide, and use poutine.trace\n",
    "    # to record the distributions that appear at each sample site in in the model\n",
    "    model_trace = poutine.trace(\n",
    "        poutine.replay(conditioned_model, trace=guide_trace)\n",
    "    ).get_trace(*args, **kwargs)\n",
    "    \n",
    "    elbo = 0.\n",
    "    for name, node in model_trace.nodes.items():\n",
    "        if node[\"type\"] == \"sample\":\n",
    "            elbo = elbo + node[\"fn\"].log_prob(node[\"value\"]).sum()\n",
    "            if not node[\"is_observed\"]:\n",
    "                elbo = elbo - guide_trace.nodes[name][\"fn\"].log_prob(node[\"value\"]).sum()\n",
    "    return -elbo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use `poutine.trace` and `poutine.block` to record `pyro.param` calls for optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, guide, data):\n",
    "    optimizer = pyro.optim.Adam({})\n",
    "    for batch in data:\n",
    "        # this poutine.trace will record all of the parameters that appear in the model and guide\n",
    "        # during the execution of monte_carlo_elbo\n",
    "        with poutine.trace() as param_capture:\n",
    "            # we use poutine.block here so that only parameters appear in the trace above\n",
    "            with poutine.block(hide_fn=lambda node: node[\"type\"] != \"param\"):\n",
    "                loss = monte_carlo_elbo(model, guide, batch)\n",
    "        \n",
    "        loss.backward()\n",
    "        params = set(node[\"value\"].unconstrained()\n",
    "                     for node in param_capture.trace.nodes.values())\n",
    "        optimizer.step(params)\n",
    "        pyro.infer.util.zero_grads(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example: exact inference via sequential enumeration\n",
    "\n",
    "Here is an example of a very different inference algorithm--exact inference via enumeration--implemented with `pyro.poutine`.  A complete explanation of this algorithm is beyond the scope of this tutorial and may be found in Chapter 3 of the short online book [Design and Implementation of Probabilistic Programming Languages](http://dippl.org/chapters/03-enumeration.html).  This example uses `poutine.queue`, itself implemented using `poutine.trace`, `poutine.replay`, and `poutine.block`, to enumerate over possible values of all discrete variables in a model and compute a marginal distribution over all possible return values or the possible values at a particular sample site:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sequential_discrete_marginal(model, data, site_name=\"_RETURN\"):\n",
    "    \n",
    "    from six.moves import queue  # queue data structures\n",
    "    q = queue.Queue()  # Instantiate a first-in first-out queue\n",
    "    q.put(poutine.Trace())  # seed the queue with an empty trace\n",
    "    \n",
    "    # as before, we fix the values of observed random variables with poutine.condition\n",
    "    # assuming data is a dictionary whose keys are names of sample sites in model\n",
    "    conditioned_model = poutine.condition(model, data=data)\n",
    "    \n",
    "    # we wrap the conditioned model in a poutine.queue,\n",
    "    # which repeatedly pushes and pops partially completed executions from a Queue()\n",
    "    # to perform breadth-first enumeration over the set of values of all discrete sample sites in model\n",
    "    enum_model = poutine.queue(conditioned_model, queue=q)\n",
    "    \n",
    "    # actually perform the enumeration by repeatedly tracing enum_model\n",
    "    # and accumulate samples and trace log-probabilities for postprocessing\n",
    "    samples, log_weights = [], []\n",
    "    while not q.empty():\n",
    "        trace = poutine.trace(enum_model).get_trace()\n",
    "        samples.append(trace.nodes[site_name][\"value\"])\n",
    "        log_weights.append(trace.log_prob_sum())\n",
    "        \n",
    "    # we take the samples and log-joints and turn them into a histogram:\n",
    "    samples = torch.stack(samples, 0)\n",
    "    log_weights = torch.stack(log_weights, 0)\n",
    "    log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)\n",
    "    return dist.Empirical(samples, log_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(Note that `sequential_discrete_marginal` is very general, but is also quite slow. For high-performance parallel enumeration that applies to a less general class of models, see the enumeration tutorial.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: implementing lazy evaluation with the `Messenger` API\n",
    "\n",
    "Now that we've learned more about the internals of `Messenger`, let's use it to implement a slightly more complicated effect: lazy evaluation. We first define a `LazyValue` class that we will use to build up a computation graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LazyValue(object):\n",
    "    def __init__(self, fn, *args, **kwargs):\n",
    "        self._expr = (fn, args, kwargs)\n",
    "        self._value = None\n",
    "        \n",
    "    def __str__(self):\n",
    "        return \"({} {})\".format(str(self._expr[0]), \" \".join(map(str, self._expr[1])))\n",
    "        \n",
    "    def evaluate(self):\n",
    "        if self._value is None:\n",
    "            fn, args, kwargs = self._expr\n",
    "            fn = fn.evaluate() if isinstance(fn, LazyValue) else fn\n",
    "            args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg\n",
    "                         for arg in args)\n",
    "            kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v\n",
    "                      for k, v in kwargs.items()}\n",
    "            self._value = fn(*args, **kwargs)\n",
    "        return self._value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `LazyValue`, implementing lazy evaluation as a `Messenger` compatible with other effect handlers is suprisingly easy. We just make each `msg[\"value\"]` a `LazyValue` and introduce a new operation type `\"apply\"` for deterministic operations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LazyMessenger(pyro.poutine.messenger.Messenger):\n",
    "    def _process_message(self, msg):\n",
    "        if msg[\"type\"] in (\"apply\", \"sample\") and not msg[\"done\"]:\n",
    "            msg[\"done\"] = True\n",
    "            msg[\"value\"] = LazyValue(msg[\"fn\"], *msg[\"args\"], **msg[\"kwargs\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, just like `torch.autograd` overloads `torch` tensor operations to record an autograd graph, we need to wrap any operations we'd like to be lazy.  We'll use `pyro.poutine.runtime.effectful` as a decorator to expose these operations to `LazyMessenger`. `effectful` constructs a message much like the one above and sends it up and down the effect handler stack, but allows us to set the type (in this case, to `\"apply\"` instead of `\"sample\"`) so that these operations aren't mistaken for `sample` statements by other effect handlers like `TraceMessenger`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "@effectful(type=\"apply\")\n",
    "def add(x, y):\n",
    "    return x + y\n",
    "\n",
    "@effectful(type=\"apply\")\n",
    "def mul(x, y):\n",
    "    return x * y\n",
    "\n",
    "@effectful(type=\"apply\")\n",
    "def sigmoid(x):\n",
    "    return torch.sigmoid(x)\n",
    "\n",
    "@effectful(type=\"apply\")\n",
    "def normal(loc, scale):\n",
    "    return dist.Normal(loc, scale)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Applied to another model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "((<function normal at 0x7fc41cbfdc80> (<function add at 0x7fc41cbf91e0> (<function mul at 0x7fc41cbfda60> ((<function normal at 0x7fc41cbfdc80> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fc41cbfdb70> ((<function normal at 0x7fc41cbfdc80> 0.0 0.25) ))) )\n",
      "tensor(6.5436)\n"
     ]
    }
   ],
   "source": [
    "def biased_scale(guess):\n",
    "    weight = pyro.sample(\"weight\", normal(guess, 1.))\n",
    "    tolerance = pyro.sample(\"tolerance\", normal(0., 0.25))\n",
    "    return pyro.sample(\"measurement\", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))\n",
    "\n",
    "with LazyMessenger():\n",
    "    v = biased_scale(8.5)\n",
    "    print(v)\n",
    "    print(v.evaluate())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Together with other effect handlers like `TraceMessenger` and `ConditionMessenger`, with which it freely composes, `LazyMessenger` demonstrates how to use Poutine to quickly and concisely implement state-of-the-art PPL techniques like [delayed sampling with Rao-Blackwellization](https://arxiv.org/abs/1708.07787)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References: algebraic effects and handlers in programming language research\n",
    "\n",
    "This section contains some references to PL papers for readers interested in this direction.\n",
    "\n",
    "Algebraic effects and handlers, which were developed starting in the early 2000s and are a subject of active research in the programming languages community, are a versatile abstraction for building modular implementations of nonstandard interpreters of particular statements in a programming language, like `pyro.sample` or `pyro.param`.  They were originally introduced to address the difficulty of composing nonstandard interpreters implemented with monads and monad transformers.\n",
    "\n",
    "- For an accessible introduction to the effect handlers literature, see the excellent review/tutorial paper [\"Handlers in Action\"](http://homepages.inf.ed.ac.uk/slindley/papers/handlers.pdf) by Ohad Kammar, Sam Lindley, and Nicolas Oury, and the references therein.\n",
    "\n",
    "- Algebraic effect handlers were originally introduced by Gordon Plotkin and Matija Pretnar in the paper [\"Handlers of Algebraic Effects\"](https://link.springer.com/chapter/10.1007/978-3-642-00590-9_7).\n",
    "\n",
    "- A useful mental model of effect handlers is as exception handlers that are capable of resuming computation in the `try` block after raising an exception and performing some processing in the `except` block. This metaphor is explored further in the experimental programming language [Eff](http://math.andrej.com/eff/) and its companion paper [\"Programming with Algebraic Effects and Handlers\"](https://arxiv.org/abs/1203.1539) by Andrej Bauer and Matija Pretnar.\n",
    "\n",
    "- Most effect handlers in Pyro are \"linear,\" meaning that they only resume once per effectful operation and do not alter the order of execution of the original program. One exception is `poutine.queue`, which uses an inefficient implementation strategy for multiple resumptions like the one described for delimited continuations in the paper [\"Capturing the Future by Replaying the Past\"](http://delivery.acm.org/10.1145/3240000/3236771/icfp18main-p36-p.pdf) by James Koppel, Gabriel Scherer, and Armando Solar-Lezama.  \n",
    "\n",
    "- More efficient implementation strategies for effect handlers in mainstream programming languages like Python or JavaScript is an area of active research. One promising line of work involves selective continuation-passing style transforms as in the paper [\"Type-Directed Compilation of Row-Typed Algebraic Effects\"](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/12/algeff.pdf) by Daan Leijen."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
