{"title": "Kullback-Leibler Proximal Variational Inference", "book": "Advances in Neural Information Processing Systems", "page_first": 3402, "page_last": 3410, "abstract": "We propose a new variational inference method based on the Kullback-Leibler (KL) proximal term. We make two contributions towards improving efficiency of variational inference. Firstly, we derive a KL proximal-point algorithm and show its equivalence to gradient descent with natural gradient in stochastic variational inference. Secondly, we use the proximal framework to derive efficient variational algorithms for non-conjugate models. We propose a splitting procedure to separate non-conjugate terms from conjugate ones. We then linearize the non-conjugate terms and show that the resulting subproblem admits a closed-form solution. Overall, our approach converts a non-conjugate model to subproblems that involve inference in well-known conjugate models. We apply our method to many models and derive generalizations for non-conjugate exponential family. Applications to real-world datasets show that our proposed algorithms are easy to implement, fast to converge, perform well, and reduce computations.", "full_text": "Kullback-Leibler Proximal Variational Inference\n\nEcole Polytechnique F\u00b4ed\u00b4erale de Lausanne\n\nEcole Polytechnique F\u00b4ed\u00b4erale de Lausanne\n\nMohammad Emtiyaz Khan\u2217\n\nLausanne, Switzerland\nemtiyaz@gmail.com\n\nPierre Baqu\u00b4e\u2217\n\nLausanne, Switzerland\n\npierre.baque@epfl.ch\n\nFranc\u00b8ois Fleuret\n\nIdiap Research Institute\nMartigny, Switzerland\n\nfrancois.fleuret@idiap.ch\n\nPascal Fua\n\nEcole Polytechnique F\u00b4ed\u00b4erale de Lausanne\n\nLausanne, Switzerland\n\npascal.fua@epfl.ch\n\nAbstract\n\nWe propose a new variational inference method based on a proximal framework\nthat uses the Kullback-Leibler (KL) divergence as the proximal term. We make\ntwo contributions towards exploiting the geometry and structure of the variational\nbound. First, we propose a KL proximal-point algorithm and show its equivalence\nto variational inference with natural gradients (e.g., stochastic variational infer-\nence). Second, we use the proximal framework to derive ef\ufb01cient variational al-\ngorithms for non-conjugate models. We propose a splitting procedure to separate\nnon-conjugate terms from conjugate ones. We linearize the non-conjugate terms\nto obtain subproblems that admit a closed-form solution. Overall, our approach\nconverts inference in a non-conjugate model to subproblems that involve inference\nin well-known conjugate models. We show that our method is applicable to a wide\nvariety of models and can result in computationally ef\ufb01cient algorithms. Applica-\ntions to real-world datasets show comparable performances to existing methods.\n\n1\n\nIntroduction\n\nVariational methods are a popular alternative to Markov chain Monte Carlo (MCMC) methods for\nBayesian inference. They have been used extensively for their speed and ease of use. In particular,\nmethods based on the evidence lower bound optimization (ELBO) are quite popular because they\nconvert a dif\ufb01cult integration problem to an optimization problem. This reformulation enables the\napplication of optimization techniques for large-scale Bayesian inference.\nRecently, an approach called stochastic variational inference (SVI) has gained popularity for infer-\nence in conditionally-conjugate exponential family models [1]. SVI exploits the geometry of the\nposterior distribution by using natural gradients and uses a stochastic method to improve scalability.\nThe resulting updates are simple and easy to implement.\nSeveral generalizations of SVI have been proposed for general latent-variable models where the\nlower bound might be intractable [2, 3, 4]. These generalizations, although important, do not take\nthe geometry of the posterior distribution into account.\nIn addition, none of these approaches exploit the structure of the lower bound. In practice, not all\nfactors of the joint distribution introduce dif\ufb01culty in the optimization. It is therefore desirable to\ntreat \u201cdif\ufb01cult\u201d terms differently from \u201ceasy\u201d terms.\n\n\u2217A note on contributions: P. Baqu\u00b4e proposed the use of the KL proximal term and showed that the resulting\n\nproximal steps have closed-form solutions. The rest of the work was carried out by M. E. Khan.\n\n1\n\n\fIn this context, we propose a splitting method for variational inference; this method exploits both\nthe structure and the geometry of the lower bound. Our approach is based on the proximal-gradient\nframework. We make two important contributions. First, we propose a proximal-point algorithm\nthat uses the Kullback-Leibler (KL) divergence as the proximal term. We show that the addition of\nthis term incorporates the geometry of the posterior distribution. We establish the equivalence of our\napproach to variational methods that use natural gradients (e.g., [1, 5, 6]).\nSecond, following the proximal-gradient framework, we propose a splitting approach for variational\ninference. In this approach, we linearize dif\ufb01cult terms such that the resulting optimization problem\nis easy to solve. We apply this approach to variational inference on non-conjugate models. We\nshow that linearizing non-conjugate terms leads to subproblems that have closed-form solutions.\nOur approach therefore converts inference in a non-conjugate model to subproblems that involve\ninference in well-known conjugate models, and for which ef\ufb01cient implementation exists.\n\n2 Latent Variable Models and Evidence Lower-Bound Optimization\n\nConsider a general latent-variable model with data vector y of length N and the latent vector z of\nlength D, following a joint distribution p(y, z) (we drop the parameters of the distribution from\nthe notation). ELBO approximates the posterior p(z|y) by a distribution q(z|\u03bb) that maximizes a\nlower bound to the marginal likelihood. Here, \u03bb is the vector of parameters of the distribution q.\nAs shown in (1), the lower bound is obtained by \ufb01rst multiplying and then dividing by q(z|\u03bb), and\nthen applying Jensen\u2019s inequality by using concavity of log. The approximate posterior q(z|\u03bb) is\nobtained by maximizing the lower bound with respect to \u03bb.\n\n(cid:20)\n\n(cid:21)\n\n(cid:90)\n\nlog p(y) = log\n\nq(z|\u03bb)\n\np(y, z)\nq(z|\u03bb)\n\ndz \u2265 max\n\n\u03bb\n\nEq(z|\u03bb)\n\nlog\n\np(y, z)\nq(z|\u03bb)\n\n:= L(\u03bb).\n\n(1)\n\nUnfortunately, the lower bound may not always be easy to optimize, e.g., some terms in the lower\nbound might be intractable or might admit a form that is not easy to optimize. In addition, the\noptimization can be slow when N and D are large.\n\n3 The KL Proximal-Point Algorithm for Conjugate Models\n\nIn this section, we introduce a proximal-point method based on Kullback-Leibler (KL) proximal\nfunction and establish its relation to the existing approaches based on natural gradients [1, 5, 6].\nIn particular, for conditionally-conjugate exponential-family models, we show that each iteration of\nour proximal-point approach is equivalent to a step along the natural gradient.\nThe Kullback-Leibler (KL) divergence between two distributions q(z|\u03bb) and q(z|\u03bb(cid:48)) is de\ufb01ned as\nfollows: D\n)]. Using the KL divergence\nas the proximal term, we introduce a proximal-point algorithm that generates a sequence of \u03bbk by\nsolving the following subproblems:\n\n)] := Eq(z|\u03bb)[log q(z|\u03bb) \u2212 log q(z|\u03bb\n\n(cid:48)\n\nKL[q(z|\u03bb)(cid:107) q(z|\u03bb\n\n(cid:48)\n\nKL Proximal-Point : \u03bbk+1 = arg max\n\n\u03bb\n\nL(\u03bb) \u2212 1\n\u03b2k\n\nD\n\nKL[q(z|\u03bb)(cid:107) q(z|\u03bbk)],\n\n(2)\n\ngiven an initial value \u03bb0 and a bounded sequence of step-size \u03b2k > 0,\nOne bene\ufb01t of using the KL term is that it takes the geometry of the posterior distribution into\naccount. This fact has lead to their extensive use in both the optimization and statistics literature,\ne.g., for speeding up the expectation-maximization algorithm [7, 8], for convex optimization [9], for\nmessage-passing in graphical models [10], and for approximate Bayesian inference [11, 12, 13].\nRelationship to the methods that use natural gradients: An alternative approach to incorporate\nthe geometry of the posterior distribution is to use natural gradients [6, 5, 1]. We now establish its\nrelationship to our approach. The natural gradient can be interpreted as \ufb01nding a descent direction\nthat ensures a \ufb01xed amount of change in the distribution. For variational inference, this is equivalent\nto the following [1, 14]:\n\narg max\n\u2206\u03bb\n\nL(\u03bbk + \u2206\u03bb), s.t. Dsym\n\nKL [q(z|\u03bbk + \u2206\u03bb)(cid:107) q(z|\u03bbk)] \u2264 \u0001,\n\n(3)\n\n2\n\n\fwhere Dsym\nKL is the symmetric KL divergence. It appears that the proximal-point subproblem (2) is\nrelated to a Lagrangian of the above optimization. In fact, as we show below, the two problems are\nequivalent for conditionally conjugate exponential-family models.\nWe consider the set-up described in [15], which is a bit more general than that of [1]. Consider a\ni p(zi|pai) where pai are the parents of\n\nBayesian network with nodes zi and a joint distribution(cid:81)\n\nzi. We assume that each factor is an exponential-family distribution de\ufb01ned as follows:\n\np(zi|pai) := hi(zi) exp(cid:2)\u03b7T\n\ni (pai)Ti(zi) \u2212 Ai(\u03b7i)(cid:3) ,\n\n(4)\nwhere \u03b7i is the natural parameter, Ti(zi) is the suf\ufb01cient statistics, Ai(\u03b7i) is the partition function\nand hi(zi) is the base measure. We seek a factorized approximation shown in (5), where each zi\nbelongs to the same exponential-family distribution as the joint distribution. The parameters of this\ndistribution are denoted by \u03bbi to differentiate them from the joint-distribution parameters \u03b7i. Also\nnote that the subscript refers to the factor i, not to the iteration.\nqi(zi|\u03bbi), where qi(zi) := hi(z) exp\n\ni Ti(zi) \u2212 Ai(\u03bbi)\n\u03bbT\n\nq(z|\u03bb) =\n\n(cid:89)\n\n(cid:105)\n\n(cid:104)\n\n(5)\n\n.\n\ni\n\nFor this model, we show the following equivalence between a gradient-descent method based on\nnatural gradients and our proximal-point approach. The proof is given in the supplementary material.\nTheorem 1. For the model shown in (4) and the posterior approximation shown in (5), the sequence\n\u03bbk generated by the proximal-point algorithm of (2) is equal to the one obtained using gradient-\ndescent along the natural gradient with step lengths \u03b2k/(1 + \u03b2k).\n\nProof of convergence : Convergence of the proximal-point algorithm shown in (2) is proved in\n[8]. We give a summary of the results here. We assume \u03b2k = 1, however the proof holds for any\nbounded sequence of \u03b2k. Let the space of all \u03bb be denoted by S. De\ufb01ne the set S0 := {\u03bb \u2208 S :\nL(\u03bb) \u2265 L(\u03bb0)}. Then, (cid:107)\u03bbk+1 \u2212 \u03bbk(cid:107) \u2192 0 under the following conditions:\n\n(A) Maximum of L exist and the gradient of L is continuous and de\ufb01ned in S0.\n(B) The KL divergence and its gradient are continuous and de\ufb01ned in S0 \u00d7 S0.\n(C) D\n\n)] = 0 only when \u03bb\n\nKL[q(z|\u03bb)(cid:107) q(z|\u03bb\n\n= \u03bb.\n\n(cid:48)\n\n(cid:48)\n\nIn our case, the conditions (A) and (B) are either assumed or satis\ufb01ed, and the condition (C) can be\nensured by choosing an appropriate parameterization of q.\n\n4 The KL Proximal-Gradient Algorithm for Non-conjugate Models\n\nThe proximal-point algorithm of (2) might be dif\ufb01cult to optimize for non-conjugate models, e.g.,\ndue to the non-conjugate factors. In this section, we present an algorithm based on the proximal-\ngradient framework where we \ufb01rst split the objective function into \u201cdif\ufb01cult\u201d and \u201ceasy\u201d terms, and\nthen, to simplify the optimization, linearize the dif\ufb01cult term. See [16] for a good review of proximal\nmethods for machine learning.\nWe split the ratio p(y, z)/q(z|\u03bb) \u2261 c \u02dcpd(z|\u03bb)\u02dcpe(z|\u03bb), where \u02dcpd contains all factors that make the\noptimization dif\ufb01cult, and \u02dcpe contains the rest (c is a constant). This results in the following split:\n\nL(\u03bb) = Eq(z|\u03bb)\n\nlog\n\n:= Eq(z|\u03bb)[log \u02dcpd(z|\u03bb)]\n\n+ Eq(z|\u03bb)[log \u02dcpe(z|\u03bb)]\n\n+ log c,\n\n(6)\n\n(cid:20)\n\n(cid:21)\n\np(y, z|\u03b8)\nq(z|\u03bb)\n\n(cid:124)\n\n(cid:123)(cid:122)\n\nf (\u03bb)\n\n(cid:125)\n\n(cid:124)\n\n(cid:123)(cid:122)\n\nh(\u03bb)\n\n(cid:125)\n\nNote that \u02dcpd and \u02dcpe can be un-normalized factors in the distribution. In the worst case, we can set\n\u02dcpe(z|\u03bb) \u2261 1 and take the rest as \u02dcpd(z|\u03bb). We give an example of the split in the next section.\nThe main idea is to linearize the dif\ufb01cult term f such that the resulting problem admits a simple\nform. Speci\ufb01cally, we use a proximal-gradient algorithm that solves the following sequence of\nsubproblems to maximize L as shown below. Here, (cid:53)f (\u03bbk) is the gradient of f at \u03bbk.\n\nKL Proximal-Gradient: \u03bbk+1 = arg max\n\n\u03bb\n\n\u03bbT (cid:53) f (\u03bbk) + h(\u03bb) \u2212 1\n\u03b2k\n\nD\n\nKL[q(z|\u03bb)(cid:107) q(z|\u03bbk)].\n\n(7)\n\n3\n\n\fNote that our linear approximation is equivalent to the one used in gradient descent. Also, the\napproximation is tight at \u03bbk. Therefore, it does not introduce any error in the optimization, rather it\nonly acts as a surrogate to take the next step. Existing variational methods have used approximations\nsuch as ours, e.g., see [17, 18, 19]. Most of these methods \ufb01rst approximate the log \u02dcpd(z|\u03bb) term\nby using a linear or quadratic approximation and then compute the expectation. As a result the\napproximation is not tight and can result in a bad performance [20]. In contrast, our approximation\nis applied directly to E[log \u02dcpd(z|\u03bb)] and therefore is tight at \u03bbk.\nThe convergence of our approach is covered under the results shown in [21]; they prove convergence\nof an algorithm more general algorithm than ours. Below, we summarize the results. As before, we\nassume that the maximum exists and L is continuous. We make three additional assumptions. First,\nthe gradient of f is L-Lipschitz continuous in S, i.e., ||(cid:53) f (\u03bb)\u2212(cid:53)f (\u03bb\n(cid:48) \u2208\nS. Second, the function h is concave. Third, there exists an \u03b1 > 0 such that,\n\n(\u03bbk+1 \u2212 \u03bbk)T (cid:53)1 D\n\nKL[q(z|\u03bbk+1)(cid:107) q(z|\u03bbk)] \u2265 \u03b1(cid:107)\u03bbk+1 \u2212 \u03bbk(cid:107)2,\n\n(8)\nwhere (cid:53)1 denotes the gradient with respect to the \ufb01rst argument. Under these conditions, (cid:107)\u03bbk+1 \u2212\n\u03bbk(cid:107) \u2192 0 when 0 < \u03b2k < \u03b1/L. The choice of constant \u03b1 is also discussed in [21]. Note that\neven though h is required to be concave, f could be non-convex. The lower bound usually contains\nconcave terms, e.g., in the entropy term. In the worst case when there are no concave terms, we can\nsimply choose h \u2261 0.\n\n)|| \u2264 L||\u03bb\u2212 \u03bb\n\n(cid:48)||, \u2200\u03bb, \u03bb\n\n(cid:48)\n\n5 Examples of KL Proximal-Gradient Variational Inference\n\nN(cid:89)\n\nIn this section, we show a few examples where the subproblem (7) has a closed-form solution.\nGeneralized linear model : We consider the generalized linear model shown in (9). Here, y is\nthe output vector (of length N) whose n\u2019th entry is equal to yn, whereas X is an N \u00d7 D feature\nn as rows. The weight vector z is a Gaussian with mean \u00b5 and\nmatrix that contains feature vectors xT\nn z is passed through p(yn|\u00b7).\ncovariance \u03a3. To obtain the probability of yn, the linear predictor xT\nn z)N (z|\u00b5, \u03a3).\n\n(9)\nWe restrict the posterior distribution to be a Gaussian q(z|\u03bb) = N (z|m, V) with mean m and\ncovariance V, therefore \u03bb := {m, V}. For this posterior family, the non-Gaussian terms p(yn|xT\nn z)\nare dif\ufb01cult to handle, while the Gaussian term N (z|\u00b5, \u03a3) is easy because it is conjugate to q.\nTherefore, we set \u02dcpe(z|\u03bb) \u2261 N (z|\u00b5, \u03a3)/N (z|m, V) and let the rest of the terms go in \u02dcpd.\nBy substituting in (6) and using the de\ufb01nition of the KL divergence, we get the lower bound shown\nbelow in (10). The \ufb01rst term is the function f that will be linearized, and the second term is the\nfunction h.\n\np(yn|xT\n\np(y, z) :=\n\nn=1\n\nN (z|\u00b5, \u03a3)\nN (z|m, V)\n\nlog\n\nh(m,V )\n\nf (m,V )\n\nn z)]\n\n(cid:125)\n\n(cid:123)(cid:122)\n\n+ Eq(z|\u03bb)\n\nL(m, V) :=\n\n(cid:53)mf (m, V) =\n\nEq(z|\u03bb)[log p(yn|xT\n\nEq(z|\u03bb)[log p(yn|xT\n\n(cid:124)\n(cid:123)(cid:122)\nFor linearization, we compute the gradient of f using the chain rule. Denote fn((cid:101)mn,(cid:101)vn) :=\nn z)] where (cid:101)mn := xT\nn m and(cid:101)vn := xT\ncan then be expressed in terms of gradients of fn w.r.t. (cid:101)mn and(cid:101)vn:\nN(cid:88)\nN(cid:88)\nxn (cid:53)(cid:101)mn fn((cid:101)mn,(cid:101)vn), (cid:53)Vf (m, V) =\nn mk and(cid:101)vnk := xT\nFor notational simplicity, we denote the gradient of fn at (cid:101)mnk := xT\n\u03b1nk := \u2212 (cid:53)(cid:101)mn fn((cid:101)mnk,(cid:101)vnk),\n= \u2212 N(cid:88)\n\n\u03b3nk := \u22122 (cid:53)(cid:101)vn fn((cid:101)mnk,(cid:101)vnk).\nn Vxn)(cid:3) .\n\nf (m, V) \u2248 \u03bbT (cid:53) f (\u03bbk) := mT [(cid:53)mf (mk, Vk)] + Tr [V {(cid:53)Vf (mk, Vk)}]\n\nUsing (11) and (12), we get the following linear approximation of f:\n\nn (cid:53)(cid:101)vn fn((cid:101)mn,(cid:101)vn),\n\nn Vxn. Gradients of f w.r.t. m and V\n\n(cid:2)\u03b1nk (xT\n\nn Vkxn by,\n(12)\n\nn m) + 1\n\n2 \u03b3nk (xT\n\n(13)\n\n(14)\n\nxnxT\n\nn=1\n\nn=1\n\n(10)\n\n(11)\n\nN(cid:88)\n(cid:124)\n\nn=1\n\n(cid:20)\n\n(cid:21)\n(cid:125)\n\n.\n\nn=1\n\n4\n\n\fSubstituting the above in (7), we get the following subproblem in the k\u2019th iteration:\n\n(mk+1, Vk+1) = arg max\nm,V (cid:31)0\n\nn m) + 1\n\n2 \u03b3nk (xT\n\n\u2212 N(cid:88)\n\n(cid:2)\u03b1nk (xT\n\nn=1\n\n\u2212 1\n\u03b2k\n\nn Vxn)(cid:3) + Eq(z|\u03bb)\n\n(cid:20) N (z|\u00b5, \u03a3)\n\nN (z|m, V)\n\n(cid:21)\n\nDKL [N (z|m, V)||N (z|mk, Vk)] ,\n\n(15)\n\nTaking the gradient w.r.t. m and V and setting it to zero, we get the following closed-form solutions\n(details are given in the supplementary material):\n\n(cid:104)\nmk+1 =(cid:2)(1 \u2212 rk)\u03a3\u22121 + rkV\u22121\n\nV\u22121\nk+1 = rkV\u22121\n\nk + (1 \u2212 rk)\n\nk\n\n(cid:3)\u22121(cid:104)\n\n\u03a3\u22121 + XT diag(\u03b3k)X\n\n,\n\n(cid:105)\n\n(1 \u2212 rk)(\u03a3\u22121\u00b5 \u2212 XT \u03b1k) + rkV\u22121\n\nk mk\n\n(cid:105)\n\n(16)\n\n(17)\n\n,\n\nwhere rk := 1/(1 + \u03b2k) and \u03b1k and \u03b3k are vectors of \u03b1nk and \u03b3nk respectively, for all k.\nComputationally ef\ufb01cient updates : Even though the updates are available in closed form, they are\nnot ef\ufb01cient when dimensionality D is large. In such a case, an explicit computation of V is costly\nbecause the resulting D \u00d7 D matrix is extremely large. We now derive ef\ufb01cient updates that avoids\nan explicit computation of V.\nOur derivation involves two key steps. The \ufb01rst step is to show that Vk+1 can be parameterized by\n\u03b3k. Speci\ufb01cally, if we initialize V0 = \u03a3, then we can show that:\n\n(cid:104)\n\u03a3\u22121 + XT diag((cid:101)\u03b3k+1)X\n\n(cid:105)\u22121\n\n, where(cid:101)\u03b3k+1 = rk(cid:101)\u03b3k + (1 \u2212 rk)\u03b3k.\n\n(18)\n\nVk+1 =\n\ngiven in the supplementary material):\n\nwith(cid:101)\u03b30 = \u03b30. A detailed derivation is given in the supplementary material.\nThe second key step is to express the updates in terms of (cid:101)mn and(cid:101)vn. For this purpose, we de\ufb01ne\nsome new quantities. Let (cid:101)m be a vector whose n\u2019th entry is (cid:101)mn. Similarly, let(cid:101)v be the vector of(cid:101)vn\nfor all n. Denote the corresponding vectors in the k\u2019th iteration by (cid:101)mk and(cid:101)vk, respectively. Finally,\nde\ufb01ne(cid:101)\u00b5 = X\u00b5 and (cid:101)\u03a3 = X\u03a3XT .\nNow, by using the fact that (cid:101)m = Xm and(cid:101)v = diag(XVXT ) and by applying the Woodbury matrix\nidentity, we can express the updates in terms of (cid:101)m and(cid:101)v, as shown below (a detailed derivation is\n(cid:101)mk+1 = (cid:101)mk + (1 \u2212 rk)(I \u2212(cid:101)\u03a3B\u22121\nk )((cid:101)\u00b5 \u2212 (cid:101)mk \u2212(cid:101)\u03a3\u03b1k), where Bk := (cid:101)\u03a3 + [diag(rk(cid:101)\u03b3k)]\u22121,\n(cid:101)vk+1 = diag((cid:101)\u03a3) \u2212 diag((cid:101)\u03a3A\u22121\nk (cid:101)\u03a3), where Ak := (cid:101)\u03a3 + [diag((cid:101)\u03b3k)]\u22121.\nNote that these updates depend on (cid:101)\u00b5,(cid:101)\u03a3, \u03b1k, and \u03b3k (whose size only depends on N and is inde-\nstoring (cid:101)mk and(cid:101)vk, both of which scale linearly with N.\nL at mk+1 and Vk+1 and simplifying, we get the following criteria: (cid:107)(cid:101)\u00b5 \u2212 (cid:101)mk+1 \u2212 (cid:101)\u03a3\u03b1k+1(cid:107)2\nTr[(cid:101)\u03a3(cid:8)diag((cid:101)\u03b3k \u2212 \u03b3k+1 \u2212 1)(cid:9)(cid:101)\u03a3] \u2264 \u0001, for some \u0001 > 0 (derivation is in the supplementary material).\n\nAlso note that the matrix Ak and Bk differ only slightly and we can reduce computation by using\nAk in place of Bk. In our experiments, this does not create any convergence issues.\nTo assess convergence, we can use the optimality condition. By taking the norm of the derivative of\n2 +\n\npendent of D). Most importantly, these updates avoid an explicit computation of V and only require\n\n(19)\n\nLinear-Basis Function Model and Gaussian Process : The algorithm presented above can be\nextended to linear-basis function models by using the weight-space view presented in [22]. Consider\na non-linear basis function \u03c6(x) that maps a D-dimensional feature vector into an N-dimensional\nfeature space. The generalized linear model of (9) is extended to a linear basis function model by\nn z with the latent function g(x) := \u03c6(x)T z. The Gaussian prior on z then translates\nreplacing xT\n\nto a kernel function \u03ba(x, x(cid:48)) := \u03c6(x)T \u03a3\u03c6(x) and a mean function(cid:101)\u00b5(x) := \u03c6(x)T \u00b5 in the latent\nfunction space. Given input vectors xn, we de\ufb01ne the kernel matrix (cid:101)\u03a3 whose (i, j)\u2019th entry is equal\nto \u03ba(xi, xj) and the mean vector(cid:101)\u00b5 whose i\u2019th entry is(cid:101)\u00b5(xi).\n(cid:101)m(x) and variance(cid:101)v(x) using the proximal-gradient algorithm. We de\ufb01ne (cid:101)m to be the vector of\n\nAssuming a Gaussian posterior distribution over the latent function g(x), we can compute its mean\n\n5\n\n\frepeat\n\nand threshold \u0001.\n\nAlgorithm 1 Proximal-gradient algorithm for linear basis function models and Gaussian process\n\nGiven: Training data (y, X), test data x\u2217, kernel mean(cid:101)\u00b5, covariance (cid:101)\u03a3, step-size sequence rk,\nInitialize: (cid:101)m0 \u2190(cid:101)\u00b5,(cid:101)v0 \u2190 diag((cid:101)\u03a3) and(cid:101)\u03b30 \u2190 \u03b411.\nFor all n in parallel: \u03b1nk \u2190 (cid:53)(cid:101)mn fn((cid:101)mnk,(cid:101)vnk) and \u03b3nk \u2190 (cid:53)(cid:101)vn fn((cid:101)mnk,(cid:101)vnk).\nUpdate (cid:101)mk and(cid:101)vk using (19).\n(cid:101)\u03b3k+1 \u2190 rk(cid:101)\u03b3k + (1 \u2212 rk)\u03b3k.\nuntil (cid:107)(cid:101)\u00b5 \u2212 (cid:101)mk \u2212(cid:101)\u03a3\u03b1k(cid:107) + Tr[(cid:101)\u03a3 diag((cid:101)\u03b3k \u2212 \u03b3k+1 \u2212 1)(cid:101)\u03a3] > \u0001.\n(cid:101)m(xn) for all n and similarly(cid:101)v to be the vector of all(cid:101)v(xn). Following the same derivation as the\nprevious section, we can show that the updates of (19) give us the posterior mean (cid:101)m and variance(cid:101)v.\n\nThese updates are the kernalized version of (16) and (17).\nFor prediction, we only need the converged value of \u03b1k and \u03b3k, denoted by \u03b1\u2217 and \u03b3\u2217, respectively.\nGiven a new input x\u2217, de\ufb01ne \u03ba\u2217\u2217 := \u03ba(x\u2217, x\u2217) and \u03ba\u2217 to be a vector whose n\u2019th entry is equal to\n\u03ba(xn, x\u2217). The predictive mean and variance can be computed as shown below:\n\nPredict test inputs x\u2217 using (20).\n\n(cid:101)m(x\u2217) =(cid:101)\u00b5\u2217 \u2212 \u03baT\u2217 \u03b1\u2217\n\n(cid:101)v(x\u2217) = \u03ba\u2217\u2217 \u2212 \u03baT\u2217 [(cid:101)\u03a3 + (diag((cid:101)\u03b3\n\n\u2217\n\nA pseudo-code is given in Algorithm 1. Here, we initialize (cid:101)\u03b3 to a small constant \u03b41, otherwise\nfunction(cid:101)\u00b5(x), and for many other latent Gaussian models such as matrix factorization models.\n\nsolving the \ufb01rst equation might be ill-conditioned.\nThese updates also work for the Gaussian process (GP) models with a kernel k(x, x(cid:48)) and mean\n\n))\u22121]\u22121\u03ba\u2217\n\n(20)\n\n,\n\n6 Experiments and Results\n\nWe now present some results on the real data. Our goal is to show that our approach gives compa-\nrable results to existing methods and is easy to implement. We also show that, in some cases, our\nmethod is signi\ufb01cantly faster than the alternatives due to the kernel trick.\nWe show results on three models: Bayesian logistic regression, GP classi\ufb01cation with logistic like-\nlihood, and GP regression with Laplace likelihood. For these likelihoods, expectations can be com-\nputed (almost) exactly, for which we used the methods described in [23, 24]. We use a \ufb01xed step-size\nof \u03b2k = 0.25 and 1 for logistic and Laplace likelihoods, respectively.\nWe consider three datasets for each model. A summary is given in Table 1. These datasets can be\nfound at the data repository1 of LIBSVM and UCI.\nBayesian Logistic Regression: Results for Bayesian logistic regression are shown in Table 2. We\nconsider two datasets. For \u2018a1a\u2019, N > D, and, for \u2018Colon\u2019, N < D. We compare our \u2018proximal\u2019\nmethod to three other existing methods: the \u2018MAP\u2019 method which \ufb01nds the mode of the penalized\nlog-likelihood, the \u2018Mean-Field\u2019 method where the distribution is factorized across dimensions, and\nthe \u2018Cholesky\u2019 method of [25]. We implemented these methods using \u2018minFunc\u2019 software by Mark\nSchmidt2. We used L-BFGS for optimization. All algorithms are stopped when optimality condition\nis below 10\u22124. We set the Gaussian prior to \u03a3 = \u03b4I and \u00b5 = 0. To set the hyperparameter \u03b4, we use\ncross-validation for MAP, and maximum marginal-likelihood estimate for the rest of the methods.\nAs we compare running times as well, we use a common range of hyperparameter values for all\nmethods. These values are shown in Table 1.\nFor Bayesian methods, we report the negative of the marginal likelihood approximation (\u2018Neg-Log-\nLik\u2019). This is (the negative of) the value of the lower bound at the maximum. We also report the\nn log \u02c6pn/N where \u02c6pn are the predictive probabilities of the test\ndata and N is the total number of test-pairs. A lower value is better and a value of 1 is equivalent\nto random coin-\ufb02ipping. In addition, we report the total time taken for hyperparameter selection.\n\nlog-loss computed as follows:\u2212(cid:80)\n\n1https://archive.ics.uci.edu/ml/datasets.html and http://www.csie.ntu.edu.tw/\u223ccjlin/libsvmtools/datasets/\n2Available at https://www.cs.ubc.ca/\u223cschmidtm/Software/minFunc.html\n\n6\n\n\fModel\n\nLogReg\n\nGP class\n\nGP reg\n\nDataset\na1a\nColon\nIonosphere\nSonar\nUSPS-3vs5\nHousing\nTriazines\nSpace ga\n\nN\n32,561\n62\n351\n208\n1,540\n506\n186\n3,106\n\nD\n123\n2000\n34\n60\n256\n13\n60\n6\n\n%Train\n5%\n50%\n50%\n50%\n50%\n50%\n50%\n50%\n\n#Splits Hyperparameter range\n\u03b4 = logspace(-3,1,30)\n1\n\u03b4 = logspace(0,6,30)\n10\nfor all datasets\n10\n10\nlog(l) = linspace(-1,6,15)\nlog(\u03c3) = linspace(-1,6,15)\n5\nlog(l) = linspace(-1,6,15)\n10\nlog(\u03c3) = linspace(-1,6,15)\n10\n1\nlog(b) = linspace(-5,1,2)\n\nTable 1: A list of models and datasets. %Train is the % of training data. The last column shows the\nhyperparameters values (\u2018linspace\u2019 and \u2018logspace\u2019 refer to Matlab commands).\n\nDataset Methods\n\na1a\n\nColon\n\nMAP\nMean-Field\nCholesky\nProximal\nMAP\nMean-Field\nProximal\n\nNeg-Log-Lik Log Loss\n\u2014\n792.8\n590.1\n590.1\n\u2014\n18.35 (0.11)\n15.82 (0.13)\n\n0.499\n0.505\n0.488\n0.488\n0.78 (0.01)\n0.78 (0.01)\n0.70 (0.01)\n\nTime\n27s\n21s\n12m\n7m\n7s (0.00)\n15m (0.04)\n18m (0.14)\n\nTable 2: A summary of the results obtained on Bayesian logistic regression. In all columns, a lower\nvalues implies better performance.\n\nFor MAP, this is the total cross-validation time, whereas for Bayesian methods it is the time taken\nto compute \u2018Neg-Log-Lik\u2019 for all hyperparameters values over the whole range.\nWe summarize these results in Table 2. For all columns, a lower value is better. We see that for \u2018a1a\u2019,\nfully Bayesian methods perform slightly better than MAP. More importantly, the Proximal method\nis faster than the Cholesky method but obtains the same error and marginal likelihood estimate. For\nthe Proximal method, we use updates of (17) and (16) because D (cid:28) N, but even in this scenario,\nthe Cholesky method is slow due to expensive line-search for a large number of parameters.\nFor the \u2018Colon\u2019 dataset, we use the update (19) for the Proximal method. We do not compare to\nthe Cholesky method because it is too slow for the large datasets.\nIn Table 2, we see that, our\nimplementation is as fast as the Mean-Field method but performs signi\ufb01cantly better.\nOverall, with the Proximal method, we achieve the same results as the Cholesky method but take less\ntime. In some cases, we can also match the running time of the Mean-Field method. Note that the\nMean-Field method does not give bad predictions and the minimum value of log-loss are comparable\nto our approach. However, as Neg-Log-Lik values for the Mean-Field method are inaccurate, it ends\nup choosing a bad hyperparameter value. This is expected as the Mean-Field method makes an\nextreme approximation. Therefore, cross-validation is more appropriate for the Mean-Field method.\nGaussian process classi\ufb01cation and regression: We compare the Proximal method to expectation\npropagation (EP) and Laplace approximation. We use the GPML toolbox for this comparison. We\nused a squared-exponential kernel for the Gaussian process with two scale parameters \u03c3 and l (as\nde\ufb01ned in GPML toolbox). We do a grid search over these hyperparameters. The grid values are\ngiven in Table 1. We report the log-loss and running time for each method.\nThe left plot in Figure 1 shows the log-loss for GP classi\ufb01cation on USPS 3vs5 dataset, where the\nProximal method shows very similar behaviour to EP. These results are summarized in Table 3. We\nsee that our method performs similar to EP, sometimes a bit better. The running times of EP and\nthe Proximal method are also comparable. The advantage of our approach is that it is easier to\nimplement compared to EP and it is also numerically robust. The predictive probabilities obtained\nwith EP and the Proximal method for \u2019USPS 3vs5\u2019 dataset are shown in the right plot of Figure\n1. The horizontal axis shows the test examples in an ascending order; the examples are sorted\naccording to their predictive probabilities obtained with EP. The probabilities themselves are shown\nin the y-axis. A higher value implies a better performance, therefore the Proximal method gives\n\n7\n\n\fFigure 1: In the left \ufb01gure, the top row shows the log-loss and the bottom row shows the running time\nin seconds for the \u2018USPS 3vs5\u2019 dataset. In each plot, the minimum value of the log-loss is shown\nwith a black circle. The right \ufb01gure shows the predictive probabilities obtained with EP and the\nProximal method. The horizontal axis shows the test examples in an ascending order; the examples\nare sorted according to their predictive probabilities obtained with EP. The probabilities themselves\nare shown in the y-axis. A higher value implies a better performance, therefore the Proximal method\ngives estimates better than EP.\n\nData\nIonosphere\nSonar\nUSPS-3vs5\nHousing\nTriazines\nSpace ga\n\nLaplace\n.285 (.002)\n.410 (.002)\n.101 (.002)\n1.03 (.004)\n1.35 (.006)\n1.01 (\u2014)\n\nLog Loss\nEP\n.234 (.002)\n.341 (.003)\n.065 (.002)\n.300 (.006)\n1.36 (.006)\n.767 (\u2014)\n\nProximal\n.230 (.002)\n.317 (.004)\n.055 (.003)\n.310 (.009)\n1.35 (.006)\n.742 (\u2014)\n\nTime (s is sec, m is min, h is hr)\nProximal\n3.6m (.10)\n63s (.13)\n1h (.02)\n61m (1.8)\n14m (.30)\n11h (\u2014)\n\nEP\n3.8m (.10)\n45s (.01)\n1h (.06)\n25m (.65)\n8m (.04)\n5h (\u2014)\n\nLaplace\n10s (.3)\n4s (.01)\n1m (.06)\n.36m (.00)\n10s (.10)\n2m (\u2014)\n\nTable 3: Results for the GP classi\ufb01cation using a logistic likelihood and the GP regression using a\nLaplace likelihood. For all rows, a lower value is better.\n\nestimates better than EP. The improvement in the performance is due to the numerical error in the\nlikelihood implementation. For the Proximal method, we use the method of [23], which is quite\naccurate. Designing such accurate likelihood approximations for EP is challenging.\n\n7 Discussion and Future Work\n\nIn this paper, we have proposed a proximal framework that uses the KL proximal term to take\nthe geometry of the posterior distribution into account. We established the equivalence between our\nproximal-point algorithm and natural-gradient methods. We proposed a proximal-gradient algorithm\nthat exploits the structure of the bound to simplify the optimization. An important future direction\nis to apply stochastic approximations to approximate gradients. This extension is discussed in [21].\nIt is also important to design a line-search method to set the step sizes. In addition, our proximal\nframework can also be used for distributed optimization in variational inference [26, 11].\n\nAcknowledgments\n\nMohammad Emtiyaz Khan would like to thank Masashi Sugiyama and Akiko Takeda from Uni-\nversity of Tokyo, Matthias Grossglauser and Vincent Etter from EPFL, and Hannes Nickisch from\nPhilips Research (Hamburg) for useful discussions and feedback. Pierre Baqu\u00b4e was supported in\npart by the Swiss National Science Foundation, under the grant CRSII2-147693 \u201dTracking in the\nWild\u201d.\n\n8\n\n0.10.10.20.20.40.40.60.6Laplace-uspslog(s)0246log(sigma)02460.50.50.51Laplace-uspslog(s)0246log(sigma)02460.070.10.10.20.20.40.40.60.6EP-uspslog(s)024602461010101515202030303030EP-uspslog(s)024602460.070.070.10.10.20.20.40.40.60.6Prox-uspslog(s)02460246551010151520203030404050Prox-uspslog(s)02460246Test Examples050100150200250300Predictive Prob00.10.20.30.40.50.60.70.80.91EP vs ProximalEPProximal\fReferences\n[1] Matthew D Hoffman, David M Blei, Chong Wang, and John Paisley. Stochastic variational inference. The\n\nJournal of Machine Learning Research, 14(1):1303\u20131347, 2013.\n\n[2] Tim Salimans, David A Knowles, et al. Fixed-form variational posterior approximation through stochastic\n\nlinear regression. Bayesian Analysis, 8(4):837\u2013882, 2013.\n\n[3] Rajesh Ranganath, Sean Gerrish, and David M Blei. Black box variational inference. arXiv preprint\n\narXiv:1401.0118, 2013.\n\n[4] Michalis Titsias and Miguel L\u00b4azaro-Gredilla. Doubly Stochastic Variational Bayes for Non-Conjugate\n\nInference. In International Conference on Machine Learning, 2014.\n\n[5] Masa-Aki Sato. Online model selection based on the variational Bayes. Neural Computation, 13(7):1649\u2013\n\n1681, 2001.\n\n[6] A. Honkela, T. Raiko, M. Kuusela, M. Tornio, and J. Karhunen. Approximate Riemannian conjugate\ngradient learning for \ufb01xed-form variational Bayes. The Journal of Machine Learning Research, 11:3235\u2013\n3268, 2011.\n\n[7] St\u00b4ephane Chr\u00b4etien and Alfred OIII Hero. Kullback proximal algorithms for maximum-likelihood estima-\n\ntion. Information Theory, IEEE Transactions on, 46(5):1800\u20131810, 2000.\n\n[8] Paul Tseng. An analysis of the EM algorithm and entropy-like proximal point methods. Mathematics of\n\nOperations Research, 29(1):27\u201344, 2004.\n\n[9] M. Teboulle. Convergence of proximal-like algorithms. SIAM Jon Optimization, 7(4):1069\u20131083, 1997.\n[10] Pradeep Ravikumar, Alekh Agarwal, and Martin J Wainwright. Message-passing for graph-structured\nlinear programs: Proximal projections, convergence and rounding schemes. In International Conference\non Machine Learning, 2008.\n\n[11] Behnam Babagholami-Mohamadabadi, Sejong Yoon, and Vladimir Pavlovic. D-MFVI: Distributed mean\n\n\ufb01eld variational inference using Bregman ADMM. arXiv preprint arXiv:1507.00824, 2015.\n\n[12] Bo Dai, Niao He, Hanjun Dai, and Le Song. Scalable Bayesian inference via particle mirror descent.\n\nComputing Research Repository, abs/1506.03101, 2015.\n\n[13] Lucas Theis and Matthew D Hoffman. A trust-region method for stochastic variational inference with\n\napplications to streaming data. International Conference on Machine Learning, 2015.\n\n[14] Razvan Pascanu and Yoshua Bengio. Revisiting natural gradient for deep networks. arXiv preprint\n\narXiv:1301.3584, 2013.\n\n[15] Ulrich Paquet. On the convergence of stochastic variational inference in bayesian networks. NIPS Work-\n\nshop on variational inference, 2014.\n\n[16] Nicholas G Polson, James G Scott, and Brandon T Willard. Proximal algorithms in statistics and machine\n\nlearning. arXiv preprint arXiv:1502.03175, 2015.\n\n[17] Harri Lappalainen and Antti Honkela. Bayesian non-linear independent component analysis by multi-\n\nlayer perceptrons. In Advances in independent component analysis, pages 93\u2013121. Springer, 2000.\n\n[18] Chong Wang and David M. Blei. Variational inference in nonconjugate models. J. Mach. Learn. Res.,\n\n14(1):1005\u20131031, April 2013.\n\n[19] M. Seeger and H. Nickisch. Large scale Bayesian inference and experimental design for sparse linear\n\nmodels. SIAM Journal of Imaging Sciences, 4(1):166\u2013199, 2011.\n\n[20] Antti Honkela and Harri Valpola. Unsupervised variational Bayesian learning of nonlinear models. In\n\nAdvances in neural information processing systems, pages 593\u2013600, 2004.\n\n[21] Mohammad Emtiyaz Khan, Reza Babanezhad, Wu Lin, Mark Schmidt, and Masashi Sugiyama. Conver-\ngence of Proximal-Gradient Stochastic Variational Inference under Non-Decreasing Step-Size Sequence.\narXiv preprint arXiv:1511.00146, 2015.\n\n[22] Carl Edward Rasmussen and Christopher K. I. Williams. Gaussian Processes for Machine Learning. MIT\n\nPress, 2006.\n\n[23] B. Marlin, M. Khan, and K. Murphy. Piecewise bounds for estimating Bernoulli-logistic latent Gaussian\n\nmodels. In International Conference on Machine Learning, 2011.\n\n[24] Mohammad Emtiyaz Khan. Decoupled Variational Inference. In Advances in Neural Information Pro-\n\ncessing Systems, 2014.\n\n[25] E. Challis and D. Barber. Concave Gaussian variational approximations for inference in large-scale\n\nBayesian linear models. In International conference on Arti\ufb01cial Intelligence and Statistics, 2011.\n\n[26] Huahua Wang and Arindam Banerjee. Bregman alternating direction method of multipliers. In Advances\n\nin Neural Information Processing Systems, 2014.\n\n9\n\n\f", "award": [], "sourceid": 1880, "authors": [{"given_name": "Mohammad Emtiyaz", "family_name": "Khan", "institution": "EPFL"}, {"given_name": "Pierre", "family_name": "Baque", "institution": null}, {"given_name": "Fran\u00e7ois", "family_name": "Fleuret", "institution": "Idiap Research Institute"}, {"given_name": "Pascal", "family_name": "Fua", "institution": null}]}