{"title": "Training Deep Models Faster with Robust, Approximate Importance Sampling", "book": "Advances in Neural Information Processing Systems", "page_first": 7265, "page_last": 7275, "abstract": "In theory, importance sampling speeds up stochastic gradient algorithms for supervised learning by prioritizing training examples. In practice, the cost of computing importances greatly limits the impact of importance sampling. We propose a robust, approximate importance sampling procedure (RAIS) for stochastic gradient de- scent. By approximating the ideal sampling distribution using robust optimization, RAIS provides much of the benefit of exact importance sampling with drastically reduced overhead. Empirically, we find RAIS-SGD and standard SGD follow similar learning curves, but RAIS moves faster through these paths, achieving speed-ups of at least 20% and sometimes much more.", "full_text": "Training Deep Models Faster with\n\nRobust, Approximate Importance Sampling\n\nTyler B. Johnson\n\nUniversity of Washington, Seattle\n\ntbjohns@washington.edu\n\nCarlos Guestrin\n\nUniversity of Washington, Seattle\nguestrin@cs.washington.edu\n\nAbstract\n\nIn theory, importance sampling speeds up stochastic gradient algorithms for super-\nvised learning by prioritizing training examples. In practice, the cost of computing\nimportances greatly limits the impact of importance sampling. We propose a robust,\napproximate importance sampling procedure (RAIS) for stochastic gradient de-\nscent. By approximating the ideal sampling distribution using robust optimization,\nRAIS provides much of the bene\ufb01t of exact importance sampling with drastically\nreduced overhead. Empirically, we \ufb01nd RAIS-SGD and standard SGD follow\nsimilar learning curves, but RAIS moves faster through these paths, achieving\nspeed-ups of at least 20% and sometimes much more.\n\n1\n\nIntroduction\n\nDeep learning models perform excellently on many tasks. Training such models is resource-intensive,\nhowever, as stochastic gradient descent algorithms can require days or weeks to train effectively. After\na short period training, models usually perform well on some\u2014or even most\u2014training examples. As\ntraining continues, frequently reconsidering such \u201ceasy\u201d examples slows further improvement.\nImportance sampling prioritizes training examples for SGD in a principled way. The technique\nsuggests sampling example i with probability proportional to the norm of loss term i\u2019s gradient. This\ndistribution both prioritizes challenging examples and minimizes the stochastic gradient\u2019s variance.\nSGD with optimal importance sampling is impractical, however, since computing the sampling\ndistribution requires excessive time. [1] and [2] analyze importance sampling for SGD and convex\nproblems; practical versions of these algorithms sample proportional to \ufb01xed constants. For deep\nmodels, other algorithms attempt closer approximations of gradient norms [3, 4, 5]. But these\nalgorithms are not inherently robust. Without carefully chosen hyperparameters or additional forward\npasses, these algorithms do not converge, let alone speed up training.\nWe propose RAIS, an importance sampling procedure for SGD with several appealing qualities. First,\nRAIS determines each sampling distribution by solving a robust optimization problem. As a result,\neach sampling distribution is minimax optimal with respect to an uncertainty set. Since RAIS trains\nthis uncertainty set in an adaptive manner, RAIS is not sensitive to hyperparameters.\nIn addition, RAIS maximizes the bene\ufb01t of importance sampling by adaptively increasing SGD\u2019s\nlearning rate\u2014an effective yet novel idea to our knowledge. This improvement invites the idea that\none RAIS-SGD iteration equates to more than one iteration of conventional SGD. Interestingly, when\nplotted in terms of \u201cepochs equivalent,\u201d the learning curves of the algorithms align closely.\nRAIS applies to any model that is trainable with SGD. RAIS also combines nicely with standard\n\u201ctricks,\u201d including data augmentation, dropout, and batch normalization. We show this empirically in\n\u00a76. In this section, we also demonstrate that RAIS consistently improves training times. To provide\ncontext for the paper, we include qualitative results from these experiments in Figure 1.\n\n32nd Conference on Neural Information Processing Systems (NeurIPS 2018), Montr\u00e9al, Canada.\n\n\fhorse\n\nbird\n\nplane\n\nship\n\ncar\n\ncar\n\nship\n\nhorse\n\ntruck\n(car)\n\ndeer\n(cat)\n\ndog\n(cat)\n\nbird\n(plane)\n\ncat\n(dog)\n\ntruck\n(ship)\n\ncar\n(truck)\n\ndog\n(deer)\n\ntrout mower chimp\n\nroad\n\nroad\n\nchair\n\nsun-\nkey-\n\ufb02ower board\n\nshrew worm seal\n(mouse) (snake)\n(otter)\n\nlion\n(tiger)\n\nbaby\n(girl)\n\nshark spider\n(whale)\n(crab)\n\notter\n(seal)\n\n1\n\n1\n\n0\n\n1\n\n0\n\n0\n\n1\n\n1\n\n9 (0)\n\n8 (4)\n\n1 (7)\n\n1 (2)\n\n5 (7)\n\n7 (2)\n\n6 (9)\n\n5 (6)\n\nFigure 1: Nonpriority and priority training examples for image classi\ufb01cation. Left: Examples\nthat RAIS samples infrequently during training. Right: Examples that RAIS prioritizes. Bold denotes\nthe image\u2019s label. Parentheses denote a different class that the model considers likely during training.\nDatasets are CIFAR-10 (top), CIFAR-100 (middle), and rotated MNIST (bottom).\n\nnPn\n\n2 Problem formulation\nGiven loss functions f1, f2, . . . , fn and a tuning parameter 2 R0, our task is to ef\ufb01ciently solve\n(P)\n\nF (w) , where F (w) = 1\n\nminimize\n\ni=1 fi(w) + \n\n2 kwk2 .\n\nw2Rd\n\nA standard algorithm for solving (P) is stochastic gradient descent. Let w(t) denote the optimization\nvariables when iteration t begins. SGD updates these weights via\nw(t+1) w(t) \u2318(t)g(t) .\n\n(1)\nAbove, \u2318(t) 2 R>0 is a learning rate, speci\ufb01ed by a schedule: \u2318(t) = lr_sched (t). The vector g(t)\nis an unbiased stochastic approximation of the gradient rF (w(t)). SGD computes g(t) by sampling\na minibatch of |M| indices from {1, 2, . . . , n} uniformly at random (or approximately so). Denoting\nthis minibatch by M(t), SGD de\ufb01nes the stochastic gradient as\n\ng(t) = 1\n\n|M|Pi2M(t) rfi(w(t)) + w(t) .\n\nIn this work, we assume an objective function, learning rate schedule, and minibatch size, and we\npropose a modi\ufb01ed algorithm called RAIS-SGD. RAIS prioritizes examples by sampling minibatches\nnon-uniformly, allowing us to train models using fewer iterations and less time.\n\n(2)\n\n(3)\n\n3 SGD with oracle importance sampling\n\nWe now introduce an SGD algorithm with \u201coracle\u201d importance sampling, which prioritizes examples\nusing exact knowledge of importance values. RAIS-SGD is an approximation of this algorithm.\nGiven w(t), let us de\ufb01ne the expected training progress attributable to iteration t as\n\nE(t) = kw(t) w?k2 Ehkw(t+1) w?k2i\n\n= 2\u2318(t)hrF (w(t)), w(t) w?i [\u2318(t)]2Ehkg(t)k2i .\n\nHere w? denotes the solution to (P), and the expectation is with respect to minibatch M(t). The\nequality follows from plugging in (1) and applying the fact that g(t) is unbiased.\nWe refer to our oracle algorithm as O-SGD, and we refer to SGD with uniform sampling as U-SGD.\nAt a high level, O-SGD makes two changes to U-SGD in order to increase E(t). First, O-SGD\nsamples training examples non-uniformly in a way that minimizes the variance of the stochastic\ngradient. This \ufb01rst change is not new\u2014see [1], for example. Second, to compensate for the \ufb01rst\nimprovement, O-SGD adaptively increases the learning rate. This second change, which is novel to\nour knowledge, can be essential for obtaining large speed-ups.\n\n2\n\n\f3.1 Maximizing progress with oracle importance sampling\nBy sampling minibatches non-uniformly, O-SGD prioritizes training examples in order to decrease\nE[kg(t)\ni = 1.\nO-SGD constructs minibatch M(t) by sampling independently |M| examples according to p(t).\nInstead of (2), the resulting stochastic gradient is\n\nO k2]. During iteration t, O-SGD de\ufb01nes a discrete distribution p(t) 2 Rn\n\n0, wherePi p(t)\n\ng(t)\nO = 1\n\n|M|Pi2M(t)\n\n1\n\ni rfi(w(t)) + w(t) .\nnp(t)\n\n(4)\n\ni )1 ensures g(t)\n\nScaling the rfi terms by (np(t)\nO-SGD de\ufb01nes p(t) as the sampling distribution that maximizes (3):\nProposition 3.1 (Oracle sampling distribution). In order to minimize E[kg(t)\neach example i with probability proportional to the ith \u201cgradient norm.\u201d That is,\n\nO remains an unbiased approximation of rF (w(t)).\n\nO k2], O-SGD samples\n\np(t)\n\nj=1 krfj(w(t))k .\n\ni = krfi(w(t))kPn\nnPn\ni=1 fi(w), we write this second moment as\ni krfi(w(t))k2 1\n\n1\np(t)\n\ni=1\n\n|M|kr \u00aff (w(t))k2 + krF (w(t))k2 .\n\n(5)\n\nProof sketch. De\ufb01ning \u00aff (w) = 1\n\nEhkg(t)\n\nO k2i = 1\n\nn2|M|Pn\n\nFinding the distribution p(t) that minimizes (5) is a problem with a closed-form solution. The solution\nis the distribution de\ufb01ned by Proposition 3.1, which we show in Appendix A.\n\nThe oracle sampling distribution is quite intuitive. Training examples with largest gradient norm\nare most important for further decreasing F , and these examples receive priority. Examples that the\nmodel handles correctly have smaller gradient norm, and O-SGD deprioritizes these examples.\n\n3.2 Adapting the learning rate\nBecause importance sampling reduces the stochastic gradient\u2019s variance\u2014possibly by a large amount\u2014\nwe \ufb01nd it important to adaptively increase O-SGD\u2019s learning rate compared to U-SGD. For O-SGD,\nwe propose a learning rate that depends on the \u201cgain ratio\u201d r(t)\n\nr(t)\n\nO = Ehkg(t)\n\nU k2i. Ehkg(t)\n\nO 2 R1:\nO k2i .\n\n(6)\n\nAbove, g(t)\nso that according to (3), one O-SGD iteration results in as much progress as r(t)\nDe\ufb01ning the edge case r(0)\n\nU is the stochastic gradient de\ufb01ned by uniform sampling. O-SGD adapts the learning rate\nO U-SGD iterations.\n\nO = 1, this learning rate depends on the \u201ceffective iteration number\u201d\n\nSince the gain ratio exceeds 1, we have \u02c6t(t)\n\nt0=1 r(t01)\n\nO =Pt\n\u02c6t(t)\nO t for all t. O-SGD de\ufb01nes the learning rate as\n\nO\n\n.\n\n\u2318(t)\nO = r(t)\n\nO lr_sched(\u02c6t(t)\n\nO ) .\n\nWe justify this choice of learning rate schedule with the following proposition:\n\nProposition 3.2 (Equivalence of gain ratio and expected speed-up). Given w(t), de\ufb01ne E(t)\nexpected progress from iteration t of U-SGD with learning rate \u2318(t)\nde\ufb01ne E(t)\nThen E(t)\n\nU as the\nU = lr_sched (t). For comparison,\nO = r(t)\nO \u2318(t)\nU .\nU . Relative to U-SGD, O-SGD multiplies the expected progress by r(t)\nO .\n\nO as the expected progress from iteration t of O-SGD with learning rate \u2318(t)\nO = r(t)\n\nO E(t)\n\nProof. Using (3), we have\n\nE(t)\n\nU = 2\u2318(t)\n\nU hrF (w(t)), w(t) w?i [\u2318(t)\n\nU ]2Ehkg(t)\n\nU k2i .\n\n3\n\n\fFor O-SGD, we expect progress\n\nE(t)\n\nO = 2\u2318(t)\n= 2r(t)\n\nO hrF (w(t)), w(t) w?i [\u2318(t)\nO \u2318(t)\n\nU hrF (w(t)), w(t) w?i r(t)\n\nO ]2Ehkg(t)\nO k2i\nU ]2Ehkg(t)\n\nO [\u2318(t)\n\nU k2i = r(t)\n\nO E(t)\nU .\n\nWe remark that the purpose of this learning rate adjustment is not necessarily to speed up training\u2014\nwhether the adjustment results in speed-up depends greatly on the original learning rate schedule.\nInstead, the purpose of this rescaling is to make O-SGD (and hence RAIS-SGD) suitable as a drop-in\nreplacement for U-SGD. We show empirically that this is the case in \u00a76.\n\n4 Robust approximate importance sampling (RAIS)\n\nO in (4).\n\nDetermining p(t) and r(t)\n\nO in O-SGD depends on knowledge of many gradient norms (rfi(w(t))\nfor all examples, r \u00aff (w(t)), and rF (w(t))). Computing these norms requires a time-\n\nconsuming pass over the data. To make importance sampling practical, we propose RAIS-SGD.\n\nR , which takes the same form as g(t)\n\n4.1 Determining a robust sampling distribution\nLike O-SGD, RAIS selects the tth minibatch by sampling indices from a discrete distribution p(t).\nWe denote the stochastic gradient by g(t)\nLet v\u21e4i = krfi(w(t))k and v\u21e4 = [v\u21e41, v\u21e42, . . . , v\u21e4n]T . RAIS de\ufb01nes p(t) by approximating v\u21e4. Na\u00efve\nalgorithms approximate v\u21e4 using a point estimate \u02c6v. The sampling distribution becomes a multiple\nof \u02c6v. [3], [4], and [6] propose algorithms based on similar point estimation strategies.\nThe drawback of the point estimation approach is extreme sensitivity to differences between \u02c6v and v\u21e4.\nFor this reason, [3, 4, 6] incorporate additive smoothing. They introduce a hyperparameter, which we\ndenote by , and sample example i with probability proportional to \u02c6vi +. This approach to robustness\nis unconvincing, however, since performance becomes critically dependent on a hyperparameter. Too\nsmall a risks divergence, while too large a value greatly limits the bene\ufb01t of importance sampling.\nInstead of a point estimate, RAIS approximates v\u21e4 with an uncertainty set U (t) \u21e2 Rn\n0, which we\nexpect contains (or nearly contains) v\u21e4. Given U (t), RAIS de\ufb01nes p(t) by minimizing the worst-case\nR k2] /Pi\nvalue of E[kg(t)\n(v\u21e4i )2 +c\nfor some c 2 R (according to (5)), RAIS de\ufb01nes p(t) as the solution to the following problem:\ni=1 pi = 1o .\n\nR k2] over all gradient norm possibilities in U (t). Noting E[kg(t)\n\np(t) = arginfnmaxPn\n\nSuch robust optimization problems are common for making decisions with data uncertainty [7].\nIt turns out (PRC) is straightforward to solve because the minimax theorem applies to (PRC) (we\nprove this in Appendix D.1, assuming our de\ufb01nition of U (t) in \u00a74.2). We \ufb01rst minimize over p by\nde\ufb01ning pi = vi(Pn\nj=1 vj)1. Plugging this into (PRC)\u2019s objective leads to the simpli\ufb01ed problem\n\n(PRC\u2019)\nDuring each iteration t, RAIS solves (PRC\u2019). After doing so, RAIS recovers the minimax optimal\nsampling distribution by de\ufb01ning p(t)\n\nv(t) = argmax (Pn\n\ni=1 vi)2 | v 2U (t) .\n\ni | v 2U (t) p 2 Rn\n\nfor all training examples.\n\n>0,Pn\n\n(PRC)\n\n1\np(t)\ni\n\nv2\n\n1\npi\n\ni=1\n\ni / v(t)\n\ni\n\n4.2 Modeling the uncertainty set\nTo de\ufb01ne U (t), RAIS uses features of SGD\u2019s state that are predictive of the true gradient norms. For\neach example i, we de\ufb01ne a feature vector s(t)\nis the gradient norm\nkrfi(w(t0))k, where t0 is the most recent iteration for which i 2M (t0). Since RAIS-SGD computes\nrfi(w(t0)) during iteration t0, constructing this feature during iteration t should add little overhead.\n\n0. A useful feature for s(t)\n\ni 2 RdR\n\ni\n\n4\n\n\fGiven s(t)\nfor all examples, RAIS de\ufb01nes the uncertainty set as an axis-aligned ellipsoid. Since\ni\nv\u21e4 0, RAIS also intersects this ellipsoid with the positive orthant. RAIS parameterizes this\nuncertainty set with two vectors, c 2 RdR\n1:n to\nparameters of the ellipsoid. Speci\ufb01cally, RAIS de\ufb01nes the uncertainty set as\n\n0. These vectors map features s(t)\n\n0 and d 2 RdR\n, vi) \uf8ff 1 , where Qcd(s, v) = (hc,siv)2\n\n0 1\ncd =v 2 Rn\nnPn\nU (t)\nHere we denote the uncertainty set by U (t)\nthis de\ufb01nition of U (t)\ncd , (PRC\u2019) has a simple closed-form solution (proven in Appendix B):\nProposition 4.1 (Solution to robust counterpart). For all i, the solution to (PRC\u2019) satis\ufb01es\n\ncd to emphasize the dependence of U (t) on c and d. With\n\ni=1 Qcd(s(t)\n\nhd,si\n\n.\n\ni\n\ni i + khd, s(t)\n\nv(t)\ni = hc, s(t)\ni i an estimate of v\u21e4i and hd, s(t)\n\ni i , where k =qnPn\n\nj=1hd, s(t)\nj i .\n\n1\n\ni=1hd, s(t)\n\ni i + khd, s(t)\n\nc, d = arginfnPn\n\nR k2] small but still ensure v\u21e4 likely lies in U (t)\n\nIf we consider hc, s(t)\ni i a measure of uncertainty in this estimate,\nthen Proposition 4.1 is quite interpretable. RAIS samples example i with probability proportional to\nhc, s(t)\ni i. The \ufb01rst term is the v\u21e4i estimate, and the second term adds robustness to error.\n4.3 Learning the uncertainty set\nThe uncertainty set parameters, c and d, greatly in\ufb02uence the performance of RAIS. If U (t)\ncd is a small\nregion near v\u21e4, then RAIS\u2019s sampling distribution is similar to O-SGD\u2019s sampling distribution. If\nU (t)\ncd is less representative of v\u21e4, the variance of the stochastic gradient could become much larger.\nIn order to make E[kg(t)\ncd , RAIS adaptively de\ufb01nes c\ncd subject to a constraint that encourages v\u21e4 2U (t)\nand d. To do so, RAIS minimizes the size of U (t)\ncd :\n(PT)\n\ni i c, d 2 RdR\n|D|P|D|\n0, 1\nHere we have de\ufb01ned U (t)\ncd \u2019s \u201csize\u201d as the sum of hd, s(t)\ni i values. The constraint that encourages\nv\u21e4 2U (t) assumes weighted training data, ( \u02dcwi, \u02dcsi, \u02dcvi)|D|\ni=1. RAIS must de\ufb01ne this training set so that\n|D|P|D|\nnPn\ni=1 \u02dcwiQcd(\u02dcsi, \u02dcvi) \u21e1 1\n\nThat is, for any c and d, the mean of Qcd(\u02dcs, \u02dcvi) over the weighted training set should approximately\nequal the mean of Qcd(s(t)\ni\nTo achieve this, RAIS uses gradients from recent minibatches. For entry j of the RAIS train-\ning set, RAIS considers an i and t0 for which i 2M (t0) and t0 < t. RAIS de\ufb01nes \u02dcsj = s(t0)\n,\n\u02dcvj = krfi(w(t0))k, and \u02dcwj = (np(t0)\n)1. The justi\ufb01cation for this choice is that the mean of\nQcd(s(t)\n,krfi(w(t))k) over training examples tends to change gradually with t. Thus, the weighted\ni\n,rfi(w(t))) values.\nmean over the RAIS training set approximates the mean of current Qcd(s(t)\ni\n\n4.4 Approximating the gain ratio\nIn addition to the sampling distribution, RAIS must approximate the gain ratio in O-SGD. De\ufb01ne\ng(t)\nR1 as a stochastic gradient of the form (4) using minibatch size 1 and RAIS sampling. De\ufb01ne g(t)\nU1\nin the same way but with uniform sampling. From (5), we can work out that the gain ratio satis\ufb01es\n\ni=1 \u02dcwiQcd(\u02dcsi, \u02dcvi) \uf8ff 1o .\n\n, v\u21e4i ), which depends on current (unknown) gradient norms.\n\n,krfi(w(t))k) .\n\ni=1 Qcd(s(t)\n\nEhkg(t)\n\nU k2i.Ehkg(t)\n\nR1k2]\u2318.E[kg(t)\nR k2] .\nTo approximate the gain ratio, RAIS estimates the three moments on the right side of this equation.\nRAIS estimates E[kg(t)\nR k2 from recent iterations:\n\nR k2] using an exponential moving average of kg(t)\n\nR k2i = 1 + 1\n\n|M|\u21e3E[kg(t)\n\nU1k2] E[kg(t)\n\n(7)\n\ni\n\ni\n\ni\n\nE[kg(t)\n\nR k2] \u21e1 \u21b5hkg(t)\n\nR k2 + (1 \u21b5)kg(t1)\n\nR\n\nk2 + (1 \u21b5)2kg(t2)\n\nR\n\nk2 + . . .i .\n\n5\n\n\fAlgorithm 4.1 RAIS-SGD\n\ncd \ni=1 vi)2 | v 2U (t)\n\ninput objective function F , minibatch size |M|, learning rate schedule lr_sched(\u00b7)\ninput RAIS training set size |D|, exponential smoothing parameter \u21b5 for gain estimate\ninitialize w(1) 2 Rd, c, d 2 RdR\n0; \u02c6t(1) 1; r_estimator GainEstimator(\u21b5)\nfor t = 1, 2, . . . , T do\nv(t) argmax (Pn\n|M|Pi2M(t)\n\np(t) v(t)/kv(t)k1\nM(t) sample_indices_from_distribution(p(t), size = |M|)\ng(t)\nR 1\nr_estimator.record_gradient_norms(kg(t)\n\u02c6r(t) r_estimator.estimate_gain_ratio()\n\u2318(t) \u02c6r(t) \u00b7 lr_sched(\u02c6t(t))\nw(t+1) w(t) \u2318(t)g(t)\n\u02c6t(t+1) \u02c6t(t) + \u02c6r(t)\nif mod(t, d|D|/|M|e) = 0 and t (n + |D|)/|M| then\n# see \u00a74.2\n\nR k, (krfi(w(t))k, p(t)\n\ni rfi(w(t)) + w(t)\nnp(t)\n\ni )i2M(t))\n\nc, d train_uncertainty_model()\n\nreturn w(T +1)\n\n# see Proposition 4.1 for closed-form solution\n\n1\n\nR\n\n# see \u00a74.4\n\nR1k2] and E[kg(t)\n\nRAIS approximates E[kg(t)\nR1k2] and E[kg(t)\nU1k2] in a similar way. After computing gradients for\nminibatch t, RAIS estimates E[kg(t)\nU1k2] using appropriately weighted averages of\nkrfi(w(t))k2 for each i 2M (t) (for E[kg(t)\nR1k2], RAIS weights terms by (np(t)\nU1k2],\nRAIS weights terms by (np(t)\ni )1). Using the same exponential averaging parameter \u21b5, RAIS\naverages these estimates from minibatch t with estimates from prior iterations.\nRAIS approximates the gain ratio by plugging these moment estimates into (7). We denote the result\n\nby \u02c6r(t). Analogous to O-SGD, RAIS uses learning rate \u2318(t) = \u02c6r(t)lr_sched\u02c6t(t), where \u02c6t(t) is the\neffective iteration number: \u02c6t(t) =Pt\n\nt0=1 \u02c6r(t01). Here we also de\ufb01ne the edge case \u02c6r(0) = 1.\n\ni )2; for E[kg(t)\n\n4.5 Practical considerations\n\nAlgorithm 4.1 summarizes our RAIS-SGD algorithm. We next discuss important practical details.\n\nSolving (PT) While computing p(t) requires a small number of length n operations (see Proposi-\ntion 4.1), learning the uncertainty set parameters requires more computation. For this reason, RAIS\nshould not solve (PT) during every iteration. Our implementation solves (PT) asynchronously after\nevery d|D|/|M|e minibatches, with updates to w(t) continuing during the process. We describe\nour algorithm for solving (PT) in Appendix D.2. Since our features s(t)\n1:n depend on past minibatch\nupdates, we do not use RAIS for the \ufb01rst epoch of training\u2014instead we sample examples sequentially.\n\nCompatibility with common tricks RAIS combines nicely with standard training tricks for deep\nlearning. With no change, we \ufb01nd RAIS works well with momentum [8, 9]. Incorporating data\naugmentation, dropout [10], or batch normalization [11] adds variance to the model\u2019s outputs and\ngradient norms. RAIS elegantly compensates for such inconsistency by learning a larger uncertainty\nset. Since the importance sampling distribution changes over time, we \ufb01nd it important to compute\nweighted batch statistics when using RAIS with batch normalization. That is, when computing\nnormalization statistics during training, we weight contributions from each example by (np(t)\ni )1.\n\nProtecting against outliers\nIn some cases\u2014typically when the gain ratio is very large\u2014we \ufb01nd\nQcd(s(t)\n, v\u21e4i ) can be quite small for most examples yet large for a small set of outliers. Typically\ni\nwe \ufb01nd RAIS does not require special treatment of such outliers. Even so, it is reasonable to protect\nagainst outliers, so that an example with extremely large Qcd(s(t)\n, v\u21e4i ) cannot greatly increase the\ni\nstochastic gradient\u2019s variance. To achieve this, we use gradient clipping, and RAIS provides a natural\n\n6\n\n\fFigure 2: Supplemental plots. Left: Visualization of top-layer gradient norm approximation. The\nmodel is an 18 layer ResNet after 30 epochs of training on CIFAR-10. Middle: Oracle importance\nsampling results for MNIST and LeNet model. Right: RAIS time overhead for rot-MNIST.\n\nway of doing so. We de\ufb01ne an \u201coutlier\u201d as any example for which Qcd(s(t)\n, v\u21e4i ) exceeds a threshold\ni\n\u2327. For each outlier i, we temporarily scale fi during iteration t until Qcd(s(t)\ni\nIn practice, we use \u2327 = 100; the fraction of outliers is often zero and rarely exceeds 0.1%.\n\n,rfi(w(t))) = \u2327.\n\nApproximating per-example gradient norms To train the uncertainty set, RAIS computes\nkrfi(w(t))k for each example in each minibatch. Unfortunately, existing software tools do not pro-\nvide ef\ufb01cient access to per-example gradient norms. Instead, libraries are optimized for aggregating\ngradients over minibatches. Thus, to make RAIS practical, we must approximate the gradient norms.\nWe do so by replacing krfi(w(t))k with the norm of only the loss layer\u2019s gradient (with respect to\nthis layer\u2019s inputs). These values correlate strongly, since the loss layer begins the backpropagation\nchain for computing rfi(w(t)). We show this empirically in Figure 2(left), and we include additional\nplots in Appendix E.1. We note this approximation may not work well for all models.\n\n5 Relation to prior work\n\nPrior strategies also consider importance sampling for speeding up deep learning. [3] proposes\ndistributing the computation of sampling probabilities. In parallel with regular training, [4] trains a\nminiature neural network to predict importance values. [5] approximates importance values using\nadditional forward passes. [12] and [13] apply importance sampling to deep reinforcement learning.\nWith the exception of [5] (which requires considerable time to compute importance values), these prior\nalgorithms are sensitive to errors in importance value estimates. For this reason, all require critical\nsmoothing hyperparameters to converge. In contrast, RAIS elegantly compensates for approximation\nerror by choosing a sampling distribution that is minimax optimal with respect to an uncertainty set.\nSince RAIS adaptively trains this uncertainty set, RAIS does not require hyperparameter tuning.\nResearchers have also considered other ways to prioritize training examples for deep learning. [14]\nconsiders examples in order of increasing dif\ufb01culty. Other researchers prioritize challenging training\nexamples [15, 16]. And yet others prioritize examples closest to the model\u2019s decision boundary [17].\nUnlike RAIS, the primary goal of these approaches is improved model performance, not optimization\nef\ufb01ciency. Importance sampling may work well in conjunction with these strategies.\nThere also exist ideas for sampling minibatches non-uniformly outside the context of deep learning.\n[18, 19] consider sampling diverse minibatches via repulsive point processes. Another strategy\nuses side information, such as class labels, for approximate importance sampling [6]. By choosing\nappropriate features for the uncertainty set, RAIS can use side information in the same way.\nIn the convex setting, there are several importance sampling strategies for SGD with theoretical guar-\nantees. This includes [1] and [2], which sample training examples proportional to Lipschitz constants.\nLeverage score sampling uses a closely related concept for matrix approximation algorithms [20, 21].\nFor more general convex problems, some adaptive sampling strategies include [22] and [23].\n\n6 Empirical comparisons\n\nIn this section, we demonstrate how RAIS performs in practice. We consider the very popular task of\ntraining a convolutional neural network to classify images.\n\n7\n\n103100Fullgradientnorms105102Toplayergrad.norms36Epochs0.020.040.06F(w(t))36Epochs0.81.01.2Validationerror(%)0306090120150Epochs0200400600Elapsedtime(s)RAIS-SGDSGDOracleISSGDRAIS-SGDSGD\fSVHN\n\nrot-MNIST\n\nCIFAR-10\n\nCIFAR-100\n\nFigure 3: Learning curve comparison. RAIS consistently outperforms SGD with uniform sam-\npling, both in terms of objective value and generalization performance. Curves show the mean of \ufb01ve\ntrials with varying random seeds. Filled areas signify \u00b11.96 times standard error of the mean.\n\nWe \ufb01rst train a LeNet-5 model [24] on the MNIST digits dataset. The model\u2019s small size makes\nit possible to compare with O-SGD. We use learning rate \u2318(t) = 3.4/p100 + t, L2 penalty =\n2.5 \u21e5 104, and batch size 32\u2014the parameters are chosen so that SGD performs well. We do not use\nmomentum or data augmentation. Figure 2(middle) includes the results of this experiment. Oracle\nsampling signi\ufb01cantly outperforms RAIS, and RAIS signi\ufb01cantly outperforms uniform sampling.\nFor our remaining comparisons, we consider street view house numbers [25], rotated MNIST [26],\nand CIFAR tiny image [27] datasets. For rot-MNIST, we train a 7 layer CNN with 20 channels per\nlayer\u2014a strong baseline from [28]. Otherwise, we train an 18 layer ResNet preactivation model\n[29]. CIFAR-100 contains 100 classes, while the other problems contain 10. The number of training\nexamples is 6.0 \u21e5 105 for SVHN, 1.2 \u21e5 104 for rot-MNIST, and 5.0 \u21e5 104 for the CIFAR problems.\nWe follow standard training procedures to attain good generalization performance. We use batch\nnormalization and standard momentum of 0.9. For rot-MNIST, we follow [28], augmenting data with\nrandom rotations and training with dropout. For the CIFAR problems, we augment the training set\nwith random horizontal re\ufb02ections and random crops (pad to 40x40 pixels; crop to 32x32).\nWe train the SVHN model with batch size 64 and the remaining models with |M| = 128. For\neach problem, we approximately optimize and the learning rate schedule in order to achieve\ngood validation performance with SGD at the end of training. The learning rate schedule decreases\nby a \ufb01xed fraction after each epoch (n/|M| iterations). This fraction is 0.8 for SVHN, 0.972 for\nrot-MNIST, 0.96 for CIFAR-10, and 0.96 for CIFAR-100. The initial learning rates are 0.15, 0.09,\n0.08, and 0.1, respectively. We use = 3 \u21e5 103 for rot-MNIST and = 5 \u21e5 104 otherwise.\nFor RAIS-SGD, we use |D| = 2 \u21e5 104 training examples to learn c and d and \u21b5 = 0.01 to estimate\n\u02c6r(t). The performance of RAIS varies little with these parameters, since they only determine the\nnumber of minibatches to consider when training the uncertainty set and estimating the gain ratio. For\nthe uncertainty set features, we use simple moving averages of the most recently computed gradient\nnorms for each example. We use moving averages of different lengths\u20141, 2, 4, 8, and 16. For lengths\nof at least four, we also include the variance and standard deviation of these prior gradient norm\nvalues. We also incorporate a bias feature as well as the magnitude of the random crop offset.\nWe compare training curves of RAIS-SGD and SGD in Figure 3. Notice that RAIS-SGD consistently\noutperforms SGD. The relative speed-up ranges from approximately 20% for the CIFAR-100 problem\nto more than 2x for the SVHN problem. Due to varying machine loads, we plot results in terms of\nepochs (not wall time), but RAIS introduces very little time overhead. For example, Figure 2(right)\nincludes time overhead results for the rot-MNIST comparison, which we ran on an isolated machine.\nFigure 4 provides additional details of these results. In the \ufb01gure\u2019s \ufb01rst row, we see the speed-up in\nterms of the gain ratio (the blue curve averages the value (\u02c6r(t) 1) \u00b7 100% over consecutive epochs).\n\n8\n\n081624Epochs1.001.051.101.151.201.25F(w(t))0306090120150Epochs0.450.500.550.600.65F(w(t))020406080100Epochs1.11.21.31.41.5F(w(t))020406080100Epochs1.21.41.61.8F(w(t))081624Epochs2345Validationerror(%)0306090120150Epochs5.05.56.06.57.07.5Validationerror(%)020406080100Epochs510152025Validationerror(%)020406080100Epochs242628303234Validationerror(%)RAIS-SGDSGD\fSVHN\n\nrot-MNIST\n\nCIFAR-10\n\nCIFAR-100\n\nSVHN\n\nrot-MNIST\n\nCIFAR-10\n\nCIFAR-100\n\nFigure 4: RAIS speed-up and alignment of epochs equivalent. Above: Blue shows increase in\noptimization speed due to RAIS, as measured by estimated gain ratio; purple indicates time overhead\ndue to RAIS. Overhead is small compared to speed-up. Below: Objective value vs. epochs equivalent.\nFor RAIS, epochs equivalent equals |M|n\n\u02c6t(t). The closely aligned curves suggest (i) RAIS-SGD is a\nsuitable drop-in replacement for SGD, and (ii) the gain ratio correctly approximates speed-up.\n\nThe gain ratio tends to increase as training progresses, implying RAIS is most useful during later\nstages of training. We also plot the relative wall time overhead for RAIS, which again is very small.\nIn the second row of Figure 4, we compare RAIS-SGD and SGD in terms of epochs equivalent\u2014the\nnumber of epochs measured in terms of effective iterations. Interestingly, the curves align closely.\nThis alignment con\ufb01rms that our learning rate adjustment is reasonable, as it results in a suitable\ndrop-in replacement for SGD. This result contrasts starkly with [3], for example, in which case\ngeneralization performance differs signi\ufb01cantly for the importance sampling and standard algorithms.\nTable 1 concludes these comparisons with a summary of results:\n\nEpochs equivalent\n\nTable 1: Quantities upon training completion.\nAlgorithm\nRAIS-SGD\nSGD\nRAIS-SGD\nSGD\nRAIS-SGD\nSGD\n\nF (w(t)) Val. error Val. loss\n0.121\n0.121\n0.149\n0.161\n0.256\n0.277\n0.962\n0.989\n\n114\n24.0\n214\n150.\n130.\n100.\n138\n100.\n\n1.01\n1.02\n0.431\n0.460\n1.08\n1.10\n1.21\n1.25\n\n0.0201\n0.0226\n0.0476\n0.0512\n0.0590\n0.0607\n0.236\n0.236\n\nDataset\nSVHN\n\nrot-MNIST\n\nCIFAR-10\n\nCIFAR-100 RAIS-SGD\n\nSGD\n\n7 Discussion\n\nWe proposed a relatively simple and very practical importance sampling procedure for speeding up\nthe training of deep models. By using robust optimization to de\ufb01ne the sampling distribution, RAIS\ndepends minimally on user-speci\ufb01ed parameters. Additionally, RAIS introduces little computational\noverhead and combines nicely with standard training strategies. All together, RAIS is a promising\napproach with minimal downside and potential for large improvements in training speed.\n\nAcknowledgements\n\nWe thank Marco Tulio Ribeiro, Tianqi Chen, Maryam Fazel, Sham Kakade, and Ali Shojaie for\nhelpful discussion and feedback. This work was supported by PECASE N00014-13-1-0023.\n\n9\n\n081624Epochs0100200300400500600Relativeimpact(%)0306090120150Epochs010203040506070Relativeimpact(%)020406080100Epochs01020304050Relativeimpact(%)020406080100Epochs020406080Relativeimpact(%)RAISgainRAIStimeoverhead081624Epochsequivalent1.001.051.101.151.201.25F(w(t))0306090120150Epochsequivalent0.450.500.550.600.65F(w(t))020406080100Epochsequivalent1.11.21.31.41.5F(w(t))020406080100Epochsequivalent1.21.41.61.8F(w(t))RAIS-SGDSGD\fReferences\n[1] P. Zhao and T. Zhang. Stochastic optimization with importance sampling for regularized loss\nminimization. In Proceedings of the 32nd International Conference on Machine Learning,\n2015.\n\n[2] D. Needell, R. Ward, and N. Srebro. Stochastic gradient descent, weighted sampling, and the\nrandomized Kaczmarz algorithm. In Advances in Neural Information Processing Systems 27,\n2014.\n\n[3] G. Alain, A. Lamb, C. Sankar, A. Courville, and Y. Bengio. Variance reduction in SGD by\ndistributed importance sampling. In 4th International Conference on Learning Representations\nWorkshop, 2016.\n\n[4] A. Katharopoulos and F. Fleuret. Biased importance sampling for deep neural network training.\n\narXiv:1706.00043, 2017.\n\n[5] A. Katharopoulos and F. Fleuret. Not all samples are created equal: Deep learning with\nimportance sampling. In Proceedings of the 35th International Conference on Machine Learning,\n2018.\n\n[6] S. Gopal. Adaptive sampling for SGD by exploiting side information. In Proceedings of the\n\n33rd International Conference on Machine Learning, 2016.\n\n[7] A. Ben-Tal, L. El Ghaoui, and A. Nemirovski. Robust Optimization. Princeton University Press,\n\n2009.\n\n[8] B. T. Polyak. Some methods of speeding up the convergence of iteration methods. USSR\n\nComputational Mathematics and Mathematical Physics, 4(5):1\u201317, 1964.\n\n[9] I. Sutskever, J. Martens, G. Dahl, and G. Hinton. On the importance of initialization and\nmomentum in deep learning. In Proceedings of the 30th International Conference on Machine\nLearning, 2013.\n\n[10] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov. Dropout: A\nsimple way to prevent neural networks from over\ufb01tting. Journal of Machine Learning Research,\n15:1929\u20131958, 2014.\n\n[11] S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing\n\ninternal covariate shift. In 32nd International Conference on Machine Learning, 2015.\n\n[12] T. Schaul, J. Quan, I. Antonoglou, and D. Silver. Prioritized experience replay. In 6th Interna-\n\ntional Conference on Learning Representations, 2016.\n\n[13] D. Horgan, J. Quan, D. Budden, G. Barth-Maron, M. Hessel, H. van Hasselt, and D. Sil-\nver. Distributed prioritized experience replay. In 6th International Conference on Learning\nRepresentations, 2018.\n\n[14] Y. Bengio, J. Louradour, R. Collobert, and J. Weston. Curriculum learning. In Proceedings of\n\nthe 26th International Conference on Machine Learning, 2009.\n\n[15] A. Shrivastava, A. Gupta, and R. Girshick. Training region-based object detectors with online\nhard example mining. In Proceedings of the IEEE Conference on Computer Vision and Pattern\nRecognition, 2016.\n\n[16] S. Shalev-Shwartz and Y. Wexler. Minimizing the maximal loss: How and why. In Proceedings\n\nof the 33rd International Conference on Machine Learning, 2016.\n\n[17] A. McCallum H.-S. Chang, E. Learned-Miller. Active bias: Training more accurate neural\nnetworks by emphasizing high variance samples. In Advances in Neural Information Processing\nSystems 30, 2017.\n\n[18] C. Zhang, H. Kjellstr\u00f6m, and S. Mandt. Determinantal point processes for mini-batch diversi\ufb01-\n\ncation. Conference in Uncertainty in Arti\ufb01cial Intelligence, 2017.\n\n[19] C. Zhang, C. \u00d6ztireli, S. Mandt, and G. Salvi. Active mini-batch sampling using repulsive point\n\nprocesses. arXiv:1804.02772, 2018.\n\n[20] M. Mahoney. Randomized algorithms for matrices and data. Foundations and Trends in\n\nMachine learning, 3(2), 2011.\n\n[21] P. Ma, B. Yu, , and M. Mahoney. A statistical perspective on algorithmic leveraging.\n\nProceedings of the 31st International Conference on Machine Learning, 2014.\n\nIn\n\n10\n\n\f[22] S. U. Stich, A. Raj, and M. Jaggi. Safe adaptive importance sampling. In Advances in Neural\n\nInformation Processing Systems 30, 2017.\n\n[23] Z. Borsos, A. Krause, and K. Y. Levy. Online variance reduction for stochastic optimization.\n\narXiv:1802.04715, 2018.\n\n[24] Y. Lecun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document\n\nrecognition. In Proceedings of the IEEE, 1998.\n\n[25] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng. Reading digits in nat-\nural images with unsupervised feature learning. In NIPS Workshop on Deep Learning and\nUnsupervised Feature Learning, 2011.\n\n[26] H. Larochelle, D. Erhan, A. Courville, J. Bergstra, and Y. Bengio. An empirical evaluation\nof deep architectures on problems with many factors of variation. In Proceedings of the 24th\nInternational Conference on Machine Learning, 2007.\n\n[27] A. Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009.\n[28] T. S. Cohen and M. Welling. Group equivariant convolutional networks. In Proceedings of the\n\n33rd International Conference on Machine Learning, 2016.\n\n[29] K. He, X. Zhang, S. Ren, and J. Sun. Identity mapping in deep residual networks. In European\n\nConference on Computer Vision, 2016.\n\n11\n\n\f", "award": [], "sourceid": 3607, "authors": [{"given_name": "Tyler", "family_name": "Johnson", "institution": "University of Washington"}, {"given_name": "Carlos", "family_name": "Guestrin", "institution": "University of Washington"}]}