{"title": "Learning Wake-Sleep Recurrent Attention Models", "book": "Advances in Neural Information Processing Systems", "page_first": 2593, "page_last": 2601, "abstract": "Despite their success, convolutional neural networks are computationally expensive because they must examine all image locations. Stochastic attention-based models have been shown to improve computational efficiency at test time, but they remain difficult to train because of intractable posterior inference and high variance in the stochastic gradient estimates. Borrowing techniques from the literature on training deep generative models, we present the Wake-Sleep Recurrent Attention Model, a method for training stochastic attention networks which improves posterior inference and which reduces the variability in the stochastic gradients. We show that our method can greatly speed up the training time for stochastic attention networks in the domains of image classification and caption generation.", "full_text": "Learning Wake-Sleep Recurrent Attention Models\n\nJimmy Ba\n\nUniversity of Toronto\n\nRoger Grosse\n\nUniversity of Toronto\n\njimmy@psi.toronto.edu\n\nrgrosse@cs.toronto.edu\n\nRuslan Salakhutdinov\nUniversity of Toronto\n\nBrendan Frey\n\nUniversity of Toronto\n\nrsalskhu@cs.toronto.edu\n\nfrey@psi.toronto.edu\n\nAbstract\n\nDespite their success, convolutional neural networks are computationally expen-\nsive because they must examine all image locations. Stochastic attention-based\nmodels have been shown to improve computational ef\ufb01ciency at test time, but they\nremain dif\ufb01cult to train because of intractable posterior inference and high vari-\nance in the stochastic gradient estimates. Borrowing techniques from the literature\non training deep generative models, we present the Wake-Sleep Recurrent Atten-\ntion Model, a method for training stochastic attention networks which improves\nposterior inference and which reduces the variability in the stochastic gradients.\nWe show that our method can greatly speed up the training time for stochastic\nattention networks in the domains of image classi\ufb01cation and caption generation.\n\nIntroduction\n\n1\nConvolutional neural networks, trained end-to-end, have been shown to substantially outperform\nprevious approaches to various supervised learning tasks in computer vision (e.g. [1])). Despite their\nwide success, convolutional nets are computationally expensive when processing high-resolution\ninput images, because they must examine all image locations at a \ufb01ne scale. This has motivated\nrecent work on visual attention-based models [2, 3, 4], which reduce the number of parameters and\ncomputational operations by selecting informative regions of an image to focus on. In addition to\ncomputational speedups, attention-based models can also add a degree of interpretability, as one can\nunderstand what signals the algorithm is using by seeing where it is looking. One such approach\nwas recently used by [5] to automatically generate image captions and highlight which image region\nwas relevant to each word in the caption.\nThere are two general approaches to attention-based image understanding: hard and soft attention.\nSoft attention based models (e.g. [5]) obtain features from a weighted average of all image locations,\nwhere locations are weighted based on a model\u2019s saliency map. By contrast, a hard attention model\n(e.g. [2, 3]) chooses, typically stochastically, a series of discrete glimpse locations. Soft attention\nmodels are computationally expensive, as they have to examine every image location; we believe\nthat the computational gains of attention require a hard attention model. Unfortunately, this comes\nat a cost: while soft attention models can be trained with standard backpropagation [6, 5], this does\nnot work for hard attention models, whose glimpse selections are typically discrete.\nTraining stochastic hard attention models is dif\ufb01cult because the loss gradient involves intractable\nposterior expectations, and because the stochastic gradient estimates can have high variance. (The\nlatter problem was also observed by [7] in the context of memory networks.) In this work, we pro-\npose the Wake-Sleep Recurrent Attention Model (WS-RAM), a method for training stochastic recur-\nrent attention models which deals with the problems of intractable inference and high-variance gradi-\nents by taking advantage of several advances from the literature on training deep generative models:\n\n1\n\n\finference networks [8], the reweighted wake-sleep algorithm [9], and control variates [10, 11]. Dur-\ning training, the WS-RAM approximates posterior expectations using importance sampling, with a\nproposal distribution computed by an inference network. Unlike the prediction network, the infer-\nence network has access to the object category label, which helps it choose better glimpse locations.\nAs the name suggests, we train both networks using the reweighted wake-sleep algorithm. In ad-\ndition, we reduce the variance of the stochastic gradient estimates using carefully chosen control\nvariates. In combination, these techniques constitute an improved training procedure for stochastic\nattention models.\nThe main contributions of our work are the following. First, we present a new learning algorithm for\nstochastic attention models and compare it with a training method based on variational inference [2].\nSecond, we develop a novel control variate technique for gradient estimation which further speeds\nup training. Finally, we demonstrate that our stochastic attention model can learn to (1) classify\ntranslated and scaled MNIST digits, and (2) generate image captions by attending to the relevant\nobjects in images and their corresponding scale. Our model achieves similar performance to the\nvariational method [2], but with much faster training times.\n\n2 Related work\n\nIn recent years, there has been a \ufb02urry of work on attention-based neural networks. Such models\nhave been applied successfully in image classi\ufb01cation [12, 4, 3, 2], object tracking [13, 3], machine\ntranslation [6], caption generation [5], and image generation [14, 15]. Attention has been shown\nboth to improve computational ef\ufb01ciency [2] and to yield insight into the network\u2019s behavior [5].\nOur work is most closely related to stochastic hard attention models (e.g. [2]). A major dif\ufb01culty\nof training such models is that computing the gradient requires taking expectations with respect\nto the posterior distribution over saccades, which is typically intractable. This dif\ufb01culty is closely\nrelated to the problem of posterior inference in training deep generative models such as sigmoid\nbelief networks [16]. Since our proposed method draws heavily from the literature on training deep\ngenerative models, we overview various approaches here.\nOne of the challenges of training a deep (or recurrent) generative model is that posterior inference\nis typically intractable due to the explaining away effect. One way to deal with intractable inference\nis to train a separate inference network whose job it is to predict the posterior distribution. A classic\nexample was the Helmholtz machine [8], where the inference network predicts a mean \ufb01eld approx-\nimation to the posterior.1 The generative and inference networks are trained with the wake-sleep\nalgorithm: in the wake phase, the generative model is updated to increase a variational lower bound\non the data likelihood. In the sleep phase, data are generated from the model, and the inference\nnetwork is trained to predict the latent variables used to generate the observations.\nThe wake-sleep approach was limited by the fact that the wake and sleep phases were minimizing\ntwo unrelated objective functions. More recently, various methods have been proposed which unify\nthe training of the generative and inference networks into a single objective function. Neural vari-\national inference and learning (NVIL) [11] trains both networks to maximize a variational lower\nbound on the log-likelihood. Since the stochastic gradient estimates in NVIL are very noisy, the\nmethod of control variates is used to reduce the variance. In particular, one uses an algortihm from\nreinforcement learning called REINFORCE [17], which attempts to infer a reward baseline for each\ninstance. The choice of baseline is crucial to good performance; NVIL uses a separate neural net-\nwork to compute the baseline, an approach also used by [3] in the context of attention networks.\nControl variates are discussed in more detail in Section 4.4.\nThe reweighted wake-sleep approach [9] is similar to traditional wake-sleep, but uses importance\nsampling in place of mean \ufb01eld inference to approximate the posterior. Reweighted wake-sleep is\ndescribed more formally in Section 4.3. Another method based on inference networks is variational\nautoencoders [18, 19], which exploit a clever reparameterization of the probabilistic model in order\nto improve the signal in the stochastic gradients. NVIL, reweighted wake-sleep, and variational\nautoencoders have all been shown to achieve considerably higher test log-likelihoods compared to\n\n1In the literature, the inference network is often called a recognition network; we avoid this terminology to\n\nprevent confusion with the task of image classi\ufb01cation.\n\n2\n\n\fFigure 1: The Wake-Sleep Recurrent Attention Model.\n\ntraditional wake-sleep. The term \u201cHelmholtz machine\u201d is often used loosely to refer to the entire\ncollection of techniques which simultaneously learn a generative network and an inference network.\n\n3 Wake-Sleep Recurrent Attention Model\n\nWe now describe our wake-sleep recurrent attention model (WS-RAM). Given an image I, the net-\nwork \ufb01rst chooses a sequence of glimpses a = (a1, . . . , aN ), and after each glimpse, receives an\nobservation xn computed by a mapping g(an,I). This mapping might, for instance, extract an\nimage patch at a given scale. The \ufb01rst glimpse is based on a low-resolution version of the input,\nwhile subsequent glimpses are chosen based on information acquired from previous glimpses. The\nglimpses are chosen stochastically according to a distribution p(an | a1:n\u22121,I, \u03b8), where \u03b8 denotes\nthe parameters of the network. This is in contrast with soft attention models, which deterministi-\ncally allocate attention across all image locations. After the last glimpse, the network predicts a\ndistribution p(y | a,I, \u03b8) over the target y (for instance, the caption or image category).\nAs shown in Figure 1, the core of the attention network is a two-layer recurrent network, which we\nterm the \u201cprediction network\u201d, where the output at each time step is an action (saccade) which is\nused to compute the input at the next time step. A low-resolution version of the input image is fed\nto the network at the \ufb01rst time step, and the network predicts the class label at the \ufb01nal time step.\nImportantly, the low-resolution input is fed to the second layer, while the class label prediction is\nmade by the \ufb01rst layer, preventing information from propagating directly from the low-resolution\nimage to the output. This prevents local optima where the network learns to predict y directly from\nthe low-resolution input, disregarding attention completely.\nOn top of the prediction network is an inference network, which receives both the class label and\nthe attention network\u2019s top layer representation as inputs. It tries to predict the posterior distribu-\ntion q(an+1 | y, a1:n,I, \u03b7), parameterized by \u03b7, over the next saccade, conditioned on the image\ncategory being correctly predicted. Its job is to guide the posterior sampler during training time,\nthereby acting as a \u201cteacher\u201d for the attention network. The inference network is described further\nin Section 4.3.\nOne of the bene\ufb01ts of stochastic attention models is that the mapping g can be localized to a small\nimage region or coarse granularity, which means it can potentially be made very ef\ufb01cient. Further-\nmore, g need not be differentiable, which allows for operations (such as choosing a scale) which\nwould be dif\ufb01cult to implement in a soft attention network. The cost of this \ufb02exibility is that stan-\ndard backpropagation cannot be applied, so instead we use novel algorithms described in the next\nsection.\n\n3\n\nyyyyinferencenetworkIlow-resolutionp(y|a,I,\u2713)p(a1|I,\u2713)p(a2|a1,I,\u2713)p(a3|a1:2,I,\u2713)p(a4|a1:3,I,\u2713)q(a4|y,a1:3,I,\u2318)q(a3|y,a1:2,I,\u2318)q(a2|y,a1,I,\u2318)q(a1|y,I,\u2318)(x1,a1)(x2,a2)(x3,a3)(xN,aN)predictionnetwork\f4 Learning\n\n(cid:88)\n\na\n\nIn this work, we assume that we have a dataset with labels y for the supervised prediction task\n(e.g. object category). In contrast to the supervised saliency prediction task (e.g. [20, 21]), there are\nno labels for where to attend. Instead, we learn an attention policy based on the idea that the best\nlocations to attend to are the ones which most robustly lead the model to predict the correct category.\nIn particular, we aim to maximize the probability of the class label (or equivalently, minimize the\ncross-entropy) by marginalizing over the actions at each glimpse:\n\n(cid:96) = log p(y |I, \u03b8) = log\n\np(a|I, \u03b8)p(y | a,I, \u03b8).\n\n(1)\n\nWe train the attention model by maximizing a lower bound on (cid:96). In Section 4.1, we \ufb01rst describe\na previous approach which minimized a variational lower bound. We then introduce our proposed\nmethod which directly estimates the gradients of (cid:96). As shown in Section 4.2, our method can be\nseen as maximizing a tighter lower bound on (cid:96).\n\n4.1 Variational lower bound\n\nWe \ufb01rst outline the approach of [2], who trained the model to maximize a variational lower bound\non (cid:96). Let q(a| y,I) be an approximating distribution. The lower bound on (cid:96) is then given by:\nq(a| y,I) log p(y, a|I, \u03b8) + H[q] = F.\n\np(a|I, \u03b8)p(y | a,I, \u03b8) \u2265\n\n(cid:88)\n\n(cid:88)\n\n(cid:96) = log\n\n(2)\n\na\n\na\n\nIn the case where q(a| y,I) = p(a|I, \u03b8) is the prior, as considered by [2], this reduces to\n\nF =\n\np(a|I, \u03b8) log p(y | a,I, \u03b8).\n\n(3)\n\n(cid:88)\n\na\n\nThe learning rules can be derived by taking derivatives of Eqn. 3 with respect to the model parame-\nters:\n\n(cid:88)\n\na\n\n\u2202F\n\u2202\u03b8\n\n=\n\np(a|I, \u03b8)\n\n(cid:20) \u2202 log p(y | a,I, \u03b8)\n(cid:20) \u2202 log p(y | \u02dcam,I, \u03b8)\n\n\u2202\u03b8\n\n\u2202\u03b8\n\n(cid:21)\nThe summation can be approximated using M Monte Carlo samples \u02dcam from p(a|I, \u03b8):\n\u2202 log p(\u02dcam |I, \u03b8)\n\nM(cid:88)\n\n.\n\n+ log p(y | \u02dcam,I, \u03b8)\n\n\u2202\u03b8\n\n\u2202F\n\u2202\u03b8 \u2248\n\n1\nM\n\nm=1\n\n(cid:21)\n\n.\n\n(4)\n\n(5)\n\n+ log p(y | a,I, \u03b8)\n\n\u2202 log p(a|I, \u03b8)\n\n\u2202\u03b8\n\nThe partial derivative terms can each be computed using standard backpropagation. This suggests a\nsimple gradient-based training algorithm: For each image, one \ufb01rst computes the samples \u02dcam from\nthe prior p(a|I, \u03b8), and then updates the parameters according to Eqn. 5. As observed by [2], one\nmust carefully use control variates in order to make this technique practical; we defer discussion of\ncontrol variates to Section 4.4.\n\n4.2 An improved lower bound on the log-likelihood\n\nThe variational method described above has some counterintuitive properties early in training. First,\nbecause it averages the log-likelihood over actions, it greatly ampli\ufb01es the differences in probabili-\nties assigned to the true category by different bad glances. For instance, a glimpse sequence which\nleads to 0.01 probability assigned to the correct class is considered much worse than one which leads\nto 0.02 probability under the variational objective, even though in practice they may be equally bad\nsince they have both missed the relevant information. A second odd behavior is that all glimpse\nsequences are weighted equally in the log-likelihood gradient. It would be better if the training\nprocedure focused its effort on using those glances which contain the relevant information. Both of\nthese effects contribute noise in the training procedure, especially in the early stages of training.\n\n4\n\n\fInstead, we adopt an approach based on the wake-p step of reweighted wake-sleep [9], where we\nattempt to maximize the marginal log-probability (cid:96) directly. We differentiate the marginal log-\nlikelihood objective in Eqn. 1 with respect to the model parameters:\n\n(cid:20) \u2202 log p(y | a,I, \u03b8)\n\n\u2202\u03b8\n\n(cid:21)\n\n.\n\n+\n\n\u2202 log p(a|I, \u03b8)\n\n\u2202\u03b8\n\n(6)\n\n\u2202(cid:96)\n\u2202\u03b8\n\n=\n\n1\n\np(y |I, \u03b8)\n\na\n\np(a|I, \u03b8)p(y | a,I, \u03b8)\n\n(cid:88)\n\nThe summation and normalizing constant are both intractable to evaluate, so we estimate them using\nimportance sampling. We must de\ufb01ne a proposal distribution q(a| y,I), which ideally should be\nclose to the posterior p(a| y,I, \u03b8). One reasonable choice is the prior p(a|I, \u03b8), but another choice\nis described in Section 4.3. Normalized importance sampling gives a biased but consistent estimator\nof the gradient of (cid:96). Given samples \u02dca1, . . . , \u02dcaM from q(a| y,I), the (unnormalized) importance\nweights are computed as:\n\nThe Monte Carlo estimate of the gradient is given by:\n\n\u02dcwm =\n\np(\u02dcam |I, \u03b8)p(y | \u02dcam,I, \u03b8)\n\n.\n\nq(\u02dcam | y,I)\n\n(cid:20) \u2202 log p(y | \u02dcam,I, \u03b8)\n\n\u2202\u03b8\n\nM(cid:88)\n\nm=1\n\nwm\n\n+\n\n\u2202 log p(\u02dcam |I, \u03b8)\n\n\u2202\u03b8\n\n(cid:21)\n\n,\n\n(7)\n\n(8)\n\n\u2202(cid:96)\n\u2202\u03b8 \u2248\n\nwhere wm = \u02dcwm/(cid:80)M\nobjective function E(cid:104)\n(cid:34)\n\nm=1 \u02dcwm(cid:105)\n(cid:80)M\n(cid:35)\nM(cid:88)\n\n\u02dcwm\n\n(cid:34)\n\n(cid:35)\n\nM(cid:88)\n(cid:35)\n\nm=1\n\ni=1 \u02dcwi are the normalized importance weights. When q is chosen to be the\nprior, this approach is equivalent to the method of [22] for learning generative feed-forward net-\nworks.\nOur importance sampling based estimator can also be viewed as the gradient ascent update on the\n. Combining Jensen\u2019s inequality with the unbiasedness of\n\nlog 1\nM\n\nthe \u02dcwm shows that this is a lower bound on the log-likelihood:\n\n1\nM\n\nm=1\n\nE\n\nlog\n\n(9)\nWe relate this to the previous section by noting that F = E[log \u02dcwm]. Another application of Jensen\u2019s\ninequality shows that our proposed bound is at least as accurate as F:\n1\nlog\nM\n\nF = E [log \u02dcwm] = E\n\n\u2264 log E\n(cid:34)\nM(cid:88)\n\n= log E [ \u02dcwm] = (cid:96).\n\nM(cid:88)\n\n\u2264 E\n\nlog \u02dcwm\n\n\u02dcwm\n\n.\n\n1\nM\n\n(cid:35)\n\n(cid:34)\n\n(10)\n\n\u02dcwm\n\n1\nM\n\nm=1\n\nm=1\n\nBurda et al. [23] further analyzed a closely related importance sampling based estimator in the con-\ntext of generative models, bounding the mean absolute deviation and showing that the bias decreases\nmonotonically with the number of samples.\n\n4.3 Training an inference network\n\nLate in training, once the attention model has learned an effective policy, the prior distribution\np(a|I, \u03b8) is a reasonable choice for the proposal distribution q(a| y,I), as it puts signi\ufb01cant prob-\nability mass on good actions. But early in training, the model may have only a small probability of\nchoosing a good set of glimpses, and the prior may have little overlap with the posterior. To deal\nwith this, we train an inference network to predict, given the observations as well as the class label,\nwhere the network should look to correctly predict that class (see Figure 1). With this additional\ninformation, the inference network can act as a \u201cteacher\u201d for the attention policy.\nThe inference network predicts a sequence of glimpses stochastically:\n\nq(a| y,I, \u03b7) =\n\nq(an | y,I, \u03b7, a1:n\u22121).\n\n(11)\n\nThis distribution is analogous to the prior, except that each decision also takes into account the class\nlabel y. We denote the parameters for the inference network as \u03b7. During training, the prediction net-\nwork is learnt by following the gradient of the estimator in Eqn. 8 with samples \u02dcam \u223c q(a| y,I, \u03b7)\ndrawn from the inference network output.\n\n5\n\nN(cid:89)\n\nn=1\n\n\fOur training procedure for the inference network parallels the wake-q step of reweighted wake-\nsleep [9]. Intuitively, the inference network is most useful if it puts large probability density over\nlocations in an image that are most informative for predicting class labels. We therefore train the\ninference weights \u03b7 to minimize the Kullback-Leibler divergence between the recognition model\nprediction q(a| y,I, \u03b7) and posterior distribution from the attention model p(a| y,I, \u03b8):\n\nmin\n\n\u03b7\n\nDKL(p(cid:107) q) = min\n\n\u03b7 \u2212\n\np(a| y,I, \u03b8) log q(a| y,I, \u03b7).\n\n(12)\n\n(cid:88)\n\na\n\nThe gradient update for the recognition weights can be obtained by taking the derivatives of Eq. (12)\nwith respect to the recognition weights \u03b7:\n= E\n\n(cid:20) \u2202 log q(a| y,I, \u03b7)\n\n\u2202DKL(p(cid:107) q)\n\n(13)\n\n(cid:21)\n\np(a | y,I,\u03b8)\n\n.\n\n\u2202\u03b7\n\n\u2202\u03b7\n\nSince the posterior expectation is intractable, we estimate it with importance sampling. In fact, we\nreuse the importance weights computed for the prediction network update (see Eqn. 7) to obtain the\nfollowing gradient estimate for the recognition network:\n\n\u2202DKL(p(cid:107) q)\n\n\u2202\u03b7\n\n\u2248\n\n4.4 Control variates\n\nwm \u2202 log q(\u02dcam | y,I, \u03b7)\n\n\u2202\u03b7\n\n.\n\n(14)\n\nM(cid:88)\n\nm=1\n\n(cid:21)\n\nThe speed of convergence of gradient ascent with the gradients de\ufb01ned in Eqns. 8 and 14 suffers\nfrom high variance of the stochastic gradient estimates. Past work using similar gradient updates has\nfound signi\ufb01cant bene\ufb01t from the use of control variates, or reward baselines, to reduce the variance\n[17, 10, 3, 11, 2]. Choosing effective control variates for the stochastic gradient estimators amounts\nto \ufb01nding a function that is highly correlated with the gradient vectors, and whose expectation is\nknown or tractable to compute [10, 24]. Unfortunately, a good choice of control variate is highly\nmodel-dependent.\nWe \ufb01rst note that:\n\n(cid:20) \u2202 log q(a| y,I, \u03b7)\n\n(cid:21)\n\n\u2202\u03b7\n\n= 0.\n\n(15)\n\nE\nq(a | y,I,\u03b7)\n\n\u2202 log p(a|I, \u03b8)\n\n\u2202\u03b8\n\n= 0, E\n\nq(a | y,I,\u03b7)\n\n(cid:20) p(a|I, \u03b8)\n\nq(a| y,I, \u03b7)\n\nThe terms inside the expectation are very similar to the gradients in Eqns. 8 and 14, suggesting\nthat stochastic estimates of these expectations would make good control variates. To increase the\ncorrelation between the gradients and the control variates, we reuse the same set of samples and\nimportance weights for the gradients and control variates. Using these control variates results in the\ngradient estimates for the prediction and recognition networks, we obtain:\n\n\uf8eb\uf8edwm \u2212\n(cid:18)\n\nwm \u2212\n\nM(cid:88)\nM(cid:88)\n\nm=1\n\nm=1\n\n\uf8f6\uf8f8 \u2202 log p(\u02dcam |I, \u03b8)\n\n\u2202\u03b8\n\np(\u02dcam | I,\u03b8)\nq(\u02dcam | y,I,\u03b7)\n\n(cid:80)M\n(cid:19) \u2202 log q(\u02dcam | y,I, \u03b7)\n\np(\u02dcai | I,\u03b8)\nq(\u02dcai | y,I,\u03b7)\n\ni=1\n\n.\n\n1\nM\n\n\u2202\u03b7\n\n,\n\n(16)\n\n(17)\n\n\u2202 log p(a|I, \u03b8)\n\n\u2202\u03b8\n\n\u2202DKL(p(cid:107) q)\n\n\u2202\u03b7\n\n\u2248\n\n\u2248\n\nOur use of control variates does not bias the gradient estimates (beyond the bias which is present\ndue to importance sampling). However, as we show in the experiments, the resulting estimates have\nmuch lower variance than those of Eqns. 8 and 14.\nFollowing the analogy with reinforcement learning highlighted by [11], these control variates can\nalso be viewed as reward baselines:\n\np(a | I,\u03b8)\nq(a | y,I,\u03b7)\n\n(cid:104) p(a | I,\u03b8)\n\nE\nq(a | y,I,\u03b7) [p(y | a,I, \u03b8)]\n\nM \u00b7 E\nq(a | y,I,\u03b7)\nq(a | y,I,\u03b7)\nE\np(a | I,\u03b8) [p(y | a,I, \u03b8)]\nM \u00b7 E\np(a | I,\u03b8) [p(y | a,I, \u03b8)]\n\nE\nq(a | y,I,\u03b8) [p(y | a,I, \u03b8)]\n=\n\n,\n\n1\nM\n\nbp =\n\nbq =\n\nwhere M is the number of samples drawn for proposal q.\n\n(cid:105) \u2248\n\n(cid:80)M\n\np(\u02dcam | I,\u03b8)\nq(\u02dcam | y,I,\u03b7)\n\np(\u02dcai | I,\u03b8)\nq(\u02dcai | y,I,\u03b7)\n\ni=1\n\n,\n\n(18)\n\n(19)\n\n6\n\n\fFigure 2: Left: Training error as a function of the number of updates. Middle: variance of the gradient\nestimates. Right: effective sample size (max = 5). Horizontal axis: thousands of updates. VAR: variational\nbaseline; WS-RAM: our proposed method; +q: uses the inference networks for the proposal distribution; +c:\nuses control variates.\n\n4.5 Encouraging exploration\n\nSimilarly to other methods based on reinforcement learning, stochastic attention networks face the\nproblem of encouraging the method to explore different actions. Since the gradient in Eqn. 8 only\nrewards or punishes glimpse sequences which are actually performed, any part of the space which\nis never visited will receive no reward signal. [2] introduced several heuristics to encourage ex-\nploration, including: (1) raising the temperature of the proposal distribution, (2) regularizing the\nattention policy to encourage viewing all image locations, and (3) adding a regularization term to\nencourage high entropy in the action distribution. We have implemented all three heuristics for\nthe WS-RAM and for the baselines. While these heuristics are important for good performance of\nthe baselines, we found that they made little difference to the WS-RAM because the basic method\nalready explores adequately.\n\n5 Experimental results\n\nTo measure the effectiveness of the proposed WS-RAM method, we \ufb01rst investigated a toy classi\ufb01-\ncation task involving a variant of the MNIST handwritten digits dataset [25] where transformations\nwere applied to the images. We then evaluated the proposed method on a substantially more dif\ufb01cult\nimage caption generation task using the Flickr8k [26] dataset.\n\n5.1 Translated scaled MNIST\n\nWe generated a dataset of randomly translated and scaled handwritten digits from the MNIST\ndataset [25]. Each digit was placed in a 100x100 black background image at a random location\nand scale. The task was to identify the digit class. The attention models were allowed four glimpses\nbefore making a classi\ufb01cation prediction. The goal of this experiment was to evaluate the effective-\nness of our proposed WS-RAM model compared with the variational approach of [2].\nFor both the WS-RAM and the baseline, the architecture was a stochastic attention model which\nused ReLU units in all recurrent layers. The actions included both continuous and discrete latent\nvariables, corresponding to glimpse scale and location, respectively. The distribution over actions\nwas represented as a Gaussian random variable for the location and an independent multinomial\nrandom variable for the scale. All networks were trained using Adam [27], with the learning rate set\nto the highest value that allowed the model to successfully converge to a sensible attention policy.\nThe classi\ufb01cation performance results are shown in Table 1. In Figure 2, the WS-RAM is compared\nwith the variational baseline, each using the same number of samples (in order to make computation\ntime roughly equivalent). We also show comparisons against ablated versions of the WS-RAM\nwhere the control variates and inference network were removed. When the inference network was\nremoved, the prior p(a|I, \u03b8) was used for the proposal distribution.\nIn addition to the classi\ufb01cation results, we measured the effective sample size (ESS) of our method\nwith and without control variates and the inference network. ESS is a standard metric for evaluating\nm(wm)2, where wm denotes the normalized importance\nweights. Results are shown in Figure 2. Using the inference network reduced the variances in\n\nimportance samplers, and is de\ufb01ned as 1/(cid:80)\n\n7\n\n02040608010010-1100Training ErrorVARVAR+cWS-RAMWS-RAM+cWS-RAM+qWS-RAM+q+c0204060801000.00.51.01.52.02.53.03.54.0Variance of Estimated GradientVARVAR+cWS-RAMWS-RAM+cWS-RAM+qWS-RAM+q+c0204060801000123456Effective Sample SizeWS-RAMWS-RAM+cWS-RAM+qWS-RAM+q+c\fTest err.\nno c.v.\n+c.v.\n\nVAR WS-RAM WS-RAM + q\n3.11%\n1.81%\n\n4.23%\n1.85%\n\n2.59%\n1.62%\n\nTable 1: Classi\ufb01cation error rate comparison for the\nattention models trained using different algorithms on\ntranslated scaled MNIST. The numbers are reported\nafter 10 million updates using 5 samples.\n\nFigure 3: The effect of the exploration heuristics on\nthe variational baseline and the WS-RAM.\n\nBLEU1\n\nBLEU2\n\nBLEU3\n\nBLEU4\n\nVAR\n\nWS-RAM+Qnet\n\n62.3\n61.1\n\n41.6\n40.4\n\n26.9\n26.9\n\n17.2\n17.8\n\nTable 2: BLEU score performance on the Flickr8K\ndataset for our WS-RAM and the variational method.\n\nFigure 4:\nTraining negative log-likelihood on\nFlickr8K for the \ufb01rst 10,000 updates. See Figure 2\nfor the labels.\n\ngradient estimation, although this improvement did not re\ufb02ect itself in the ESS. Control variates\nimproved both metrics.\nIn Section 4.5, we described heuristics which encourage the models to explore the action space. Fig-\nure 3 compares the training with and without these heuristics. Without the heuristics, the variational\nmethod quickly fell into a local minimum where the model predicted only one glimpse scale over\nall images; the exploration heuristics \ufb01xed this problem. By contrast, the WS-RAM did not appear\nto have this problem, so the heuristics were not necessary.\n\n5.2 Generating captions using multi-scale attention\n\nWe also applied the WS-RAM method to learn a stochastic attention model similar to [5] for gener-\nating image captions. We report results on the widely-used Flickr8k dataset. The training/valid/test\nsplit followed the same protocol as used in previous work [28].\nThe goal of this experiment was to examine the improvement of the WS-RAM over the variational\nmethod for learning with realistic imgaes. Similarly to [5], we \ufb01rst ran a convolutional network,\nand the attention network then determined which part of the convolutional net representation to\nattend to. The attention network predicted both which layer to attend to and a location within the\nlayer, in contrast with [5], where the scale was held \ufb01xed. Because a convolutional net shrinks\nthe representation with max-pooling, choosing a layer is analogous to choosing a scale. At each\nglimpse, the inference network was given the immediate preceding word in the target sentences.\nWe compare the BLEU scores of our WS-RAM and the variational method in in Table 2. Figure 4\nshows training curves for both models. We observe that WS-RAM obtained similar performance to\nthe variatinoal method, but trained more ef\ufb01ciently.\n\n6 Conclusions\n\nIn this paper, we introduced the Wake-Sleep Recurrent Attention Model (WS-RAM), an ef\ufb01cient\nmethod for training stochastic attention models. This method improves upon prior work by using the\nreweighted wake-sleep algorithm [9] to approximate expectations from the posterior over glimpses.\nWe also introduced control variates to reduce the variability of the stochastic gradients. Our method\nreduces the variance in the gradient estimates and accelerates training of attention networks for both\ninvariant handwritten digit recognition and image caption generation.\n\nAcknowledgments\nThis work was supported by the Fields Institute, Samsung, ONR Grant N00014-14-1-0232 and the\nhardware donation of NVIDIA Corporation.\n\n8\n\n05010015020010-210-1Training ErrorVAR+c, no explorationVAR+c + explorationWS-RAM+q+c, no explorationWS-RAM+q+c + exploration0204060801003638404244464850Training Negative LoglikelihoodWS-RAM+q+cVAR+c\fReferences\n[1] A. Krizhevsky, I. Sutskever, , and G. E. Hinton. ImageNet classi\ufb01cation with deep convolutional neural\n\nnetworks. In Neural Information Processing Systems, 2012.\n\n[2] J. Ba, V. Mnih, and K. Kavukcuoglu. Multiple object recognition with visual attention. In International\n\nConference on Learning Representations, 2015.\n\n[3] V. Mnih, N. Heess, A. Graves, and K. Kavukcuoglu. Recurrent models of visual attention. In Neural\n\nInformation Processing Systems, 2014.\n\n[4] Y. Tang, N. Srivastava, and R. Salakhutdinov. Learning generative models with visual attention. In Neural\n\nInformation Processing Systems, 2014.\n\n[5] K. Xu, J. Ba, R. Kiros, K. Cho, A. Courville, R. Salakhutdinov, R. S. Zemel, and Y. Bengio. Show, attend,\nand tell: neural image caption generation with visual attention. In International Conference on Machine\nLearning, 2015.\n\n[6] D. Bahdanau, K. Cho, and Y. Bengio. Neural machine translation by jointly learning to align and translate.\n\nIn International Conference on Learning Representations, 2015.\n\n[7] W. Zaremba and I. Sutskever. Reinforcement learning neural Turing machines. arXiv:1505.00521, 2015.\n[8] P. Dayan, G. E. Hinton, R. M. Neal, and R. S. Zemel. The Helmholtz machine. Neural Computation,\n\n7:889\u2013904, 1995.\n\n[9] J. Bornschein and Y. Bengio. Reweighted wake-sleep. arXiv:1406.2751, 2014.\n[10] J. Paisley, D. M. Blei, and M. I. Jordan. Variational Bayesian inference with stochastic search. In Inter-\n\nnational Conference on Machine Learning, 2012.\n\n[11] A. Mnih and K. Gregor. Neural variational inference and learning in belief networks. In International\n\nConference on Machine Learning, 2014.\n\n[12] H. Larochelle and G. E. Hinton. Learning to combine foveal glimpses with a third-order Boltzmann\n\nmachine. In Neural Information Processing Systems, 2010.\n\n[13] M. Denil, L. Bazzani, H. Larochelle, and N. de Freitas. Learning where to attend with deep architectures\n\nfor image tracking. Neural Computation, 24(8):2151\u201384, April 2012.\n\n[14] A. Graves. Generating sequences with recurrent neural networks. arXiv:1308.0850, 2014.\n[15] K. Gregor, I. Danihelka, A. Graves, and D. Wierstra. DRAW: a recurrent neural network for image\n\ngeneration. arXiv:1502.04623, 2015.\n\n[16] Radford M. Neal. Connectionist learning of belief networks. Arti\ufb01cial Intelligence, 1992.\n[17] R. J. Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning.\n\nMachine Learning, 8:229\u2013256, 1992.\n\n[18] D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In International Conference on Learning\n\nRepresentations, 2014.\n\n[19] D. J. Rezende, S. Mohamed, and D. Wierstra. Stochastic backpropagation and approximate inference in\n\ndeep generative models. In International Conference on Machine Learning, 2014.\n\n[20] L. Itti, C. Koch, and E. Niebur. A model of saliency-based visual attention for rapid scene analysis. IEEE\n\nTransactions of Pattern Analysis and Machine Intelligence, 20(11):1254\u201359, November 1998.\n\n[21] T. Judd, K. Ehinger, F. Durand, and A. Torralba. Learning to predict where humans look. In International\n\nConference on Computer Vision, 2009.\n\n[22] Y. Tang and R. Salakhutdinov. Learning stochastic feedforward neural networks. In Neural Information\n\nProcessing Systems, 2013.\n\n[23] Y. Burda, R. Grosse, and R. Salakhutdinov. Importance weighted autoencoders. arXiv:1509.00519, 2015.\n[24] Lex Weaver and Nigel Tao. The optimal reward baseline for gradient-based reinforcement learning.\nIn Proceedings of the Seventeenth conference on Uncertainty in arti\ufb01cial intelligence, pages 538\u2013545.\nMorgan Kaufmann Publishers Inc., 2001.\n\n[25] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition.\n\nProceedings of the IEEE, 86(11):2278\u20132324, 1998.\n\n[26] Micah Hodosh, Peter Young, and Julia Hockenmaier. Framing image description as a ranking task: Data,\n\nmodels and evaluation metrics. Journal of Arti\ufb01cial Intelligence Research, pages 853\u2013899, 2013.\n\n[27] D. Kingma and J. L. Ba. Adam: a method for stochastic optimization. arXiv:1412.6980, 2014.\n[28] Andrej Karpathy and Li Fei-Fei. Deep visual-semantic alignments for generating image descriptions.\n\narXiv preprint arXiv:1412.2306, 2014.\n\n9\n\n\f", "award": [], "sourceid": 1520, "authors": [{"given_name": "Jimmy", "family_name": "Ba", "institution": "University of Toronto"}, {"given_name": "Russ", "family_name": "Salakhutdinov", "institution": "University of Toronto"}, {"given_name": "Roger", "family_name": "Grosse", "institution": "University of Toronto"}, {"given_name": "Brendan", "family_name": "Frey", "institution": "U. Toronto"}]}