{"title": "Direct Optimization through $\\arg \\max$ for Discrete Variational Auto-Encoder", "book": "Advances in Neural Information Processing Systems", "page_first": 6203, "page_last": 6214, "abstract": "Reparameterization of variational auto-encoders with continuous random variables is an effective method for reducing the variance of their gradient estimates. In the discrete case, one can perform reparametrization using the Gumbel-Max trick, but the resulting objective relies on an $\\arg \\max$ operation and is non-differentiable. In contrast to previous works which resort to \\emph{softmax}-based relaxations, we propose to optimize it directly by applying the \\emph{direct loss minimization} approach. Our proposal extends naturally to structured discrete latent variable models when evaluating the $\\arg \\max$ operation is tractable. We demonstrate empirically the effectiveness of the direct loss minimization technique in variational autoencoders with both unstructured and structured discrete latent variables.", "full_text": "Direct Optimization through arg max for Discrete\n\nVariational Auto-Encoder\n\nGuy Lorberbom\n\nTechnion\n\nAndreea Gane\n\nMIT\n\nTommi Jaakkola\n\nMIT\n\nTamir Hazan\n\nTechnion\n\nAbstract\n\nReparameterization of variational auto-encoders with continuous random variables\nis an effective method for reducing the variance of their gradient estimates. In the\ndiscrete case, one can perform reparametrization using the Gumbel-Max trick, but\nthe resulting objective relies on an arg max operation and is non-differentiable. In\ncontrast to previous works which resort to softmax-based relaxations, we propose to\noptimize it directly by applying the direct loss minimization approach. Our proposal\nextends naturally to structured discrete latent variable models when evaluating\nthe arg max operation is tractable. We demonstrate empirically the effectiveness\nof the direct loss minimization technique in variational autoencoders with both\nunstructured and structured discrete latent variables.\n\n1\n\nIntroduction\n\nModels with discrete latent variables drive extensive research in machine learning applications,\nincluding language classi\ufb01cation and generation [42, 11, 34], molecular synthesis [19], or game\nsolving [25]. Compared to their continuous counterparts, discrete latent variable models can decrease\nthe computational complexity of inference calculations, for instance, by discarding alternatives in\nhard attention models [21], they can improve interpretability by illustrating which terms contributed\nto the solution [27, 42], and they can facilitate the encoding of inductive biases in the learning process,\nsuch as images consisting of a small number of objects [8] or tasks requiring intermediate alignments\n[25]. Finally, in some cases, discrete latent variables are natural choices, for instance when modeling\ndatasets with discrete classes [32, 12, 23].\nPerforming maximum likelihood estimation of latent variable models is challenging due to the\nrequirement to marginalize over the latent variables. Instead, one can maximize a variational lower-\nbound to the data log-likelihood, de\ufb01ned via an (approximate) posterior distribution over the latent\nvariables, an approach followed by latent Dirichlet allocation [3], learning hidden Markov models [28]\nand variational auto-encoders [16]. The maximization can be carried out by alternatively computing\nthe (approximate) posterior distribution corresponding to the current model parameters estimate, and\nestimating the new model parameters. Variational auto-encoders (VAEs) are generative latent variable\nmodels where the approximate posterior is a (neural network based) parameterized distribution which\nis estimated jointly with the model parameters. Maximization is performed via stochastic gradient\nascent, provided that one can compute gradients with respect to both the model parameters and the\napproximate posterior parameters.\nLearning VAEs with discrete n-dimensional latent variables is computationally challenging since\nthe size of the support of the posterior distribution is exponential in n. Although the score function\nestimator (also known as REINFORCE) [39] enables computing the required gradients with respect to\nthe approximate posterior, in both the continuous and discrete latent variable case, it is known to have\nhigh-variance. The reparametrization trick provides an appealing alternative to the score function\nestimator and recent work has shown its effectiveness for continuous latent spaces [17, 30]. In the\ndiscrete case, despite being able to perform reparametrization via the Gumbel-Max trick, the resulting\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fmapping remains non-differentiable due to the presence of arg max operations. Recently, Maddison\net al. [23] and Jang et al. [12] have used a relaxation of the reparametrized objective, replacing the\narg max operation with a softmax operation. The proposed Gumbel-Softmax reformulation results in\na smooth objective function, similar to the continuous latent variable case. Unfortunately, the softmax\noperation introduces bias to the gradient computation and becomes computationally intractable\nwhen using high-dimensional structured latent spaces, because the softmax normalization relies on a\nsummation over all possible latent assignments.\nThis paper proposes optimizing the reparameterized discrete VAE objective directly, by using the\ndirect loss minimization approach [24, 14, 35], originally proposed for learning discriminative models.\nThe cited work proves that a (biased) gradient estimator of the arg max operation can be obtained\nfrom the difference between two maximization operations, over the original and over a perturbed\nobjective, respectively. We apply the proposed estimator to the arg max operation obtained from\napplying the Gumbel-Max trick. Compared to the Gumbel-Softmax estimator, our approach relies on\nmaximization over the latent variable assignments, rather than summation, which is computationally\nmore ef\ufb01cient. In particular, performing maximization exactly or approximately is possible in\nmany structured cases, even when summation remains intractable. We demonstrate empirically\nthe effectiveness of the direct optimization technique to high-dimensional discrete VAEs, with\nunstructured and structured discrete latent variables.\nOur technical contributions can be summarized as follows: (1) We apply the direct loss minimization\napproach to learning generative models; (2) We provide an alternative proof for the direct loss\nminimization approach, which does not rely on regularity assumptions; (3) We extend the proposed\ndirect optimization-based estimator to discrete VAEs with structured latent spaces.\n\n2 Related work\n\nReparameterization is an effective method to reduce the variance of gradient estimates in learning\nlatent variable models with continuous latent representations [17, 30, 29, 2, 26, 10]. The success\nof these works led to reparameterization approaches in discrete latent spaces. Rolfe et al. [32] and\nVahdat and collaborators [38, 37, 1] represent the marginal distribution per binary latent variable\nwith a continuous variable in the unit interval. This reparameterization approach allow propagating\ngradients through the continuous representation, but these works are restricted to binary random\nvariables, and as a by-product, they require high-dimensional representations for which inference is\nexponential in the dimension size. Djolonga and Krause used the Lovasz extension to relax a discrete\nsubmodular decision in order to propagate gradients through its continuous representation [7].\nMost relevant to our work, Maddison et al. [23] and Jang et al. [12] use the Gumbel-Max trick to\nreparameterize the discrete VAE objective, but, unlike our work, they relax the resulting formulation,\nreplacing the arg max with a softmax operation. In particular, they introduce the continuous Concrete\n(Gumbel-Softmax) distribution and replace the discrete random variables with continuous ones.\nInstead, our reparameterized objective remains non-differentiable and we use the direct optimization\napproach to propagate gradients through the arg max using the difference of two maximization\noperations.\nRecent work [25, 5] tackles the challenges associated with learning VAEs with structured discrete\nlatent variables, but they can only handle speci\ufb01c structures. For instance, the Gumbel-Sinkhorn\napproach [25] extends the Gumbel-Softmax distribution to model permutations and matchings. The\nPerturb-and-Parse approach [5] focuses on latent dependency parses, and iteratively replaces any\narg max with a softmax operation in a spanning tree algorithm. In contrast, our framework is not\nrestricted to a particular class of structures. Similar to our work, Johnson et al. [13] use the VAE\nencoder network to compute local potentials to be used in a structured potential function. Unlike the\ncited work, which makes use of message passing in graphical models with conjugacy structure, we\nuse the Gumbel-Max trick, which enables us to apply our method whenever the two maximization\noperations can be computed ef\ufb01ciently.\n\n3 Background\nTo model the data generating distribution, we consider samples S = {x1, ..., xm} from a potentially\nhigh-dimensional set xi \u2208 X , originating from an unknown underlying distribution. We estimate the\n\n2\n\n\fmodels of the form p\u03b8(x) =(cid:80)\n\u2212 log p\u03b8(x) \u2264(cid:88)\n\n(cid:88)\n\nx\u2208S\n\nx\u2208S\n\nparameters \u03b8 of a model p\u03b8(x) by minimizing its negative log-likelihood. We consider latent variable\nz\u2208Z p\u03b8(z)p\u03b8(x|z), with high-dimensional discrete variables z \u2208 Z,\nwhose log-likelihood computation requires marginalizing over the latent representation. Variational\nautoencoders utilize an auxiliary distribution q\u03c6(z|x) to upper bound the negative log-likelihood of\nthe observed data points:\n\n\u2212Ez\u223cq\u03c6 log p\u03b8(x|z) +\n\nKL(q\u03c6(z|x)||p\u03b8(z)).\n\n(1)\n\n(cid:88)\n\nx\u2208S\n\nIn discrete VAEs, the posterior distribution q\u03c6(z|x) and the data distribution conditioned on the\nlatent representation p\u03b8(x|z) are modeled via the Gibbs distribution, namely q\u03c6(z|x) = eh\u03c6(x,z) and\np\u03b8(x|z) = ef\u03b8(x,z). We use h\u03c6(x, z) and f\u03b8(x, z) to denote the (normalized) log-probabilities. Both\nquantities are modeled via differentiable (neural network based) mappings.\nParameter estimation of \u03b8 and \u03c6 is carried out by performing gradient descent on the right-hand\nside of Equation (1). Unfortunately, computing the gradient of the \ufb01rst term Ez\u223cq\u03c6 log p\u03b8(x|z) in\na high-dimensional discrete latent space z = (z1, ..., zn) is challenging because the expectation\nenumerates over all possible latent assignments:\n\n\u2207\u03c6Ez\u223cq\u03c6 log p\u03b8(x|z) =\n\neh\u03c6(x,z)\u2207\u03c6h\u03c6(x, z)f\u03b8(x, z)\n\n(2)\n\n(cid:88)\n\nz\u2208Z\n\nAlternatively, the score function estimator (REINFORCE) requires sampling from the high-\ndimensional structured latent space, which can be computationally challenging, and has high-variance,\nnecessitating many samples.\n\n3.1 Gumbel-Max reparameterization\nThe Gumbel-Max trick provides an alternative representation of the Gibbs distribution q\u03c6(z|x) that is\nbased on the extreme value statistics of Gumbel-distributed random variables. Let \u03b3 be a random\nfunction that associates an independent random variable \u03b3(z) for each input z \u2208 Z. When the\nrandom variables follow the zero mean Gumbel distribution law, whose probability density function\nz\u2208Z e\u2212(\u03b3(z)+c+e\u2212(\u03b3(z)+c)) for the Euler constant c \u2248 0.57, we obtain the following\n\nis g(\u03b3) = (cid:81)\n\nidentity1 (cf. [18]):\n\neh\u03c6(x,z) = P\u03b3\u223cg[z\u2217 = z], where z\u2217 (cid:44) arg max\n\n\u02c6z\u2208Z {h\u03c6(x, \u02c6z) + \u03b3(\u02c6z)}\n\n(3)\n\nNotably, samples from the Gibbs distribution can be obtained by drawing samples from the Gumbel\ndistribution (which does not depend on learnable parameters) and applying a parameterized mapping,\nbased on the arg max operation. For completeness, a proof for the above equality appears in the\nsupplementary material.\nIn the context of variational autoencoders, the Gumbel-Max formulation enables rewriting the\nexpectation Ez\u223cq\u03c6 log p\u03b8(x|z) with respect to the Gumbel distribution, similar to the application of the\nreparametrization trick in the continuous latent variable case [16]. Unfortunately, the parameterized\nmapping is non-differentiable, as the arg max function is piecewise constant. In response, the\nGumbel-Softmax estimator [23, 12] approximates the arg max via the softmax operation\n\nP\u03b3\u223cg[z\u2217 = z] = E\u03b3\u223cg[1z\u2217=z] \u2248 E\u03b3\u223cg\n\n(cid:80)\n\ne(h\u03c6(x,z)+\u03b3(z))/\u03c4\n\u02c6z\u2208Z e(h\u03c6(x,\u02c6z)+\u03b3(\u02c6z))/\u03c4\n\n(4)\n\nfor a temperature parameter \u03c4 (treated as a hyper-parameter), which produces a smooth objective\nfunction. Nevertheless, the approximated Gumbel-Softmax objective introduces bias, uses continuous\nrather than discrete variables (requiring discretization at test time), and its dependence on the\nsoftmax function can be computationally prohibitive when considering structured latent spaces\nz = (z1, ..., zn), as the normalization constant in Equation (4) sums over all the possible latent\nvariable realizations \u02c6z.\n\n1The set arg max\u02c6z\u2208Z{h\u03c6(x, \u02c6z) + \u03b3(\u02c6z)} contains all maximizing assignments (possibly more than one).\nHowever, since the Gumbel distribution is continuous, the \u03b3 for which the set of maximizing assignments\ncontains multiple elements has measure zero. For notational convenience, when we consider integrals (or\nprobability distributions), we ignore measure zero sets.\n\n3\n\n\f3.2 Direct loss minimization\n\nThe direct loss minimization approach has been introduced for learning discriminative models\n[24, 14, 35]. In the discriminative setting, the goal is to estimate a set of parameters2 \u03c6, used to\npredict a label for each (high-dimensional) input x \u2208 X via y\u2217 = arg maxy\u2208Y h\u03c6(x, y), where Y is\nthe set of continuous or discrete candidate labels. The score function h\u03c6(x, y) can be non-linear as a\nfunction of the parameters \u03c6, as developed by [14, 35].\nGiven training data tuples (x, y) sampled from an unknown underlying data distribution D, the\ngoodness of \ufb01t of the learned predictor is measured by a loss function f (y, y\u2217), which is not\nnecessarily differentiable. This is the case, for instance, when the labels Y are discrete, such as\nobject labels in object recognition or action labels in action classi\ufb01cation in videos [35]. As a result,\nthe expected loss E(x,y)\u223cD[f (y, y\u2217)] cannot always be optimized using standard methods such as\ngradient descent.\nThe typical solution is to replace the desired objective with a surrogate differentiable loss, such as\nthe cross-entropy loss between the targets and the predicted distribution over labels. However, the\ndirect loss minimization approach proposes to minimize the desired objective directly. The proposed\ngradient estimator uses a loss-perturbed predictor y\u2217(\u0001) = arg max\u02c6y{h\u03c6(x, \u02c6y) + \u0001f (y, \u02c6y)} and takes\nthe following form:\n\n\u2207\u03c6E(x,y)\u223cD[f (y, y\u2217)] = lim\n\u0001\u21920\n\n1\n\u0001\n\nE(x,y)\u223cD[\u2207\u03c6h\u03c6(x, y\u2217(\u0001)) \u2212 \u2207\u03c6h\u03c6(x, y\u2217)]\n\n(5)\n\nIn other words, the gradient estimator is obtained by performing pairs of maximization operations,\none over the original objective (second term) and one over a perturbed objective (\ufb01rst term). The\nunbiased estimator is obtained when the perturbation parameter \u0001 is approaching 0. In practice, the\nparameter \u0001 is assigned a small value, treated as a hyper-parameter, which introduces bias.\nUnfortunately, the standard direct loss minimization approach predicts a single label y\u2217 for an input\nx and, therefore, cannot generate a posterior distribution over samples y, i.e., it lacks a generative\nmodel. In our work we inject the Gumbel random variable to create a posterior over the label space\nenabling the application of this method to learning generative models. The Gumbel random variable\nallows us to overcome the general position assumption and the regularity conditions of [24, 14, 35].\n\n4 Gumbel-Max reparameterization and direct optimization\n\n(cid:16)\n\n(cid:17)\n\n(cid:88)\n\nz\u2208Z\n\nWe use the Gumbel-Max trick to rewrite the expected log-likelihood in the variational autoencoder\nobjective Ez\u223cq\u03c6 log p\u03b8(x|z) in the following form:\n\n(6)\n\nEz\u223cq\u03c6 log p\u03b8(x|z) =\n\nP\u03b3\u223cg[z\u2217 = z]f\u03b8(x, z) = E\u03b3\u223cg[f\u03b8(x, z\u2217)]\n\nidentity P\u03b3\u223cg[z\u2217 = z] = E\u03b3\u223cg[1z\u2217=z], the linearity of expectation(cid:80)\nE\u03b3\u223cg[(cid:80)\n\nz\u2208Z 1z\u2217=zf\u03b8(x, z\u2217)] and the fact that(cid:80)\n\nwhere z\u2217 is the maximizing assignment de\ufb01ned in Equation (3). The equality results from the\nz\u2208Z E\u03b3\u223cg[1z\u2217=z]f\u03b8(x, z) =\n\nz\u2208Z 1z\u2217=z = 1.\n\nThe gradient of f\u03b8(x, z\u2217) with respect to the decoder parameters \u03b8 can be derived by the chain\nrule. The main challenge is evaluating the gradient of E\u03b3\u223cg[f\u03b8(x, z\u2217)] with respect to the encoder\nparameters \u03c6, since z\u2217 relies on an arg max operation which is not differentiable. Our main result is\npresented in Theorem 1 and proposes a gradient estimator for the expectation E\u03b3\u223cg[f\u03b8(x, z\u2217)] with\nrespect to the encoder parameters \u03c6. In the following, we omit \u03b3 \u223c g to avoid notational overhead.\nTheorem 1. Assume that h\u03c6(x, z) is a smooth function of \u03c6. Let z\u2217 (cid:44) arg max\u02c6z\u2208Z{h\u03c6(x, \u02c6z)+\u03b3(\u02c6z)}\nand z\u2217(\u0001) (cid:44) arg max\u02c6z\u2208Z{\u0001f\u03b8(x, \u02c6z) + h\u03c6(x, \u02c6z) + \u03b3(\u02c6z)} be two random variables. Then\n\n(cid:16)\n\n(cid:17)\n\n\u2207\u03c6E\u03b3[f\u03b8(x, z\u2217)] = lim\n\u0001\u21920\n\n1\n\u0001\n\nE\u03b3[\u2207\u03c6h\u03c6(x, z\u2217(\u0001)) \u2212 \u2207\u03c6h\u03c6(x, z\u2217)]\n\n(7)\n\nProof sketch: We use a prediction generating function G(\u03c6, \u0001) = E\u03b3[max\u02c6z\u2208Z{\u0001f\u03b8(x, \u02c6z)+h\u03c6(x, \u02c6z)+\n\u03b3(\u02c6z)}], whose derivatives are functions of the predictions z\u2217, z\u2217(\u0001). The proof is composed of\n2We match the notation of the parameters \u03c6 of the posterior distribution to highlight the connection between\n\nthe two objectives.\n\n4\n\n\fFigure 1: Highlights the the bias-variance tradeoff of the direct optimization estimate as a function\nof \u0001, compared to the Gumbel-Softmax gradient estimate as a function of its temperature \u03c4. In both\ncases, the architecture consists of an encoder X \u2192 F C(300) \u2192 ReLU \u2192 F C(K) and a matching\ndecoder. The parameters were learned using the unbiased gradient in Equation (2) to ensure both the\ndirect and GSM have the same (unbiased) reference point. From its optimal parameters we estimate\nthe gradient randomly for 10, 000 times. Left: the bias from the analytic gradient. Right: the average\nstandard deviation of the gradient estimate.\n\nthree steps: (i) We prove that G(\u03c6, \u0001) is a smooth function of \u03c6, \u0001. Therefore, the Hessian of\nG(\u03c6, \u0001) exists and it is symmetric, namely \u2202\u03c6\u2202\u0001G(\u03c6, \u0001) = \u2202\u0001\u2202\u03c6G(\u03c6, \u0001). (ii) We show that the\nencoder gradient is apparent in the Hessian: \u2202\u03c6\u2202\u0001G(\u03c6, 0) = \u2207\u03c6E\u03b3[f\u03b8(x, z\u2217)]. (iii) We rely on the\nsmoothness G(\u03c6, \u0001) and derive our update rule as the complement representation of the Hessian:\n\u0001 (E\u03b3[\u2207\u03c6h\u03c6(x, z\u2217(\u0001)) \u2212 \u2207\u03c6h\u03c6(x, z\u2217)]). The complete proof is included in\n\u2202\u0001\u2202\u03c6G(\u03c6, 0) = lim\u0001\u21920\nthe supplementary material.\n\n1\n\nThe gradient estimator proposed in Theorem 1 requires two maximization operations. While comput-\ning z\u2217 is straightforward, realizing z\u2217(\u0001) requires evaluating f\u03b8(x, z) for each z \u2208 Z, i.e. evaluating\nthe decoder network multiple times. Nevertheless, the resulting computational overhead can be re-\nduced by performing these operations in parallel (we used batched operations in our implementation).\nThe gradient estimator is unbiased in the limit \u0001 \u2192 0. However, for small \u0001 values the gradient is\neither zero, when z\u2217(\u0001) = z\u2217, or very large, since the gradients\u2019 difference is multiplied by 1/\u0001.\nIn practice we use \u0001 \u2265 0.1 which means that the gradient estimator is biased. In Figure 1 we\ncompare the bias-variance tradeoff of the direct optimization estimator as a function of \u0001, with the\nGumbel-Softmax gradient estimator as a function of its temperature \u03c4. Figure 1 shows that while \u0001\nand \u03c4 are the sources of bias in these two estimates, they have different impact in each framework.\nAlgorithm 1 highlights the proposed\napproach. Each iteration begins with\ndrawing a minibatch x and computing\nthe corresponding latent representations\nby mapping x to h\u03c6(x, \u02c6z) and sam-\npling from the resulting posterior dis-\ntribution q\u03c6(z|x) (lines 3-5). The gradi-\nents w.r.t. \u03b8 are obtained via standard\nbackpropagation (line 7). The gradients\nw.r.t. \u03c6 are obtained by reusing the com-\nputed z\u2217 (line 5) and evaluating the loss-\nperturbed predictor (lines 6, 8).\nNotably, the arg max operations can\nbe solved via non-differentiable solvers\n(e.g. branch and bound, max-\ufb02ow).\n\nAlgorithm 1 Direct Optimization for discrete VAEs\n1: \u03c6, \u03b8 \u2190 Initialize parameters\n2: while \u03c6, \u03b8 not converged do\nx \u2190 Random minibatch\n3:\n\u03b3 \u2190 Random variables drawn from Gumbel distribution.\n4:\nz\u2217 \u2190 arg max\u02c6z{h\u03c6(x, \u02c6z) + \u03b3(\u02c6z)}\n5:\nz\u2217(\u0001) \u2190 arg max\u02c6z{\u0001f\u03b8(x, \u02c6z) + h\u03c6(x, \u02c6z) + \u03b3(\u02c6z)}\n6:\n7:\nCompute \u03b8-gradient:\ng\u03b8 \u2190 \u2207\u03b8f\u03b8(x, z\u2217)\nCompute \u03c6-gradient (eq. 7):\ng\u03c6 \u2190 1\n\u03c6, \u03b8 \u2190 Update parameters using gradients g\u03c6, g\u03b8\n\n9:\n10: end while\n\n8:\n\n(cid:16)\u2207\u03c6h\u03c6(x, z\u2217(\u0001)) \u2212 \u2207\u03c6h\u03c6(x, z\u2217)\n\n(cid:17)\n\n\u0001\n\n4.1 Structured latent spaces\n\nDiscrete latent variables often carry semantic meaning. For example, in the CelebA dataset there\nare n possible attributes for an images, e.g., Eyeglasses, Smiling, see Figure 5. Assigning a binary\nrandom variable to each of the attributes, namely z = (z1, ..., zn), allows us to generate images with\ncertain attributes turned on or off. In this example, the number of possible realizations of z is 2n.\n\n5\n\n\fdimensions, i.e., h\u03c6(x, z) =(cid:80)n\n\nLearning a discrete structured space may be computationally expensive. The Gumbel-Softmax\nestimator, as described in Equation (4), depends on the softmax normalization constant that requires\nto sum over exponential many terms (exponential in n). This computational complexity can be\nrelaxed by ignoring structural relations within the encoder h\u03c6(x, z) and decompose it according to its\ni=1 hi(x, zi; \u03c6). In this case the normalization constant requires only\nlinearly many term (linear in n). However, the encoder does not account for correlations between the\nvariables in the structured latent space.\nGumbel-Max reparameterization can account for structural relations in the latent space h\u03c6(x, z)\nwithout suffering from the exponential cost of the softmax operation, since computing the arg max is\noften more ef\ufb01cient than summing over all exponential possible options.\nFor computational ef\ufb01ciency we model only pairwise interactions in the structured encoder:\n\nn(cid:88)\n\nn(cid:88)\n\nh\u03c6(x, z) =\n\nhi(x, zi; \u03c6) +\n\nhi,j(x, zi, zj; \u03c6)\n\n(8)\n\ni=1\n\ni,j=1\n\nThe additional modeling power of hi,j(x, zi, zj; \u03c6) allows the encoder to better calibrate the depen-\ndences of the structured latent space that are fed into the decoder. In general, the pairwise correlations\nrequires a quadratic integer program solvers, such as the CPLEX to recover the arg max. How-\never, ef\ufb01cient max\ufb02ow solvers may be used when the pairwise correlations have special structural\nrestrictions, e.g., hi,j(x, zi, zj; \u03c6) = \u03b1i,j(x)zizj for \u03b1i,j(x) \u2265 0.\nThe gradient realization in Theorem 1 holds also for the structured setting, whenever the structure of\n\u03b3 follows the structure of h\u03c6. This gradient realization requires to compute z\u2217, z\u2217(\u0001). While z\u2217 only\ndepends on the structured encoder, the arg max-perturbation z\u2217(\u0001) involves the structured decoder\nf\u03b8(x, z1, ..., zn) that does not necessarily decompose according to the structured encoder. We use\nthe fact that we can compute z\u2217 ef\ufb01ciently and apply the low dimensional approximation \u02dcf\u03b8(x, z) =\nn). With this in mind, we approximate\nz\u2217(\u0001) with \u02dcz\u2217(\u0001) that is computed by replacing f\u03b8(x, z) with \u02dcf\u03b8(x, z). In our implementation we use\nthe batch operation to compute \u02dcf\u03b8(x, z) ef\ufb01ciently.\n\n\u02dcfi(x, zi; \u03b8), where \u02dcfi(x, zi; \u03b8) = f\u03b8(x, z\u2217\n\n1 , ..., zi, ..., z\u2217\n\n(cid:80)n\n\ni=1\n\n4.2 Semi-supervised learning\n\nDirect optimization naturally extends to semi-supervised learning, where we may add to the learning\nobjective the loss function (cid:96)(z, z\u2217), for supervised samples (x, z) \u2208 S1, to better control the prediction\nof the latent space. The semi-supervised discrete VAEs objective function is\n\nE\u03b3[f\u03b8(x, z\u2217)] +\n\nE\u03b3[(cid:96)(z, z\u2217)] +\n\nKL(q\u03c6(z|x)||p\u03b8(z))\n\n(9)\n\n(cid:88)\n\n(x,z)\u2208S1\n\n(cid:88)\n\nx\u2208S\n\n(cid:88)\n\nx\u2208S\n\nThe supervised component is explicitly handled by Theorem 1. Our supervised component is\nintimately related to direct loss minimization [24, 35]. The added random perturbation \u03b3 allows us to\nuse a generative model to prediction, namely, we can randomly generate different explanations z\u2217\nwhile the direct loss minimization allows a single explanation for a given x.\n\n5 Experimental evaluation\n\nWe begin our experiments by comparing the test loss of direct optimization, the Gumbel-Softmax\n(GSM) and the unbiased gradient computation in Equation (2). We performed these experiments using\nthe binarized MNIST dataset [33], Fashion-MNIST [40] and Omniglot [20]. The architecture consists\nof an encoder X \u2192 F C(300) \u2192 ReLU \u2192 F C(K), a matching decoder K \u2192 F C(300) \u2192\nReLU \u2192 F C(X) and a BCE loss. Following [12] we set our learning rate to 1e \u2212 3 and the\nannealing rate to 1e \u2212 5 and we used their annealing schedule every 1000 steps, setting the minimal \u0001\nto be 0.1. The results appear in Table 1. When considering MNIST and Omniglot, direct optimization\nachieves similar test loss to the unbiased method, which uses the analytical gradient computation in\nEquation (2). Also, direct optimization achieves a better result than GSM, in spite the fact both direct\noptimization and GSM use biased gradient descent: direct optimization uses a biased gradient for the\nexact objective in Equation (1), while GSM uses an exact gradient for an approximated objective.\nSurprisingly, on Fashion-MNIST, direct optimization achieves better test loss than the unbiased. To\n\n6\n\n\fMNIST\ndirect\n165.26\n153.08\n147.38\n143.95\n140.38\n\nFashion MNIST\n\nGSM unbiased\n228.46\n167.88\n206.40\n156.41\n205.60\n152.15\n147.56\n205.68\n200.88\n146.12\n\ndirect\n222.86\n198.39\n189.44\n184.21\n180.31\n\nunbiased\n164.53\n152.31\n149.17\n142.86\n155.37\n\nGSM\nk\n160.13\n10\n166.76\n20\n157.33\n30\n156.09\n40\n164.01\n50\nTable 1: Compares the test loss of VAEs with different categorial variables z \u2208 {1, ..., k}. Direct\noptimization achieves similar test loss to the unbiased method (Equation (2)) and achieves a better\ntest loss than GSM, in spite the fact both direct optimization and GSM use biased gradient descent.\n\nGSM unbiased\n155.44\n238.37\n152.05\n211.87\n152.10\n197.01\n195.22\n151.38\n156.84\n191.00\n\nOmniglot\ndirect\n155.94\n152.13\n150.14\n150.33\n149.12\n\nMNIST\n\nFashion-MNIST\n\nOmniglot\n\nFigure 2: Comparing the decrease of the test loss for k = 10. Top row: test loss as a function of the\nlearning epoch. Bottom row: test loss as a function of the learning wall-clock time. Incomplete plot\nin the bottom row suggests the algorithm required less time to \ufb01nish 300 epochs.\n\nfurther explore this phenomenon, in Figure 2 one can see that the unbiased method takes more epochs\nto converge, and eventually it achieves similar and often better test loss than direct optimization\non MNIST and Omniglot. In contrast, on Fashion-MNIST, direct optimization is better than the\nunbiased gradient method, which we attribute to the slower convergence of the unbiased method, see\nsupplementary material for more evidence.\nIt is important to compare the wall-clock time of each approach. The unbiased method requires\nk computations of the encoder and the decoder in a forward and backward pass. GSM requires a\nsingle forward pass and a single backward pass (encapsulating the k computations of the softmax\nnormalization within the code). In contrast, our approach requires a single forward pass, but k\ncomputations of the decoder f\u03b8(x, z) for z = 1, ..., k in the backward pass. In our implementation we\nuse the batch operation to compute f\u03b8(x, z) ef\ufb01ciently. Figure 2 compares the test loss as a function\nof the wall clock time and shows that while our method is 1.5 times slower than GSM, its test loss is\nlower than the GSM at any time.\nNext we perform a set of experiments on Fashion-MNIST using discrete structured latent spaces\nz = (z1, ..., zn) while each zi is binary, i.e., zi \u2208 {0, 1}. In the following experiments we consider\na structured decoder f\u03b8(x, z) = f\u03b8(x, z1, ..., zn). The decoder architecture consists of the modules\n(2 \u00d7 15) \u2192 F C(300) \u2192 ReLU \u2192 F C(X) and a BCE loss. For n = 15 the computational cost of\nthe softmax in GSM is high (exponential in n) and therefore one cannot use a structured encoder with\nGSM.\nOur \ufb01rst experiment with a structured decoder considers an unstructured encoder h\u03c6(x, z) =\ni=1 hi(x, zi; \u03c6) for GSM and direct optimization. This experiment demonstrates the effective-\n\u02dcfi(x, zi; \u03b8), where \u02dcfi(x, zi; \u03b8) =\n\n(cid:80)n\nness of our low dimensional approximation \u02dcf\u03b8(x, z) = (cid:80)n\n\ni=1\n\n7\n\n\fFigure 3: Left: test loss of unstructured encoder and a structured decoder as a function of their\nepochs. Middle: using structured decoders and comparing unstructured encoders to structured\nencoders, hi,j(x, zi, zj; \u03c6) = \u03b1i,j(x)zizj, both for general \u03b1i,j(x) (recovering the arg max using\nCPLEX) and for \u03b1i,j(x) \u2265 0 (recovering the arg max using max\ufb02ow). Right: comparing the\nwall-clock time of decomposable and structured encoders.\n\nMNIST\n\naccuracy\n\nbound\n\naccuracy\n\nGSM direct GSM direct\n\n#labels\n\n50\n100\n300\n600\n1200\n\ndirect\n92.6% 84.7% 90.24\n95.4% 88.4% 90.93\n96.4% 91.7% 90.39\n96.7% 92.3% 90.78\n96.8% 92.7% 90.45\n\n91.23\n90.64\n90.01\n89.77\n90.37\n\nFashion-MNIST\n\nbound\n\nGSM\n\ndirect\n63.3% 61.2% 129.66\n67.2% 64.2% 130.822\n70.0% 69.3% 130.653\n72.1% 71.6% 130.81\n73.7% 73.2% 130.921\n\nGSM\n129.813\n129.054\n130.371\n129.973\n130.063\n\nTable 2: Semi-supervised VAE on MNIST and Fashion-MNIST with 50/100/300/600/1200 labeled\nexamples out of the 50, 000 training examples.\n\n1 , ..., zi, ..., z\u2217\n\ni=1 hi(x, zi; \u03c6) +(cid:80)n\n\nrelations between latent random variables h\u03c6(x, z) =(cid:80)n\n\nf\u03b8(x, z\u2217\nn) for applying direct optimization to structured decoders in Section 4.1. We\nalso compare the unbiased estimators REBAR [36] and RELAX [9] and the recent ARM estimator\n[41].3 The results appear in Figure 3 and may suggest that using the approximated \u02dcz\u2217(\u0001), the gradient\nestimate of direct optimization still points towards a direction of descent for the exact objective.\nOur second experiment uses a structured decoder with structured encoders, which may account for cor-\ni,j=1 hi,j(x, zi, zj; \u03c6).\nIn this experiment we compare two structured encoders with pairwise functions hi,j(x, zi, zj; \u03c6) =\n\u03b1i,j(x)zizj. We use a general pairwise structured encoder where the arg max is recovered using the\nCPLEX algorithm [6]. We also apply a super-modular encoder, where \u03b1i,j(x) \u2265 0 is enforced using\nthe softplus transfer function, and the arg max is recovered using the max\ufb02ow algorithm [4]. In\nFigure 3 we compare the general and super-modular structured encoders with an unstructured encoder\n(\u03b1i,j(x) = 0), all are learned using direct optimization. One can see that structured encoders achieve\nbetter bounds, while the wall-clock time of learning super-modular structured encoder using max\ufb02ow\n(\u03b1i,j(x) \u2265 0) is comparable to learning unstructured encoders. One can also see that the general\nstructured encoder, with any \u03b1i,j(x), achieves better test loss than the super-modular structured\nencoder. However, this comes with a computational price, as the max\ufb02ow algorithm is orders of\nmagnitude faster than CPLEX, and structured encoder with CPLEX becomes better than max\ufb02ow\nonly in epoch 85, see Figure 3.\nFinally, we perform a set of semi-supervised experiments, for which we use a mixed continuous\ndiscrete architecture, [15, 12]. The architecture of the base encoder is (28 \u00d7 28) \u2192 F C(400) \u2192\nReLU \u2192 F C(200). The output of this layer is fed both to a discrete encoder hd and a continuous\nencoder hc. The discrete latent space is zd \u2208 {1, ..., 10} and its encoder hd is 200 \u2192 F C(100) \u2192\nReLU \u2192 F C(10). The continuous latent space considers k = 10, c = 20, and its encoder hc\nconsists of a 200 \u2192 F C(100) \u2192 ReLU \u2192 F C(66) \u2192 F C(40) to estimate the mean and variance\nof 20\u2212dimensional Gaussian random variables z1, ..., z10. The mixed discrete-continuous latent\nspace consists of the matrix diag(z\u2217\nd = i then this matrix is all zero, except for the i-th\nrow. The parameters of zc are shared across the rows z = 1, ..., k through the batch operation.\n\nd) \u00b7 zc, i.e, if z\u2217\n\n3For REBAR and RELAX we used the code in https://github.com/duvenaud/relax. and for ARM\n\nwe used the code in https://github.com/mingzhang-yin/ARM-gradient\n\n8\n\n\funsupervised\n\nsemisupervised\n\nFigure 4: Comparing unsupervised to semi-supervised VAE on MNIST, for which the discrete latent\nvariable has 10 values, i.e., z \u2208 {1, ..., 10}. Weak supervision helps the VAE to capture the class\ninformation and consequently improve the image generation process.\n\nwoman\n\nw/o smile\n\nsmile\n\nw/o glasses\n\nman\n\nw/o smile\n\nsmile\n\nwoman\n\nsmile\n\nw/o\nsmile\n\nglasses\n\nw/o\nsmile\n\nman\n\nsmile\n\nFigure 5: Learning attribute representation in CelebA, using our semi-supervised setting, by cali-\nbrating our arg max prediction using a loss function. These images here are generated while setting\ntheir attributes to get the desired image. The i\u2212th row consists the generation of the same continuous\nlatent variable for all the attributes\n\nWe conducted a quantitive experiment with weak supervision on MNIST and Fashion-MNIST with\n50/100/300/600/1200 labeled examples out of the 50, 000 training examples. For labeled examples,\nwe set the perturbed label z\u2217(\u0001) to be the true label. This is equivalent to using the indicator loss\nfunction over the space of correct predictions. A comparison of direct optimization with GSM appears\nin Table 2. Figure 4 shows the importance of weak supervision in semantic latent space, as it allows\nthe VAE to better capture the class information.\nSupervision in generative models also helps to control discrete semantics within images. We learn\nto generate images using k = 8 discrete attributes of the CelebA dataset (cf. [22]) while using our\nsemi-supervised VAE. For this task, we use convolutional layers for both the encoder and the decoder,\nexcept the last two layers of the continuous latent model which are linear layers that share parameters\nover the 8 possible representations of the image. In Figure 5, we show generated images with discrete\nsemantics turned on/off (with/without glasses, with/without smile, woman/man).\n\n6 Discussion and future work\n\nIn this work, we use the Gumbel-Max trick to reparameterize discrete VAEs using the arg max\nprediction and show how to propagate gradients through the non-differentiable arg max function. We\nshow that this approach compares favorably to state-of-the-art methods, and extend it to structured\nencoders and semi-supervised learning.\nThese results can be taken in a number of different directions. Our gradient estimation is practically\nbiased, while REINFORCE is an unbiased estimator. As a result, our methods may bene\ufb01t from the\nREBAR/RELAX framework, which directs biased gradients towards the unbiased gradient [36, 31].\nThere are also optimization-related questions that arise from our work, such as exploring the interplay\nbetween the \u0001 parameter and the learning rate.\n\n9\n\n\fReferences\n[1] Evgeny Andriyash, Arash Vahdat, and Bill Macready. Improved gradient-based optimization\n\nover discrete distributions. arXiv preprint arXiv:1810.00116, 2018.\n\n[2] David M Blei, Alp Kucukelbir, and Jon D McAuliffe. Variational inference: A review for\n\nstatisticians. Journal of the American Statistical Association, 112(518):859\u2013877, 2017.\n\n[3] David M Blei, Andrew Y Ng, and Michael I Jordan. Latent dirichlet allocation. Journal of\n\nmachine Learning research, 3(Jan):993\u20131022, 2003.\n\n[4] Y. Boykov, O. Veksler, and R. Zabih. Fast approximate energy minimization via graph cuts.\n\nPAMI, 2001.\n\n[5] Caio Corro and Ivan Titov. Differentiable perturb-and-parse: Semi-supervised parsing with a\nstructured variational autoencoder. In International Conference on Learning Representations,\n2019.\n\n[6] IBM ILOG Cplex. V12. 1: User?s manual for cplex.\n\nCorporation, 46(53):157, 2009.\n\nInternational Business Machines\n\n[7] Josip Djolonga and Andreas Krause. Differentiable learning of submodular models. In Advances\n\nin Neural Information Processing Systems, pages 1013\u20131023, 2017.\n\n[8] SM Ali Eslami, Nicolas Heess, Theophane Weber, Yuval Tassa, David Szepesvari, Geoffrey E\nHinton, et al. Attend, infer, repeat: Fast scene understanding with generative models. In\nAdvances in Neural Information Processing Systems, pages 3225\u20133233, 2016.\n\n[9] Will Grathwohl, Dami Choi, Yuhuai Wu, Geoff Roeder, and David Duvenaud. Backpropagation\nthrough the void: Optimizing control variates for black-box gradient estimation. In International\nConference on Learning Representations, 2018.\n\n[10] Shixiang Gu, Sergey Levine, Ilya Sutskever, and Andriy Mnih. Muprop: Unbiased backpropa-\n\ngation for stochastic neural networks. arXiv preprint arXiv:1511.05176, 2015.\n\n[11] Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, and Eric P Xing. Toward\ncontrolled generation of text. In International Conference on Machine Learning, pages 1587\u2013\n1596, 2017.\n\n[12] Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax.\n\narXiv preprint arXiv:1611.01144, 2016.\n\n[13] Matthew J Johnson, David K Duvenaud, Alex Wiltschko, Ryan P Adams, and Sandeep R\nDatta. Composing graphical models with neural networks for structured representations and\nfast inference. In Advances in neural information processing systems, pages 2946\u20132954, 2016.\n\n[14] J. Keshet, D. McAllester, and T. Hazan. Pac-bayesian approach for minimization of phoneme\n\nerror rate. In ICASSP, 2011.\n\n[15] Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-\nsupervised learning with deep generative models. In Advances in Neural Information Processing\nSystems, pages 3581\u20133589, 2014.\n\n[16] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint\n\narXiv:1312.6114, 2013.\n\n[17] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint\n\narXiv:1312.6114, 2013.\n\n[18] S. Kotz and S. Nadarajah. Extreme value distributions: theory and applications. World Scienti\ufb01c\n\nPublishing Company, 2000.\n\n[19] Matt J Kusner, Brooks Paige, and Jos\u00e9 Miguel Hern\u00e1ndez-Lobato. Grammar variational\n\nautoencoder. arXiv preprint arXiv:1703.01925, 2017.\n\n10\n\n\f[20] Brenden M Lake, Ruslan Salakhutdinov, and Joshua B Tenenbaum. Human-level concept\n\nlearning through probabilistic program induction. Science, 350(6266):1332\u20131338, 2015.\n\n[21] Dieterich Lawson, Chung-Cheng Chiu, George Tucker, Colin Raffel, Kevin Swersky, and\nNavdeep Jaitly. Learning hard alignments with variational inference. In 2018 IEEE International\nConference on Acoustics, Speech and Signal Processing (ICASSP), pages 5799\u20135803. IEEE,\n2018.\n\n[22] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in\nthe wild. In Proceedings of the IEEE International Conference on Computer Vision, pages\n3730\u20133738, 2015.\n\n[23] Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: A Con-\ntinuous Relaxation of Discrete Random Variables. In International Conference on Learning\nRepresentations, 2017.\n\n[24] D. McAllester, T. Hazan, and J. Keshet. Direct loss minimization for structured prediction.\n\nAdvances in Neural Information Processing Systems, 23:1594\u20131602, 2010.\n\n[25] Gonzalo Mena, David Belanger, Scott Linderman, and Jasper Snoek. Learning latent permuta-\ntions with gumbel-sinkhorn networks. In International Conference on Learning Representations,\n2018.\n\n[26] Andriy Mnih and Karol Gregor. Neural variational inference and learning in belief networks.\n\narXiv preprint arXiv:1402.0030, 2014.\n\n[27] Igor Mordatch and Pieter Abbeel. Emergence of grounded compositional language in multi-\n\nagent populations. In Thirty-Second AAAI Conference on Arti\ufb01cial Intelligence, 2018.\n\n[28] Lawrence R Rabiner and Biing-Hwang Juang. An introduction to hidden markov models. ieee\n\nassp magazine, 3(1):4\u201316, 1986.\n\n[29] Rajesh Ranganath, Sean Gerrish, and David Blei. Black box variational inference. In Arti\ufb01cial\n\nIntelligence and Statistics, pages 814\u2013822, 2014.\n\n[30] Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation\nand approximate inference in deep generative models. In Proceedings of the 31st International\nConference on Machine Learning, volume 32, pages 1278\u20131286, 2014.\n\n[31] Geoffrey Roeder, Yuhuai Wu, and David K Duvenaud. Sticking the landing: Simple, lower-\nvariance gradient estimators for variational inference. In Advances in Neural Information\nProcessing Systems, pages 6925\u20136934, 2017.\n\n[32] Jason Tyler Rolfe. Discrete variational autoencoders. arXiv preprint arXiv:1609.02200, 2016.\n\n[33] Ruslan Salakhutdinov and Iain Murray. On the quantitative analysis of deep belief networks. In\nProceedings of the 25th international conference on Machine learning, pages 872\u2013879. ACM,\n2008.\n\n[34] Dinghan Shen, Qinliang Su, Paidamoyo Chapfuwa, Wenlin Wang, Guoyin Wang, Lawrence\nCarin, and Ricardo Henao. Nash: Toward end-to-end neural architecture for generative semantic\nhashing. arXiv preprint arXiv:1805.05361, 2018.\n\n[35] Y. Song, A. G. Schwing, R. Zemel, and R. Urtasun. Training Deep Neural Networks via Direct\n\nLoss Minimization. In Proc. ICML, 2016.\n\n[36] George Tucker, Andriy Mnih, Chris J Maddison, John Lawson, and Jascha Sohl-Dickstein.\nIn\n\nRebar: Low-variance, unbiased gradient estimates for discrete latent variable models.\nAdvances in Neural Information Processing Systems, pages 2624\u20132633, 2017.\n\n[37] Arash Vahdat, Evgeny Andriyash, and William Macready. Dvae#: Discrete variational autoen-\ncoders with relaxed boltzmann priors. In Advances in Neural Information Processing Systems,\npages 1864\u20131874, 2018.\n\n11\n\n\f[38] Arash Vahdat, William G Macready, Zhengbing Bian, and Amir Khoshaman. Dvae++: Discrete\nvariational autoencoders with overlapping transformations. arXiv preprint arXiv:1802.04920,\n2018.\n\n[39] Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforce-\n\nment learning. In Reinforcement Learning, pages 5\u201332. Springer, 1992.\n\n[40] Han Xiao, Kashif Rasul, and Roland Vollgraf. Fashion-mnist: a novel image dataset for\n\nbenchmarking machine learning algorithms, 2017.\n\n[41] Mingzhang Yin and Mingyuan Zhou. ARM: Augment-REINFORCE-merge gradient for\nstochastic binary networks. In International Conference on Learning Representations, 2019.\n\n[42] Dani Yogatama, Phil Blunsom, Chris Dyer, Edward Grefenstette, and Wang Ling. Learning to\ncompose words into sentences with reinforcement learning. arXiv preprint arXiv:1611.09100,\n2016.\n\n12\n\n\f", "award": [], "sourceid": 3347, "authors": [{"given_name": "Guy", "family_name": "Lorberbom", "institution": "Technion"}, {"given_name": "Andreea", "family_name": "Gane", "institution": "Google AI"}, {"given_name": "Tommi", "family_name": "Jaakkola", "institution": "MIT"}, {"given_name": "Tamir", "family_name": "Hazan", "institution": "Technion"}]}