{"title": "VAE Learning via Stein Variational Gradient Descent", "book": "Advances in Neural Information Processing Systems", "page_first": 4236, "page_last": 4245, "abstract": "A new method for learning variational autoencoders (VAEs) is developed, based on Stein variational gradient descent. A key advantage of this approach is that one need not make parametric assumptions about the form of the encoder distribution. Performance is further enhanced by integrating the proposed encoder with importance sampling. Excellent performance is demonstrated across multiple unsupervised and semi-supervised problems, including semi-supervised analysis of the ImageNet data, demonstrating the scalability of the model to large datasets.", "full_text": "VAE Learning via Stein Variational Gradient Descent\n\nYunchen Pu, Zhe Gan, Ricardo Henao, Chunyuan Li, Shaobo Han, Lawrence Carin\n{yp42, zg27, r.henao, cl319, shaobo.han, lcarin}@duke.edu\n\nDepartment of Electrical and Computer Engineering, Duke University\n\nAbstract\n\nA new method for learning variational autoencoders (VAEs) is developed, based\non Stein variational gradient descent. A key advantage of this approach is that\none need not make parametric assumptions about the form of the encoder distri-\nbution. Performance is further enhanced by integrating the proposed encoder with\nimportance sampling. Excellent performance is demonstrated across multiple un-\nsupervised and semi-supervised problems, including semi-supervised analysis of\nthe ImageNet data, demonstrating the scalability of the model to large datasets.\n\n1\n\nIntroduction\n\nThere has been signi\ufb01cant recent interest in the variational autoencoder (VAE) [11], a generalization\nof the original autoencoder [33]. VAEs are typically trained by maximizing a variational lower\nbound of the data log-likelihood [2, 10, 11, 12, 18, 21, 22, 23, 30, 34, 35]. To compute the variational\nexpression, one must be able to explicitly evaluate the associated distribution of latent features, i.e.,\nthe stochastic encoder must have an explicit analytic form. This requirement has motivated design\nof encoders in which a neural network maps input data to the parameters of a simple distribution,\ne.g., Gaussian distributions have been widely utilized [1, 11, 27, 25].\nThe Gaussian assumption may be too restrictive in some cases [28]. Consequently, recent work has\nconsidered normalizing \ufb02ows [28], in which random variables from (for example) a Gaussian distri-\nbution are fed through a series of nonlinear functions to increase the complexity and representational\npower of the encoder. However, because of the need to explicitly evaluate the distribution within the\nvariational expression used when learning, these nonlinear functions must be relatively simple, e.g.,\nplanar \ufb02ows. Further, one may require many layers to achieve the desired representational power.\nWe present a new approach for training a VAE. We recognize that the need for an explicit form for\nthe encoder distribution is only a consequence of the fact that learning is performed based on the\nvariational lower bound. For inference (e.g., at test time), we do not need an explicit form for the\ndistribution of latent features, we only require fast sampling from the encoder. Consequently, rather\nthan directly employing the traditional variational lower bound, we seek to minimize the Kullback-\nLeibler (KL) distance between the true posterior of model and latent parameters. Learning then\nbecomes a novel application of Stein variational gradient descent (SVGD) [15], constituting its \ufb01rst\napplication to training VAEs. We extend SVGD with importance sampling [1], and also demonstrate\nits novel use in semi-supervised VAE learning.\nThe concepts developed here are demonstrated on a wide range of unsupervised and semi-supervised\nlearning problems, including a large-scale semi-supervised analysis of the ImageNet dataset. These\nexperimental results illustrate the advantage of SVGD-based VAE training, relative to traditional\napproaches. Moreover, the results demonstrate further improvements realized by integrating SVGD\nwith importance sampling.\nIndependent work by [3, 6] proposed the similar models, in which the aurthers incorporated SVGD\nwith VAEs [3] and importance sampling [6] for unsupervised learning tasks.\n\n31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.\n\n\f2 Stein Learning of Variational Autoencoder (Stein VAE)\n\n2.1 Review of VAE and Motivation for Use of SVGD\nn=1, where xn are modeled via decoder xn|zn \u223c p(x|zn; \u03b8). A prior\nConsider data D = {xn}N\n(cid:80)N\np(z) is placed on the latent codes. To learn parameters \u03b8, one typically is interested in maximizing\nn=1 log p(xn; \u03b8). A variational lower bound is often\nthe empirical expected log-likelihood, 1\n(cid:105)\n(cid:104) p(x|z; \u03b8)p(z)\nN\nemployed:\n\n(1)\nwith log p(x; \u03b8) \u2265 L(\u03b8, \u03c6; x), and where Ez|x;\u03c6[\u00b7] is approximated by averaging over a \ufb01nite\nnumber of samples drawn from encoder q(z|x; \u03c6). Parameters \u03b8 and \u03c6 are typically iteratively\n\noptimized via stochastic gradient descent [11], seeking to maximize(cid:80)N\n\n= \u2212KL(q(z|x; \u03c6)(cid:107)p(z|x; \u03b8)) + log p(x; \u03b8) ,\n\nL(\u03b8, \u03c6; x) = Ez|x;\u03c6 log\n\nq(z|x; \u03c6)\n\nn=1 L(\u03b8, \u03c6; xn).\n\nTo evaluate the variational expression in (1), we require the ability to sample ef\ufb01ciently from\nq(z|x; \u03c6), to approximate the expectation. We also require a closed form for this encoder, to evalu-\nate log[p(x|z; \u03b8)p(z)/q(z|x; \u03c6)]. In the proposed VAE learning framework, rather than maximiz-\ning the variational lower bound explicitly, we focus on the term KL(q(z|x; \u03c6)(cid:107)p(z|x; \u03b8)), which\nwe seek to minimize. This can be achieved by leveraging Stein variational gradient descent (SVGD)\n[15]. Importantly, for SVGD we need only be able to sample from q(z|x; \u03c6), and we need not\npossess its explicit functional form.\nIn the above discussion, \u03b8 is treated as a parameter; below we treat it as a random variable, as\nwas considered in the Appendix of [11]. Treatment of \u03b8 as a random variable allows for model\naveraging, and a point estimate of \u03b8 is revealed as a special case of the proposed method.\nThe set of codes associated with all xn \u2208 D is represented Z = {zn}N\n\nn=1. The prior on {\u03b8,Z} is\nn=1 p(zn). We desire the posterior p(\u03b8,Z|D). Consider the\n\nhere represented as p(\u03b8,Z) = p(\u03b8)(cid:81)N\n\nrevised variational expression\nL1(q;D) = Eq(\u03b8,Z) log\n(2)\nwhere p(D;M) is the evidence for the underlying model M. Learning q(\u03b8,Z) such that L1 is\nmaximized is equivalent to seeking q(\u03b8,Z) that minimizes KL(q(\u03b8,Z)(cid:107)p(\u03b8,Z|D)). By leveraging\nand generalizing SVGD, we will perform the latter.\n\n= \u2212KL(q(\u03b8,Z)(cid:107)p(\u03b8,Z|D)) + log p(D;M) ,\n\n(cid:104) p(D|Z, \u03b8)p(\u03b8,Z)\n\nq(\u03b8,Z)\n\n(cid:105)\n\n2.2 Stein Variational Gradient Descent (SVGD)\nRather than explicitly specifying a form for p(\u03b8,Z|D), we sequentially re\ufb01ne samples of \u03b8 and Z,\nsuch that they are better matched to p(\u03b8,Z|D). We alternate between updating the samples of \u03b8 and\nsamples of Z, analogous to how \u03b8 and \u03c6 are updated alternatively in traditional VAE optimization\nof (1). We \ufb01rst consider updating samples of \u03b8, with the samples of Z held \ufb01xed. Speci\ufb01cally,\nassume we have samples {\u03b8j}M\nj=1 drawn from\ndistribution q(Z). We wish to transform {\u03b8j}M\nj=1 by feeding them through a function, and the\ncorresponding (implicit) transformed distribution from which they are drawn is denoted as qT (\u03b8). It\nis desired that, in a KL sense, qT (\u03b8)q(Z) is closer to p(\u03b8,Z|D) than was q(\u03b8)q(Z). The following\ntheorem is useful for de\ufb01ning how to best update {\u03b8j}M\nTheorem 1 Assume \u03b8 and Z are Random Variables (RVs) drawn from distributions q(\u03b8) and q(Z),\nrespectively. Consider the transformation T (\u03b8) = \u03b8 + \u0001\u03c8(\u03b8;D) and let qT (\u03b8) represent the distri-\nbution of \u03b8\n\nj=1 drawn from distribution q(\u03b8), and samples {zjn}M\n\n= T (\u03b8). We have\n\nj=1.\n\n(cid:48)\n\n(cid:17)|\u0001=0 = \u2212E\u03b8\u223cq(\u03b8)\n\n(cid:0)trace(Ap(\u03b8;D))(cid:1) ,\n\n(cid:16)\n\n\u2207\u0001\n\nKL(qT(cid:107)p)\n\n(3)\nwhere qT = qT (\u03b8)q(Z), p = p(\u03b8,Z|D), Ap(\u03b8;D) = \u2207\u03b8 log \u02dcp(\u03b8;D)\u03c8(\u03b8;D)T + \u2207\u03b8\u03c8(\u03b8;D),\nlog \u02dcp(\u03b8;D) = EZ\u223cq(Z)[log p(D,Z, \u03b8)], and p(D,Z, \u03b8) = p(D|Z, \u03b8)p(\u03b8,Z).\nThe proof is provided in Appendix A. Following [15], we assume \u03c8(\u03b8;D) lives in a reproducing\nkernel Hilbert space (RKHS) with kernel k(\u00b7,\u00b7). Under this assumption, the solution for \u03c8(\u03b8;D)\n\n2\n\n\fthat maximizes the decrease in the KL distance (3) is\n\n\u03c8\u2217(\u00b7;D) = Eq(\u03b8)[k(\u03b8,\u00b7)\u2207\u03b8 log \u02dcp(\u03b8;D) + \u2207\u03b8k(\u03b8,\u00b7)] .\n\n(4)\nTheorem 1 concerns updating samples from q(\u03b8) assuming \ufb01xed q(Z). Similarly, to update q(Z)\nwith q(\u03b8) \ufb01xed, we employ a complementary form of Theorem 1 (omitted for brevity). In that case,\nwe consider transformation T (Z) = Z + \u0001\u03c8(Z;D), with Z \u223c q(Z), and function \u03c8(Z;D) is also\nassumed to be in a RKHS.\nThe expectations in (3) and (4) are approximated by samples \u03b8(t+1)\n\n= \u03b8(t)\n\nj + \u0001\u2206\u03b8(t)\n\nj , with\n\nj\n\n(cid:21)\n\n,\n\n(cid:80)M\n\n(cid:20)\n(cid:80)N\n(cid:20)\n\nM\n\n\u2206\u03b8(t)\n\nj \u2248 1\n\nM\n\nj )\u2207\u03b8(t)\n\nj(cid:48)\n\nj(cid:48) ;D) + \u2207\u03b8(t)\n\nj(cid:48)\n\n(5)\n\nM\n\nn=1\n\nj(cid:48)=1\n\nj(cid:48)=1\n\nj ))\n\n\u2206z(t)\n\nk\u03b8(\u03b8(t)\n\nk\u03b8(\u03b8(t)\n\nkz(z(t)\n\nj(cid:48) , \u03b8(t)\n\njn = 1\nM\n\nlog \u02dcp(\u03b8(t)\n\nwith \u2207\u03b8 log \u02dcp(\u03b8;D) \u2248 1\nmanifested for the latent variables z(t+1)\n\nj(cid:48) , \u03b8(t)\n(cid:80)M\nj=1 \u2207\u03b8 log p(xn|zjn, \u03b8)p(\u03b8). A similar update of samples is\njn = z(t)\njn )\u2207z(t)\n\njn + \u0001\u2206z(t)\n(cid:80)M\njn:\nj(cid:48)n;D) + \u2207z(t)\nwhere \u2207zn log \u02dcp(zn;D) \u2248 1\n(cid:48)\nj)p(zn). The kernels used to update sam-\nples of \u03b8 and zn are in general different, denoted respectively k\u03b8(\u00b7,\u00b7) and kz(\u00b7,\u00b7), and \u0001 is a small\nstep size. For notational simplicity, M is the same in (5) and (6), but in practice a different number\nof samples may be used for \u03b8 and Z.\nIf M = 1 for parameter \u03b8, indices j and j(cid:48) are removed in (5). Learning then reduces to gradient\ndescent and a point estimate for \u03b8, identical to the optimization procedure used for the traditional\nVAE expression in (1), but with the (multiple) samples associated with Z sequentially transformed\nvia SVGD (and, importantly, without the need to assume a form for q(z|x; \u03c6)). Therefore, if only a\npoint estimate of \u03b8 is desired, (1) can be optimized wrt \u03b8, while for updating Z SVGD is applied.\n\nj(cid:48)n, z(t)\n(cid:80)M\nj=1 \u2207zn log p(xn|zn, \u03b8\n\nj(cid:48)n, z(t)\njn )\n\nlog \u02dcp(z(t)\n\nkz(z(t)\n\n(cid:21)\n\n,\n\n(6)\n\nj(cid:48) n\n\nj(cid:48) n\n\n2.3 Ef\ufb01cient Stochastic Encoder\n\njn}M\n\nAt iteration t of the above learning procedure, we realize a set of latent-variable (code) samples\nj=1 for each xn \u2208 D under analysis. For large N, training may be computationally expensive.\n{z(t)\nFurther, the need to evolve (learn) samples {zj\u2217}M\nj=1 for each new test sample, x\u2217, is undesirable.\nWe therefore develop a recognition model that ef\ufb01ciently computes samples of latent codes for a data\nsample of interest. The recognition model draws samples via zjn = f \u03b7(xn, \u03bejn) with \u03bejn \u223c q0(\u03be).\nDistribution q0(\u03be) is selected such that it may be easily sampled, e.g., isotropic Gaussian.\nAfter each iteration of updating the samples of Z, we re\ufb01ne recognition model f \u03b7(x, \u03be) to mimic\nthe Stein sample dynamics. Assume recognition-model parameters \u03b7(t) have been learned thus far.\njn = f \u03b7(t) (xn, \u03bejn), with \u03bejn \u223c q0(\u03be).\nUsing \u03b7(t), latent codes for iteration t are constituted as z(t)\nThese codes are computed for all data xn \u2208 Bt, where Bt \u2282 D is the minibatch of data at iteration\nt. The change in the codes is \u2206z(t)\njn, as de\ufb01ned in (6). We then update \u03b7 to match the re\ufb01ned codes,\nas\n\n(cid:80)\n\n(cid:80)M\nj=1 (cid:107)f \u03b7(xn, \u03bejn) \u2212 z(t+1)\n\njn\n\n(cid:107)2 .\n\n(7)\n\n\u03b7(t,k\u22121) \u2212 \u03b4(cid:80)\n\n\u03b7(t+1) = arg min\u03b7\n\nxn\u2208Bt\n\n(cid:80)M\nj=1 \u2206\u03b7(t,k\u22121)\n\njn\n\nxn\u2208Bt\n\n, where \u2206\u03b7(t,k\u22121)\n\nThe analytic solution of (7) is intractable. We update \u03b7 with K steps of gradient descent as \u03b7(t,k) =\n= \u2202\u03b7f \u03b7(xn, \u03bejn)(f \u03b7(xn, \u03bejn) \u2212\n)|\u03b7=\u03b7(t,k\u22121), \u03b4 is a small step size, \u03b7(t) = \u03b7(t,0), \u03b7(t+1) = \u03b7(t,K), and \u2202\u03b7f \u03b7(xn, \u03bejn) is\nz(t+1)\njn\nthe transpose of the Jacobian of f \u03b7(xn, \u03bejn) wrt \u03b7. Note that the use of minibatches mitigates\nchallenges of training with large training sets, D.\nThe function f \u03b7(x, \u03be) plays a role analogous to q(z|x; \u03c6) in (1), in that it yields a means of ef\ufb01-\nciently drawing samples of latent codes z, given observed x; however, we do not impose an explicit\nfunctional form for the distribution of these samples.\n\njn\n\n3\n\n\f3 Stein Variational Importance Weighted Autoencoder (Stein VIWAE)\n\n3.1 Multi-sample importance-weighted KL divergence\n\n(cid:104)\n\n(cid:104)\n\n(cid:80)k\n\n(cid:80)k\n\n(cid:105)\n\n(cid:105)\n\nRecall the variational expression in (1) employed in conventional VAE learning. Recently, [1, 19]\nshowed that the multi-sample (k samples) importance-weighted estimator\n\nLk(x) = Ez1,...,zk\u223cq(z|x)\n\nlog 1\nk\n\np(x,zi)\nq(zi|x)\n\ni=1\n\n,\n\n(8)\n\nprovides a tighter lower bound and a better proxy for the log-likelihood, where z1, . . . , zk are ran-\ndom variables sampled independently from q(z|x). Recall from (3) that the KL divergence played\na key role in the Stein-based learning of Section 2. Equation (8) motivates replacement of the KL\nobjective function with the multi-sample importance-weighted KL divergence\n\nq,p(\u0398;D) (cid:44) \u2212E\n\nKLk\n\n(9)\nwhere \u0398 = (\u03b8,Z) and \u03981:k = \u03981, . . . , \u0398k are independent samples from q(\u03b8,Z). Note that the\nspecial case of k = 1 recovers the standard KL divergence. Inspired by [1], the following theorem\n(proved in Appendix A) shows that increasing the number of samples k is guaranteed to reduce the\nKL divergence and provide a better approximation of target distribution.\n\n\u03981:k\u223cq(\u0398)\n\nlog 1\nk\n\ni=1\n\n,\n\np(\u0398i|D)\nq(\u0398i)\n\nq,p(\u0398;D) \u2265 KLk+1\n\nq,p (\u0398;D) \u2265 0, and if\n\nTheorem 2 For any natural number k, we have KLk\nq(\u0398)/p(\u0398|D) is bounded, then limk\u2192\u221e KLk\nWe minimize (9) with a sample transformation based on a generalization of SVGD and the recogni-\ntion model (encoder) is trained in the same way as in Section 2.3. Speci\ufb01cally, we \ufb01rst draw samples\nj=1 from a simple distribution q0(\u00b7), and convert these to approximate draws\n{\u03b81:k\nfrom p(\u03b81:k,Z 1:k|D) by minimizing the multi-sample importance weighted KL divergence via non-\nlinear functional transformation.\n\nq,p(\u0398;D) = 0.\n\nj=1 and {z1:k\n\njn }M\n\nj }M\n\n3.2\n\nImportance-weighted SVGD for VAEs\n\nThe following theorem generalizes Theorem 1 to multi-sample weighted KL divergence.\n\nq,p(\u0398,D) is the\nTheorem 3 Let \u03981:k be RVs drawn independently from distribution q(\u0398) and KLk\nmulti-sample importance weighted KL divergence in (9). Let T (\u0398) = \u0398 + \u0001\u03c8(\u0398;D) and qT (\u0398)\nrepresent the distribution of \u0398(cid:48) = T (\u0398). We have\n\n(cid:16)\n\n(cid:17)|\u0001=0 = \u2212E\n\n\u2207\u0001\n\nq,p(\u0398(cid:48);D)\n\nKLk\n\n\u03981:k\u223cq(\u0398)(Ak\n\np(\u03981:k;D)) .\n\n(10)\n\n(cid:16)\n,Z) . We have\n\u2207\u0001\np(\u03b81:k;D) = 1\n\nThe proof and detailed de\ufb01nition is provided in Appendix A. The following corollaries generalize\nTheorem 1 and (4) via use of importance sampling, respectively.\nCorollary 3.1 \u03b81:k and Z 1:k are RVs drawn independently from distributions q(\u03b8) and q(Z), re-\nspectively. Let T (\u03b8) = \u03b8 + \u0001\u03c8(\u03b8;D), qT (\u03b8) represent the distribution of \u03b8\n= T (\u03b8), and\n\u0398(cid:48) = (\u03b8\n\n(cid:48)\n\n(cid:48)\n\nKLk\n\n\u03b81:k\u223cq(\u03b8)(Ak\nqT ,p(\u0398(cid:48);D)\n(cid:80)k\ni=1 \u03c9iAp(\u03b8i;D), \u03c9i = EZ i\u223cq(Z)\n\np(\u03b81:k;D)) ,\n\n(cid:104) p(\u03b8i,Z i,D)\n\n(cid:105)\n\n, \u02dc\u03c9 = (cid:80)k\n\nwhere Ak\nAp(\u03b8;D) and log \u02dcp(\u03b8;D) are as de\ufb01ned in Theorem 1.\nCorollary 3.2 Assume \u03c8(\u03b8;D) lives in a reproducing kernel Hilbert space (RKHS) with kernel\nk\u03b8(\u00b7,\u00b7). The solution for \u03c8(\u03b8;D) that maximizes the decrease in the KL distance (11) is\n\nq(\u03b8i)q(Z i)\n\ni=1 \u03c9i;\n\n\u02dc\u03c9\n\n(11)\n\n(cid:104) 1\n\n\u02dc\u03c9\n\n(cid:80)k\n\ni=1 \u03c9i\n\n(cid:0)\u2207\u03b8i k\u03b8(\u03b8i,\u00b7) + k\u03b8(\u03b8i,\u00b7)\u2207\u03b8i log \u02dcp(\u03b8i;D)(cid:1)(cid:105)\n\n\u03c8\u2217(\u00b7;D) = E\n\n\u03b81:k\u223cq(\u03b8)\n\n.\n\n(12)\n\n(cid:17)|\u0001=0 = \u2212E\n\n4\n\n\fCorollary 3.1 and Corollary 3.2 provide a means of updating multiple samples {\u03b81:k\nj=1 from q(\u03b8)\nvia T (\u03b8i) = \u03b8i + \u0001\u03c8(\u03b8i;D). The expectation wrt q(Z) is approximated via samples drawn from\nq(Z). Similarly, we can employ a complementary form of Corollary 3.1 and Corollary 3.2 to update\nj=1 from q(Z). This suggests an importance-weighted learning procedure\nmultiple samples {Z 1:k\nthat alternates between update of particles {\u03b81:k\nj }M\nj=1, which is similar to the one in\nSection 2.2. Detailed update equations are provided in Appendix B.\n\nj=1 and {Z 1:k\n\nj }M\n\nj }M\n\nj }M\n\n4 Semi-Supervised Learning with Stein VAE\nn=1, where the label yn \u2208 {1, . . . , C} and the de-\nConsider labeled data as pairs Dl = {xn, yn}Nl\ncoder is modeled as (xn, yn|zn) \u223c p(x, y|zn; \u03b8, \u02dc\u03b8) = p(x|zn; \u03b8)p(y|zn; \u02dc\u03b8), where \u02dc\u03b8 represents\nthe parameters of the decoder for labels. The set of codes associated with all labeled data are repre-\nsented as Zl = {zn}Nl\nn=1. We desire to approximate the posterior distribution on the entire dataset\np(\u03b8, \u02dc\u03b8,Z,Zl|D,Dl) via samples, where D represents the unlabeled data, and Z is the set of codes\nassociated with D. In the following, we will only discuss how to update the samples of \u03b8, \u02dc\u03b8 and Zl.\nUpdating samples Z is the same as discussed in Sections 2 and 3.2 for Stein VAE and Stein VIWAE,\nrespectively.\nAssume {\u03b8j}M\n{zjn}M\nand (4), which is useful for de\ufb01ning how to best update {\u03b8j}M\nCorollary 3.3 Assume \u03b8, \u02dc\u03b8, Z and Zl are RVs drawn from distributions q(\u03b8), q(\u02dc\u03b8), q(Z) and\nq(Zl), respectively. Consider the transformation T (\u03b8) = \u03b8 + \u0001\u03c8(\u03b8;D,Dl) where \u03c8(\u03b8;D,Dl)\nlives in a RKHS with kernel k\u03b8(\u00b7,\u00b7). Let qT (\u03b8) represent the distribution of \u03b8\n= T (\u03b8). For\nqT = qT (\u03b8)q(Z)q(\u02dc\u03b8) and p = p(\u03b8, \u02dc\u03b8,Z|D,Dl), we have\n\nj=1 drawn from distribution q(\u02dc\u03b8), and samples\nj=1 drawn from (distinct) distribution q(Zl). The following corollary generalizes Theorem 1\n\nj=1 drawn from distribution q(\u03b8), {\u02dc\u03b8j}M\n\nj=1.\n\n(cid:48)\n\n(cid:16)\n\n(cid:17)|\u0001=0 = \u2212E\u03b8\u223cq(\u03b8)(Ap(\u03b8;D,Dl)) ,\n\n\u2207\u0001\n\nKL(qT(cid:107)p)\n\n(13)\nwhere Ap(\u03b8;D,Dl) = \u2207\u03b8\u03c8(\u03b8;D,Dl) + \u2207\u03b8 log \u02dcp(\u03b8;D,Dl)\u03c8(\u03b8;D,Dl)T , log \u02dcp(\u03b8;D,Dl) =\nEZ\u223cq(Z)[log p(D|Z, \u03b8)] + EZl\u223cq(Zl)[log p(Dl|Zl, \u03b8)], and the solution for \u03c8(\u03b8;D,Dl) that maxi-\nmizes the change in the KL distance (13) is\n\n\u03c8\u2217(\u00b7;D,Dl) = Eq(\u03b8)[k(\u03b8,\u00b7)\u2207\u03b8 log \u02dcp(\u03b8;D,Dl) + \u2207\u03b8k(\u03b8,\u00b7)] .\n\n(14)\n\nFurther details are provided in Appendix C.\n\nh(cid:107)x \u2212 x(cid:48)(cid:107)2\n\n5 Experiments\nFor all experiments, we use a radial basis-function (RBF) kernel as in [15], i.e., k(x, x(cid:48)) =\nexp(\u2212 1\n2), where the bandwidth, h, is the median of pairwise distances between cur-\nrent samples. q0(\u03b8) and q0(\u03be) are set to isotropic Gaussian distributions. We share the samples of \u03be\nacross data points, i.e., \u03bejn = \u03bej, for n = 1, . . . , N (this is not necessary, but it saves computation).\nThe samples of \u03b8 and z, and parameters of the recognition model, \u03b7, are optimized via Adam [9]\nwith learning rate 0.0002. We do not perform any dataset-speci\ufb01c tuning or regularization other\nthan dropout [32] and early stopping on validation sets. We set M = 100 and k = 50, and use\nminibatches of size 64 for all experiments, unless otherwise speci\ufb01ed.\n\n5.1 Expressive power of Stein recognition model\nGaussian Mixture Model We synthesize data by (i) drawing zn \u223c 1\n\nwhere \u00b51 = [5, 5]T , \u00b52 = [\u22125,\u22125]T ; (ii) drawing xn \u223c N (\u03b8zn, \u03c32I), where \u03b8 = (cid:2) 2 \u22121\n\n2N (\u00b51, I) + 1\n\n(cid:3) and\n\n2N (\u00b52, I),\n1 \u22122\n\n\u03c3 = 0.1. The recognition model f\u03b7(xn, \u03bej) is speci\ufb01ed as a multi-layer perceptron (MLP) with\n100 hidden units, by \ufb01rst concatenating \u03bej and xn into a long vector. The dimension of \u03bej is set\nto 2. The recognition model for standard VAE is also an MLP with 100 hidden units, and with the\nassumption of a Gaussian distribution for the latent codes [11].\n\n5\n\n\fFigure 1: Approximation of posterior distribution: Stein VAE vs. VAE. The \ufb01gures represent differ-\nent samples of Stein VAE. (left) 10 samples, (center) 50 samples, and (right) 100 samples.\nWe generate N = 10, 000 data points for training and 10 data points for testing. The analytic form\nof true posterior distribution is provided in Appendix D. Figure 1 shows the performance of Stein\nVAE approximations for the true posterior; other similar examples are provided in Appendix F. The\nStein recognition model is able to capture the multi-modal posterior and produce accurate density\napproximation.\n\n+\n\nPoisson Factor Analysis Given a discrete vector\nxn \u2208 ZP\n+, Poisson factor analysis [36] assumes xn\nis a weighted combination of V latent factors xn \u223c\nPois(\u03b8zn), where \u03b8 \u2208 RP\u00d7V\nis the factor loadings\nmatrix and zn \u2208 RV\n+ is the vector of factor scores.\nWe consider topic modeling with Dirichlet priors\non \u03b8v (v-th column of \u03b8) and gamma priors on each\ncomponent of zn.\nWe evaluate our model on the 20 Newsgroups\ndataset containing N = 18, 845 documents with a\nvocabulary of P = 2, 000. The data are partitioned\ninto 10,314 training, 1,000 validation and 7,531 test\ndocuments. The number of factors (topics) is set to\nV = 128. \u03b8 is \ufb01rst learned by Markov chain Monte\nCarlo (MCMC) [4]. We then \ufb01x \u03b8 at its MAP value,\nand only learn the recognition model \u03b7 using stan-\ndard VAE and Stein VAE; this is done, as in the\nprevious example, to examine the accuracy of the\nrecognition model to estimate the posterior of the\nlatent factors, isolated from estimation of \u03b8. The\nrecognition model is an MLP with 100 hidden units.\nAn analytic form of\nthe true posterior distribution\np(zn|xn) is intractable for this problem. Consequently,\nwe employ samples collected from MCMC as ground\ntruth. With \u03b8 \ufb01xed, we sample zn via Gibbs sampling, us-\ning 2,000 burn-in iterations followed by 2,500 collection\ndraws, retaining every 10th collection sample. We show\nthe marginal and pairwise posterior of one test data point\nin Figure 2. Additional results are provided in Appendix\nF. Stein VAE leads to a more accurate approximation than\nstandard VAE, compared to the MCMC samples. Consid-\nering Figure 2, note that VAE signi\ufb01cantly underestimates\nthe variance of the posterior (examining the marginals), a\nwell-known problem of variational Bayesian analysis [7].\nIn sharp contrast, Stein VAE yields highly accurate ap-\nproximations to the true posterior.\n5.2 Density estimation\nData We consider \ufb01ve benchmark datasets: MNIST and four text corpora: 20 Newsgroups\n(20News), New York Times (NYT), Science and RCV1-v2 (RCV2). For MNIST, we used the stan-\ndard split of 50K training, 10K validation and 10K test examples. The latter three text corpora\n\nTable 1: Negative log-likelihood (NLL) on\nMNIST. \u2020Trained with VAE and tested with\nIWAE. \u2021Trained and tested with IWAE.\nNLL\n89.90\nNormalizing \ufb02ow [28]\n85.10\nVAE + IWAE [1]\u2020\n86.76\nIWAE + IWAE [1]\u2021\n84.78\n85.21\nStein VAE + ELBO\n84.98\nStein VAE + S-ELBO\nStein VIWAE + ELBO\n83.01\nStein VIWAE + S-ELBO 82.88\n\nFigure 2: Univariate marginals and pairwise pos-\nteriors. Purple, red and green represent the distribu-\ntion inferred from MCMC, standard VAE and Stein\nVAE, respectively.\n\nMethod\n\nDGLM [27]\n\n6\n\n\fconsist of 133K, 166K and 794K documents. These three datasets are split into 1K validation, 10K\ntesting and the rest for training.\n\nEvaluation Given new data x\u2217 (testing data), the marginal log-likelihood/perplexity values are\nestimated by the variational evidence lower bound (ELBO) while integrating the decoder parame-\nters \u03b8 out, log p(x\u2217) \u2265 Eq(z\u2217)[log p(x\u2217, z\u2217)] + H(q(z\u2217)) = ELBO(q(z\u2217)), where p(x\u2217, z\u2217) =\nEq(\u03b8)[log p(x\u2217, \u03b8, z\u2217)] and H(q(\u00b7)) = \u2212Eq(log q(\u00b7)) is the entropy. The expectation is approxi-\nmated with samples {\u03b8j}M\nj=1 with z\u2217j = f \u03b7(x\u2217, \u03bej), \u03bej \u223c q0(\u03be). Directly evaluating\nq(z\u2217) is intractable, thus it is estimated via density transformation q(z) = q0(\u03be)\n.\n\nj=1 and {z\u2217j}M\n\n(cid:12)(cid:12)(cid:12)det \u2202f \u03b7(x,\u03be)\n\n(cid:12)(cid:12)(cid:12)\u22121\n\n\u2202\u03be\n\nDocNADE [14]\n\nDEF [24]\nNVDM [17]\n\nTable 2: Test perplexities on four text corpora.\nMethod\n\nWe further estimate the marginal log-\nlikelihood/perplexity values via the\nstochastic variational lower bound, as\nthe mean of 5K-sample importance\nweighting estimate [1]. Therefore, for\neach dataset, we report four results: (i)\nStein VAE + ELBO, (ii) Stein VAE + S-\nELBO, (iii) Stein VIWAE + ELBO and\n(iv) Stein VIWAE + S-ELBO; the \ufb01rst\nterm denotes the training procedure is\nemployed as Stein VAE in Section 2 or Stein VIWAE in Section 3; the second term denotes the\ntesting log-likelihood/perplexity is estimated by the ELBO or the stochastic variational lower bound,\nS-ELBO [1].\n\nScience RCV2\n1725\n742\n\u2014-\n1576\n550\n\u2014-\n549\n1499\n544\n1497\n1453\n523\n518\n1421\n\n20News NYT\n2496\n2416\n\u2014-\n2402\n2401\n2315\n2277\n\nStein VAE + ELBO\nStein VAE + S-ELBO\nStein VIWAE + ELBO\nStein VIWAE + S-ELBO\n\n896\n\u2014-\n852\n849\n845\n837\n829\n\nModel For MNIST, we train the model with one stochastic layer, zn, with 50 hidden units and\ntwo deterministic layers, each with 200 units. The nonlinearity is set as tanh. The visible layer,\nxn, follows a Bernoulli distribution. For the text corpora, we build a three-layer deep Poisson\nnetwork [24]. The sizes of hidden units are 200, 200 and 50 for the \ufb01rst, second and third layer,\nrespectively (see [24] for detailed architectures).\n\nResults The log-likelihood/perplexity results\nare summarized in Tables 1 and 2. On MNIST,\nour Stein VAE achieves a variational lower bound\nof -85.21 nats, which outperforms standard VAE\nwith the same model architecture. Our Stein VI-\nWAE achieves a log-likelihood of -82.88 nats,\nexceeding normalizing \ufb02ow (-85.1 nats) and im-\nportance weighted autoencoder (-84.78 nats),\nwhich is the best prior result obtained by feed-\nforward neural network (FNN). DRAW [5] and\nPixelRNN [20], which exploit spatial structure,\nachieved log-likelihoods of around -80 nats. Our\nmodel can also be applied on these models, but\nthis is left as interesting future work. To further illustrate the bene\ufb01t of model averaging, we vary\nthe number of samples for \u03b8 (while retaining 100 samples for Z) and show the results associated\nwith training/testing time in Figure 3. When M = 1 for \u03b8, our model reduces to a point estimate\nfor that parameter. Increasing the number of samples of \u03b8 (model averaging) improves the negative\nlog-likelihood (NLL). The testing time of using 100 samples of \u03b8 is around 0.12 ms per image.\n\nFigure 3: NLL vs. Training/Testing time on MNIST\nwith various numbers of samples for \u03b8.\n\n5.3 Semi-supervised Classi\ufb01cation\n\nWe consider semi-supervised classi\ufb01cation on MNIST and ImageNet [29] data. For each dataset,\nwe report the results obtained by (i) VAE, (ii) Stein VAE, and (iii) Stein VIWAE.\n\nMNIST We randomly split the training set into a labeled and unlabeled set, and the number of\nlabeled samples in each category varies from 10 to 300. We perform testing on the standard test\nset with 20 different training-set splits. The decoder for labels is implemented as p(yn|zn, \u02dc\u03b8) =\nsoftmax(\u02dc\u03b8zn). We consider two types of decoders for images p(xn|zn, \u03b8) and encoder f \u03b7(x, \u03be):\n\n7\n\n123456858687881510204060100200300Number of Samples (M)Time (s)Negative Log\u2212likelihood (nats)Negative Log\u2212likelihoodTesting Time for Entire DatasetTraining Time for Each Epoch\fN\u03c1\n\n10\n60\n100\n300\n\nVAE\u00a7\n\n3.33 \u00b1 0.14\n2.59 \u00b10.05\n2.40 \u00b10.02\n2.18 \u00b10.04\n\nVAE\n\nthis extension as future work.\n\nStein VIWAE\n1.90 \u00b1 0.05\n1.41 \u00b1 0.02\n0.99 \u00b1 0.02\n0.86 \u00b1 0.01\n\nStein VIWAE\n2.67 \u00b1 0.09\n2.09 \u00b1 0.03\n1.88 \u00b1 0.01\n1.75 \u00b1 0.01\n\nVAE\u2020\n\n2.44 \u00b1 0.17\n1.88 \u00b10.05\n1.47 \u00b10.02\n0.98 \u00b10.02\n\nFNN\n\nStein VAE\n2.78 \u00b1 0.24\n2.13 \u00b1 0.08\n1.92 \u00b1 0.05\n1.77 \u00b1 0.03\n\nCNN\n\nStein VAE\n1.94 \u00b1 0.24\n1.44 \u00b1 0.04\n1.01 \u00b1 0.03\n0.89 \u00b1 0.03\n\nTable 4: Semi-supervised classi\ufb01cation accuracy (%) on ImageNet.\n\nindicating the former produces more robust parameter estimates.\n\nState-of-\nresults [26] are achieved by the Ladder network, which can be employed with\n\nTable 3: Semi-supervised classi\ufb01cation error (%) on MNIST. N\u03c1 is the number\nof labeled images per class. \u00a7[12]; \u2020our implementation.\n\n(i) FNN: Following [12], we use a 50-dimensional latent variables zn and two hidden layers, each\nwith 600 hidden units, for both encoder and decoder; softplus is employed as the nonlinear activation\nfunction. (ii) All convolutional nets (CNN): Inspired by [31], we replace the two hidden layers with\n32 and 64 kernels of size 5 \u00d7 5 and a stride of 2. A fully connected layer is stacked on the CNN to\nproduce a 50-dimensional latent variables zn. We use the leaky recti\ufb01ed activation [16]. The input\nof the encoder is formed by spatially aligning and stacking xn and \u03be, while the output of decoder is\nthe image itself.\nTable 3 shows the classi-\n\ufb01cation results. Our Stein\nVAE and Stein VIWAE\nconsistently achieve bet-\nter performance than the\nVAE. We further observe\nthat the variance of Stein\nVIWAE results is much\nsmaller than that of Stein\nVAE results on small la-\nbeled data,\nthe-art\nour Stein-based approach, however, we will consider\nImageNet\n2012 We\nconsider scalability of our\nmodel to large datasets.\nWe split the 1.3 million\ntraining images into an\nunlabeled and labeled set,\nand vary the proportion\nof labeled images from\n1% to 40%. The classes\nare balanced to ensure\nthat no particular class\nis over-represented, i.e., the ratio of labeled and unlabeled images is the same for each class. We\nrepeat the training process 10 times for the training setting with labeled images ranging from 1% to\n10% , and 5 times for the the training setting with labeled images ranging from 20% to 40%. Each\ntime we utilize different sets of images as the unlabeled ones.\nWe employ an all convolutional net [31] for both the encoder and decoder, which replaces determin-\nistic pooling (e.g., max-pooling) with stridden convolutions. Residual connections [8] are incorpo-\nrated to encourage gradient \ufb02ow. The model architecture is detailed in Appendix E. Following [13],\nimages are resized to 256 \u00d7 256. A 224 \u00d7 224 crop is randomly sampled from the images or its\nhorizontal \ufb02ip with the mean subtracted [13]. We set M = 20 and k = 10.\nTable 4 shows classi\ufb01cation results indicating that Stein VAE and Stein IVWAE outperform VAE\nin all the experiments, demonstrating the effectiveness of our approach for semi-supervised classi-\n\ufb01cation. When the proportion of labeled examples is too small (< 10%), DGDN [21] outperforms\nall the VAE-based models, which is not surprising provided that our models are deeper, thus have\nconsiderably more parameters than DGDN [21].\n6 Conclusion\nWe have employed SVGD to develop a new method for learning a variational autoencoder, in which\nwe need not specify an a priori form for the encoder distribution. Fast inference is manifested\nby learning a recognition model that mimics the manner in which the inferred code samples are\nmanifested. The method is further generalized and improved by performing importance sampling.\nAn extensive set of results, for unsupervised and semi-supervised learning, demonstrate excellent\nperformance and scaling to large datasets.\n\nStein VIWAE DGDN [21]\n43.98\u00b1 1.15\n36.91 \u00b1 0.98\n46.92\u00b1 1.11\n42.57 \u00b1 0.84\n46.20 \u00b1 0.52\n47.36\u00b1 0.91\n48.67 \u00b1 0.31\n48.41\u00b1 0.76\n51.77 \u00b1 0.12\n51.51\u00b1 0.28\n54.14\u00b1 0.12\n55.45 \u00b1 0.11\n58.21 \u00b1 0.12\n57.34\u00b1 0.18\n\n1 % 35.92\u00b1 1.91\n2 % 40.15\u00b1 1.52\n5 % 44.27\u00b1 1.47\n10 % 46.92\u00b1 1.02\n20 % 50.43\u00b1 0.41\n30 % 53.24\u00b1 0.33\n40 % 56.89\u00b1 0.11\n\nStein VAE\n36.44 \u00b1 1.66\n41.71 \u00b1 1.14\n46.14 \u00b1 1.02\n47.83 \u00b1 0.88\n51.62 \u00b1 0.24\n55.02 \u00b1 0.22\n58.17 \u00b1 0.16\n\nAcknowledgements\n\nThis research was supported in part by ARO, DARPA, DOE, NGA, ONR and NSF.\n\n8\n\n\fReferences\n[1] Y. Burda, R. Grosse, and R. Salakhutdinov.\n\n2016.\n\nImportance weighted autoencoders.\n\nIn ICLR,\n\n[2] L. Chen, S. Dai, Y. Pu, C. Li, and Q. Su Lawrence Carin. Symmetric variational autoencoder\n\nand connections to adversarial learning. In arXiv, 2017.\n\n[3] Y. Feng, D. Wang, and Q. Liu. Learning to draw samples with amortized stein variational\n\ngradient descent. In UAI, 2017.\n\n[4] Z. Gan, C. Chen, R. Henao, D. Carlson, and L. Carin. Scalable deep poisson factor analysis\n\nfor topic modeling. In ICML, 2015.\n\n[5] K. Gregor, I. Danihelka, A. Graves, and D. Wierstra. Draw: A recurrent neural network for\n\nimage generation. In ICML, 2015.\n\n[6] J. Han and Q. Liu. Stein variational adaptive importance sampling. In UAI, 2017.\n\n[7] S. Han, X. Liao, D.B. Dunson, and L. Carin. Variational gaussian copula inference. In AIS-\n\nTATS, 2016.\n\n[8] K. He, X. Zhang, S. Ren, and Sun J. Deep residual learning for image recognition. In CVPR,\n\n2016.\n\n[9] D. Kingma and J. Ba. Adam: A method for stochastic optimization. In ICLR, 2015.\n\n[10] D. P. Kingma, T. Salimans, R. Jozefowicz, X.i Chen, I. Sutskever, and M. Welling. Improving\n\nvariational inference with inverse autoregressive \ufb02ow. In NIPS, 2016.\n\n[11] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014.\n\n[12] D.P. Kingma, D.J. Rezende, S. Mohamed, and M. Welling. Semi-supervised learning with\n\ndeep generative models. In NIPS, 2014.\n\n[13] A. Krizhevsky, I. Sutskever, and G. E. Hinton. Imagenet classi\ufb01cation with deep convolutional\n\nneural networks. In NIPS, 2012.\n\n[14] H. Larochelle and S. Laulyi. A neural autoregressive topic model. In NIPS, 2012.\n\n[15] Q. Liu and D. Wang. Stein variational gradient descent: A general purpose bayesian inference\n\nalgorithm. In NIPS, 2016.\n\n[16] A. L. Maas, A. Y. Hannun, and A. Y. Ng. Recti\ufb01er nonlinearities improve neural network\n\nacoustic models. In ICML, 2013.\n\n[17] Y. Miao, L. Yu, and Phil Blunsomi. Neural variational inference for text processing. In ICML,\n\n2016.\n\n[18] A. Mnih and K. Gregor. Neural variational inference and learning in belief networks. In ICML,\n\n2014.\n\n[19] A. Mnih and D. J. Rezende. Variational inference for monte carlo objectives. In ICML, 2016.\n\n[20] A. Oord, N. Kalchbrenner, and K. Kavukcuoglu. Pixel recurrent neural network. In ICML,\n\n2016.\n\n[21] Y. Pu, Z. Gan, R. Henao, X. Yuan, C. Li, A. Stevens, and L. Carin. Variational autoencoder for\n\ndeep learning of images, labels and captions. In NIPS, 2016.\n\n[22] Y. Pu, X. Yuan, and L. Carin. Generative deep deconvolutional learning. In ICLR workshop,\n\n2015.\n\n[23] Y. Pu, X. Yuan, A. Stevens, C. Li, and L. Carin. A deep generative deconvolutional image\n\nmodel. Arti\ufb01cial Intelligence and Statistics (AISTATS), 2016.\n\n9\n\n\f[24] R. Ranganath, L. Tang, L. Charlin, and D. M.Blei. Deep exponential families. In AISTATS,\n\n2015.\n\n[25] R. Ranganath, D. Tran, and D. M. Blei. Hierarchical variational models. In ICML, 2016.\n\n[26] A. Rasmus, M. Berglund, M. Honkala, H. Valpola, and T. Raiko. Semi-supervised learning\n\nwith ladder networks. In NIPS, 2015.\n\n[27] D. J. Rezende, S. Mohamed, and D. Wierstra. Stochastic backpropagation and approximate\n\ninference in deep generative models. In ICML, 2014.\n\n[28] D.J. Rezende and S. Mohamed. Variational inference with normalizing \ufb02ows. In ICML, 2015.\n\n[29] O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy,\nA. Khosla, M. Bernstein, A. C. Berg, and L. Fei-fei. Imagenet large scale visual recognition\nchallenge. IJCV, 2014.\n\n[30] D. Shen, Y. Zhang, R. Henao, Q. Su, and L. Carin. Deconvolutional latent-variable model for\n\ntext sequence matching. In arXiv, 2017.\n\n[31] J. T. Springenberg, A. Dosovitskiy, T. Brox, and M. Riedmiller. Striving for simplicity: The\n\nall convolutional net. In ICLR workshop, 2015.\n\n[32] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov. Dropout: A\n\nsimple way to prevent neural networks from over\ufb01tting. JMLR, 2014.\n\n[33] P. Vincent, H. Larochelle, I. Lajoie, Y. Bengio, and P.-A. Manzagol. Stacked denoising au-\ntoencoders: Learning useful representations in a deep network with a local denoising criterion.\nJMLR, 2010.\n\n[34] Y. Pu W. Wang, R. Henao, L. Chen, Z. Gan, C. Li, and Lawrence Carin. Adversarial symmetric\n\nvariational autoencoder. In NIPS, 2017.\n\n[35] Y. Zhang, D. Shen, G. Wang, Z. Gan, R. Henao, and L. Carin. Deconvolutional paragraph\n\nrepresentation learning. In NIPS, 2017.\n\n[36] M. Zhou, L. Hannah, D. Dunson, and L. Carin. Beta-negative binomial process and Poisson\n\nfactor analysis. In AISTATS, 2012.\n\n10\n\n\f", "award": [], "sourceid": 2225, "authors": [{"given_name": "Yuchen", "family_name": "Pu", "institution": "Duke University"}, {"given_name": "Zhe", "family_name": "Gan", "institution": "Duke University"}, {"given_name": "Ricardo", "family_name": "Henao", "institution": "Duke University"}, {"given_name": "Chunyuan", "family_name": "Li", "institution": "Duke University"}, {"given_name": "Shaobo", "family_name": "Han", "institution": "Duke University"}, {"given_name": "Lawrence", "family_name": "Carin", "institution": "Duke University"}]}