{"title": "Autoconj: Recognizing and Exploiting Conjugacy Without a Domain-Specific Language", "book": "Advances in Neural Information Processing Systems", "page_first": 10716, "page_last": 10726, "abstract": "Deriving conditional and marginal distributions using conjugacy relationships can be time consuming and error prone. In this paper, we propose a strategy for automating such derivations. Unlike previous systems which focus on relationships between pairs of random variables, our system (which we call Autoconj) operates directly on Python functions that compute log-joint distribution functions. Autoconj provides support for conjugacy-exploiting algorithms in any Python-embedded PPL. This paves the way for accelerating development of novel inference algorithms and structure-exploiting modeling strategies. The package can be downloaded at https://github.com/google-research/autoconj.", "full_text": "Autoconj: Recognizing and Exploiting Conjugacy\n\nWithout a Domain-Speci\ufb01c Language\n\nMatthew D. Hoffman\u2217\n\nGoogle AI\n\nmhoffman@google.com\n\nMatthew J Johnson*\n\nGoogle Brain\n\nmattjj@google.com\n\nDustin Tran\nGoogle Brain\n\ntrandustin@google.com\n\nAbstract\n\nDeriving conditional and marginal distributions using conjugacy relationships can\nbe time consuming and error prone. In this paper, we propose a strategy for au-\ntomating such derivations. Unlike previous systems which focus on relationships\nbetween pairs of random variables, our system (which we call Autoconj) oper-\nates directly on Python functions that compute log-joint distribution functions.\nAutoconj provides support for conjugacy-exploiting algorithms in any Python-\nembedded PPL. This paves the way for accelerating development of novel infer-\nence algorithms and structure-exploiting modeling strategies.1\n\n1\n\nIntroduction\n\nSome models enjoy a property called conjugacy that makes computation easier. Conjugacy lets us\ncompute complete conditional distributions, that is, the distribution of some variable conditioned\non all other variables in the model. Complete conditionals are at the heart of many classical sta-\ntistical inference algorithms such as Gibbs sampling (Geman and Geman, 1984), coordinate-ascent\nvariational inference (Jordan et al., 1999), and even the venerable expectation-maximization (EM)\nalgorithm (Dempster et al., 1977). Conjugacy also makes it possible to marginalize out some vari-\nables, which makes many algorithms faster and/or more accurate (e.g.; Grif\ufb01ths and Steyvers, 2004).\nMany popular models in the literature enjoy some form of conjugacy, and these models can often be\nextended in ways that preserve conjugacy.\nFor experienced researchers, deriving conditional and marginal distributions using conjugacy rela-\ntionships is straightforward. But it is also time consuming and error prone, and diagnosing bugs in\nthese derivations can require signi\ufb01cant effort (Cook et al., 2006).\nThese considerations motivated specialized systems such as BUGS (Spiegelhalter et al., 1995),\nVIBES (Winn and Bishop, 2005), and their many successors.\nIn these systems, one speci\ufb01es a\nmodel in a probabilistic programming language (PPL), provides observed values for some of the\nmodel\u2019s variables, and lets the system automatically translate the model speci\ufb01cation into an algo-\nrithm (typically Gibbs sampling or variational inference) that approximates the model\u2019s posterior\nconditioned on the observed variables.\nThese systems are useful, but their monolithic design imposes a major limitation: they are dif\ufb01cult\nto compose with other systems. For example, a user who wants to interleave Gibbs sampling steps\nwith some customized Markov chain Monte Carlo (MCMC) kernel will \ufb01nd it very dif\ufb01cult to take\nadvantage of BUGS\u2019 Gibbs sampler.\nIn this paper, we propose a different strategy for exploiting conditional conjugacy relationships.\nUnlike previous approaches (which focus on relationships between pairs of random variables) our\n\n\u2217equal contribution\n1 Autoconj (including experiments) is available at https://github.com/google-research/autoconj.\n\n32nd Conference on Neural Information Processing Systems (NeurIPS 2018), Montr\u00e9al, Canada.\n\n\fsystem (which we call Autoconj) operates directly on Python functions that compute log-joint dis-\ntribution functions. If asked to compute a marginal distribution, Autoconj returns a Python function\nthat implements that marginal distribution\u2019s log-joint. If asked to compute a complete conditional,\nit returns a Python function that returns distribution objects.\nAutoconj is not tied to any particular approximate inference algorithm. But, because Autoconj\nis a simple Python API, implementing conjugacy-exploiting approximate inference algorithms\nusing Autoconj is easy and fast (as we demonstrate in section 5).\nIn particular, working in\nthe Python/NumPy ecosystem gives Autoconj users access to vectorized kernels, automatic dif-\nferentiation (via Autograd (Maclaurin et al., 2014)), sophisticated optimization algorithms (via\nscipy.optimize), and even accelerated hardware (via TensorFlow).\nAutoconj provides support for conjugacy-exploiting algorithms in any Python-embedded PPL. More\nambitiously, we hope that, just as automatic differentiation has accelerated research in deep learning,\nAutoconj will accelerate the development of novel inference algorithms and modeling strategies that\nexploit conjugacy.\n\n2 Background: Exponential Families and Conjugacy\n\nTo develop a system that can automatically \ufb01nd and exploit conjugacy, we \ufb01rst develop a general\nperspective on exponential families. Given a probability space (X ,B(X ), \u03bd), where B(X ) is the\nBorel sigma algebra with respect to the standard topology on X , and a statistic function t : X \u2192 Rn,\nde\ufb01ne the corresponding exponential family of densities (Wainwright and Jordan, 2008), indexed by\nthe natural parameter \u03b7 \u2208 Rn, and log-normalizer function A as\nA(\u03b7) (cid:44) log\n\n(1)\nwhere (cid:104)\u00b7 , \u00b7(cid:105) denotes the standard inner product. The log-normalizer function A is directly related to\nthe cumulant-generating function, and in particular it satis\ufb01es\n\np(x; \u03b7) = exp{(cid:104)\u03b7, t(x)(cid:105) \u2212 A(\u03b7)} ,\n\nexp{(cid:104)\u03b7, t(x)(cid:105)} \u03bd(dx),\n\n(cid:90)\n\n\u22072A(\u03b7) = E(cid:2)t(x)t(x)T(cid:3) \u2212 E [t(x)] E [t(x)]T ,\n\n\u2207A(\u03b7) = E [t(x)] ,\n\nwhere the expectation is with respect to p(x; \u03b7). For a given statistic function t, when the corre-\nsponding distribution can be sampled ef\ufb01ciently, and when A and its derivatives can be evaluated\nef\ufb01ciently, we say the exponential family (or the statistic function that de\ufb01nes it) is tractable.\nConsider an exponential-family model where the log density has the form\n\nlog p(z, x) = log p(z1, z2, . . . , zM , x) =(cid:80)\n\n\u03b2\u2208\u03b2(cid:104)\u03b7\u03b2(x), tz1 (z1)\u03b21 \u2297 \u00b7\u00b7\u00b7 \u2297 tzM (zM )\u03b2M(cid:105)\n\n(2)\n\n(3)\n\n(cid:44) g(tz1(z1), . . . , tzM (zM )),\n\nwhere \u03b2 \u2286 {0, 1}M is an index set, we take tzm(zm)0 \u2261 1, and where the functions {tzm(zm)}M\nm=1\nare each the suf\ufb01cient statistics of a tractable exponential family. In words, the log joint density\nlog p(z, x) can be written as a multilinear (or multiaf\ufb01ne) polynomial g applied to the statistic func-\ntions {tzm (zm)}. These models arise when building complex distributions from simpler, tractable\nones, and the algebraic structure in g corresponds to graphical model structure (Wainwright and\nJordan, 2008; Koller and Friedman, 2009). In general the posterior log p(z | x) is not tractable, but\nit admits ef\ufb01cient approximate inference algorithms.\nModels of the form (3) are known as conditionally conjugate models. Each conditional p(zm | z\u00acm)\n(where z\u00acm (cid:44) {z1, . . . , zM}\\{zm}) is a tractable exponential family. Moreover, the parameters of\nthese conditional densities can be extracted using differentiation. We formalize this below.\nClaim 2.1. Given an exponential family with density of the form (3), we have\n\n, tzm(zm)(cid:105) \u2212 Azm (\u03b7\u2217\n\nzm\n\ng(tz1 (z1), . . . , tzM (zM )).\nAs a consequence, if we had code for evaluating the functions g and {tzm}, along with a table of\nsampling routines corresponding to each tractable statistic tzm, then we could use automatic differ-\nentiation to write a generic Gibbs sampling algorithm. This generic algorithm could be extended\nto work with any tractable exponential-family distribution simply by populating a table matching\n\nzm\n\np(zm | z\u00acm) = exp(cid:8)(cid:104)\u03b7\u2217\n\nzm)(cid:9) where \u03b7\u2217\n\n(cid:44) \u2207tzm\n\n2\n\n\ftractable statistics functions to their corresponding samplers. Note this differs from a table of pairs\nof random variables: conjugacy derives from this lower-level algebraic relationship.\nThe model structure (3) can be exploited in other approximate inference algorithms, including vari-\national mean \ufb01eld (Wainwright and Jordan, 2008) and stochastic variational inference (Hoffman\net al., 2013). Consider the variational distribution\n\nq(z) =\n\nq(zm; \u03b7zm ),\n\nq(zm; \u03b7zm) = exp{(cid:104)\u03b7zm , tzm(zm)(cid:105) \u2212 Azm(\u03b7zm)} ,\n\n(4)\n\n(cid:89)\n\nm\n\nwhere \u03b7zm are natural parameters of the variational factors. We write the variational evidence lower\nbound objective L = L(\u03b7z1, . . . , \u03b7zM ) for approximating the posterior p(z | x) as\n\n(cid:90)\n\nlog p(x) = log\n\np(z, x) \u03bdz(dz) = log Eq(z)\n\n(cid:20) p(z, x)\n\n(cid:21)\n\nq(z)\n\n(cid:20)\n\n(cid:21)\n\n\u2265 Eq(z)\n\nlog\n\np(z, x)\nq(z)\n\n(cid:44) L.\n\n(5)\n\nWe can write block coordinate ascent updates for this objective using differentiation:\nClaim 2.2. Given a model with density of the form (3) and variational problem (4)-(5), we have\n\nL(\u03b7z1 , . . . \u03b7zM ) = \u2207\u00b5zm\n\ng(\u00b5z1 , . . . , \u00b5zM ) where \u00b5zm(cid:48) (cid:44) \u2207Azm(cid:48) (\u03b7zm(cid:48) ), m(cid:48)\n\n= 1, . . . , M.\n\narg max\n\n\u03b7zm\n\nThus if we had code for evaluating the functions g and {tzm}, along with a table of log-normalizer\nfunctions Azm corresponding to each tractable statistic tzm, then we could use automatic differ-\nentiation to write a generic block coordinate-ascent variational inference algorithm. New tractable\nstructures could be added to this algorithm\u2019s repertoire simply by populating the table of statistics\nand their corresponding log-normalizers.\nIf all this tractable exponential-family structure can be exploited generically, why is writing\nconjugacy-exploiting inference software still so laborious? The reason is that it is not always easy\nto get our hands on the representation (3). Even when a model\u2019s log joint density could be written\nas in (3), it is often dif\ufb01cult and error-prone to write code to evaluate g directly; it is much more\nnatural to specify model densities without being constrained to this form. The situation is analogous\nto deep learning research before \ufb02exible automatic differentiation: we\u2019re stuck writing too much\ncode by hand, and even though in principle this process could be automated, our current software\ntools aren\u2019t up to the task unless we\u2019re willing to get locked into a limited mini-language.\nBased on this derivation, Autoconj is built to automatically extract these tractable structures (i.e., the\nfunctions g and {tzm}). It does this given log density functions written in plain Python and NumPy.\nAnd it reaps automatic structure-exploiting inference algorithms as a result.\n\n3 Analyzing Log-Joint Functions\n\nTo extract suf\ufb01cient statistics and natural parameters from a log-joint function, Autoconj \ufb01rst rep-\nresent that function in a convenient canonical form. It applies a canonicalization process, which\ncomprises two stages: 1. a tracer maps Python log-joint probability functions to symbolic term\ngraphs; 2. a domain-speci\ufb01c rewrite system puts the log-joint functions in a canonical form and\nextracts the component functions de\ufb01ned in Section 2.\n\n3.1 Tracing Python programs to generate term graphs\n\nThe tracer\u2019s purpose is to map a Python function denoting a log-joint function to an acyclic term\ngraph data structure. It accomplishes this mapping without having to analyze Python syntax or rea-\nson about its semantics directly; instead, the tracer monitors the execution of a Python function in\nterms of the primitive functions that are applied to its arguments to produce its \ufb01nal output. As a\nconsequence, intermediates like non-primitive function calls and auxiliary data structures, includ-\ning tuples/lists/dicts as well as custom classes, do not appear in the trace and instead all get traced\nthrough. The ultimate output of the tracer is a directed acyclic data \ufb02ow graph, where nodes repre-\nsent application of primitive functions (typically NumPy functions) and edges represent data \ufb02ow.\nThis approach is both simple to implement and able to handle essentially any Python code.\nA weakness of this tracing approach is that we only trace one evaluation of the function on example\narguments, and we assume that the trace represents the same mathematical function that the original\n\n3\n\n\fFigure 1: Left: Python code for evaluating the log joint density of a Gaussian mixture model. Right:\ncanonicalized computation graph, representing the same log joint density function but rewritten as a\nsum of np.einsums of statistic functions.\n\nPython code denotes. This assumption can fail. For example, if a Python function has an if/else\nthat depends on the value of the arguments (and is not expressed in a primitive function), then the\ntracer could only follow one branch, and so instead raises an error. In the context of tracing log-joint\nfunctions, this limitation does not seem to arise too frequently, but it does affect our handling of\ndiscrete random variables; for densities of discrete random variables, the tracer can intercept either\nindexing expressions like pi[z] or the use of the primitive function one_hot.\nFigure 1 summarizes the tracer\u2019s use on Python code to generate a term graph. To implement the\ntracing mechanism, we reuse Autograd\u2019s tracer (Maclaurin et al., 2014), which is designed to be\ngeneral-purpose and extensible with a simple API. Other similar tracing mechanisms are common\nin probabilistic programming (Goodman and Stuhlm\u00fcller, 2014).\n\n3.2 Domain-speci\ufb01c term graph rewriting system\n\nThe goal of the rewrite system is to take a log-joint term graph and manipulate it into a canoni-\ncal form. Mathematically, the canonical form described in Section 2 is a multilinear polynomial\ng on tensor-valued statistic functions t1, . . . , tM . For term graphs, we say a term graph is in this\ncanonical form when its output node represents a sum of np.einsum nodes, with each np.einsum\nnode corresponding to a monomial term in g and each np.einsum argument being either a con-\nstant, a nonlinear function of an input, or an input itself, with the latter two cases corresponding to\nstatistic functions tm. We rely on np.einsum because it is capable of expressing arbitrary tensor\ncontractions, meaning it is a uniform way to express arbitrary monomial terms in g.\nAt its core, the rewrite system is based on pattern-directed invocation of rewrite rules, each of which\ncan match and then modify a small subgraph corresponding to a few primitive function applications.\nOur pattern language is a new Python-embedded DSL, which is compiled into continuation-passing\nmatcher combinators (Radul, 2013). In addition to basic matchers for data types and each primitive\nfunction, the pattern combinators include Choice, which produces a match if any of its argument\ncombinators produce a match, and Segment, which can match any number of elements in a list,\nincluding argument lists. By using continuation passing, backtracking is effectively handled by\nthe Python call stack, and it\u2019s straightforward to extract just one match or all possible matches.\nThe pattern language compiler is only ~300 lines and is fully extensible by registering new syntax\nhandlers.\nA rewrite rule is then a pattern paired with a rewriter function. A rewriter essentially represents a\nsyntactic macro operating on the term subgraph, using matched sub-terms collected by the pattern to\ngenerate a new term subgraph. To specify each rewriter, we again make use of the tracer: we simply\nwrite a Python function that, when traced on appropriate arguments, produces the new subgraph,\nwhich we then patch into the term graph. This mechanism is analogous to quasiquoting (Radul,\n2013), since it speci\ufb01es a syntactic transformation in terms of native Python expressions. Thus\nby using pattern matching and tracing-based rewriters, we can de\ufb01ne general rewrite rules without\nwriting any code that manually traverses or modi\ufb01es the term graph data structure. As a result,\n\n4\n\ndefnormal_logpdf(x,loc,scale):prec=1./scale**2return-(np.sum(prec*mu**2)-np.sum(np.log(prec))+np.log(2.*np.pi))*N/2.deflog_joint(pi,z,mu,tau,x):logp=(np.sum((alpha-1)*np.log(x))-np.sum(gammaln(alpha))+np.sum(gammaln(np.sum(alpha,-1))))logp+=normal_logpdf(mu,0.,1./np.sqrt(kappa*tau))logp+=np.sum(one_hot(z,K)*np.log(pi))logp+=((a-1)*np.log(tau)-b*tau+a*np.log(b)-gammaln(a))mu_z=np.dot(one_hot(z,K),mu)loglike=normal_logpdf(x,mu_z,1./np.sqrt(tau))returnlogp+loglike6\fit is straightforward to add new rewrite rules to the system. See Listing 1 for an example rewrite\nrule.\npat = (Einsum, Str('formula'), Segment('args1'),\n\n(Choice(Subtract('op'), Add('op')), Val('x'), Val('y')), Segment('args2'))\n\ndef rewriter(formula, op, x, y, args1, args2):\n\nreturn op(np.einsum(formula, *(args1 + (x,) + args2)),\nnp.einsum(formula, *(args1 + (y,) + args2)))\n\ndistribute_einsum = Rule(pat, rewriter) # Rule is a namedtuple\n\nListing 1: A rewrite for distributing np.einsum over addition and subtraction.\n\nRewrite rules are composed into a term rewriting system by an alternating strategy with two steps.\nIn the \ufb01rst step, for each rule we look for a pattern match anywhere in the term graph starting from\nthe output; if no match is found then the process terminates, and if there is a match we apply the\ncorresponding rewriter and move to the second step. In the second step, we traverse the graph from\nthe inputs to the output, performing common subexpression elimination (CSE) and applying local\nsimpli\ufb01cations that only involve one primitive at a time (like replacing a np.dot with an equivalent\nnp.einsum) and hence don\u2019t require pattern matching. By alternating rewrites with CSE, we remove\nany redundancies introduced by the rewrites. It is straightforward to compose new rewrite systems,\ninvolving different sets of rewrite rules or different strategies for applying them.\nThe process is summarized in Figure 1. The rewriting process aims to transform the term graph of\na log joint density into the canonical sum-of-einsums polynomial form corresponding to Eq. (3) (up\nto commutativity). We do not have a proof that the rewrites are terminating or con\ufb02uent (Baader and\nNipkow, 1999), and the set of possible terms is very complex, though intuitively each rewrite rule\napplied makes strict progress towards the canonical form (e.g. by distributing multiplication across\naddition). In practice there have been no problems with termination or normalization.\nOnce we have processed the log-joint term graph into a canonical form, it is straightforward to\nextract the objects of interest (namely the statistic functions t1, . . . , tM and the polynomial g), match\nthe tractable statistics with corresponding log-normalizer and sampler functions from a table, and\nperform any further manipulations like automatic differentiation. Moreover, we can map the term\ngraph back into a Python function (via an interpreter), so the rewrite system is hermetic: we can use\nits output with any other Python tools, like Autograd or SciPy, without those tools needing to know\nanything about it.\nTerm rewriting systems have a long history in compilers and symbolic math systems (Sussman\net al., 2018; Radul, 2013; Diehl, 2013; Rozenberg, 1997; Baader and Nipkow, 1999). The main\nnovelty here is the application domain and speci\ufb01c concerns and capabilities that arise from it; we\u2019re\nmanipulating exponential families of densities for multidimensional random variables, and hence our\nsystem is focused on matrix and tensor manipulations, which have limited support in other systems,\nand a speci\ufb01c canonical form informed by structure-exploiting approximate inference algorithms.\nOur implementation is closely related to the term rewriting system in scmutils (Sussman et al.,\n2018) and Rules (Radul, 2013), which also use a pattern language (embedded in Scheme) based on\ncontinuation-passing matcher combinators and quasiquote-based syntactic macros. Two differences\nin the implementation are that our system operates on term graphs rather than syntax trees, and that\nwe use tracing to implement a kind of macro system on our term graph data structures (instead of\nusing Scheme\u2019s built-in quasiquotes and homoiconicity).\n\n3.3 Recognizing Suf\ufb01cient Statistics and Natural Parameters\n\nOnce the log-joint graph has been canonicalized as a sum of np.einsums of functions of the inputs,\nwe can discover and extract exponential-family structure.\nSuppose we are interested in the complete conditional of an input z. We \ufb01rst need to \ufb01nd all nodes\nthat represent suf\ufb01cient statistics of z. We begin at the output node, and search up through the graph,\nignoring any nodes that do not depend on z. We walk through any add or subtract nodes until\nwe reach an np.einsum node. If z is a parent of more than one argument of that np.einsum node,\nthen the node represents a nonlinear function of z and we label it as a suf\ufb01cient statistic (if the node\n\n5\n\n\fhas any inputs that do not depend on z we also need to split those out). Otherwise, we walk through\nthe np.einsum node since it is a linear function of z. If at any point in the search we reach either z\nor a node that is not linear in z (i.e., an add, subtract, or np.einsum), we label it as a suf\ufb01cient\nstatistic.\nOnce we have found the set of suf\ufb01cient statistic nodes, we can determine whether they correspond\nto a known tractable exponential family. For example, in Figure 1, z has integer support and the\none-hot statistic, so its complete conditional is a categorical distribution; \u03c0\u2019s support is the simplex\nand its only suf\ufb01cient statistic is log \u03c0, so \u03c0\u2019s complete conditional is a Dirichlet; \u03c4\u2019s support is\nthe non-negative reals, and its suf\ufb01cient statistics are \u03c4 and log \u03c4, so its complete conditional is a\ngamma distribution. If the suf\ufb01cient-statistic functions do not correspond to a known exponential\nfamily, then the system raises an exception.\nFinally, to get the natural parameters we can simply take the symbolic gradient of the output node\nwith respect to each suf\ufb01cient-statistic node using Autograd.\n\n4 Related Work\n\nMany probabilistic programming languages (PPLs) exploit conjugacy relationships. PPLs like\nBUGS (Spiegelhalter et al., 1995), VIBES (Winn and Bishop, 2005), and Augur (Tristan et al.,\n2014) build an explicit graph of random variables and \ufb01nd conjugate pairs in that graph. This strat-\negy remains widely applicable, but ties the system very strongly to the PPL\u2019s model representation.\nMost recently, Birch (Murray et al., 2018) utilizes a \ufb02exible strategy for combining conjugacy and\napproximate inference in order to enable algorithms such as Sequential Monte Carlo with Rao-\nBlackwellization. Autoconj could extend their conjugacy component.\nPPLs such as Hakaru (Narayanan et al., 2016) have considered treating conditioning and marginal-\nization as program transformations based on computer algebra (Carette and Shan, 2016; Gehr et al.,\n2016). Unfortunately, most existing computer algebra systems have very limited support for linear\nalgebra and multidimensional array processing, which in turn makes it hard for these systems to\neither express models using NumPy-style broadcasting or take advantage of vectorized hardware\n(although Narayanan and Shan (2017) take steps to address this). Exploiting multivariate-Gaussian\nstructure in these languages is particularly cumbersome. Orthogonal to our work, Narayanan and\nShan (2017) advances symbolic manipulation for general probability spaces such as mixed discrete-\nand-continuous events. These ideas could also be used in Autoconj.\n\n5 Examples and Experiments\n\nIn this section we provide code snippets and empirical results to demonstrate Autoconj\u2019s function-\nality, as well as the bene\ufb01ts of being embedded in Python as opposed to a more narrowly focused\ndomain-speci\ufb01c language. We begin with some examples.\nListing 2 demonstrates doing exact conditioning and marginalization in a trivial Beta-Bernoulli\nmodel. The log-joint is implemented using NumPy, and is passed to complete_conditional()\nand marginalize(). These functions also take an argnum parameter that says which parameter to\nmarginalize out or take the complete conditional of (0 in this example, referring to counts_prob)\nand a support parameter. Finally, they take a list of dummy arguments that are used to propagate\nshapes and types when tracing the log-joint function.\nListing 3 demonstrates how one can handle a more complicated compound prior: the normal-gamma\ndistribution, which is the natural conjugate prior for Bayesian linear regression. Note that we can\ncall complete_conditional() on the function produced by marginalize().\nWe can extend the marginalize-and-condition strategy above to more complicated models. In the\nsupplement, we demonstrate how one can implement the Kalman-\ufb01lter recursion with Autoconj.\nThe generative process is\n\nx1 \u223c Normal(0, s0);\n\nxt>1 \u223c Normal(xt\u22121, sx);\n\n(6)\nThe core recursion consists of using marginalize() to compute p(xt+1, yt+1 | y1:t) from the\n| y1:t) and p(xt+1, yt+1 | xt), then using marginalize() again to compute\nfunctions p(xt\n\nyt \u223c Normal(xt, sy).\n\n6\n\n\fdef log_joint(counts_prob, n_heads, n_draws, prior_a, prior_b):\n\nlog_prob = (prior_a-1)*np.log(counts_prob) + (prior_b-1)*np.log1p(-counts_prob)\nlog_prob += n_heads*np.log(counts_prob) + (n_draws-n_heads)*np.log1p(-counts_prob)\nlog_prob += -gammaln(prior_a) - gammaln(prior_b) + gammaln(prior_a + prior_b)\nreturn log_prob\n\nn_heads, n_draws = 60, 100\nprior_a, prior_b = 0.5, 0.5\nall_args = [0.5, n_heads, n_draws, prior_a, prior_b]\nmake_complete_conditional = autoconj.complete_conditional(\n\nlog_joint, 0, SupportTypes.UNIT_INTERVAL, *all_args)\n\n# A Beta(60.5, 40.5) distribution object.\ncomplete_conditional = make_complete_conditional(n_heads, n_draws, prior_a, prior_b)\n# Computes the marginal log-probability of n_heads, n_draws given prior_a, prior_b\nmarginal = autoconj.marginalize(log_joint, 0, SupportTypes.UNIT_INTERVAL, *all_args)\nprint('log p(n_heads=60 | a, b) =', marginal(n_heads, n_draws, prior_a, prior_b))\n\nListing 2: Exact inference in a simple Beta-Bernoulli model.\n\ndef log_joint(tau, beta, x, y, a, b, kappa, mu0):\n\nlog_p_tau = log_probs.gamma_gen_log_prob(tau, a, b)\nlog_p_beta = log_probs.norm_gen_log_prob(beta, mu0, 1. / np.sqrt(kappa * tau))\nlog_p_y = log_probs.norm_gen_log_prob(y, np.dot(x, beta), 1. / np.sqrt(tau))\nreturn log_p_tau + log_p_beta + log_p_y\n\n# log p(tau, x, y), marginalizing out beta\ntau_x_y_log_prob = autoconj.marginalize(log_joint, 1, SupportTypes.REAL, *all_args)\n# compute and sample from p(tau | x, y)\nmake_tau_posterior = autoconj.complete_conditional(\n\ntau_x_y_log_prob, 0, SupportTypes.NONNEGATIVE, *all_args_ex_beta)\n\ntau_sample = make_tau_posterior(x, y, a, b, kappa, mu0).rvs()\n# compute and sample from p(beta | tau, x, y)\nmake_beta_conditional = autoconj.complete_conditional(\n\nlog_joint, 1, SupportTypes.REAL, *all_args)\n\nbeta_sample = make_beta_conditional(tau, x, y, a, b, kappa, mu0)\n\nListing 3: Exact\ninference in a Bayesian linear regression with normal-gamma compound\nprior. We factorize the joint posterior on the mean and precision as p(\u00b5, \u03c4 | x, y) = p(\u03c4 |\nx, y)p(\u00b5 | x, y, \u03c4 ). We \ufb01rst compute the marginal joint distribution p(x, y, \u03c4 ) by calling\nmarginalize() on the full log-joint. We then compute the marginal posterior p(\u03c4 | x, y) by calling\ncomplete_conditional() on the marginal p(x, y, \u03c4 ), and \ufb01nally we compute p(\u00b5 | x, y, \u03c4 ) by\ncalling complete_conditional() on the full log-joint.\n\np(yt+1 | y1:t) and complete_conditional() to compute p(xt+1 | y1:t+1). As in the normal-\ngamma example, it is up to the user to reason about the graphical model structure, but Autoconj\nhandles all of the conditioning and marginalization automatically. The same code could be applied\nto a hidden Markov model (which has the same graphical model structure) by simply changing the\ndistributions in the log-joint and the support from real to integer.\nWhen not all complete conditionals are tractable, the variational evidence lower bound (ELBO)\nis not tractable to compute exactly. Several strategies exist for dealing with this problem. One\napproach is to \ufb01nd a lower bound on the log-joint that is only a function of expected suf\ufb01cient\nstatistics of some exponential family (Jaakkola and Jordan, 1996; Blei and Lafferty, 2005). Another\nis to linearize problematic terms in the log-joint (Khan et al., 2015).\nKnowledge of conjugate pairs is not suf\ufb01cient to implement either of these strategies, which rely on\ndirect manipulation of the log-joint to achieve a kind of quasi-conjugacy. But Autoconj naturally\nfacilitates these strategies, since it does not require that the log-joint functions it is given exactly\ncorrespond to any true generative process.\nListing 4 demonstrates variational inference for Bayesian logistic regression (which has a non-\nconjugate likelihood) using Autoconj to optimize the bound of Jaakkola and Jordan (1996). One\n\n7\n\n\fdef log_joint_bound(beta, xi, x, y):\n\nlog_prior = np.sum(-0.5 * beta**2 - 0.5 * np.log(2*np.pi))\ny_logits = (2 * y - 1) * np.dot(x, beta)\n# Lower bound on -log(1 + exp(-y_logits)).\nlamda = (0.5 - expit(xi)) / (2. * xi)\nlog_likelihood_bound = np.sum(-np.log(1 + np.exp(-xi)) + 0.5 * (y_logits - xi)\n\n+ lamda * (y_logits ** 2 - xi ** 2))\n\nreturn log_prior + log_likelihood_bound\n\ndef xi_update(beta_mean, beta_secondmoment, x):\n\n\"\"\"Sets the bound parameters xi to their optimal value.\"\"\"\nbeta_cov = beta_secondmoment - np.outer(beta_mean, beta_mean)\nreturn np.sqrt(np.einsum('ij,ni,nj->n', beta_cov, x, x) +\n\nx.dot(beta_mean)**2)\n\nneg_energy, (t_beta,), (lognorm_beta,), = meanfield.multilin_repr(\nlog_joint_bound, argnums=(0,), supports=(SupportTypes.REAL,),\nexample_args=(beta, xi, x, y))\n\nelbo = partial(meanfield.elbo, neg_energy, (lognorm_beta,))\nmu_beta = grad(lognorm_beta)(grad(neg_energy)(t_beta(beta), xi, x, y)) # initialize\n\nfor iteration in range(100):\n\nxi = xi_update(mu_beta[0], mu_beta[1], x)\nmu_beta = grad(lognorm_beta)(grad(neg_energy)(mu_beta, xi, x, y))\nprint('{}\\t{}'.format(iteration, elbo(mu_beta, xi, x, y))\n\nListing 4: Variational Bayesian logistic regression using the lower bound of Jaakkola and Jordan\n(1996). Autoconj can work with log_joint_bound() even though it is not a true log-joint density.\n\nFigure 2: Comparison of algorithms for Bayesian factor analysis according to their estimate of the\nexpected log-joint as a function of runtime. (left) Relative to other algorithms, mean-\ufb01eld ADVI\ngrossly under\ufb01ts. (right) Zoom-in on other algorithms. Block coordinate-ascent variational infer-\nence (CAVI) converges faster than Gibbs.\n\ncould also use Autoconj to implement other methods such as proximal variational inference (Khan\nand Wu, 2017; Khan et al., 2016, 2015).\n\nFactor Analysis Autoconj facilitates many structure-exploiting inference algorithms. Here, we\ndemonstrate why such algorithms are important for ef\ufb01cient inference, and that Autoconj supports\ntheir diverse collection. We generate data from a linear factor model,\nwmk \u223c Normal(0, 1); znk \u223c Normal(0, 1); \u03c4 \u223c Gamma(\u03b1, \u03b2); xmn \u223c Normal(w(cid:62)\nmzn, \u03c4\u22121/2).\nThere are N examples of D-dimensional vectors x \u2208 RN\u00d7D, and the data assumes a latent factor-\nization according to all examples\u2019 feature representations z \u2208 RN\u00d7K and the principal components\nw \u2208 RD\u00d7K. As a toy demonstration, we use relatively small N, D, and K.\nAutoconj naturally produces a structured mean-\ufb01eld approximation, since conditioned on w and\nx the rows of z each have multivariate-Gaussian complete conditionals (and vice versa for z and\nw). We compared Autoconj structured block coordinate-ascent variational inference (CAVI) with\nAutoconj block Gibbs, mean-\ufb01eld ADVI (Kucukelbir et al., 2016), and MAP implemented using\nscipy.optimize. All algorithms besides ADVI yield reasonable results, demonstrating the value of\nexploiting conjugacy when it is available.\n\n8\n\n0102030405060708090Runtime (s)250000200000150000100000500000Expected Log JointMAPGibbsADVICAVI0102030405060708090Runtime (s)1000080006000400020000Expected Log JointMAPGibbsCAVI\fImplementation\nAutoconj (NumPy; 1 CPU)\nAutoconj (TensorFlow; 1 CPU)\nAutoconj (TensorFlow; 6 CPU)\nAutoconj (TensorFlow; 1 GPU)\n\nRuntime (s)\n62.9\n75.9\n19.7\n4.3\n\nTable 1: Time to run 500 iterations of variational inference on a mixture of Gaussians. TensorFlow\noffers little advantage on one CPU core, but an order-of-magnitude speedup on GPU.\n\nBenchmarking Autoconj While we used NumPy as a numerical backend for Autoconj, other\nPython-based backends are possible. We wrote a simple translator that replaces NumPy ops in our\ncomputation graph to TensorFlow ops (Abadi et al., 2016). We can therefore take a log-joint written\nin NumPy, extract complete conditionals or marginals from that model, and then run the conditional\nor marginal computations in a TensorFlow graph (possibly on a GPU or TPU).\nWe ran Autoconj\u2019s CAVI in NumPy and TensorFlow for a mixture-of-Gaussians model:\n\u03c0 \u223c Dirichlet(\u03b1);\n\nzn \u223c Categorical(\u03c0); \u00b5kd \u223c Normal(0, \u03c3);\nxnd \u223c Normal(\u00b5znd, \u03c4\n\n\u03c4kd \u223c Gamma(a, b);\n\u22121/2\nznd ).\n\nSee Listing 5. We automatically translated the NumPy CAVI ops to TensorFlow ops, and bench-\nmarked 500 iterations of CAVI in NumPy and TensorFlow on CPU and GPU. Table 1 shows the\nresults, which clearly demonstrate the value of running on GPUs.\n\nimport autoconj.pplham as ph # a simple \"probabilistic programming language\"\n\ndef make_model(alpha, beta):\n\ndef sample_model():\n\n\"\"\"Generates matrix of shape [num_examples, num_features].\"\"\"\nepsilon = ph.norm.rvs(0, 1, size=[num_examples, num_latents])\nw = ph.norm.rvs(0, 1, size=[num_features, num_latents])\ntau = ph.gamma.rvs(alpha, beta)\nx = ph.norm.rvs(np.dot(epsilon, w.T), 1. / np.sqrt(tau))\nreturn [epsilon, w, tau, x]\n\nreturn sample_model\n\nnum_examples = 50\nnum_features = 10\nnum_latents = 5\nalpha = 2.\nbeta = 8.\nsampler = make_model(alpha, beta)\n\nlog_joint_fn = ph.make_log_joint_fn(sampler)\n\nListing 5: Implementing the log joint function for Table 1. This example also illustrates how\nAutoconj could be embedded in a probabilistic programming language where models are sampling\nfunctions and utilities exist for tracing their execution (e.g., Tran et al. (2018)).\n\n6 Discussion\n\nIn this paper, we proposed a strategy for automatically deriving conjugacy relationships. Unlike\nprevious systems which focus on relationships between pairs of random variables, Autoconj operates\ndirectly on Python functions that compute log-joint distribution functions. This provides support for\nconjugacy-exploiting algorithms in any Python-embedded PPL. This paves the way for accelerating\ndevelopment of novel inference algorithms and structure-exploiting modeling strategies.\nAcknowledgements. We thank the anonymous reviewers for their suggestions and Hung Bui for\nhelpful discussions.\n\n9\n\n\fReferences\nAbadi, M., Barham, P., Chen, J., Chen, Z., Davis, A., Dean, J., Devin, M., Ghemawat, S., Irving, G.,\nIsard, M., Kudlur, M., Levenberg, J., Monga, R., Moore, S., Murray, D. G., Steiner, B., Tucker,\nP., Vasudevan, V., Warden, P., Wicke, M., Yu, Y., and Zheng, X. (2016). Tensor\ufb02ow: A system\nfor large-scale machine learning. In Proceedings of the 12th USENIX Conference on Operating\nSystems Design and Implementation, OSDI\u201916, pages 265\u2013283, Berkeley, CA, USA. USENIX\nAssociation.\n\nBaader, F. and Nipkow, T. (1999). Term rewriting and all that. Cambridge University Press.\nBlei, D. M. and Lafferty, J. D. (2005). Correlated topic models. In Proceedings of the 18th Interna-\n\ntional Conference on Neural Information Processing Systems.\n\nCarette, J. and Shan, C.-C. (2016). Simplifying probabilistic programs using computer algebra. In\nGavanelli, M. and Reppy, J., editors, Practical Aspects of Declarative Languages, pages 135\u2013152,\nCham. Springer International Publishing.\n\nCook, S. R., Gelman, A., and Rubin, D. B. (2006). Validation of software for bayesian models using\n\nposterior quantiles. Journal of Computational and Graphical Statistics, 15(3):675\u2013692.\n\nDempster, A. P., Laird, N. M., and Rubin, D. B. (1977). Maximum likelihood from incomplete data\nvia the em algorithm. Journal of the royal statistical society. Series B (methodological), pages\n1\u201338.\n\nDiehl, S. (2013). Pyrewrite: Python term rewriting. Accessed: 2018-5-17.\n\nGehr, T., Misailovic, S., and Vechev, M. (2016). PSI: Exact symbolic inference for probabilistic\nprograms. In International Conference on Computer Aided Veri\ufb01cation, pages 62\u201383. Springer.\nGeman, S. and Geman, D. (1984). Stochastic relaxation, Gibbs distributions, and the Bayesian\nrestoration of images. IEEE Transactions on pattern analysis and machine intelligence, (6):721\u2013\n741.\n\nGoodman, N. D. and Stuhlm\u00fcller, A. (2014). The Design and Implementation of Probabilistic\n\nProgramming Languages. http://dippl.org. Accessed: 2018-5-17.\n\nGrif\ufb01ths, T. L. and Steyvers, M. (2004). Finding scienti\ufb01c topics. Proceedings of the National\n\nacademy of Sciences, 101(suppl 1):5228\u20135235.\n\nHoffman, M. D., Blei, D. M., Wang, C., and Paisley, J. (2013). Stochastic variational inference.\n\nJournal of Machine Learning Research, 14:1303\u20131347.\n\nJaakkola, T. and Jordan, M. (1996). A variational approach to Bayesian logistic regression models\nIn International Workshop on Arti\ufb01cial Intelligence and Statistics, vol-\n\nand their extensions.\nume 82, page 4.\n\nJordan, M. I., Ghahramani, Z., Jaakkola, T. S., and Saul, L. K. (1999). An introduction to variational\n\nmethods for graphical models. Machine learning, 37(2):183\u2013233.\n\nKhan, M. E., Babanezhad, R., Lin, W., Schmidt, M., and Sugiyama, M. (2016). Faster stochas-\ntic variational inference using proximal-gradient methods with general divergence functions. In\nConference on Uncertainty in Arti\ufb01cial Intelligence (UAI).\n\nKhan, M. E., Baqu\u00e9, P., Fleuret, F., and Fua, P. (2015). Kullback-leibler proximal variational infer-\n\nence. In Advances in Neural Information Processing Systems, pages 3402\u20133410.\n\nKhan, M. E. and Wu, L. (2017). Conjugate-computation variational inference : Converting varia-\ntional inference in non-conjugate models to inferences in conjugate models. In Arti\ufb01cial Intelli-\ngence and Statistics (AISTATS).\n\nKoller, D. and Friedman, N. (2009). Probabilistic Graphical Models: Principles and Techniques.\n\nMIT Press.\n\nKucukelbir, A., Tran, D., Ranganath, R., Gelman, A., and Blei, D. M. (2016). Automatic differenti-\n\nation variational inference. arXiv preprint arXiv:1603.00788.\n\n10\n\n\fMaclaurin, D., Duvenaud, D., Johnson, M., and Adams, R. P. (2014). Autograd: Reverse-mode\n\ndifferentiation of native Python. Accessed: 2018-5-17.\n\nMurray, L. M., Lund\u00e9n, D., Kudlicka, J., Broman, D., and Sch\u00f6n, T. B. (2018). Delayed sam-\npling and automatic rao-blackwellization of probabilistic programs. In Arti\ufb01cial Intelligence and\nStatistics.\n\nNarayanan, P., Carette, J., Romano, W., Shan, C.-c., and Zinkov, R. (2016). Probabilistic inference\nby program transformation in hakaru (system description). In Kiselyov, O. and King, A., editors,\nFunctional and Logic Programming, pages 62\u201379, Cham. Springer International Publishing.\n\nNarayanan, P. and Shan, C.-c. (2017). Symbolic conditioning of arrays in probabilistic programs.\n\nProceedings of the ACM on Programming Languages, 1(ICFP):11.\n\nRadul, A. (2013). Rules: An extensible pattern matching, pattern dispatch, and term rewriting\n\nsystem for MIT Scheme. Accessed: 2018-5-17.\n\nRozenberg, G. (1997). Handbook of Graph Grammars and Comp., volume 1. World scienti\ufb01c.\nSpiegelhalter, D. J., Thomas, A., Best, N. G., and Gilks, W. R. (1995). BUGS: Bayesian inference\n\nusing Gibbs sampling, version 0.50. MRC Biostatistics Unit, Cambridge.\n\nSussman, G. J., Abelson, H., Wisdom, J., Katzenelson, J., Mayer, H., Hanson, C. P., Halfant, M.,\nSiebert, B., Rozas, G. J., Skordos, P., Koniaris, K., Lin, K., and Zuras, D. (2018). SCMUTILS.\nAccessed: 2018-5-17.\n\nTran, D., Hoffman, M. D., Moore, D., Suter, C., Vasudevan, S., Radul, A., Johnson, M., and Saurous,\nR. A. (2018). Simple, distributed, and accelerated probabilistic programming. In Neural Infor-\nmation Processing Systems.\n\nTristan, J.-B., Huang, D., Tassarotti, J., Pocock, A. C., Green, S., and Steele, G. L. (2014). Augur:\n\nData-parallel probabilistic modeling. In Neural Information Processing Systems.\n\nWainwright, M. J. and Jordan, M. I. (2008). Graphical models, exponential families, and variational\n\ninference. Found. Trends Mach. Learn., 1(1-2):1\u2013305.\n\nWinn, J. and Bishop, C. M. (2005). Variational message passing. Journal of Machine Learning\n\nResearch, 6(Apr):661\u2013694.\n\n11\n\n\f", "award": [], "sourceid": 6836, "authors": [{"given_name": "Matthew D.", "family_name": "Hoffman", "institution": "Google"}, {"given_name": "Matthew", "family_name": "Johnson", "institution": ""}, {"given_name": "Dustin", "family_name": "Tran", "institution": ""}]}