{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom SVI Objectives\n",
    "\n",
    "Pyro provides support for various optimization-based approaches to Bayesian inference, with `Trace_ELBO` serving as the basic implementation of SVI (stochastic variational inference).\n",
    "See the [docs](http://docs.pyro.ai/en/dev/inference_algos.html#module-pyro.infer.svi) for more information on the various SVI implementations and SVI \n",
    "tutorials [I](http://pyro.ai/examples/svi_part_i.html), \n",
    "[II](http://pyro.ai/examples/svi_part_ii.html), \n",
    "and [III](http://pyro.ai/examples/svi_part_iii.html) for background on SVI.\n",
    "\n",
    "In this tutorial we show how advanced users can modify and/or augment the variational\n",
    "objectives (alternatively: loss functions) provided by Pyro to support special use cases."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. [Basic SVI Usage](#Basic-SVI-Usage)\n",
    "    1. [A Lower Level Pattern](#A-Lower-Level-Pattern)\n",
    "2. [Example: Custom Regularizer](#Example:-Custom-Regularizer)\n",
    "3. [Example: Scaling the Loss](#Example:-Scaling-the-Loss)\n",
    "4. [Example: Mixing Optimizers](#Example:-Mixing-Optimizers)\n",
    "5. [Example: Custom ELBO](#Example:-Custom-ELBO)\n",
    "6. [Example: KL Annealing](#Example:-KL-Annealing)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic SVI Usage\n",
    "\n",
    "We first review the basic usage pattern of `SVI` objects in Pyro. We assume that the user\n",
    "has defined a `model` and a `guide`.  The user then creates an optimizer and an `SVI` object:\n",
    "\n",
    "```python\n",
    "optimizer = pyro.optim.Adam({\"lr\": 0.001, \"betas\": (0.90, 0.999)})\n",
    "svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())\n",
    "```\n",
    "\n",
    "Gradient steps can then be taken with a call to `svi.step(...)`. The arguments to `step()` are then\n",
    "passed to `model` and `guide`.\n",
    "\n",
    "\n",
    "### A Lower-Level Pattern\n",
    "\n",
    "The nice thing about the above pattern is that it allows Pyro to take care of various \n",
    "details for us, for example:\n",
    "\n",
    "- `pyro.optim.Adam` dynamically creates a new `torch.optim.Adam` optimizer whenever a new parameter is encountered \n",
    "- `SVI.step()` zeros gradients between gradient steps\n",
    "\n",
    "If we want more control, we can directly manipulate the differentiable loss method of \n",
    "the various `ELBO` classes. For example, (assuming we know all the parameters in advance) \n",
    "this is equivalent to the previous code snippet:\n",
    "\n",
    "```python\n",
    "# define optimizer and loss function\n",
    "optimizer = torch.optim.Adam(my_parameters, {\"lr\": 0.001, \"betas\": (0.90, 0.999)})\n",
    "loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n",
    "# compute loss\n",
    "loss = loss_fn(model, guide, model_and_guide_args)\n",
    "loss.backward()\n",
    "# take a step and zero the parameter gradients\n",
    "optimizer.step()\n",
    "optimizer.zero_grad()\n",
    "```\n",
    "\n",
    "## Example: Custom Regularizer\n",
    "\n",
    "Suppose we want to add a custom regularization term to the SVI loss. Using the above \n",
    "usage pattern, this is easy to do. First we define our regularizer:\n",
    "\n",
    "```python\n",
    "def my_custom_L2_regularizer(my_parameters):\n",
    "    reg_loss = 0.0\n",
    "    for param in my_parameters:\n",
    "        reg_loss = reg_loss + param.pow(2.0).sum()\n",
    "    return reg_loss  \n",
    "```\n",
    "\n",
    "Then the only change we need to make is:\n",
    "\n",
    "```diff\n",
    "- loss = loss_fn(model, guide)\n",
    "+ loss = loss_fn(model, guide) + my_custom_L2_regularizer(my_parameters)\n",
    "```\n",
    "\n",
    "## Example: Scaling the Loss\n",
    "\n",
    "Depending on the optimization algorithm, the scale of the loss may or not matter. Suppose \n",
    "we want to scale our loss function by the number of datapoints before we differentiate it.\n",
    "This is easily done:\n",
    "\n",
    "```diff\n",
    "- loss = loss_fn(model, guide)\n",
    "+ loss = loss_fn(model, guide) / N_data\n",
    "```\n",
    "\n",
    "Note that in the case of SVI, where each term in the loss function is a log probability \n",
    "from the model or guide, this same effect can be achieved using [poutine.scale](http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.scale). For \n",
    "example we can use the `poutine.scale` decorator to scale both the model and guide:\n",
    "\n",
    "```python\n",
    "@poutine.scale(scale=1.0/N_data)\n",
    "def model(...):\n",
    "    pass\n",
    "   \n",
    "@poutine.scale(scale=1.0/N_data)\n",
    "def guide(...):\n",
    "    pass\n",
    "```\n",
    "\n",
    "## Example: Mixing Optimizers\n",
    "\n",
    "The various optimizers in `pyro.optim` allow the user to specify optimization settings (e.g. learning rates) on\n",
    "a per-parameter basis. But what if we want to use different optimization algorithms for different parameters? \n",
    "We can do this using Pyro's `MultiOptimizer` (see below), but we can also achieve the same thing if we directly manipulate `differentiable_loss`:\n",
    "\n",
    "```python\n",
    "adam = torch.optim.Adam(adam_parameters, {\"lr\": 0.001, \"betas\": (0.90, 0.999)})\n",
    "sgd = torch.optim.SGD(sgd_parameters, {\"lr\": 0.0001})\n",
    "loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n",
    "# compute loss\n",
    "loss = loss_fn(model, guide)\n",
    "loss.backward()\n",
    "# take a step and zero the parameter gradients\n",
    "adam.step()\n",
    "sgd.step()\n",
    "adam.zero_grad()\n",
    "sgd.zero_grad()\n",
    "```\n",
    "\n",
    "For completeness, we also show how we can do the same thing using [MultiOptimizer](http://docs.pyro.ai/en/dev/optimization.html?highlight=multi%20optimizer#module-pyro.optim.multi), which allows\n",
    "us to combine multiple Pyro optimizers. Note that since `MultiOptimizer` uses `torch.autograd.grad` under the hood (instead of `torch.Tensor.backward()`), it has a slightly different interface; in particular the `step()` method also takes parameters as inputs.\n",
    "\n",
    "```python\n",
    "def model():\n",
    "    pyro.param('a', ...)\n",
    "    pyro.param('b', ...)\n",
    "    ...\n",
    "  \n",
    "adam = pyro.optim.Adam({'lr': 0.1})\n",
    "sgd = pyro.optim.SGD({'lr': 0.01})\n",
    "optim = MixedMultiOptimizer([(['a'], adam), (['b'], sgd)])\n",
    "with pyro.poutine.trace(param_only=True) as param_capture:\n",
    "    loss = elbo.differentiable_loss(model, guide)\n",
    "params = {'a': pyro.param('a'), 'b': pyro.param('b')}\n",
    "optim.step(loss, params)\n",
    "```\n",
    "\n",
    "## Example: Custom ELBO\n",
    "\n",
    "In the previous three examples we bypassed creating a `SVI` object and directly manipulated \n",
    "the differentiable loss function provided by an `ELBO` implementation. Another thing we \n",
    "can do is create custom `ELBO` implementations and pass those into the `SVI` machinery. \n",
    "For example, a simplified version of a `Trace_ELBO` loss function might look as follows:\n",
    "\n",
    "```python\n",
    "# note that simple_elbo takes a model, a guide, and their respective arguments as inputs\n",
    "def simple_elbo(model, guide, *args, **kwargs):\n",
    "    # run the guide and trace its execution\n",
    "    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)\n",
    "    # run the model and replay it against the samples from the guide\n",
    "    model_trace = poutine.trace(\n",
    "        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)\n",
    "    # construct the elbo loss function\n",
    "    return -1*(model_trace.log_prob_sum() - guide_trace.log_prob_sum())\n",
    "\n",
    "svi = SVI(model, guide, optim, loss=simple_elbo)\n",
    "```\n",
    "Note that this is basically what the `elbo` implementation in [\"mini-pyro\"](https://github.com/uber/pyro/blob/dev/pyro/contrib/minipyro.py) looks like.\n",
    "\n",
    "### Example: KL Annealing\n",
    "\n",
    "In the [Deep Markov Model Tutorial](http://pyro.ai/examples/dmm.html) the ELBO variational objective\n",
    "is modified during training. In particular the various KL-divergence terms between latent random\n",
    "variables are scaled downward (i.e. annealed) relative to the log probabilities of the observed data.\n",
    "In the tutorial this is accomplished using `poutine.scale`. We can accomplish the same thing by defining \n",
    "a custom loss function. This latter option is not a very elegant pattern but we include it anyway to \n",
    "show the flexibility we have at our disposal. \n",
    "\n",
    "```python\n",
    "def simple_elbo_kl_annealing(model, guide, *args, **kwargs):\n",
    "    # get the annealing factor and latents to anneal from the keyword\n",
    "    # arguments passed to the model and guide\n",
    "    annealing_factor = kwargs.pop('annealing_factor', 1.0)\n",
    "    latents_to_anneal = kwargs.pop('latents_to_anneal', [])\n",
    "    # run the guide and replay the model against the guide\n",
    "    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)\n",
    "    model_trace = poutine.trace(\n",
    "        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)\n",
    "        \n",
    "    elbo = 0.0\n",
    "    # loop through all the sample sites in the model and guide trace and\n",
    "    # construct the loss; note that we scale all the log probabilities of\n",
    "    # samples sites in `latents_to_anneal` by the factor `annealing_factor`\n",
    "    for site in model_trace.values():\n",
    "        if site[\"type\"] == \"sample\":\n",
    "            factor = annealing_factor if site[\"name\"] in latents_to_anneal else 1.0\n",
    "            elbo = elbo + factor * site[\"fn\"].log_prob(site[\"value\"]).sum()\n",
    "    for site in guide_trace.values():\n",
    "        if site[\"type\"] == \"sample\":\n",
    "            factor = annealing_factor if site[\"name\"] in latents_to_anneal else 1.0        \n",
    "            elbo = elbo - factor * site[\"fn\"].log_prob(site[\"value\"]).sum()\n",
    "    return -elbo\n",
    "\n",
    "svi = SVI(model, guide, optim, loss=simple_elbo_kl_annealing)\n",
    "svi.step(other_args, annealing_factor=0.2, latents_to_anneal=[\"my_latent\"])\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
