{"title": "Using Statistics to Automate Stochastic Optimization", "book": "Advances in Neural Information Processing Systems", "page_first": 9540, "page_last": 9550, "abstract": "Despite the development of numerous adaptive optimizers, tuning the learning rate of stochastic gradient methods remains a major roadblock to obtaining good practical performance in machine learning. Rather than changing the learning rate at each iteration, we propose an approach that automates the most common hand-tuning heuristic: use a constant learning rate until \"progress stops,\" then drop. We design an explicit statistical test that determines when the dynamics of stochastic gradient descent reach a stationary distribution. This test can be performed easily during training, and when it fires, we decrease the learning rate by a constant multiplicative factor. Our experiments on several deep learning tasks demonstrate that this statistical adaptive stochastic approximation (SASA) method can automatically find good learning rate schedules and match the performance of hand-tuned methods using default settings of its parameters. The statistical testing helps to control the variance of this procedure and improves its robustness.", "full_text": "Using Statistics to Automate Stochastic Optimization\n\nHunter Lang\n\nPengchuan Zhang\nMicrosoft Research AI\n\nRedmond, WA 98052, USA\n\nLin Xiao\n\n{hunter.lang, penzhan, lin.xiao}@microsoft.com\n\nAbstract\n\nDespite the development of numerous adaptive optimizers, tuning the learning\nrate of stochastic gradient methods remains a major roadblock to obtaining good\npractical performance in machine learning. Rather than changing the learning\nrate at each iteration, we propose an approach that automates the most common\nhand-tuning heuristic: use a constant learning rate until \u201cprogress stops\u201d, then\ndrop. We design an explicit statistical test that determines when the dynamics\nof stochastic gradient descent reach a stationary distribution. This test can be\nperformed easily during training, and when it \ufb01res, we decrease the learning rate\nby a constant multiplicative factor. Our experiments on several deep learning tasks\ndemonstrate that this statistical adaptive stochastic approximation (SASA) method\ncan automatically \ufb01nd good learning rate schedules and match the performance of\nhand-tuned methods using default settings of its parameters. The statistical testing\nhelps to control the variance of this procedure and improves its robustness.\n\nIntroduction\n\n1\nStochastic approximation methods, including stochastic gradient descent (SGD) and its many variants,\nserve as the workhorses of machine learning with big data. Many tasks in machine learning can be\nformulated as the stochastic optimization problem:\n\nminimizex2Rn F(x) , E\u21e0\u21e5 f(x,\u21e0 )\u21e4,\n\nwhere \u21e0 is a random variable representing data sampled from some (unknown) probability distribution,\nx 2 Rn represents the parameters of the model (e.g., the weight matrices in a neural network), and f\nis a loss function. In this paper, we focus on the following variant of SGD with momentum,\n\ndk+1 = (1 k)gk + kdk,\nxk+1 = xk \u21b5kdk+1,\n\n(1)\nwhere gk = rx f(xk,\u21e0 k) is a stochastic gradient, \u21b5k > 0 is the learning rate, and k 2[ 0, 1) is the\nmomentum coecient. This approach can be viewed as an extension of the heavy-ball method\n(Polyak, 1964) to the stochastic setting.1 To distinguish it from the classical SGD, we refer to the\nmethod (1) as SGM (Stochastic Gradient with Momentum).\nTheoretical conditions on the convergence of stochastic approximation methods are well established\n(see, e.g., Wasan, 1969; Kushner and Yin, 2003, and references therein). Unfortunately, these\nasymptotic conditions are insucient in practice. For example, the classical rule \u21b5k = a/(k + b)c\nwhere a, b > 0 and 1/2 < c \uf8ff 1, often gives poor performance even when a, b, and c are hand-tuned.\nAdditionally, despite the advent of numerous adaptive variants of SGD and SGM (e.g., Duchi et al.,\n2011; Tieleman and Hinton, 2012; Kingma and Ba, 2014, and other variants), achieving good\nperformance in practice often still requires considerable hand-tuning (Wilson et al., 2017).\n\n1For \ufb01xed values of \u21b5 and , this \u201cnormalized\u201d update formula is equivalent to the more common updates\n\ndk+1 = gk + dk, xk+1 = xk \u21b50dk+1 with the reparametrization \u21b5 = \u21b50/(1 ).\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\f(a) training loss\n\n(b) test accuracy\n\n(c) learning rate\n\nFigure 1: Smoothed training loss, test accuracy, and (global) learning rate schedule for an 18-layer\nResNet model (He et al., 2016) trained on the CIFAR-10 dataset using four dierent methods\n(with constant momentum = 0.9). Adam: \u21b5 = 0.0001; SGM-const: \u21b5 = 1.0; SGM-poly:\na = 1.0, b = 1, c = 0.5; SGM-hand: \u21b5 = 1.0, drop by 10 every 50 epochs.\n\n}\n\nend\nif test(statistics) then\n\u21b5 \u21e3 \u21b5\n// reset statistics\n\nend\n\nfor k 2{ jM, . . . ,(j + 1)M 1} do\n\nAlgorithm 1: General SASA method\nInput: {x0,\u21b5 0, M,,\u21e3\n1 for j 2{ 0, 1, . . .} do\n2\nSample \u21e0k.\n3\nCompute gk = rx f(xk,\u21e0 k).\n4\ndk+1 = (1 )gk + dk\n5\nxk+1 = xk \u21b5dk+1\n6\n// collect statistics\n7\n8\n9\n10\n11\n12\n13 end\n\nFigure 1 shows the training loss and test accu-\nracy of a typical deep learning task using four\ndierent methods: SGM with constant step size\n(SGM-const), SGM with diminishing O(1/k)\nstep size (SGM-poly), Adam (Kingma and Ba,\n2014), and hand-tuned SGM with learning rate\nscheduling (SGM-hand). For the last method, we\ndecrease the step size by a multiplicative factor\nafter a suitably long number of epochs (\u201cconstant-\nand-cut\u201d). The relative performance depicted in\nFigure 1 is typical of many tasks in deep learn-\ning. In particular, SGM with a large momentum\nand a constant-and-cut step-size schedule often\nachieves the best performance. Many former\nand current state-of-the-art results use constant-\nand-cut schedules during training, such as those\nin image classi\ufb01cation (Huang et al., 2018), ob-\nject detection (Szegedy et al., 2015), machine\ntranslation (Gehring et al., 2017), and speech\nrecognition (Amodei et al., 2016). Additionally, some recent theoretical evidence indicates that\nin some (strongly convex) scenarios, the constant-and-cut scheme has better \ufb01nite-time last-iterate\nconvergence performance than other methods (Ge et al., 2019).\nInspired by the success of the \u201cconstant-and-cut\u201d scheduling approach, we develop an algorithm that\ncan automatically decide when to drop \u21b5. Most common heuristics try to identify when \u201ctraining\nprogress has stalled.\u201d We formalize stalled progress as when the SGM dynamics in (1), with constant\nvalues of \u21b5 and , reach a stationary distribution. The existence of such a distribution seems to\nmatch well with many empirical results (e.g., Figure 1), though it may not exist in general. Since\nSGM generates a rich set of information as it runs (i.e. {x0, g0, . . . , xk, gk}), a natural approach is\nto collect some statistics from this information and perform certain tests on them to decide whether\nthe process (1) has reached a stationary distribution. We call this general method SASA: statistical\nadaptive stochastic approximation.\nAlgorithm 1 summarizes the general SASA method. It performs the SGM updates (1) in phases of\nM iterations, in each iteration potentially computing some additional statistics. After M iterations\nare complete, the algorithm performs a statistical test to decide whether to drop the learning rate by\na factor \u21e3< 1. Dropping \u21b5 after a \ufb01xed number of epochs and dropping \u21b5 based on the loss of a\nheld-out validation set correspond to heuristic versions of Algorithm 1. In the rest of this work, we\ndetail how to perform the \u201ctest\u201d procedure and evaluate SASA on a wide range of deep learning tasks.\n\n1.1 Related Work and Contributions\nThe idea of using statistical testing to augment stochastic optimization methods goes back at least to\nP\ufb02ug (1983), who derived a stationary condition for the dynamics of SGD on quadratic functions\n\n2\n\n\fand designed a heuristic test to determine when the dynamics had reached stationary. He used this\ntest to schedule a \ufb01xed-factor learning rate drop. Chee and Toulis (2018) recently re-investigated\nP\ufb02ug\u2019s method for general convex functions. P\ufb02ug\u2019s stationary condition relies heavily on a quadratic\napproximation to F and limiting noise assumptions, as do several other recent works that derive a\nstationary condition (e.g., Mandt et al., 2017; Chee and Toulis, 2018). Additionally, P\ufb02ug\u2019s test\nassumes no correlation exists between consecutive samples in the optimization trajectory. Neither is\ntrue in practice, which we show in Appendix C.2 can lead to poor predictivity of this condition.\nYaida (2018) derived a very general stationary condition that does not depend on any assumption\nabout the underlying function F and applies under general noise conditions regardless of the size of \u21b5.\nLike P\ufb02ug (1983), Yaida (2018) used this condition to determine when to decrease \u21b5, and showed\ngood performance compared to hand-tuned SGM for two deep learning tasks with small models.\nHowever, Yaida\u2019s method does not account for the variance of the terms involved in the test, which we\nshow can cause large variance in the learning rate schedules in some cases. This variance can in turn\ncause poor empirical performance.\nIn this work, we show how to more rigorously perform statistical hypothesis testing on samples\ncollected from the dynamics of SGM. We combine this statistical testing with Yaida\u2019s stationary\ncondition to develop an adaptive \u201cconstant-and-cut\u201d optimizer (SASA) that we show is more robust\nthan present methods. Finally, we conduct large-scale experiments on a variety of deep learning tasks\nto demonstrate that SASA is competitive with the best hand-tuned and validation-tuned methods\nwithout requiring additional tuning.\n\n2 Stationary Conditions\nTo design a statistical test that \ufb01res when SGM reaches a stationary distribution, we \ufb01rst need to derive\na condition that holds at stationarity and consists of terms that we can estimate during training. To do\nso, we analyze the long-run behavior of SGM with constant learning rate and momentum parameter:\n\ndk+1 = (1 )gk + dk,\nxk+1 = xk \u21b5dk+1,\n\nwhere \u21b5> 0 and 0 \uf8ff < 1. This process starts with d0 = 0 and arbitrary x0. Since \u21b5 and are\nconstant, the sequence {xk} does not converge to a local minimum, but the distribution of {xk}\nmay converge to a stationary distribution. Letting Fk denote the -algebra de\ufb01ned by the history of\nthe process (2) up to time k, i.e., Fk = (d0, . . . , dk; x0, . . . , xk), we denote by Ek[\u00b7] := E[\u00b7|Fk] the\nexpectation conditioned on that history. Assuming that gk is Markovian and unbiased, i.e.,\n(3)\nthen the SGM dynamics (2) form a homogeneous2 Markov chain (Bach and Moulines, 2013;\nDieuleveut et al., 2017) with continuous state (dk, xk, gk)2 R3n. These assumptions are always\nsatis\ufb01ed when gk = rx f(xk,\u21e0 k) for an i.i.d. sample \u21e0k. We further assume that the SGM process\nconverges to a stationary distribution, denoted as \u21e1(d, x, g)3. With this notation, we need a relationship\nE\u21e1[X] = E\u21e1[Y] for certain functions X and Y of (xk, dk, gk) that we can compute during training.\nThen, if we assume the Markov chain is ergodic, we have that:\n\nP[gk|Fk] = P[gk|dk, xk], E[gk|dk, xk] = rF(xk),\n\n(2)\n\n(4)\n\n\u00afzN =\n\n1\nN\n\nzi =\n\n1\nN\n\nN1Xi=0\n\nN1Xi=0 \u21e3X(xi, di, gi) Y(xi, di, gi)\u2318 ! 0.\n\nThen we can check the magnitude of the time-average \u00afzN to see how close the dynamics are to\nreaching their stationary distribution. Next, we consider two dierent stationary conditions.\n\n2.1 P\ufb02ug\u2019s condition\nAssuming F(x) = (1/2)xT Ax, where A is positive de\ufb01nite with maximum eigenvalue L, and that the\nstochastic gradient gk satis\ufb01es gk = rF(xk) + r k, with E[r k] = 0 and r k independent of xk, P\ufb02ug\n\n2\u201cHomogeneous\u201d means that the transition kernel is time independent.\n3As stated in Section 1, this need not be true in general, but seems to often be the case in practice.\n\n3\n\n\f(1983) derived a stationary condition for the SGD dynamics. His condition can be extended to the\nSGM dynamics. For appropriate \u21b5 and , the generalized P\ufb02ug stationary condition says\n\nE\u21e1\u21e5hg, di\u21e4 = \n\n\u21b5(1 )\n2(1 + )\n\ntr(A\u2303r) + O(\u21b52),\n\n(5)\nwhere \u2303r is the covariance of the noise r. One can estimate the left-hand-side during training by\ncomputing the inner product hgk, dki in each iteration. P\ufb02ug (1983) also designed a clever estimator\nfor the right-hand-side, so it is possible to compute estimators for both sides of (5).\nThe Taylor expansion in \u21b5 used to derive (5) means that the relationship may only be accurate for\nsmall \u21b5, but \u21b5 is typically large in the \ufb01rst phase of training. This, together with the other assumptions\nrequired for P\ufb02ug\u2019s condition, are too strong to make the condition (5) useful in practice.\n\n2.2 Yaida\u2019s condition\nYaida (2018) showed that as long as the stationary distribution \u21e1 exists, the following relationship\nholds exactly:\n\nE\u21e1[hx, rF(x)i] =\n\n\u21b5\n2\n\n1 + \n1 \n\nE\u21e1[hd, di]\n\nE\u21e1[hx, gi] =\n\nIn particular, this holds for general functions F and arbitrary values of \u21b5. Because the stochastic\ngradients gk are unbiased, one can further show that:\n1 + \n1 \n\n(6)\nIn the quadratic, i.i.d. noise setting of Section 2.1, the left-hand-side of (6) is simply E\u21e1[xT Ax],\ntwice the average loss value at stationarity. So this condition can be considered as a generalization of\n\u201ctesting for when the loss is stationary.\u201d We can estimate both sides of (6) by computing hxk, gki and\nhdk, dki = ||dk||2 at each iteration and updating the running mean \u00afzN with their dierence. That is,\nwe let\n(7)\n\nE\u21e1[hd, di].\n\n\u00afzN =\n\n\u21b5\n2\n\nzk .\n\nzk = hxk, gki \n\n\u21b5\n2\n\n1 + \n1 hdk, dki\n\nHere B is the number of samples discarded as part of a \u201cburn-in\u201d phase to reduce bias that might be\ncaused by starting far away from the stationary distribution; we typically take B = N/2, so that we\nuse the most recent N/2 samples.\nYaida\u2019s condition has two key advantages over P\ufb02ug\u2019s: it holds with no approximation for arbitrary\nfunctions F and any learning rate \u21b5, and both sides can be estimated with negligible cost. In Appendix\nC.2, we show in Figure 14 that even on a strongly convex function, the error term in (5) is large,\nwhereas \u00afzN in (7) quickly converges to zero. Given these advantages, in the next section, we focus on\nhow to test (6), i.e., that \u00afzN de\ufb01ned in (7) is approximately zero.\n\n1\nN\n\nN +B1Xk=B\n\n3 Testing for Stationarity\nBy the Markov chain law of large numbers, we know that \u00afzN ! 0 as N grows, but there are multiple\nways to determine whether \u00afzN is \u201cclose enough\u201d to zero that we should drop the learning rate.\n\nDeterministic test.\n\nA natural idea is to test\n\nIf in addition to \u00afzN in (7), we keep track of\n1 + \n1 hdi, dii,\n\nN +B1Xi=B\n\n\u00afvN =\n\n1\nN\n\n\u21b5\n2\n\nor equivalently\n\n| \u00afzN| < \u00afvN\n\n| \u00afzN/\u00afvN| <\n\n(9)\nto detect stationarity, where > 0 is an error tolerance. The \u00afvN term is introduced to make the\nerror term relative to the scale of \u00afz and \u00afv (\u00afvN is always nonnegative). If \u00afzN satis\ufb01es (9), then the\ndynamics (2) are \u201cclose\u201d to stationarity. This is precisely the method used by Yaida (2018).\nHowever, because \u00afzN is a random variable, there is some potential for error in this procedure due to\nits variance, which is unaccounted for by (9). Especially when we aim to make a critical decision\nbased on the outcome of this test (i.e., dropping the learning rate), it seems important to more directly\naccount for this variance. To do so, we can appeal to statistical hypothesis testing.\n\n(8)\n\n4\n\n\fI.i.d. t-test. The simplest approach to accounting for the variance in \u00afzN is to assume each sample\nfrom the same distribution. Then by the central limit theorem, we have that\nzi is drawn i.i.d.\npN \u00afzN !N( 0, 2\nz), and moreover \u02c62\nz for large N. So we can estimate\nthe variance of \u00afzN\u2019s sampling distribution using the sample variance of the zi\u2019s. Using this variance\nestimate, we can form the (1 ) con\ufb01dence interval\n\u00afzN \u00b1 t\u21e41/2\n\ni=1(zi \u00afzN)2 \u21e1 2\n\nN1PN\n\nN = 1\n\nwhere t\u21e41/2 is the (1 /2) quantile of the Student\u2019s t-distribution with N 1 degrees of freedom.\nThen we can check whether\n\n,\n\n\u02c6NpN\n\u02c6NpN 2\u00afvN, \u00afvN.\n\n\uf8ff\u00afzN t\u21e41/2\n\n\u02c6NpN\n\n, \u00afzN + t\u21e41/2\n\n(10)\n\nIf so, we can be con\ufb01dent that \u00afzN is close to zero. The method of P\ufb02ug (1983, Algorithm 4.2) is also\na kind of i.i.d. test that tries to account for the variance of \u00afzN, but in a more heuristic way than (10).\nThe procedure (10) can be thought of as a relative equivalence test in statistical hypothesis testing (e.g.\nStreiner, 2003). When \u02c6N = 0 (no variance) or = 1 (t\u21e41/2 = 0, no con\ufb01dence), this recovers (9).\nUnfortunately, in our case, samples zi evaluated at nearby points are highly correlated (due to the\nunderlying Markov dynamics), which makes this procedure inappropriate. To deal with correlated\nsamples, we appeal to a stronger Markov chain result than the Markov chain law of large numbers (4).\n\nNPN1\n\npN (\u00afzN E\u21e1 z) !N( 0, 2\nz),\n\ni=0 z(Xi) is the running mean over time of z(Xi), and 2\n\nMarkov chain t-test Under suitable conditions, Markov chains admit the following analogue of the\ncentral limit theorem:\nTheorem 1 (Markov Chain CLT (informal); (Jones et al., 2006)). Let X = {X0, X1, . . .} be a Harris\nergodic Markov chain with state space X, and with stationary distribution \u21e1, that satis\ufb01es any one of\na number of additional ergodicity criteria (see Jones et al. (2006), page 6). For suitable functions\nz : X! R, we have that:\nwhere \u00afzN = 1\ncorrelations in the Markov chain.\nThis shows that in the presence of correlation, the sample variance is not the correct estimator for the\nvariance of \u00afzN\u2019s sampling distribution. In light of Theorem 1, if we are given a consistent estimator\n\u02c62\nN ! 2\nz , we can properly perform the test (10). All that remains is to construct such an estimator.\nBatch mean variance estimator. Methods for estimating the asymptotic variance of the history\naverage estimator, e.g., \u00afzN in (7), on a Markov chain are well-studied in the MCMC (Markov chain\nMonte Carlo) literature. They can be used to set a stopping time for an MCMC simulation and\nto determine the simulation\u2019s random error (Jones et al., 2006). We present one of the simplest\nestimators for 2\nGiven N samples {zi}, divide them into b batches each of size m, and compute the batch means:\n\u00afz j = 1\n\nz , the batch means estimator.\n\nz , var\u21e1 z in general due to\n\nzi for each batch j. Then let\n\nmP(j+1)m1\n\ni=jm\n\n\u02c62\nN =\n\nm\nb 1\n\nb1Xj=0(\u00afz j \u00afzN)2.\n\n(11)\n\nHere \u02c62\nN is simply the variance of the batch means around the full mean \u00afzN. When used in the test\n(10), it has b 1 degrees of freedom. Intuitively, when b and m are both large enough, these batch\nmeans are roughly independent because of the mixing of the Markov chain, so their unbiased sample\nvariance gives a good estimator of 2\nz . Jones et al. (2006) survey the formal conditions under which\nz , and suggest taking b = m = pN (the theoretically correct\n\u02c62\nN is a strongly consistent estimator of 2\nsizes of b and m depend on the mixing of the Markov chain). Flegal and Jones (2010) prove strong\nconsistency for a related method called overlapping batch means (OLBM) that has better asymptotic\nvariance. The OLBM estimator is similar to (11), but uses n b + 1 overlapping batches of size b and\nhas n b degrees of freedom.\n\n5\n\n\ffor k 2{ jM, . . . ,(j + 1)M 1} do\n\nSample \u21e0k and compute gk = rx f(xk,\u21e0 k)\ndk+1 = (1 )gk + dk\nxk+1 = xk \u21b5dk+1\nzQ.push(hxk, gki \u21b5\nvQ.push( \u21b5\n2\n\n1 ||dk+1||2)\n\n1 ||dk+1||2)\n\n1+\n\n2\n\n}\n\nAlgorithm 2: SASA\nInput: {x0,\u21b5 0, M,,,,\u21e3\n1 zQ = HalfQueue()\n2 vQ = HalfQueue()\n3 for j 2{ 0, 1, 2, . . .} do\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16 end\n\n\u21b5 \u21e3 \u21b5\nzQ.reset()\nvQ.reset()\n\nend\n\n1+\n\nend\nif test(zQ, vQ,, ) then\n\nAlgorithm 3: Test\nInput: {zQ, vQ,, }\nOutput: boolean (whether to drop)\n1 \u00afzN = 1\n2 \u00afvN = 1\n3 m = b = pzQ.N\n4 for i 2{ 0, . . . , b 1} do\nzQ[t]\n5\n6 end\ni=0 (\u00afzi \u00afzN)2.\n7 \u02c62\n8 L = \u00afzN t\u21e41/2\n9 U = \u00afzN + t\u21e41/2\n10 return [L, U] 2 (\u00afvN, \u00afvN)\n\nzQ.NPi zQ[i]\nvQ.NPi vQ[i]\nmP(i+1)m1\nb1Pb1\n\n\u02c6NpzQ.N\n\u02c6NpzQ.N\n\nN = m\n\n\u00afzi = 1\n\nt=im\n\n3.1 Statistical adaptive stochastic approximation (SASA)\nFinally, we turn the above analysis into an adaptive algorithm for detecting stationarity of SGM and\ndecreasing \u21b5, and discuss implementation details. Algorithm 2 describes our full SASA algorithm.\nTo diminish the eect of \u201cinitialization bias\u201d due to starting outside of the stationary distribution, we\nonly keep track of the latter half of samples zi and vi. That is, if N total iterations of SGM have been\nrun, the \u201cHalfQueues\u201d zQ and vQ contain the most recent N/2 values of zi and vi\u2014these queues\n\u201cpop\u201d every other time they \u201cpush.\u201d If we decrease the learning rate, we empty the queues; otherwise,\nwe keep collecting more samples. To compute the batch mean estimator, we need O(N) space, but in\ndeep learning the total number of training iterations (the worst case size of these queues) is usually\nsmall compared to the number of parameters of the model. Collection of the samples zi and vi only\nrequires two more inner products per iteration than SGM.\nThe \u201ctest\u201d algorithm follows the Markov chain t-test procedure discussed above. Lines 1-2 compute\nthe running means \u00afzN and \u00afvN; lines 3-7 compute the variance estimator \u02c62\nN according to (11), and\nlines 8-10 determine whether the corresponding con\ufb01dence interval for \u00afzN is within the acceptable\ninterval (\u00afvN, \u00afvN). Like the sample collection, the test procedure is computationally ecient: the\nbatch mean and overlapping batch mean estimators can both be computed with a 1D convolution.\nFor all experiments, we use default values = 0.02 and = 0.2. In equivalence testing, is typically\ntaken larger than usual to increase the power of the test (Streiner, 2003). We discuss the apparent\nmultiple testing problem of this sequential testing procedure in Appendix D.\n\n4 Experiments\nTo evaluate the performance of SASA, we run Algorithm 2 on several models from deep learning.\nWe compare SASA to tuned versions of Adam and SGM. Many adaptive optimizers do not compare\nto SGM with hand-tuned step size scheduling, (e.g., Schaul et al., 2013; Zhang and Mitliagkas, 2017;\nBaydin et al., 2018), and instead compare to SGM with a \ufb01xed \u21b5 or to SGM with tuned polynomial\ndecay. As detailed in Section 1, tuned constant-and-cut schedules are typically a stronger baseline.\nThroughout this section, we do not tune the SASA parameters , ,M, instead using the default settings\nof = 0.02 and = 0.2, and setting M = one epoch (we test the statistics once per epoch). In each\nexperiment, we use the same \u21b50 and \u21e3 as for the best SGM baseline. We stress that SASA is not fully\nautomatic: it requires choices of \u21b50 and \u21e3, but we show in Appendix A that SASA achieves good\nperformance for dierent values of \u21e3. We use weight decay in every experiment\u2014without weight\ndecay, there are simple examples where the process (2) does not converge to a stationary distribution,\n\n6\n\n\fFigure 2: Training loss, test accuracy, and learning rate schedule for SASA, SGM, and Adam on\ndierent datasets. Top: ResNet18 on CIFAR-10. Middle: ResNet18 on ImageNet. Bottom: RNN\nmodel on WikiText-2. In all cases, starting with the same \u21b50, SASA achieves similar performance to\nthe best hand-tuned or validation-tuned SGM result. Across three independent runs, the variance\nof each optimizer\u2019s best test accuracy was never larger than 1%, and the relative orderings between\noptimizers held for every run. Figure 5 studies the variance of SASA in a semi-synthetic setting.\n\n(a)\n\n(b)\n\n(c)\n\n(d)\n\nFigure 3: Evolution of the dierent statistics for SASA over the course of training ResNet18 on\nCIFAR-10 using the default parameters = 0.02, = 0.2,\u21e3 = 0.1. Panel (a) shows the raw data for\nboth sides of condition (6). That is, it shows the values of hxk, gki and \u21b5\n1 hdk, dki at each iteration.\nPanel (1) shows \u00afzN with its lower and upper con\ufb01dence interval [lci, uci] and the \u201cright hand side\u201d\n(rhs) (\u00afvN, \u00afvN) (see Eqn. (10)). Panel (c) shows a zoomed-in version of (b) to show the drop points\nin more detail. Panel (d) depicts the dierent variance estimators (i.i.d., batch means, overlapping\nbatch means) over the course of training. The i.i.d. variance (green) is a poor estimate of 2\nz .\n\n1+\n\n2\n\nsuch as with logistic regression on separable data. While weight decay does not guarantee convergence\nto a stationary distribution, it at least rules out this simple case. Finally, we conduct an experiment on\nCIFAR-10 that shows directly accounting for the variance of the test statistic, as in (10), improves the\nrobustness of this procedure compared to (9).\nFor hand-tuned SGM (SGM-hand), we searched over \u201cconstant-and-cut\u201d schemes for each experiment\nby tuning \u21b50, the drop frequency, and the drop amount \u21e3 with grid search. In all experiments, SASA\nand SGM use a constant = 0.9. For Adam, we tuned the initial global learning rate as in Wilson et al.\n(2017) and used 1 = 0.9, 2 = 0.999. We also allowed Adam to have access to a \u201cwarmup\u201d phase to\nprevent it from decreasing the learning rate too quickly. To \u201cwarm up\u201d Adam, we initialize it with the\n\n7\n\n\fFigure 4: Smoothed training loss, test accuracy and learning rate schedule for ResNet18 trained on\nImageNet using SASA with dierent values of \u21e3. SASA automatically adapts the drop frequency.\n\nparameters obtained after running SGM with constant \u21b50 for a tuned number of iterations. While the\nwarmup phase improves Adam\u2019s performance, it still does not match SASA or SGM on the tasks we\ntried. Appendix A contains a full list of the hyperparameters used in each experiment, additional\nresults for object detection, sensitivity analysis for and , and plots of the dierent estimators for\nthe variance 2\nz .\nCIFAR-10. We trained an 18-layer ResNet model4 He et al.\n(He et al., 2016) on CIFAR-10\n(Krizhevsky and Hinton, 2009) with random cropping and random horizontal \ufb02ipping for data\naugmentation and weight decay 0.0005. Row 1 of Figure 2 compares the best performance of each\nmethod. Here SGM-hand uses \u21b50 = 1.0 and = 0.9 and drops \u21b5 by a factor of 10 (\u21e3 = 0.1) every\n50 epochs. SASA uses = 0.2 and = 0.02, as always. Adam has a tuned global learning rate\n\u21b50 = 0.0001 and a tuned \u201cwarmup\u201d phase of 50 epochs, but is unable to match SASA and tuned SGM.\nEvolution of statistics. Figure 3 shows the evolution of SASA\u2019s dierent statistics over the course of\ntraining the ResNet18 model on CIFAR-10 using the default parameter settings = 0.02, = 0.2,\u21e3 =\n0.1. In each phase, the running average of the dierence between the statistics, \u00afzN, decays toward\nzero. The learning rate \u21b5 drops once \u00afzN and its con\ufb01dence interval are contained in (\u00afvN, \u00afvN); see\nEqn (10). After the drop, the statistics increase in value and enter another phase of convergence. The\nbatch means variance estimator (BM) and overlapping batch means variance estimator (OLBM) give\nvery similar estimates of the variance, while the i.i.d. variance estimator, as expected, gives quite\ndierent values.\nImageNet. Unlike CIFAR-10, reaching a good performance level on ImageNet (Deng et al., 2009)\nseems to require more gradual annealing. Even when tuned and allowed to have a long warmup phase,\nAdam failed to match the generalization performance of SGM. On the other hand, SASA was able to\nmatch the performance of hand-tuned SGM using the default values of its parameters. We again used\nan 18-layer ResNet model with random cropping, random \ufb02ipping, normalization, and weight decay\n0.0001. Row 2 of Figure 2 shows the performance of the dierent optimizers.\nRNN. We also evaluate SASA on a language modeling task using an RNN. In particular, we train\nthe PyTorch word-level language model example (2019) on the Wikitext-2 dataset (Merity et al.,\n2016). We used 600-dimensional embeddings, 600 hidden units, tied weights, and dropout 0.65,\nand gradient clipping with threshold 2.0 (note that this model is not state-of-the-art for Wikitext-2).\nWe compare against SGM and Adam with (global) learning rate tuned using a validation set. These\nbaselines drop the learning rate \u21b5 by a factor of 4 when the validation loss stops improving. Row 3 of\nFigure 2 shows that without using the validation set, SASA is competitive with these baselines.\nAdaptation to the drop factor. At \ufb01rst glance, the choice of the drop factor \u21e3 seems critical.\nHowever, Figure 4 shows that SASA automatically adapts to dierent values of \u21e3. When \u21e3 is larger, so\n\u21b5 decreases slower, the dynamics converge more quickly to the stationary distribution, so the overall\nrate of decrease stays roughly constant across dierent values of \u21e3. Aside from the dierent choices\nof \u21e3, all other hyperparameters were the same as in the ImageNet experiment of Figure 2.\nVariance. Figure 5 shows the variance in learning rate schedule and training loss for the two tests\nin (9) (top row) and (10) (bottom row) with a \ufb01xed testing frequency M = 400 iterations, across\n\ufb01ve independent runs. The model is ResNet18 trained on CIFAR-10 using the same procedure as\nslightly modi\ufb01ed ResNet model of\nhttps://github.com/kuangliu/pytorch-cifar, which we found to give a small performance gain over\nthe model of He et al. (2016) for all optimizers we tested. The \ufb01rst convolutional layer in this model uses \ufb01lter\nsize 3 with stride 1 and padding 1, rather than 7, 2, and 3, respectively.\n\nexperiments on CIFAR-10, we used the\n\n4In our\n\n8\n\n\fFigure 5: Variance in learning rate schedule and training loss for the two tests (9) (top row) and (10)\n(bottom row) with \ufb01xed testing frequency M, across \ufb01ve independent runs. The left two columns\nuse batch size four, and the right two use batch size eight. With the same testing frequency and the\nsame value of (0.02), the test (9) is much more sensitive to the level of noise. In row 1, column 2,\nonly one of the \ufb01ve runs (plotted in red) achieves a low training loss because of the high variance in\nschedule (row 1, column 1).\n\nin the previous CIFAR experiment, but with dierent batch sizes. The left two columns use batch\nsize four, and the right two use batch size eight. With the same testing frequency and the same value\nof = 0.02, the test (9) is much more sensitive to the level of noise in these small-batch examples.\nWhen the batch size is four, only one of the training runs using the test (9) achieves training loss on\nthe same scale as the others. Appendix B contains additional discussion comparing these two tests.\n\n5 Conclusion\nWe provide a theoretically grounded statistical procedure for automatically determining when to\ndecrease the learning rate \u21b5 in constant-and-cut methods. On the tasks we tried, SASA was competitive\nwith the best hand-tuned schedules for SGM, and it came close to the performance of SGM and Adam\nwhen they were tuned using a validation set. The statistical testing procedure controls the variance of\nthe method and makes it more robust than other more heuristic tests. Our experiments across several\ndierent tasks and datasets did not require any adjustment to the parameters , , or M.\nWe believe these practical results indicate that automatic \u201cconstant-and-cut\u201d algorithms are a promising\ndirection for future research in adaptive optimization. We used a simple statistical test to check\nYaida\u2019s stationary condition (6). However, there may be better tests that more properly control the\nfalse discovery rate (Blanchard and Roquain, 2009; Lindquist and Mejia, 2015), or more sophisticated\nconditions that also account for non-stationary dynamics like over\ufb01tting or limit cycles (Yaida, 2018).\nSuch techniques could make the SASA approach more broadly useful.\n\nReferences\nPytorch word language model. https://github.com/pytorch/examples/tree/master/word_\n\nlanguage_model, 2019.\n\nDario Amodei, Sundaram Ananthanarayanan, Rishita Anubhai, Jingliang Bai, Eric Battenberg,\nCarl Case, Jared Casper, Bryan Catanzaro, Qiang Cheng, Guoliang Chen, et al. Deep speech 2:\nEnd-to-end speech recognition in english and mandarin. In International conference on machine\nlearning, pages 173\u2013182, 2016.\n\nFrancis Bach and Eric Moulines. Non-strongly-convex smooth stochastic approximation with\nconvergence rate o (1/n). In Advances in neural information processing systems, pages 773\u2013781,\n2013.\n\nAtilim G\u00fcnes Baydin, Robert Cornish, David Mart\u00ednez Rubio, Mark Schmidt, and Frank Wood. Online\nlearning rate adaptation with hypergradient descent. In Proceedings of the Sixth International\nConference on Learning Representations (ICLR), Vancouver, Canada, 2018.\n\n9\n\n\fYoav Benjamini and Yosef Hochberg. Controlling the false discovery rate: a practical and powerful\napproach to multiple testing. Journal of the Royal statistical society: series B (Methodological), 57\n(1):289\u2013300, 1995.\n\nGilles Blanchard and \u00c9tienne Roquain. Adaptive false discovery rate control under independence and\n\ndependence. Journal of Machine Learning Research, 10(Dec):2837\u20132871, 2009.\n\nJerry Chee and Panos Toulis. Convergence diagnostics for stochastic gradient descent with constant\nlearning rate. In International Conference on Arti\ufb01cial Intelligence and Statistics, pages 1476\u20131485,\n2018.\n\nJ. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei. ImageNet: A Large-Scale Hierarchical\n\nImage Database. In CVPR09, 2009.\n\nAymeric Dieuleveut, Alain Durmus, and Francis Bach. Bridging the gap between constant step size\n\nstochastic gradient descent and markov chains. arXiv preprint arXiv:1707.06386, 2017.\n\nJohn Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learning and\n\nstochastic optimization. Journal of Machine Learning Research, 12(Jul):2121\u20132159, 2011.\n\nJames M Flegal and Galin L Jones. Batch means and spectral variance estimators in markov chain\n\nmonte carlo. The Annals of Statistics, 38(2):1034\u20131070, 2010.\n\nRong Ge, Sham M Kakade, Rahul Kidambi, and Praneeth Netrapalli. The step decay schedule: A near\noptimal, geometrically decaying learning rate procedure. arXiv preprint arXiv:1904.12838, 2019.\nJonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N Dauphin. Convolutional\nsequence to sequence learning. In Proceedings of the 34th International Conference on Machine\nLearning-Volume 70, pages 1243\u20131252. JMLR. org, 2017.\n\nKaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual networks for image\nIn Proceedgins of the 29th IEEE Conference on Computer Vision and Pattern\n\nrecognition.\nRecognition (CVPR), pages 770\u2013778, 2016.\n\nKaiming He, Georgia Gkioxari, Piotr Doll\u00e1r, and Ross Girshick. Mask r-cnn. In Proceedings of the\n\nIEEE international conference on computer vision, pages 2961\u20132969, 2017.\n\nYanping Huang, Yonglong Cheng, Dehao Chen, HyoukJoong Lee, Jiquan Ngiam, Quoc V Le, and\nZhifeng Chen. Gpipe: Ecient training of giant neural networks using pipeline parallelism. arXiv\npreprint arXiv:1811.06965, 2018.\n\nGalin L Jones, Murali Haran, Brian S Cao, and Ronald Neath. Fixed-width output analysis for\nmarkov chain monte carlo. Journal of the American Statistical Association, 101(476):1537\u20131547,\n2006.\n\nDiederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint\n\narXiv:1412.6980, 2014.\n\nAlex Krizhevsky and Georey Hinton. Learning multiple layers of features from tiny images.\n\nTechnical report, Citeseer, 2009.\n\nHarold J. Kushner and G. George Yin. Stochastic Approximation and Recursive Algorithms and\n\nApplications. Springer, 2nd edition, 2003.\n\nTsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr\nDoll\u00e1r, and C Lawrence Zitnick. Microsoft coco: Common objects in context. In European\nconference on computer vision, pages 740\u2013755. Springer, 2014.\n\nTsung-Yi Lin, Piotr Doll\u00e1r, Ross Girshick, Kaiming He, Bharath Hariharan, and Serge Belongie.\nFeature pyramid networks for object detection. In Proceedings of the IEEE Conference on Computer\nVision and Pattern Recognition, pages 2117\u20132125, 2017.\n\nMartin A Lindquist and Amanda Mejia. Zen and the art of multiple comparisons. Psychosomatic\n\nmedicine, 77(2):114, 2015.\n\n10\n\n\fStephan Mandt, Matthew D Homan, and David M Blei. Stochastic gradient descent as approximate\n\nbayesian inference. The Journal of Machine Learning Research, 18(1):4873\u20134907, 2017.\n\nFrancisco Massa and Ross Girshick. maskrcnn-benchmark: Fast, modular reference implementation\nof Instance Segmentation and Object Detection algorithms in PyTorch. https://github.com/\nfacebookresearch/maskrcnn-benchmark, 2018. Accessed: [Insert date here].\n\nJohn H McDonald. Handbook of biological statistics, volume 2. 2009.\nStephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture\n\nmodels. arXiv preprint arXiv:1609.07843, 2016.\n\nGeorg Ch. P\ufb02ug. On the determination of the step size in stochastic quasigradient methods.\nCollaborative Paper CP-83-025, International Institute for Applied Systems Analysis (IIASA),\nLaxenburg, Austria, 1983.\n\nGeorg Ch. P\ufb02ug. Non-asymptotic con\ufb01dence bounds for stochastic approximation algorithms with\n\nconstant step size. Monatshefte f\u00fcr Mathematik, 110:297\u2013314, 1990.\n\nBoris T. Polyak. Some methods of speeding up the convergence of iteration methods. USSR\n\nComputational Mathematics and Mathematical Physics, 4(5):1\u201317, 1964.\n\nTom Schaul, Sixin Zhang, and Yann LeCun. No more pesky learning rates.\n\nConference on Machine Learning, pages 343\u2013351, 2013.\n\nIn International\n\nDavid L Streiner. Unicorns do exist: A tutorial on \u201cproving\u201d the null hypothesis. The Canadian\n\nJournal of Psychiatry, 48(11):756\u2013761, 2003.\n\nChristian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov,\nDumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In\nProceedings of the IEEE conference on computer vision and pattern recognition, pages 1\u20139, 2015.\nTijmen Tieleman and Georey Hinton. Lecture 6.5-rmsprop: Divide the gradient by a running\naverage of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2):26\u201331,\n2012.\n\nM. T. Wasan. Stochastic Approximation. Cambridge University Press, 1969.\nAshia C Wilson, Rebecca Roelofs, Mitchell Stern, Nati Srebro, and Benjamin Recht. The marginal\nvalue of adaptive gradient methods in machine learning. In Advances in Neural Information\nProcessing Systems, pages 4148\u20134158, 2017.\n\nSho Yaida. Fluctuation-dissipation relations for stochastic gradient descent. arXiv preprint\n\narXiv:1810.00004, 2018.\n\nJian Zhang and Ioannis Mitliagkas. Yellow\ufb01n and the art of momentum tuning. arXiv preprint\n\narXiv:1706.03471, 2017.\n\n11\n\n\f", "award": [], "sourceid": 5078, "authors": [{"given_name": "Hunter", "family_name": "Lang", "institution": "Microsoft Research"}, {"given_name": "Lin", "family_name": "Xiao", "institution": "Microsoft Research"}, {"given_name": "Pengchuan", "family_name": "Zhang", "institution": "Microsoft Research"}]}