{"title": "Learning to Pass Expectation Propagation Messages", "book": "Advances in Neural Information Processing Systems", "page_first": 3219, "page_last": 3227, "abstract": "Expectation Propagation (EP) is a popular approximate posterior inference algorithm that often provides a fast and accurate alternative to sampling-based methods. However, while the EP framework in theory allows for complex non-Gaussian factors, there is still a significant practical barrier to using them within EP, because doing so requires the implementation of message update operators, which can be difficult and require hand-crafted approximations. In this work, we study the question of whether it is possible to automatically derive fast and accurate EP updates by learning a discriminative model e.g., a neural network or random forest) to map EP message inputs to EP message outputs. We address the practical concerns that arise in the process, and we provide empirical analysis on several challenging and diverse factors, indicating that there is a space of factors where this approach appears promising.", "full_text": "Learning to Pass Expectation Propagation Messages\n\nNicolas Heess\u2217\nGatsby Unit, UCL\n\nDaniel Tarlow\n\nMicrosoft Research\n\nJohn Winn\n\nMicrosoft Research\n\nAbstract\n\nExpectation Propagation (EP) is a popular approximate posterior inference al-\ngorithm that often provides a fast and accurate alternative to sampling-based\nmethods. However, while the EP framework in theory allows for complex non-\nGaussian factors, there is still a signi\ufb01cant practical barrier to using them within\nEP, because doing so requires the implementation of message update operators,\nwhich can be dif\ufb01cult and require hand-crafted approximations. In this work, we\nstudy the question of whether it is possible to automatically derive fast and ac-\ncurate EP updates by learning a discriminative model (e.g., a neural network or\nrandom forest) to map EP message inputs to EP message outputs. We address the\npractical concerns that arise in the process, and we provide empirical analysis on\nseveral challenging and diverse factors, indicating that there is a space of factors\nwhere this approach appears promising.\n\nIntroduction\n\n1\nModel-based machine learning and probabilistic programming offer the promise of a world where a\nprobabilistic model can be speci\ufb01ed independently of the inference routine that will operate on the\nmodel. The vision is to automatically perform fast and accurate approximate inference to compute\na range of quantities of interest (e.g., marginal probabilities of query variables). Approaches to\nthe general inference challenge can roughly be divided into two categories. We refer to the \ufb01rst\ncategory as the \u201cuninformed\u201d case, which is exempli\ufb01ed by e.g. Church [4], where the modeler has\ngreat freedom in the model speci\ufb01cation. The cost of this \ufb02exibility is that inference routines have\na more super\ufb01cial understanding of the model structure, being unaware of symmetries and other\nidiosyncrasies of its components, which makes the already challenging inference task even harder.\nThe second category is what we refer to as the \u201cinformed\u201d case, which is exempli\ufb01ed by\n(e.g. BUGS[14], Stan[12], Infer.NET[8]). Here, models must be constructed out of a toolbox of\nbuilding blocks, and a building block can only be used if a set of associated computational op-\nerations have been implemented by the toolbox designers. This gives inference routines a deeper\nunderstanding of the structure of the model and can lead to signi\ufb01cantly faster inference, but the\ntradeoff is that ef\ufb01cient and accurate implementation of the building blocks can be a signi\ufb01cant\nchallenge. For example, EP message update operations, which are used by Infer.NET, often require\nthe computation of integrals that do not have analytic expressions, so methods must be devised that\nare robust, accurate and ef\ufb01cient, which is generally quite nontrivial.\nIn this work, we aim to bridge the gap between the informed and the uninformed cases and achieve\nthe best of both worlds by automatically implementing the computational operations required for\nthe informed case from a speci\ufb01cation such as would be given in the uninformed case. We train\nhigh-capacity discriminative models that learn to map EP message inputs to EP message outputs\nfor each message operation needed for EP inference. Importantly, the training is done so that the\nlearned modules implement the same EP communication protocol as hand-crafted modules, so after\nthe training phase is complete, we get a factor that behaves like a fast hand-crafted approximation\nthat exploits factor structure, but which was generated using only the speci\ufb01cation that would be\n\n\u2217The majority of this work was done while NH was visiting Microsoft Research, Cambridge.\n\n1\n\n\fgiven in the uninformed case. Models may then be constructed from any combination of these\nlearned modules and previously implemented modules.\n2 Background and Notation\n2.1 Factor graphs, directed graphical models, and probabilistic programming\nAs is common for message passing algorithms, we assume that models of interest are represented\nas factor graphs: the joint distribution over a set of random variables x = {x1, . . . , xD} is speci-\n\ufb01ed in terms of non-negative factors \u03c81, . . . , \u03c8J, which capture the relation between variables, and\nit decomposes as p(x) = 1/Z!J\nj=1 \u03c8j(x\u03c8j ). Here x\u03c8j is used to mean the set of variables that\nfactor \u03c8j is de\ufb01ned over and whose index set we will denote by Scope(\u03c8j). We further use x\u03c8j\u2212i\nto mean the set of variables x\u03c8j excluding xi. The set x may have a mix of discrete and contin-\nuous random variables and factors can operate over variables of mixed types. We are interested in\ncomputing marginal probabilities pi(xi) = \" p(x)dx\u2212i, where x\u2212i is all variables except for i,\n\nand where integrals should be replaced by sums when the variable being integrated out is discrete.\nNote that this formulation allows for conditioning on variables by attaching factors with no inputs\nto variables which constrain the variable to be equal to a particular value, but we suppress this detail\nfor simplicity of presentation.\nAlthough our approach can be extended to factors of arbitrary form, for the purpose of this paper\nwe will focus on directed factors, i.e. factors of the form \u03c8j(xout(j) | xin(j)) which directly specify\nthe (conditional) distribution (or density) over xout(j) as a function of the vector of inputs xin(j)\n(here x\u03c8j is the set of variables {xout(j)}\u222a xin(j)). In a (unconditioned) directed graphical model\nall factors will have this form, and we allow xin(j) to be empty, for example, to allow for prior\ndistributions over the variables.\nProbabilistic programming is an umbrella term for the speci\ufb01cation of probabilistic models via a\nprogramming-like syntax. In its most general form, an arbitrary program is speci\ufb01ed, which can\ninclude calls to a random number generator (e.g. [4]).\nThis can be related to the factor graph\nnotation by introducing forward-sampling functions f1, . . . , fJ. If we associate each directed fac-\ntor \u03c8j(xout(j) | xin(j)) with a stochastic forward-sample function fj mapping xin(j) to xout(j) and\nthen de\ufb01ne the probabilistic program as the sequential sampling of xout(j) = fj(xin(j)) follow-\ning a topographical ordering of the variables, then there is a clear association between directed\ngraphical models and forward-sampling procedures. Speci\ufb01cally, fj is a stochastic function that\ndraws a sample from \u03c8j(xout(j) | xin(j)). The key difference is that the factor graph speci\ufb01cation\nusually assumes that an analytic expression will be given to de\ufb01ne \u03c8j(xout(j) | xin(j)), while the\nforward-sampling formulation allows for fj to execute an arbitrary piece of computer code. The\nextra \ufb02exibility afforded by the forward-sampling formulation has led to the popularity of methods\nlike Approximate Bayesian Computation (ABC) [11], although the cost of this \ufb02exibility is that\ninference becomes less informed.\n2.2 Expectation Propagation\nExpectation Propagation (EP) is a message passing algorithm that is a generalization of sum-product\nbelief propagation. It can be used for approximate marginal inference in models that have a mixed\nset of types. EP has been used successfully in a number of large-scale applications [5, 13], can be\nused with a wide range of factors and can support some programming language constructs like for\nloops and if statements [7]. For a detailed review of EP, we recommend [6].\nFor the purposes of this paper there are two important aspects of EP. First, we use the common\nvariant where the posterior is approximated as a fully factorized distribution (except for some ho-\nmogeneous variables which we treat as a single vector-valued variable) and each variable then has\nan associated type, type(x), which determines the distribution family used in its approximation. The\nsecond aspect is the form of the message from a factor \u03c8 to a variable i. It is de\ufb01ned as follows:\n\nproj#\" \u03c8(xout | xin)$!i\"\u2208Scope(\u03c8) mi\"\u03c8(xi\")% dx\u03c8\u2212i&\n\nmi\u03c8(xi).\n\nm\u03c8i(xi) =\n\n(1)\n\nThe update has an intuitive form. The proj operator ensures that the message being passed is a\ndistribution of type type(xi) \u2013 it only has an effect if its argument is outside the approximating\nfamily used for the target message. If the projection operation (proj [\u00b7]) is ignored, then the mi\u03c8(xi)\n\n2\n\n\fterm in the denominator cancels with the corresponding term in the numerator, and standard be-\nlief propagation updates are recovered. The projection is implemented as \ufb01nding the distribution\nq in the approximating family that minimizes the KL-divergence between the argument and q:\nproj [p] = argminq KL(p||q), where q is constrained to be a distribution of type(xi). Multiplying\nthe reverse message mi\u03c8(xi) into the numerator before performing the projection effectively de\ufb01nes\na \u201ccontext\u201d, which can be seen as reweighting the approximation to the standard BP update, placing\nmore importance in the region where other parts of the model have placed high probability mass.\n3 Formulation\nWe now present the method that is the focus of this paper. The goal is to allow a user to specify a\nfactor to be used in EP solely via specifying a forward sampling procedure; that is, we assume that\nthe user provides an executable stochastic function f(xin), which, given xin returns a sample of\nxout. The user further speci\ufb01es the families of distributions with which to represent the messages\nassociated with the variables of the factor (e.g. Discrete, Gamma, Gaussian, Beta). Below we show\nhow to learn fast EP message operators so that the new factor can be used alongside existing factors\nin a variety of models.\nComputing Targets with Importance Sampling Our goal is to compute EP messages from the\nfactor \u03c8 that is associated with f, as if we had access to an analytic expression for \u03c8(xout | xin). The\nonly way a factor interacts with the rest of the model is via the incoming and outgoing messages, so\nwe can focus on this mapping and the resulting operator can be used in any model. Given incoming\nmessages {mi\u03c8(xi)}i\u2208Scope(\u03c8), the simplest approach to computing m\u03c8i(xi) is to use importance\nsampling. A proposal distribution q(xin) is speci\ufb01ed, and then the approach is based on the fact that\n\n\u2019 \u03c8(xout | xin)\uf8eb\uf8ed *i\"\u2208Scope(\u03c8)\n\nmi\"\u03c8(xi\")\uf8f6\uf8f8 dx\u03c8 = Er-!i\"\u2208Scope(\u03c8) mi\"\u03c8(xi\")\n\nq(xin)\n\n. ,\n\n(2)\n\nq(xin)\n\nwhere r(x) = q(xin)\u03c8(xout | xin) can be sampled from by \ufb01rst drawing values of xin from q,\nthen passing those values through the forward-sampling procedure f to get a value for xout. To\nuse this procedure for computing messages m\u03c8i(xi), we use importance sampling with proposal\ndistribution r. Roughly, samples are drawn from r and weighted by Qi\"\u2208Scope(\u03c8) mi\"\u03c8(xi\" )\n, then\nall variables other than xi are summed out to yield a mixture of point mass distributions over xi.\nThe proj [\u00b7] operator is then applied to this distribution. Note that a simple choice for q(xin) is\n!i\"\u2208in mi\"\u03c8(xi\"), in which case the weighting term simpli\ufb01es to just be mout\u03c8(xout). Despite its\nsimplicity, however, we found this choice to often be suboptimal. We elaborate on this issue and\ngive concrete suggestions for improving over the naive approach in the experiments section.\nGeneration of Training\nData For\na given set\nincoming messages\nof\n{mi\u03c8(xi)}i\u2208Scope(\u03c8),\nwe can produce a target\noutgoing message using\nthe technique from the\nprevious section. To train\na model\nto automatically\ncompute these messages,\nwe need many example\nincoming and target out-\ngoing message pairs. We\ncan generate such a data\nset by drawing sets of\nincoming messages from\nsome speci\ufb01ed distribution, then computing the target outgoing message as above.\nLearning Given the training data, we learn a neural network model that takes as input the suf-\nand outputs suf\ufb01cient statistics de\ufb01ning\n\nAlgorithm 1 Generate training data\n1: Input: \u03c8, i, specifying we are learning to send message m\u03c8i(xi).\n2: Input: Dm training distribution over messages {mi\"\u03c8(xi\" )}i\"\u2208Scope(\u03c8)\n3: Input: q(xin) importance sampling distribution\n4: for n = 1 : N do\n5:\nDraw mn\nfor k = 1 : K do\n6:\n7:\nDraw xnk\nout = f (xnk\nin )\nCompute importance weight wnk = Qi\"\u2208Scope(\u03c8) mn\n8:\nend for\n9:\nCompute \u02c6\u00b5n(xi) = projhPk wnk\u03b4(xi)\n10:\nPk wnk\n11:\nAdd pair (\"mn\nD(xD)# , \u02c6\u00b5n(xi)) to training set.\n12: end for\n13: Return training set.\n\nD(xD) \u223cD m\nin ) then compute xnk\n\n0 (x0), . . . , mn\nin \u223c q(xnk\n\n\ufb01cient statistics de\ufb01ning mn = /mn\n\ni\"\u03c8(xi\")0i\"\u2208Scope(\u03c8)\n\n0 (x0), . . . , mn\n\ni\"\u03c8 (xnk\ni\" )\n\n.\n\nq(xnk\nin )\n\ni\n\n3\n\n\fN1N\n\nthe approximation g(mn) to target \u02c6\u00b5n(xi). For each output message that the factor needs to send,\nwe train a separate network. The error measure that we optimize is the average KL divergence\nn=1 KL(\u02c6\u00b5n||g(mn)). We differentiate this objective analytically for the appropriate output\n1\ndistribution type and compute gradients via back-propagation.\nChoice of Decomposition Structure So far, we have shown how to incorporate factors into a\nmodel when the de\ufb01nition of the factor is via the forward-sample procedure f rather than as an\nanalytic expression \u03c8. When specifying a model, there is some \ufb02exibility in how this capability\nis used. The natural use case is when a model can mostly be expressed using factors that have\nanalytic expressions and corresponding hand-constructed operator implementations, but when a few\nof the interactions that we would like to use are more easily speci\ufb01ed in terms of a forward-sampling\nprocedure or would be dif\ufb01cult to implement hand-crafted approximations for.\nThere is an alternative use-case, which is that even if we have analytic expressions and hand-crafted\nimplementations for all the factors that we wish to use in a model, it might be that the approximations\nwhich arise due to the nature of message passing (that is, passing messages that factorize fully over\nvariables) leads to a poor approximation in some block of the model. In this case, it may be desirable\nto collapse the problematic block of several factors into a single factor, then to use the approach we\npresent here. If the new collapsed factor is suf\ufb01ciently structured in a statistical sense, then this may\nlead to improved accuracy. In this view, the goal should be to \ufb01nd groups of modeling components\nthat go together logically, which are reusable, and which de\ufb01ne interactions that have input-output\nstructure that is amenable to the learned approximation strategy.\n4 Related Work\nPerhaps the most super\ufb01cially similar line of work to the approach we present here is that of infer-\nence machines and truncated belief propagation [2, 3, 10, 9], where inference is done via an algo-\nrithm that is structurally similar to belief propagation, but where some parameters of the updates\nare learned. The fundamental difference between those approaches and ours is how the learning is\nperformed. In inference machine training, learning is done jointly over parameters for all updates\nthat will be used in the model. This means that the process of learning couples together all factors\nin the model; if part of the model changes, the parameters of the updates must be re-learned. A key\nproperty of our approach is that a factor may be learned once, then used in a variety of different\nmodels without need for re-training.\nThe most closely related work is ABC-EP [1]. This approach employs a very similar importance\nsampling strategy but performs inference simply by sending the messages that we use as training\ndata. The advantage is that no function approximator needs to be chosen, and if enough samples are\ndrawn for each message update, the accuracy should be good. There is also no up-front cost of learn-\ning as in our case. The downside is that generation and weighting of a suf\ufb01cient number of samples\ncan be very expensive, and it is usually not practical to generate enough samples every time a mes-\nsage needs to be sent. Our formulation allows for a very large number of samples to be generated\nonce as an up-front cost then, as long as the learning is done effectively, each message computation\nis much faster while still effectively drawing on a large number of samples. Our approach also opens\nup the possibility of using more accurate but slower methods to generate the training samples, which\nwe believe will be important as we look ahead to applying the method to even more complex factors.\nEmpirically we have found that using importance sampling but reducing the number of samples so\nas to make runtime computation times close to our method can lead to unreliable inference.\nFinally, at a high level, our goal in this work is to start from an informed general inference scheme\nand to extend the range of model speci\ufb01cations that can be used within the framework. There is work\nthat aims for a similar goal but comes from the opposite direction of starting with a general speci\ufb01-\ncation language and aiming to build more informed inference routines. For example, [15] attempts\nto infer basic interaction structure from general probabilistic program speci\ufb01cations. Also of note\nis [16], which applies mean \ufb01eld-like variational inference to general program speci\ufb01cations. We\nbelieve both these directions and the direction we explore here to be promising and worth exploring.\n5 Experimental Analyses\nWe now turn our attention to experimental evaluation. The primary question of interest is whether\ngiven f it is feasible to learn the mapping from EP message inputs to outputs in such a way that the\nlearned factors can be used within nontrivial models. This obviously depends on the speci\ufb01cs of f\nand the model in which the learned factor is used. We attempt to explore these issues thoroughly.\n\n4\n\n\f1\n\nChoice of Factors We made speci\ufb01c choices about which functions f to apply our framework\nto. First, we wanted a simple factor to prove the concept and give an indication of the performance\nthat we might expect. For this, we chose the sigmoid factor, which deterministically computes\nxout = f(xin) =\n1+exp(\u2212xin). For this factor, sensible choices for the messages to xout and xin are\nBeta and Gaussian distributions respectively. Second, we wanted factors that stressed the framework\nin different ways. For the \ufb01rst of these, we chose a compound Gamma factor, which is sampled by\n\ufb01rst drawing a random Gamma variable r2 with rate r1 and shape s1, then drawing another random\nGamma variable xout with rate r2 and shape s2. This de\ufb01nes xout = f(r1, s1, s2), which is a\nchallenging factor because depending on the choice of inputs, this can produce a very heavy tailed\ndistribution over xout. Another challenging factor we experiment with is the product factor, which\nuses xout = f(xin,1, xin,2) = xin,1 \u00d7 xin,2. While this is a conceptually simple function, it is\nhighly challenging to use within EP for several reasons, including symmetries due to signs, and the\nfact that message outputs can change very quickly as functions of message inputs (see Fig. 3).\nOne main reason for the above factor choices is that there are existing hand-crafted implementa-\ntions in Infer.NET, which we can use to evaluate our learned message operators. It would have\nbeen straightforward to experiment with more example factors that could not be implemented with\nexisting hand-crafted factors, but it would have been much harder to evaluate our proposed method.\nFinally, we developed a factor that models the throwing of a ball, which is representative of the type\nof factors that we believe our framework to be well-suited for, and which is not easily implemented\nwith hand-crafted approximations.\nFor all factors, we use the extensible factor interface in Infer.NET to create factors that compute\nmessages by running a forward pass of the learned neural network. We then studied these factors\nin a variety of models, using the default Infer.NET settings for all other implementation details,\ne.g. message schedules and other factor implementations. Additional details of the models used in\nthe experiments can be found in the supplemental material.\n\nSigmoid Factor For the sigmoid factor, we ran two main sets of experiments. First, we learned a\nfactor using the methodology described in Section 3 and evaluated how well the network was able\nto reconstruct the training data. In Fig. 1 we show histograms of KL errors for the network trained\nto send forward messages (Fig. 1a) and the network trained to send backwards messages (Fig. 1b).\nTo aid the interpretation of these results, we also show the best, median, and worst approximations\nfor each. There are a small number of moderate-sized errors, but average performance is very good.\nWe then used the learned factor within a Bayesian logistic regression model where the output nonlin-\nearity is implemented using either the default Infer.NET sigmoid factor or our learned sigmoid factor.\nThe number of training points is given in the table. There were always 2000 data points for testing.\nData points for training and testing were generated according to p(y = 1|x) = sigmoid(wT x).\nEntries of x were drawn from N(0, 1). Entries of w were drawn from N(0, 1) for all relevant\ndimensions, and the others were set to 0. Results are shown in Table 1, which appears in the Supple-\nmentary materials. Predictive performance is very similar across the board, and although there are\nmoderately large KL divergences between the learned posteriors in some cases, when we compared\nthe distance between the true generating weights and the learned posteriors means for the EP and\nNN case, we found them to be similar.\n\n(a) Backward Message (to Gaussian)\n\n(b) Forward Message (to Beta)\n\nFigure 1: Sigmoid factor: Histogram of training KL divergences between target and predicted dis-\ntributions for the two messages outgoing from the learned sigmoid factor (left: backward message;\nright: forward message). Also illustrated are best(1), median (2,3), and worst (4) examples. The red\ncurve is the density of the target message, and the green is of the predicted message. In the inset are\nmessage parameters (left: Gaussian mean and precision; right: Beta \u03b1 and \u03b2) for the true (top line)\nand predicted (middle line) message, along with the KL (bottom line).\n\n5\n\n\uf001\uf002\uf003\uf004\uf005\uf006\uf007\uf008\uf002\uf009\uf00a\uf00b\uf00c\uf00d\uf00e\uf00f\uf00e\uf00d\uf00c\uf00f\uf001\uf002\uf003\uf004\uf005\uf006\uf007\uf008\uf002\uf009\uf00a\uf00b\uf00c\uf00d\uf00e\uf00f\uf00c\uf00d\uf00e\uf00f\fCompound Gamma Factor The compound Gamma factor is useful as a heavy-tailed prior over\nprecisions of Gaussian random variables. Accordingly, we evaluate performance in the context of\nmodels where the factor provides a prior precision for learning a Gaussian or mixture of Gaussians\nmodel. As before, we trained a network using the methodology from Section 3. For this factor, we\n\ufb01xed the value of the inputs xin, which is a standard way that the compound Gamma construction is\nused as a prior. We experimented with values of (3, 3, 1) and (1, 1, 1) for the inputs. In both cases,\nthese settings induce a heavy-tailed distribution over the precision.\nWe begin by evaluating the importance sampler. We \ufb01rst evaluate the naive choice for proposal\ndistribution q as described in Section 3. As can be seen in the bottom left plot of Fig. 2, there is a\nrelatively large region of possible input-message space (white region) where almost no samples are\ndrawn, and thus the importance sampling estimates will be unreliable. Here shapein and ratein de-\nnote the parameters of the message being sent from the precision variable to the compound Gamma\nfactor. By instead using a mixture distribution over q, which has one component equivalent to the\nnaive sampler and one broader component, we achieve the result in the top left of Fig. 2, which\nhas better coverage of the space of possible messages. The plots in the second column show the\nimportance sampling estimates of factor-to-variable messages (one plot per message parameter) as a\nfunction of the variable-to-factor message coming from the precision variable, which are unreliable\nin the regions that would be expected based on the previous plot. The third column shows the same\nfunction but for the learned neural network model. Surprisingly, we see that the neural network has\nsmoothed out some of the noise of the importance sampler, and that it has extrapolated in a smooth,\nreasonable manner. Overlaid on these plots are the message values that were actually encountered\nwhen running the experiments in Fig. 8, which are described next.\n\nin\n\nshape\n\nmessages expt. 1\n\nshape\n\nin\n\n)\nd\ne\nn\nr\na\ne\nl\n(\n \n!\n\n4\n10\n\n2\n10\n\n0\n10\n\n10!2\n\n10!4\n \n\n \n\nCG111, 20\nCG111, 100\nCG331, 20\nCG331, 100\n\n10!4\n\n0\n10\n\n10!2\n! (default)\n\n2\n10\n\n4\n10\n\nmessages expt. 1\n\nFigure 2: Compound Gamma plots. First column: Log sum of importance weights arising from im-\nproved importance sampler (top) and naive sampler (bottom) as a function of the incoming context\nmessage. Second column: Improved importance sampler estimate of outgoing message shape pa-\nrameter (top) and rate parameter (bottom) as a function of the incoming context message. We show\nthe suf\ufb01cient statistics of the numerator of eq. 1. Third column: Learned neural network estimates\nfor the same messages. Parameters of the variable-to-factor messages encountered when running\nthe experiments in Fig. 8 are super-imposed as black dots. Rightmost plot: Precisions learned for\nmixture of Gaussians model with \u201clearned\u201d / standard Infer.NET (\u201cdefault\u201d) factor for 20 and 100\ndatapoints respectively and true precisions: \u03bb1 = 0.01; \u03bb2 = 1000. Best viewed in color.\n\nIn the next experiments, we generate data from Gaussians with a wide range of variances, and we\nevaluate how well we are able to learn the precision as a function of the number of data points (x-\naxis). We compare to the same construction but using two hand-crafted Gamma factors to implement\nthe compound Gamma prior. The plots in Fig. 8 in the supplementary material show the means of\nthe learned precisions for two choices of compound Gamma parameters (top is (3, 3, 1), bottom is\n(1, 1, 1)). Even though some messages were passed in regions with little representation under the\nimportance sampling, the factor was still able to perform reliably.\nWe next evaluate performance of the compound Gamma factors when learning a mixture of Gaus-\nsians. We generated data from a mixture of two Gaussians with \ufb01xed means but widely different\nvariances, using the compound Gamma prior on the precisions of both Gaussians in the mixture. Re-\nsults are shown in the right-most plot of Fig. 2. We see that both factors sometimes under-estimate\nthe true variance, but the learned factor is equally as reliable as the hand-crafted version. We also ob-\nserved in these experiments that the learned factor was an order of magnitude faster than the built-in\nfactor (total runtime was 11s for the learned factor vs. 110s for the standard Infer.NET construction).\n\n6\n\n\f!\"%#&\n!\"!\u2019!\n\n!\"(#%\n!\"!!&\n\n%\"\u2019&)\n!\"!!%\n\n!\"(#*\n!\"!!(\n\n!\"%*+\n!\"!!!\n\n!\"(\u2019%\n!\"!%(\n\n%\"#!*\n!\"!!%\n\n!\"%&#\n!\"!!!\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n!!\"# ! !\"#\n\n$\n\n$\n\n$\n\n$\n\n$\n\n$\n\n$\n\n$\n\n!\"!%!\n!\"!!#\n\n%\"\u2019!)\n!\"!!&\n\n!\"!!\u2019\n!\"!!+\n\n!\"!&(\n!\"!!%\n\n!\"!&(\n!\"!!#\n\n!\"!!(\n!\"!!&\n\n!\"&\u2019%\n!\"!!\u2019\n\n!\"!!+\n!\"!!*\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\n!\n\n%\"#\n\n!\"#\n\n%\"#\n\n%\n!\n\nFigure 4: Learned posteriors from the multiplicative noise regression model. We compare the built-\nin factor\u2019s result (green) to our learned factor (red) and an importance sampler that is given the same\nruntime budget as the learned model (black). Top row: Representative posteriors over weights w.\nBottom row: Representative posteriors over \u03b7n variables. Inset gives KL between built-in factor\nand learned factor (red) and IS factor (black).\n\n \n\n!x\n\n \n\ny\n\n\u00b5\n\ny\n\n\u00b5\n\n \n\ni\n\n!10\n\n!10\n\n2\n\n0\n\n10\n\n0\n\n10\n\n0\n\n10\n\n0\n\n \n\n \n\n=0.1\n\n \n\n=1\n\n \n\n1\n\n!x\n\n=1;!y\n\n=1;\u00b5z\n\n=5;!z\n\nfor\n\n=0.1;!y\n\n=0.1;\u00b5z\n\n=5;!z\n\ni\n\nt\nn\no\np\n \nl\na\ne\nd\n\nfactor\n\nregular\nSHG\nNN\n\nProduct Factor The product factor is a surprisingly dif\ufb01cult factor to work with. To illustrate\nsome of the dif\ufb01culty, we provide plots of output message parameters along slices in input message\nspace (Fig. 3). In our \ufb01rst experiment with the product factor, we build a Bayesian linear regression\nmodel with multiplicative output noise. Given a vector of inputs xn, we take an inner product of\nxn with multivariate Gaussian variables w, then for each instance n multiply the result by a random\nnoise variable \u03b7n that is drawn from a Gaussian with mean 1 and standard deviation 0.1. Additive\nnoise is then added to the output to produce a noisy observation yn. The goal is to infer w and \u03b7\nvalues given x\u2019s and y\u2019s. We compare using the default Infer.NET product factor to using our learned\nproduct factor for the multiplication of \u03b7 and the output of the inner products. Results are shown in\nFig. 4, where we also compare to importance sampling, which was given a runtime budget similar\nto that of the neural network.\nIn the second experiment\nwith the product factor, we\nimplemented an ideal point\nmodel, which is essentially\na\nlatent-dimensional\nbinary matrix-factorization\nmodel, using our learned\nproduct\nthe\nmultiplications. This is the\nmost\nchallenging model\nwe have considered yet,\nbecause (a) EP is known to\nbe unreliable in matrix fac-\ntorization models [13], and\n(b) there is an additional\nlevel\napproximation\ndue to the loopiness of the\ngraph, which pushes the\nfactor into more extreme\nranges, which it might not have been trained as reliably for and/or where importance sampling\nestimates used for generating training data are unreliable.\nWe ran the model on a subset of US senate vote records from the 112th congress.1 We evaluated\nthe model based on how well the learned factor version recovered the posteriors over senator latent\nfactors that were found by the built-in product factor and the approximate product factor of [13]. The\nresult of this experiment was that midway through inference, the learned factor version produced\nposteriors with means that were consistent with the built-in factors, although the variances were\nslightly larger, and the means were noisier. After this, we observed gradual degradation of the\nestimates for a subset of about 5-10% of the senators. By the end of inference, results had degraded\nsigni\ufb01cantly. Investigating the cause of this result, we found that a large number of zero-precision\nmessages were being sent, which happens when the projected distribution has larger variance than\n\nFigure 3: Message surfaces and failure case plot for the product factor\n(computing z = xy). Left: Mean of the factor to z message as\na function of the mean-parameters of the incoming messages from x\nand y. Top row shows ground truth, the bottom row shows the learned\nNN approximation. Right: Posterior over the ideal-point variables\nfor all senators (inferred std.-dev. is shown as error bars). Senators\nare ordered according to ideal-points means inferred with factor [13]\n(SHG). Red/blue dots indicate true party af\ufb01liation.\n\nsenator index\n\nof\n\n \n\n \n\n10\n\n0\n\n!10\n\n\u00b5x\n\n!2\n\n \n\n!10\n\n\u00b5x\n\n1Data obtained from http://www.govtrack.us/\n\n7\n\n\fthe context message. We believe that the cause of this is that as the messages in this model begin to\nconverge, the messages being passed take on a distribution that is dif\ufb01cult to approximate (leading\nthe neural network to under\ufb01t), that is different from the training distribution, or is in a regime where\nimportance sampling estimates are noisy. In these cases, our KL-objective factors are overestimating\nthe variance.\nIn some cases, these errors can propagate and lead to complete failure of inference, and we have\nobserved this in our experiments. This leads to perhaps an obvious point, which is that our approach\nwill fail when messages required by inference are signi\ufb01cantly different from those that were in\nthe training distribution. This can happen via the choice of too extreme priors, too many observa-\ntions driving precisions to be extreme, and due to complicated effects arising from the dynamics of\nmessage passing on loopy graphs. We will discuss some possibly mitigating strategies in Section 6.\n\n0.1\n\n0\n\n30\n\n30\n\n0.1\n\n0\n\n0.1\n\n0\n\n0.1\n\n0\n\n0.2\n\n0.1\n\n0\n\n30\n\n50\n\nperson #7\n\n0 10\n\n30\n\n50\n\n0 10\n\n50\n\n0 10\n\n50\n\n0 10\n\n50\n\n0 10\n\nperson #4\n\n0.2\n\nperson #5\n\n0.2\n\nperson #2\n\n0.2\n\nperson #1\n\n0.2\n\n30\nvelocity\n\nFigure 5: Throwing a ball factor experiments. True distributions\nover individual throwing velocities (black) and predictive distri-\nbution based on the learned posterior over velocity rates.\n\nThrowing a Ball Factor With this factor, we model the distance that a ball travels as a function of\nthe angle, velocity, and initial height that it was thrown from. While this is also a relatively simple\ninteraction conceptually, it would be highly challenging to implement it as a hand-crafted factor. In\nour framework, it suf\ufb01ces to provide a function f that, given the angle, velocity, and initial height,\ncomputes and returns the distance that the ball travels. We do so by constructing and solving the\nappropriate quadratic equation. Note that this requires multiplication and trigonometric functions.\nWe learn the factor as before\nand evaluate it in the context of\ntwo models. In the \ufb01rst model,\nwe have person-speci\ufb01c distri-\nbutions over height (Gaussian),\nlog slope (Gaussian) and the\nrate parameter (Gamma) of a\nGamma distribution that deter-\nmines velocity. We then observe several samples (generated from the model) of noisy distances that\nthe ball traveled for each person. We then use our learned factor to infer posteriors over the person-\nspeci\ufb01c parameters. The inferred posteriors for several representative people are shown in Fig. 5.\nSecond, we extended the above model to have the person-speci\ufb01c rate parameter be produced by a\nlinear regression model (with exponential link function) with observed person-speci\ufb01c features and\nunknown weights. We again generated data from the model, observed several sample throws per\nperson, and inferred the regression weights. We found that we were able to recover the generating\nweights with reasonable accuracy, although the posterior was a bit overcon\ufb01dent: true (\u2212.5, .5, 3)\nvs. posterior mean (\u2212.43, .55, 3.1) and standard deviations (.04, .03, .02).\n6 Discussion\nWe have shown that it is possible to learn to pass EP messages in several challenging cases. The\ntechniques that we use build upon a number of tools well-known in the \ufb01eld, but the combination\nin this application is novel, and we believe it to have great practical potential. Although we have\nestablished viability of the idea, in its current form it works better for some factors than others. Its\nsuccess depends on (a) the ability of the function approximator to represent the required message\nupdates (which may be highly discontinuous) and (b) the availability of reliable samples of these\nmappings (some factors may be very hard to invert). Here, we expect that great improvements can\nbe made taking advantage of recent progress in uninformed sampling, and high capacity regression\nmodels. We tested factors with multiple models and/or datasets but this does not mean that they will\nwork with all models, hyper-parameter settings, or datasets (we found varying degrees of robustness\nto such variations). A critical ingredient is here an appropriate choice of the distribution of training\nmessages which, at the current stage, can require some manual tuning and experimentation. This\nleads to an interesting extension, which would be to maintain an estimate of the quality of the\napproximation over the domain of the factor, and to re-train the factor on the \ufb02y when a message\nis encountered that lies in a low-con\ufb01dence region. A second direction for future study, which\nis enabled by our work, is to add additional constraints during learning in order to guarantee that\nupdates have certain desirable properties. For example, we may be able to ask the network to learn\nthe best message updates subject to a constraint that guarantees convergence.\n\nAcknowledgements: NH acknowledges funding from the European Community\u2019s Seventh Framework Programme (FP7/2007-2013) under grant agreement no.\n270327, and from the Gatsby Charitable foundation.\n\n8\n\n\fReferences\n[1] S. Barthelm\u00b4e and N. Chopin. ABC-EP: Expectation Propagation for likelihood-free Bayesian\ncomputation. In Proceedings of the 28th International Conference on Machine Learning, 2011.\n[2] J. Domke. Parameter learning with truncated message-passing. In Computer Vision and Pattern\n\nRecognition (CVPR). IEEE, 2011.\n\n[3] J. Domke. Learning graphical model parameters with approximate marginal inference. Pattern\n\nAnalysis and Machine Intelligence (PAMI), 2013.\n\n[4] N.D. Goodman, V.K. Mansinghka, D.M. Roy, K. Bonawitz, and J.B. Tenenbaum. Church: A\nlanguage for generative models. In Proc. of Uncertainty in Arti\ufb01cial Intelligence (UAI), 2008.\n[5] R. Herbrich, T.P. Minka, and T. Graepel. Trueskill: A Bayesian skill rating system. Advances\n\nin Neural Information Processing Systems, 19:569, 2007.\n\n[6] T.P. Minka. A family of algorithms for approximate Bayesian inference. PhD thesis, Mas-\n\nsachusetts Institute of Technology, 2001.\n\n[7] T.P. Minka and J. Winn. Gates: A graphical notation for mixture models.\n\nNeural Information Processing Systems, 2008.\n\nIn Advances in\n\n[8] T.P. Minka, J.M. Winn, J.P. Guiver, and D.A. Knowles.\n\nResearch. http://research.microsoft.com/infernet.\n\nInfer.NET 2.5, 2012. Microsoft\n\n[9] P. Kohli R. Shapovalov, D. Vetrov. Spatial inference machines. In Computer Vision and Pattern\n\nRecognition (CVPR). IEEE, 2013.\n\n[10] S. Ross, D. Munoz, M. Hebert, and J.A. Bagnell. Learning message-passing inference ma-\nchines for structured prediction. In Computer Vision and Pattern Recognition (CVPR). IEEE,\n2011.\n\n[11] D.B. Rubin. Bayesianly justi\ufb01able and relevant frequency calculations for the applies statisti-\n\ncian. The Annals of Statistics, pages 1151\u20131172, 1984.\n\n[12] Stan Development Team. Stan: A C++ library for probability and sampling, version 1.3, 2013.\n[13] D.H. Stern, R. Herbrich, and T. Graepel. Matchbox: Large scale online Bayesian recommenda-\ntions. In Proceedings of the 18th international conference on World Wide Web, pages 111\u2013120.\nACM, 2009.\n\n[14] A. Thomas. BUGS: A statistical modelling package. RTA/BCS Modular Languages Newslet-\n\nter, 1994.\n\n[15] D. Wingate, N.D. Goodman, A. Stuhlmueller, and J. Siskind. Nonstandard interpretations of\nprobabilistic programs for ef\ufb01cient inference. In Advances in Neural Information Processing\nSystems, 2011.\n\n[16] D. Wingate and T. Weber. Automated variational inference in probabilistic programming. In\n\narXiv:1301.1299, 2013.\n\n9\n\n\f", "award": [], "sourceid": 1493, "authors": [{"given_name": "Nicolas", "family_name": "Heess", "institution": "Gatsby Unit, UCL"}, {"given_name": "Daniel", "family_name": "Tarlow", "institution": "Microsoft Research"}, {"given_name": "John", "family_name": "Winn", "institution": "Microsoft Research"}]}