{"title": "A New Distribution on the Simplex with Auto-Encoding Applications", "book": "Advances in Neural Information Processing Systems", "page_first": 13670, "page_last": 13680, "abstract": "We construct a new distribution for the simplex using the Kumaraswamy distribution and an ordered stick-breaking process. We explore and develop the theoretical properties of this new distribution and prove that it exhibits symmetry (exchangeability) under the same conditions as the well-known Dirichlet. Like the Dirichlet, the new distribution is adept at capturing sparsity but, unlike the Dirichlet, has an exact and closed form reparameterization--making it well suited for deep variational Bayesian modeling. We demonstrate the distribution's utility in a variety of semi-supervised auto-encoding tasks. In all cases, the resulting models achieve competitive performance commensurate with their simplicity, use of explicit probability models, and abstinence from adversarial training.", "full_text": "A New Distribution on the Simplex\nwith Auto-Encoding Applications\n\nAndrew Stirn\u2217, Tony Jebara\u2020, David A Knowles\u2021\n\nDepartment of Computer Science\n\nColumbia University\nNew York, NY 10027\n\n{andrew.stirn,jebara,daknowles}@cs.columbia.edu\n\nAbstract\n\nWe construct a new distribution for the simplex using the Kumaraswamy dis-\ntribution and an ordered stick-breaking process. We explore and develop the\ntheoretical properties of this new distribution and prove that it exhibits symmetry\n(exchangeability) under the same conditions as the well-known Dirichlet. Like\nthe Dirichlet, the new distribution is adept at capturing sparsity but, unlike the\nDirichlet, has an exact and closed form reparameterization\u2013making it well suited\nfor deep variational Bayesian modeling. We demonstrate the distribution\u2019s utility\nin a variety of semi-supervised auto-encoding tasks. In all cases, the resulting\nmodels achieve competitive performance commensurate with their simplicity, use\nof explicit probability models, and abstinence from adversarial training.\n\n1\n\nIntroduction\n\nThe Variational Auto-Encoder (VAE) [12] is a computationally ef\ufb01cient approach for performing\nvariational inference [11, 27] since it avoids per-data-point variational parameters through the use of\nan inference network with shared global parameters. For models where stochastic gradient variational\nBayes requires Monte Carlo estimates in lieu of closed-form expectations, [23, 12] note that low-\nvariance estimators can be calculated from gradients of samples with respect to the variational\nparameters that describe their generating densities. In the case of the normal distribution, such\ngradients are straightforward to obtain via an explicit, tractable reparameterization, which is often\nreferred to as the \u201creparameterization trick\u201d. Unfortunately, most distributions do not admit such\na convenient reparameterization. Computing low-bias and low-variance stochastic gradients is an\nactive area of research with a detailed breakdown of current methods presented in [4]. Of particular\ninterest in Bayesian modeling is the well-known Dirichlet distribution that often serves as a conjugate\nprior for latent categorical variables. Perhaps the most desirable property of a Dirichlet prior is its\nability to induce sparsity by concentrating mass towards the corners of the simplex. In this work, we\ndevelop a surrogate distribution for the Dirichlet that offers explicit, tractable reparameterization, the\nability to capture sparsity, and has barycentric symmetry (exchangeability) properties equivalent to\nthe Dirichlet.\nGenerative processes can be used to infer missing class labels in semi-supervised learning. The\n\ufb01rst VAE-based method that used deep generative models for semi-supervised learning derived two\nvariational objectives for the same the generative process\u2013one for when labels are observed and one\nfor when labels are latent\u2013that are jointly optimized [13]. As they note, however, the variational\ndistribution over class labels appears only in the objective for unlabeled data. Its absence from\n\n\u2217jointly af\ufb01liated with New York Genome Center\n\u2020jointly af\ufb01liated with Spotify Technology S.A.\n\u2021jointly af\ufb01liated with Columbia University\u2019s Data Science Institute and the New York Genome Center\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fthe labeled-data objective, as they point out, results from their lack of a (Dirichlet) prior on the\n(latent) labels. We suspect they neglected to specify this prior, because, at the time, it would have\nrendered inference intractable. They ameliorate this shortcoming by introducing a discriminative third\nobjective, the cross-entropy of the variational distribution over class labels, which they compute over\nthe labeled data. They then jointly optimize the two variational objectives after adding a scaled version\nof the cross-entropy term. Our work builds on [13], while offering some key improvements. First,\nwe remove the need for adding an additional discriminative loss through our use of a Dirichlet prior.\nWe overcome intractability using our proposed distribution as an approximation for the Dirichlet\nposterior. Naturally, our generative process is slightly different, but it allows us to consider only\nunmodi\ufb01ed variational objectives. Second, we do not stack models together. Kingma et al.\u2019s best\nresults utilized a standard VAE (M1) to learn a latent space upon which their semi-supervised VAE\n(M2) was \ufb01t. For SVHN data, they perform dimensionality reduction with PCA prior to \ufb01tting M1.\nWe abandon the stacked-model approach in favor of training a single model with more expressive\nrecognition and generative networks. Also, we use minimal preprocessing (rescaling pixel intensities\nto [0, 1]).\nUse of the Kumaraswamy distribution [14] by the machine learning community has only occurred\nin the last few years. It has been used to \ufb01t Gaussian Mixture Models, for which a Dirichlet prior\nis part of the generative process, with VAEs [19]. To sample mixture weights from the variational\nposterior, they recognize they can decompose a Dirichlet into its stick-breaking Beta distributions\nand approximate them with the Kumaraswamy distribution. We too employ the same stick-breaking\ndecomposition coupled with Kumaraswamy approximations. However, we improve on this technique\nby expounding and resolving the order-dependence their approximation incurs. As we detail in\nsection 2, using the Kumaraswamy for stick-breaking is not order agnostic (exchangeable); the\ngenerated variable has a density that depends on ordering. We leverage the observation that one\ncan permute a Dirichlet\u2019s parameters, perform the stick-breaking sampling procedure with Beta\ndistributions, and undo the permutation on the sampled variable without affecting its density. Those\nsame authors also use this Beta-Kumaraswamy stick-breaking approximation to \ufb01t a Bayesian non-\nparametric model with a VAE [20]. Here too, they do not account for the impact ordering has on their\napproximation. Their latent space, being non-parametric, grows in dimensions when it insuf\ufb01ciently\nrepresents the data. As we demonstrate in section 2.2 and \ufb01g. 1, approximating sparse Dirichlet\nsamples with the Kumaraswamy stick-breaking decomposition without accounting for the ordering\ndependence produces a large bias in the samples\u2019 last dimension. We conjecture that their Bayesian\nnon-parametric model would utilize fewer dimensions with our proposed distribution and would be\nan interesting follow up to our work.\n\nFigure 1: Sampling bias for a 5-dimensional sparsity-inducing Dirichlet approximation using \u03b1 =\n5 (1, 1, 1, 1, 1). We maintain histograms for each sample dimension for three methods: Dirichlet,\n1\nKumaraswamy stick-breaks with a \ufb01xed order, Kumaraswamy stick-breaks with a random ordering.\nNote the bias on the last dimension when using a \ufb01xed order. Randomizing order eliminates this bias.\n\n2 A New Distribution on the Simplex\n\nThe stick-breaking process is a sampling procedure used to generate a K dimensional random variable\nin the K \u2212 1 simplex. The process requires sampling from K \u2212 1 (often out of K) distributions each\nwith support over [0, 1]. Let pi(v; ai, bi) be some distribution for v \u2208 [0, 1] parameterized by ai and\nbi. Let o be some ordering (permutation) of {1, . . . , K}. Then, algorithm 1 captures a generalized\nstick-breaking process. The necessity for incorporating ordering will become clear in section 2.1.\n\n2\n\n\fAlgorithm 1 A Generalized Stick-Breaking Process\nRequire: K \u2265 2, base distributions pi(v; ai, bi) \u2200 i \u2208 {1, . . . , K}, and some ordering o\n\nSample: vo1 \u223c po1 (v; ao1 , bo1)\nAssign: xo1 \u2190 vo1 , i \u2190 2\n(cid:16)\n1 \u2212(cid:80)i\u22121\nwhile i < K do\nAssign: xoK \u2190 1 \u2212(cid:80)K\u22121\n\nSample: voi \u223c poi (v; aoi, boi)\nAssign: xoi \u2190 voi\n\nend while\n\nj=1 xoj\n\nreturn x\n\n(cid:17)\n\n, i \u2190 i + 1\n\nj=1 xoj\n\nFrom a probabilistic perspective, algorithm 1 recursively creates a joint distribution p(xo1 , . . . , xoK\u22121)\nfrom its chain-rule factors p(xo1)p(xo2|xo1)p(xo3|xo2, xo1) . . . p(xoK\u22121|xoK\u22122, . . . xo1). Note, how-\never, that xoK does not appear in the distribution. Its absence occurs because it is deterministic\ngiven xo1, . . . , xoK\u22121 (the K \u2212 1 degrees of freedom for the K \u2212 1 simplex). Each iteration of the\nwhile loop generates p(xoi|xoi\u22121, . . . , xo1) by sampling poi(v; aoi, boi) and a change-of-variables\ntransform Ti : [0, 1]i \u2192 [0, 1]i to the samples collected thus far. This transform and its inverse are\n\n(cid:32)\n(cid:32)\n\nTi(xo1, . . . , xoi\u22121, voi) =\n\nxo1, . . . , xoi\u22121 , voi\n\nT \u22121\n\ni\n\n(xo1, . . . , xoi\u22121, xoi) =\n\nxo1 , . . . , xoi\u22121 , xoi\n\n(cid:17)(cid:33)\n(cid:17)\u22121(cid:33)\n\n.\n\nxoj\n\nxoj\n\n(cid:16)\n1 \u2212 i\u22121(cid:88)\n(cid:16)\n1 \u2212 i\u22121(cid:88)\n\nj=1\n\nj=1\n\n(1)\n\n(2)\n\nApplying the change-of-variables formula to the conditional distribution generated by a while loop\niteration, allows us to formulate the conditional as an expression involving just pi(v; ai, bi), which\nwe assume access to, and det(JT\n\nis the Jacobian of eq. (2).\n\n), where JT\n\n\u22121\ni\n\n\u22121\ni\n\n) = poi(v; aoi, boi) \u00b7(cid:16)\n\n1 \u2212 i\u22121(cid:88)\n\n(cid:17)\u22121\n\nxoj\n\nj=1\n\np(xoi|xoi\u22121, . . . , xo1) = p(voi|xoi\u22121 , . . . , xo1 ) \u00b7 det(JT\n\n\u22121\ni\n\npi(v; ai, bi) \u2261 Beta(x; \u03b1i,(cid:80)K\n\nA common application of the stick-breaking process is to construct a Dirichlet sample from\nBeta samples. If we wish to sample from Dirichlet(x; \u03b1), with \u03b1 \u2208 RK\n++, it suf\ufb01ces to assign\nj=i+1 \u03b1j). With this assignment, algorithm 1 will return a Dirichlet\n\ndistributed x with density\n\n(cid:16)(cid:80)K\n(cid:81)K\n\n\u0393\n\ni=1 \u03b1oi\n\n(cid:17)\n\nK(cid:89)\n\n\u03b1oi\u22121\noi\n\n.\n\nx\n\ni=1 \u0393(\u03b1oi)\n\np(xo1 , . . . , xoK ; \u03b1) =\n\nThis form requires substituting for algorithm 1\u2019s \ufb01nal assignment xoK \u2261 1 \u2212(cid:80)K\u22121\n\ni=1 xoK . Upon\ninspection, the Dirichlet distribution is order agnostic (exchangeable). In other words, given any\nordering o, the random variable returned from algorithm 1 can be permuted to (x1, . . . , xK) (along\nwith the parameters) without modifying its probability density. This convenience arises from the Beta\ndistribution\u2019s form.\n\nTheorem 1 For K \u2265 2 and pi(v; ai, bi) \u2261 Beta(x; \u03b1i,(cid:80)K\n\nj=i+1 \u03b1j), algorithm 1 returns a random\n\ni=1\n\nvariable whose density is captured via the Dirichlet distribution.\n\nA proof of theorem 1 appears in section 7.1 (appendix). A variation of this proof also appears in [5].\n\n2.1 The Kumaraswamy distribution\n\nThe Kumaraswamy(a, b) [14], a Beta-like distribution, has two parameters a, b > 0 and support\nfor x \u2208 [0, 1] with PDF f (x; a, b) = abxa\u22121(1 \u2212 xa)b\u22121 and CDF F (x; a, b) = 1 \u2212 (1 \u2212 xa)b.\nWith this analytically invertible CDF, one can reparameterize a sample u from the continuous\n\n3\n\n\f\u22121\ni\n\na Dirichlet sample from Beta samples, we let pi(v; ai, bi) \u2261 Kumaraswamy(x; \u03b1i,(cid:80)K\n\nUniform(0, 1) via the transform T (u) = (1\u2212 (1\u2212 u)1/b)1/a such that T (u) \u223c Kumaraswamy(a, b).\nUnfortunately, this convenient reparameterization comes at a cost when we derive p(xo1, . . . , xoK ; \u03b1),\nwhich captures the density of the variable returned by algorithm 1. If, in a manner similar to generating\nj=i+1 \u03b1j), then\nthe resulting variable\u2019s density is no longer order agnostic (exchangeable). The exponential in the\nKumaraswamy\u2019s (1 \u2212 xa) term that admits analytic inverse-CDF sampling, can no longer cancel out\n) terms as the (1 \u2212 x) term in the Beta analog could. In the simplest case, the 1-simplex\ndet(JT\n(K = 2), the possible orderings for algorithm 1 are o \u2208 O = {{1, 2},{2, 1}}. Indeed, algorithm 1\nreturns two distinct densities according to their respective orderings:\n1 \u2212 x\u03b11\n1 \u2212 x1\n1 \u2212 x\u03b12\n1 \u2212 x2\n\n(cid:33)\u03b12\u22121\n(cid:33)\u03b11\u22121\n\nf12(x; a, b) = \u03b11\u03b12x\u03b11\u22121\n\n(cid:32)\n(cid:32)\n\nx\u03b12\u22121\n\n2\n\nf21(x; a, b) = \u03b11\u03b12x\u03b11\u22121\n\n1\n\nx\u03b12\u22121\n\n2\n\n1\n\n2\n\n(3)\n\n(4)\n\n1\n\n.\n\nIn section 7.2 of the appendix, we derive f12 and f21 as well as the distribution for the 2-simplex,\nwhich has orderings o \u2208 O = {{1, 2, 3},{1, 3, 2},{2, 1, 3},{2, 3, 1},{3, 1, 2},{3, 2, 1}}. For\nK > 3, the algebraic book-keeping gets rather involved. We thus rely on algorithm 1 to succinctly\nrepresent the complicated densities over the simplex that describe the random variables generated by\na stick-breaking process using the Kumaraswamy distribution as the base (stick-breaking) distribution.\nOur code repository \u00a7 contains a symbolic implementation of algorithm 1 with the Kumaraswamy\nthat programmatically keeps track of the algebra.\n\n2.2 The multivariate Kumaraswamy\n\nWe posit that a good surrogate for the Dirichlet will exhibit symmetry (exchangeability) properties\nidentical to the Dirichlet it is approximating. If our stick-breaking distribution, pi(v; ai, bi), cannot\nachieve symmetry for all values ai = bi > 0, then it is possible that the samples will exhibit bias\n(\ufb01g. 1). If x \u223c Beta(a, b), then (1\u2212x) \u223c Beta(b, a). It follows then that when a = b, p(x) = p(1\u2212x).\nUnfortunately, Kumaraswamy(a, b) does not admit such symmetry for all a = b > 0. However, hope\nis not lost. From [6, 8], we have lemma 1.\n\nLemma 1 Given a function f of n variables, one can induce symmetry by taking the sum of f over\nall n! possible permutations of the variables.\n\nKumaraswamy(x; \u03b1i,(cid:80)K\n\nIf we de\ufb01ne fo(xo1 , . . . , xoK ; \u03b1o1, . . . , \u03b1oK ) to be the joint density of the K-dimensional ran-\ndom variable returned from algorithm 1 with stick-breaking base distribution as pi(v; ai, bi) \u2261\nj=i+1 \u03b1j) and some ordering o, then our proposed distribution for the\n(K \u2212 1)-simplex is\n\nMV-Kumaraswamy(x; \u03b1) =\n\nE\n\no\u223cUniform(O)\n\n[fo(xo1 , . . . , xoK ; \u03b1o1, . . . , \u03b1oK )],\n\n(5)\n\nwhere MV-Kumaraswamy stands for Multivariate Kumaraswamy. Here, O is the set of all possible\norderings (permutations) of {1, . . . , K}. In the context of [8], we create a U-statistic over the\nvariables x, \u03b1. The expectation in eq. (5) is a summation since we are uniformly sampling o from a\ndiscrete set. We therefore can apply lemma 1 to eq. (5) to prove corollary 1.\nCorollary 1 Let S \u2286 {1, . . . , K} be the set of indices i where for i (cid:54)= j we have \u03b1i = \u03b1j. De\ufb01ne\nA = {1, . . . , K} \\ S. Then, Eo\u223cUniform(O)[fo(xo1 , . . . , xoK ; \u03b1o1, . . . , \u03b1oK )] is symmetric across\nbarycentric axes xa \u2200 a \u2208 A.\nWhile the factorial growth (|O| = K!) for full symmetry is undesirable, we expect approximate\nsymmetry should arise, in expectation, after O(K) samples. Since the problematic bias occurs\nduring the last stick break, each label ideally experiences an ordering where it is not last; this occurs\nwith probability K\u22121\nK\u22121 draws from Uniform(O).\n\nK . Thus, a label is not last, in expectation, after K\n\n\u00a7https://github.com/astirn/MV-Kumaraswamy\n\n4\n\n\fTherefore, to satisfy this condition for all labels, one needs K2\nK\u22121 = O(K) samples, in expectation. An\nalternative, which we discuss and demonstrate below in \ufb01g. 4, would be to use the K cyclic orderings\n(e.g. {{1, 2, 3},{2, 3, 1},{3, 1, 2}} for K = 3) to achieve approximate symmetry (exchangeability).\nIn \ufb01g. 2, we provide 1-simplex examples for varying \u03b1 that demonstrate the effect ordering has on the\nKumaraswamy distributions f12(x; \u03b1) and f21(x; \u03b1) (respectively in eqs. (3) and (4)). In each exam-\nple, we plot the symmetrized versions arising from our proposed distribution Eo[fo(x; \u03b1)] (eq. (5)).\nFor reference, we plot the corresponding Dirichlet(x; \u03b1), which is equivalent to Beta(x1; \u03b11, \u03b12)\nfor the 1-simplex. Qualitatively, we observe how effectively our proposed distribution resolves the\ndifferences between f12 and f21 and yields a E[fo(x; \u03b1)] \u2248 Dirichlet(x; \u03b1).\n\nFigure 2: Kumaraswamy asymmetry and symmetrization examples on the 1-simplex.\n\nIn \ufb01g. 3, we employ Beta distributed stick breaks to generate a Dirichlet random variable. In this\nexample, we pick an \u03b1 such that the resulting density should be symmetric only about the barycentric\nx1 axis. Furthermore, because the resulting density is a Dirichlet, the densities arising from all\npossible orderings should be identical with identical barycentric symmetry properties. The \ufb01rst row\ncontains densities. The subsequent rows measure asymmetry across the speci\ufb01ed barycentric axis\nby computing the absolute difference of the PDF folded along that axis. The \ufb01rst column is for\nexpectation over all possible orderings. The second column is for the expectation over the cyclic\norderings. Each column thereafter represents a different stick-breaking order. Indeed, we \ufb01nd that the\nDirichlet has an order agnostic density with symmetry only about the barycentric x1 axis.\n\nFigure 3: 2-simplex with Beta sticks\n\nFigure 4: 2-simplex with Kumaraswamy sticks\n\nIn \ufb01g. 4, we employ the same methodology with the same \u03b1 as in \ufb01g. 3, but this time we use\nKumaraswamy distributed stick breaks. Note the signi\ufb01cant variations among the densities resulting\nfrom the different orderings. It follows that symmetry/asymmetry too vary with respect to ordering.\nWe only see the desired symmetry about the barycentric x1 axis when we take the expectation over\nall orderings. This example qualitatively illustrates corollary 1. However, we do achieve approximate\nsymmetry when we average over the K cyclic orderings\u2013suggesting we can, in practice, get away\nwith linearly scaling complexity.\n\n3 Gradient Variance\n\nWe compare our method\u2019s gradient variance to other non-explicit gradient reparameterization methods:\nImplicit Reparameterization Gradients (IRG) [4], RSVI [18], and Generalized Reparameterization\n\n5\n\n\fGradient (GRG) [22]. These works all seek gradient methods with low variance. In \ufb01g. 5, we compare\nMV-Kumaraswamy\u2019s (MVK) gradient variance to these other methods by leveraging techniques and\ncode from [18]. Speci\ufb01cally, we consider their test that \ufb01ts a variational Dirichlet posterior to\nCategorical data with a Dirichlet prior. In this conjugate setting, true analytic gradients can be\ncomputed. Their reported \u2018gradient variance\u2019 is actually the mean square error with respect to the true\ngradient. In our test, however, we are \ufb01tting a MV-Kumaraswamy variational posterior. Therefore,\nwe compute gradient variance, for all methods, according to variance\u2019s more common de\ufb01nition. Our\ntests show that IRG and RSVI (B = 10) offer similar variance; this result matches \ufb01ndings in [4].\n\nFigure 5: Variance of the ELBO\u2019s gradient\u2019s \ufb01rst dimension for GRG [22], RSVI [18], IRG [4],\nand MVK (ours) when \ufb01tting a variational posterior to Categorical data with 100 dimensions and\na Dirichlet prior. They \ufb01t a Dirichlet. We \ufb01t a MV-Kumaraswamy using K = 100 samples from\nUniform(O) to Monte-Carlo approximate the full expectation; this corresponds to linear complexity.\n\n4 A single generative model for semi-supervised learning\n\nWe demonstrate the utility of the MV-Kumaraswamy in the context of a parsimonious generative\nmodel for semi-supervised learning, with observed data x, partially observable classes/labels y with\nprior \u03c0 and latent variable z, all of which are local to each data point. We specify,\nz \u223c N (z; 0, I),\nx|y, z \u223c p(x|f\u03b8(y, z)),\n\n\u03c0 \u223c Dirichlet(\u03c0; \u03b1),\ny|\u03c0 \u223c Discrete(y; \u03c0),\n\nwhere f\u03b8(y, z) is a neural network, with parameters \u03b8, operating on the latent variables. For observable\ny, the evidence lower bound (ELBO) for a mean-\ufb01eld posterior approximation q(\u03c0, z) = q(\u03c0)q(z) is\n\nln p(x, y) \u2265 E\n\n[ln p(x|f\u03b8(y, z)) + ln \u03c0y] \u2212 DKL(q(\u03c0) || p(\u03c0)) \u2212 DKL(q(z) || p(z))\n\n(6)\nFor latent y, we can derive an alternative ELBO that corresponds to the same generative process\nof eq. (6), by reintroducing y via marginalization. We derive eqs. (6) and (7) in section 7.3 of the\nappendix.\n\nln p(x) \u2265 E\n\np(x|f\u03b8(y, z))\u03c0y\n\n(cid:105) \u2212 DKL(q(\u03c0) || p(\u03c0)) \u2212 DKL(q(z) || p(z))\n(cid:88)\n\nLu(x, \u03c6, \u03b8)\n\nq(\u03c0,z)\n\n\u2261Ll(x, y, \u03c6, \u03b8).\n\n(cid:104)\n\n(cid:88)\n\nq(\u03c0,z)\n\nln\n\u2261Lu(x, \u03c6, \u03b8)\n\ny\n\n(cid:88)\n(cid:88)\n\n(x,y)\u2208L\n\nL =\n\n1\n|L|\n\u2248 1\nB\n\n(7)\nLet L be our set of labeled data and U be our unlabeled set. We then consider a combined objective\n\nLl(x, y, \u03c6, \u03b8) +\n\n1\n|U|\n\nx\u2208U\nLl(xi, yi, \u03c6, \u03b8) +\n\n1\nB\n\n(cid:88)\n\n(xi,yi)\u223cL \u2200 i\u2208[B]\n\nxi\u223cU \u2200 i\u2208[B]\n\nLu(xi, \u03c6, \u03b8)\n\n(8)\n\n(9)\n\n6\n\n\f(cid:34) K\u22121(cid:88)\n\nE\np(o)\n\n(cid:16)\n\n(cid:16)\n\nK(cid:88)\n\n(cid:17)(cid:12)(cid:12)(cid:12)(cid:12)(cid:12)(cid:12) Beta\n\n(cid:16)\n\nK(cid:88)\n\n\u03b1(oj )(cid:17)(cid:17)(cid:35)\n\nthat balances the two ELBOs evenly. Of concern is when |U| (cid:29) |L|. Here, the optimizer could\neffectively ignore Ll(x, y, \u03c6, \u03b8). This possibility motivates our rebalancing in eq. (8). During\noptimization we employ batch updates of size B to maximize eq. (9), which similarly balances\nthe contribution between U and L. We de\ufb01ne an epoch to be the set of batches (sampled without\nreplacement) that constitute U. Therefore, when |U| (cid:29) |L|, the optimizer will observe samples from\nL many more times than samples from U. Intuitively, the data with observable labels in conjunction\nwith eq. (6) breaks symmetry and encourages the correct assignment of classes to labels.\nFollowing [12, 13], we use an inference network with parameters \u03c6 and de\ufb01ne our variational\ndistribution q(z) = N (z; \u00b5\u03c6(x), \u03a3\u03c6(x)), where \u00b5\u03c6(x) and \u03a3\u03c6(x) are outputs of a neural network\noperating on the observable data. We restrict \u03a3\u03c6(x) to output a diagonal covariance and use a softplus,\nln(exp(x) + 1), output layer to constrain it to the positive reals. Since \u00b5\u03c6(x) \u2208 Rdim(z), we use an\naf\ufb01ne output layer. We let q(\u03c0) = MV-Kumaraswamy(\u03c0; \u03b1\u03c6(x)), where \u03b1\u03c6(x) is also an output of\nour inference network. We similarly restrict \u03b1\u03c6(x) to the positive reals via the softplus activation.\nWe evaluate the expectations in eqs. (6) and (7) using Monte-Carlo integration. For q(z), we sample\nfrom N (0, I) and utilize the reparameterization trick. Since q(\u03c0) contains an expectation over\norderings, we \ufb01rst sample o \u223c Uniform(O) and then employ algorithm 1 with pi(v; ai, bi) \u2261\nj=i+1 \u03b1j), for which we use inverse-CDF sampling. In both cases, gradients\nare well de\ufb01ned with respect to the variational parameters.\nWe can decompose DKL(MV-Kumaraswamy(\u03b1\u03c6(x)) || Dirichlet(\u03b1)) into a sum over the corre-\nsponding Kumaraswamy and Beta stick-breaking distributions as in [20]. Let \u03b1(j)\n\u03c6 (x) be the jth con-\ncentration parameter of the inference network, and \u03b1(j) be jth parameter of the Dirichlet prior. If, as\nabove, we let p(o) = Uniform(O) for the set of all orderings O, then DKL(q(\u03c0; \u03b1\u03c6(x)) || p(\u03c0; \u03b1)) =\n\nKumaraswamy(x; \u03b1i,(cid:80)K\n\nDKL\n\nKumaraswamy\n\n\u03b1(oi)\n\n\u03c6 (x),\n\n\u03b1(oj )\n\u03c6\n\n(x)\n\ni=1\n\nj=i+1\n\n\u03b1(oi),\n\nj=i+1\n\nWe compute DKL(Kumaraswamy(a, b) || Beta(a(cid:48), b(cid:48))) analytically as in [20] with a Taylor approxi-\nmation order of 5. We too approximate this expectation with far fewer than K! samples from p(o).\nPlease see section 7.4 of the appendix for a reproduction of this KL-Divergence\u2019s mathematical form.\n\n5 Experiments\n\nWe consider a variety of baselines for our semi-supervised model. Since our work expounds and\nresolves the order dependence of the original Kumaraswamy stick-breaking construction [20] that\nuses \ufb01xed and constant ordering, we employ their construction (Kumar-SB) as a baseline, for which\nwe force our implementation to use a \ufb01xed and constant order during the stick-breaking procedure.\nAs noted in section 1, our model is similar to the M2 model [13]. We too consider it an important\nbaseline for our semi-supervised experiments. Additionally, we use the Softmax-Dirichlet sampling\napproximation [25]. This approximation forces logits sampled from a Normal variational posterior\nonto the simplex via the softmax function. In this case, the Dirichlet prior is approximated with\na prior for the Gaussian logits [25]. However, this softmax approximation struggles to capture\nsparsity because the Gaussian prior cannot achieve the multi-modality available to the Dirichlet\n[22]. Lastly, we include a comparison to Implicit Reparameterization Gradients (IRG) [4]. Here,\nwe set q(\u03c0; \u03b1\u03c6(x)) = Dirichlet(\u03c0; \u03b1\u03c6(x)) in our semi-supervised model with the same architecture.\nIRG uses independent Gamma samples to construct Beta and Dirichlet samples. IRG\u2019s principle\ncontribution for gradient reparameterization is that it side-steps the need to invert the standardization\nfunction (i.e. the CDF). However, IRG still requires Gamma CDF gradients w.r.t. the variational\nparameters. These gradients do not have a known analytic form, mandating their application of\nforward-mode automatic differentiation to a numerical method. In our IRG baseline, both the prior\nand variational posterior are Dirichlet distributions yielding an analytic KL-Divergence. We mention\nbut do not test [9], which similarly constructs Dirichlet samples from normalized Gamma samples.\nThey too employ implicit differentiation to avoid differentiating the inverse CDF, but necessarily fall\nback to numerically differentiating the Gamma CDF.\nOur source code can be found at https://github.com/astirn/MV-Kumaraswamy. For our latest\nexperimental results, please refer to https://arxiv.org/abs/1905.12052. In our generative\n\n7\n\n\fprocess and eqs. (6) and (7), we referred generally to our data likelihood as p(x|f\u03b8(y, z)). In all of\nour experiments, we assume p(x|f\u03b8(y, z)) = N (x, \u00b5\u03b8(y, z), \u03a3\u03b8(y, z)), where \u00b5\u03b8(y, z) and \u03a3\u03b8(y, z)\nare outputs of a neural network with parameters \u03b8 operating on the latent variables. We use diagonal\ncovariance for \u03a3\u03b8(y, z). Across all of our experiments, we maintain consistent recognition and\ngenerative network architectures, which we detail in section 7.5 of the appendix.\nWe do not use any explicit regularization. Our models are implemented in TensorFlow and were\ntrained using ADAM with a batch size B = 250 and 5 Monte-Carlo samples for each training\nexample. We use learning rates 1 \u00d7 10\u22123 and 1 \u00d7 10\u22124 respectively for MNIST and SVHN. Other\noptimizer parameters were kept at TensorFlow defaults. We utilized GPU acceleration and found that\ncards with \u223c8 GB of memory were suf\ufb01cient. We utilize the TensorFlow Datasets API, from which\nwe source our data. For all experiments, we split our data into 4 subsets: unlabeled training (U)\ndata, labeled training (L) data, validation data, and test data. For MNIST: |U| = 49, 400, |L| = 600,\n|validation| = |test| = 10, 0000. For SVHN: |U| = 62, 257, |L| = 1000, |validation| = 10, 000,\n|test| = 26, 032. When constructing L, we enforce label balancing. We allow all trials to train for a\nmaximum of 750 epochs, but use validation set performance to enable early stopping whenever the\nloss (eq. (8)) and classi\ufb01cation error have not improved in the previous 15 epochs. All reported metrics\nwere collected from the test set during the validation set\u2019s best epoch\u2013we do this independently for\nclassi\ufb01cation error and log likelihood. For each trial, all models utilize the same random data split\nexcept where noted\u2020. We translate the uint8 encoded pixel intensities to [0, 1] by dividing by 255, but\nperform no other preprocessing.\nTable 1: Held-out test set classi\ufb01cation errors and log likelihoods. A \u201c\u2212\u2212\u201d for a p-value indicates it\nwas unavailable either because it was with respect to itself or the corresponding data and/or number\nof trials were missing. Since [13] did not report log likelihoods, we did not collect them with our\nimplementation.\n\nExperiment\nMNIST\n10 trials\n600 labels\ndim(z) = 0\nMNIST\n10 trials\n600 labels\ndim(z) = 2\n\nMethod\nMV-Kum.\nIRG[4]\nKumar-SB[20]\nSoftmax\nMV-Kum.\nIRG[4]\nM2 (ours)\nKumar-SB[20]\nSoftmax\nMV-Kum.\nIRG[4]\nM2 (ours)\n\nMNIST\n10 trials\n600 labels\ndim(z) = 50 Kumar-SB[20]\n\nSVHN\n4 trials\n1000 labels\ndim(z) = 50 Kumar-SB[20]\n\nSoftmax\nM2\u2020[13]\nM1 + M2\u2020[13]\nMV-Kum.\nIRG[4]\nM2 (ours)\n\nSoftmax\nM1 + M2\u2020[13]\n\nError\n\n0.099 \u00b1 0.011\n0.097 \u00b1 0.008\n0.248 \u00b1 0.009\n0.093 \u00b1 0.009\n0.043 \u00b1 0.005\n0.044 \u00b1 0.006\n0.098 \u00b1 0.014\n0.138 \u00b1 0.015\n0.042 \u00b1 0.003\n0.018 \u00b1 0.004\n0.018 \u00b1 0.004\n0.020 \u00b1 0.003\n0.071 \u00b1 0.008\n0.018 \u00b1 0.003\n0.049 \u00b1 0.001\n0.026 \u00b1 0.005\n0.296 \u00b1 0.014\n0.288 \u00b1 0.008\n0.406 \u00b1 0.027\n0.702 \u00b1 0.011\n0.300 \u00b1 0.007\n0.360 \u00b1 0.001\n\np-value\n\u2212\u2212\n0.72\n\n1.05 \u00d7 10\u221217\n\n0.24\n\u2212\u2212\n0.89\n\n5.37 \u00d7 10\u221210\n1.65 \u00d7 10\u221213\n\n0.40\n\u2212\u2212\n0.98\n0.32\n\n2.58 \u00d7 10\u221213\n\n0.87\n\u2212\u2212\n\u2212\u2212\n\u2212\u2212\n0.38\n\n3.64 \u00d7 10\u221204\n7.42 \u00d7 10\u221209\n\n0.61\n\u2212\u2212\n\nLog Likelihood\n\u22126.4 \u00b1 6.3\n\u22127.8 \u00b1 7.1\n\u22126.5 \u00b1 6.3\n\u22126.5 \u00b1 6.2\n45.06 \u00b1 0.92\n45.69 \u00b1 0.38\nNot collected\n44.33 \u00b1 1.65\n45.14 \u00b1 0.73\n116.58 \u00b1 0.68\n116.57 \u00b1 0.43\nNot collected\n116.22 \u00b1 0.33\n116.24 \u00b1 0.45\nNot reported\nNot reported\n669.37 \u00b1 0.57\n669.84 \u00b1 0.84\nNot collected\n669.44 \u00b1 0.77\n669.51 \u00b1 0.72\nNot reported\n\np-value\n\u2212\u2212\n0.64\n0.95\n0.95\n\u2212\u2212\n0.06\n\u2212\u2212\n0.24\n0.82\n\u2212\u2212\n0.97\n\u2212\u2212\n0.15\n0.21\n\u2212\u2212\n\u2212\u2212\n\u2212\u2212\n0.39\n\u2212\u2212\n0.89\n0.78\n\u2212\u2212\n\nFor the semi-supervised learning task, we present classi\ufb01cation and reconstruction performances in\ntable 1 using our algorithm as well as the baselines discussed previously. We organize our results\nby experiment group. All reported p-values are with respect to our MV-Kumaraswamy model\u2019s\nperformance for corresponding dim(z). We say, \u201cM2 (ours),\u201d whenever we use the generative\nprocess of [13] with our neural network architecture. For a subset of experiments, we present results\nfrom [13]\u2013without knowing how many trials they ran we cannot compute the corresponding p-value.\nWe recognize that there are numerous works [21, 1, 26, 15, 10, 7, 24, 2, 16, 17] that offer superior\n\n8\n\n\fperformance on these tasks, however, we abstain from reporting these performances whenever those\nmodels are not variational Bayesian, use adversarial training, lack explicit generative processes, use\narchitectures vastly larger in size than ours, or use a different number of labeled examples ((cid:54)= 600 for\nMNIST and (cid:54)= 1000 for SVHN).\nIn \ufb01g. 6, we plot the latent space representation for our MV-Kumaraswamy model for MNIST when\ndim(z) = 2. Each digit\u2019s manifold is over (\u22121.5,\u22121.5) \u00d7 (1.5, 1.5), which corresponds to \u00b11.5\nstandard deviations from the prior. The only difference in latent encoding between corresponding\nmanifold positions is the label provided to the generative network. Interestingly, the model learns to\nuse z in a qualitatively similar way to represent character transformations across classes.\n\nFigure 6: Latent space for MV-Kumaraswamy model with dim(z) = 2.\n\n6 Discussion\n\nThe statistically signi\ufb01cant classi\ufb01cation performance gains of MV-Kumaraswamy (approximate\nintegration over all orderings) against Kumar-SB [20] (\ufb01xed and constant ordering) validates the\nimpact of our contribution. Kumar-SB\u2019s worse performance is likely due to the over allocation of\nprobability mass to the \ufb01nal stick during sampling (\ufb01g. 1). When the class-assignment posterior has\nhigh entropy, the \ufb01xed order sampling will bias the last label dimension. Further, MV-Kumaraswamy\nbeats [13] for both classi\ufb01cation tasks despite our single model approach and minimal preprocessing.\nInterestingly, our implementation of M2 seemingly requires a larger dim(z) to match the classi\ufb01cation\nperformance of MV-Kumaraswamy. Lastly, IRG\u2019s classi\ufb01cation performance is not statistically\ndistinguishable from ours. Deep learning frameworks\u2019 (e.g. TensorFlow, PyTorch, Theano, CNTK,\nMXNET, Chainer) distinct advantage is NOT requiring user-computed gradients. We argue that\nmethods requiring numerical gradients [4, 9] do not admit a straightforward implementation for\nthe common practitioner as they require additional (often non-trivial) code to supply the gradient\nestimates to the framework\u2019s optimizer. Conversely, our method has analytic gradients, enabling easy\nintegration into ANY deep learning framework. To the best of our knowledge, IRG for the Gamma,\nBeta, and Dirichlet distributions only exists in TensorFlow (IRG was developed at Deep Mind).\nVAEs offer scalable and ef\ufb01cient learning for a subset of Bayesian models. Applied Bayesian\nmodeling, however, makes heavy use of distributions outside this subset. In particular, the Dirichlet,\nwithout some form of accommodation or approximation, will render a VAE intractable since gradients\nwith respect to variational parameters are challenging to compute. Ef\ufb01cient approximation of such\ngradients is an active area of research. However, explicit reparameterization is advantageous in terms\nof simplicity and ef\ufb01ciency. In this article, we present and develop theory for a computationally\nef\ufb01cient and explicitly reparameterizable Dirichlet surrogate that has similar sparsity-inducing\ncapabilities and identical exchangeability properties to the Dirichlet it is approximating. We con\ufb01rm\nits surrogate candidacy through a range of semi-supervised auto-encoding tasks. We look forward to\nutilizing our new distribution to scale inference in more structured probabilistic models such as topic\nmodels. We hope others will use our distribution not only as a surrogate for a Dirichlet posterior but\nalso as a prior. The latter might yield a more exact divergence between the variational posterior and\nits prior.\n\nAcknowledgments\n\nThis work was supported in part by NSF grant III-1526914.\n\n9\n\n\fReferences\n[1] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin\nRaffel. Mixmatch: A holistic approach to semi-supervised learning. arXiv:1905.02249, 2019.\n\n[2] Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, and Pieter Abbeel. Infogan:\nInterpretable representation learning by information maximizing generative adversarial nets. In\nD. D. Lee, M. Sugiyama, U. V. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural\nInformation Processing Systems 29, pages 2172\u20132180. Curran Associates, Inc., 2016.\n\n[3] Djork-Arne\u2019 Clevert, Thomas Unterthiner, and Sepp Hochreiter. Fast and accurate deep\nnetwork learning by exponential linear units (elus). International Conference on Learning\nRepresentations (ICLR), 2016.\n\n[4] Mikhail Figurnov, Shakir Mohamed, and Andriy Mnih. Implicit reparameterization gradients.\nIn S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors,\nAdvances in Neural Information Processing Systems 31, pages 441\u2013452. Curran Associates,\nInc., 2018.\n\n[5] Bela A. Frigyik, Amol Kapila, and Maya R. Gupta. Introduction to the dirichlet distribution\nand related processes. Technical report, University of Washington, Department of Electrical\nEngineering, 2010.\n\n[6] Michiel Hazewinkel, editor. Encyclopaedia of Mathematics, volume 6. Springer Netherlands, 1\n\nedition, 1990.\n\n[7] Tobias Hinz and Stefan Wermter. Inferencing based on unsupervised learning of disentangled\n\nrepresentations. CoRR, abs/1803.02627, 2018.\n\n[8] Wassily Hoeffding. A class of statistics with asymptotically normal distributions. Annals of\n\nStatistics, 19(3):293\u2013325, 1948.\n\n[9] Martin Jankowiak and Fritz Obermeyer. Pathwise derivatives beyond the reparameterization\nIn Jennifer Dy and Andreas Krause, editors, Proceedings of the 35th International\ntrick.\nConference on Machine Learning, volume 80 of Proceedings of Machine Learning Research,\npages 2235\u20132244, Stockholmsm\u00e4ssan, Stockholm Sweden, 10\u201315 Jul 2018. PMLR.\n\n[10] Xu Ji, Jo\u00e3o F. Henriques, and Andrea Vedaldi. Invariant information distillation for unsupervised\n\nimage segmentation and clustering. CoRR, abs/1807.06653, 2018.\n\n[11] Michael I. Jordan, Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. An\nintroduction to variational methods for graphical models. Machine Learning, 37(2):183\u2013233,\n1999.\n\n[12] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. International Confer-\n\nence on Learning Representations (ICLR), 2014.\n\n[13] Durk P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised\nlearning with deep generative models. In Advances in Neural Information Processing Systems\n27, pages 3581\u20133589. Curran Associates, Inc., 2014.\n\n[14] Ponnambalam Kumaraswamy. A generalized probability density function for double-bounded\n\nrandom processes. Journal of Hydrology, 1980.\n\n[15] Samuli Laine and Timo Aila. Temporal ensembling for semi-supervised learning. CoRR,\n\nabs/1610.02242, 2016.\n\n[16] Alireza Makhzani and Brendan J Frey. Pixelgan autoencoders. In I. Guyon, U. V. Luxburg,\nS. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural\nInformation Processing Systems 30, pages 1975\u20131985. Curran Associates, Inc., 2017.\n\n[17] Alireza Makhzani, Jonathon Shlens, Navdeep Jaitly, and Ian Goodfellow. Adversarial autoen-\n\ncoders. In International Conference on Learning Representations (ICLR), 2016.\n\n10\n\n\f[18] Christian A. Naesseth, Francisco J. R. Ruiz, Scott W. Linderman, and David M. Blei. Reparam-\neterization gradients through acceptance-rejection sampling algorithms. In Proceedings of the\n20th International Conference on Arti\ufb01cial Intelligence and Statistics, 2017.\n\n[19] Eric Nalisnick, Lars Hertel, and Padhraic Smyth. Approximate inference fordeep latent gaussian\n\nmixtures. Workshop on Bayesian Deep Learning, NIPS, 2016.\n\n[20] Eric Nalisnick and Padhraic Smyth. Stick-breaking variational autoencoders. International\n\nConference on Learning Representations (ICLR), Apr 2017.\n\n[21] Antti Rasmus, Harri Valpola, Mikko Honkala, Mathias Berglund, and Tapani Raiko. Semi-\n\nsupervised learning with ladder network. CoRR, abs/1507.02672, 2015.\n\n[22] Francisco R Ruiz, Michalis Titsias RC AUEB, and David Blei. The generalized reparameteriza-\ntion gradient. In D. D. Lee, M. Sugiyama, U. V. Luxburg, I. Guyon, and R. Garnett, editors,\nAdvances in Neural Information Processing Systems 29, pages 460\u2013468. Curran Associates,\nInc., 2016.\n\n[23] Tim Salimans and David A. Knowles. Fixed-form variational posterior approximation through\n\nstochastic linear regression. Bayesian Analysis, 8, 2013.\n\n[24] Jost Tobias Springenberg. Unsupervised and semi-supervised learning with categorical gen-\nerative adversarial networks. International Conference on Learning Representations (ICLR),\n2016.\n\n[25] Akash Srivastava and Charles Sutton. Autoencoding variational inference for topic models. In\n\nInternational Conference on Learning Representations (ICLR), 2017.\n\n[26] Antti Tarvainen and Harri Valpola. Mean teachers are better role models: Weight-averaged\nconsistency targets improve semi-supervised deep learning results. In Advances in Neural\nInformation Processing Systems 30, pages 1195\u20131204. Curran Associates, Inc., 2017.\n\n[27] Martin J. Wainwright and Michael I. Jordan. Graphical models, exponential families, and\n\nvariational inference. Foundations and Trends in Machine Learning, 1(1-2):1\u2013305, 2008.\n\n11\n\n\f", "award": [], "sourceid": 7601, "authors": [{"given_name": "Andrew", "family_name": "Stirn", "institution": "Columbia University"}, {"given_name": "Tony", "family_name": "Jebara", "institution": "Spotify"}, {"given_name": "David", "family_name": "Knowles", "institution": "Columbia University"}]}