{"title": "Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent", "book": "Advances in Neural Information Processing Systems", "page_first": 8572, "page_last": 8583, "abstract": "A longstanding goal in deep learning research has been to precisely characterize training and generalization. However, the often complex loss landscapes of neural networks have made a theory of learning dynamics elusive. In this work, we show that for wide neural networks the learning dynamics simplify considerably and that, in the infinite width limit, they are governed by a linear model obtained from the first-order Taylor expansion of the network around its initial parameters. Furthermore, mirroring the correspondence between wide Bayesian neural networks and Gaussian processes, gradient-based training of wide neural networks with a squared loss produces test set predictions drawn from a Gaussian process with a particular compositional kernel. While these theoretical results are only exact in the infinite width limit, we nevertheless find excellent empirical agreement between the predictions of the original network and those of the linearized version even for finite practically-sized networks. This agreement is robust across different architectures, optimization methods, and loss functions.", "full_text": "Wide Neural Networks of Any Depth Evolve as\n\nLinear Models Under Gradient Descent\n\nJaehoon Lee\u21e4, Lechao Xiao\u21e4, Samuel S. Schoenholz, Yasaman Bahri\n\nRoman Novak, Jascha Sohl-Dickstein, Jeffrey Pennington\n\nGoogle Brain\n\n{jaehlee, xlc, schsam, yasamanb, romann, jaschasd, jpennin}@google.com\n\nAbstract\n\nA longstanding goal in deep learning research has been to precisely characterize\ntraining and generalization. However, the often complex loss landscapes of neural\nnetworks have made a theory of learning dynamics elusive. In this work, we show\nthat for wide neural networks the learning dynamics simplify considerably and\nthat, in the in\ufb01nite width limit, they are governed by a linear model obtained from\nthe \ufb01rst-order Taylor expansion of the network around its initial parameters. Fur-\nthermore, mirroring the correspondence between wide Bayesian neural networks\nand Gaussian processes, gradient-based training of wide neural networks with a\nsquared loss produces test set predictions drawn from a Gaussian process with a\nparticular compositional kernel. While these theoretical results are only exact in the\nin\ufb01nite width limit, we nevertheless \ufb01nd excellent empirical agreement between\nthe predictions of the original network and those of the linearized version even\nfor \ufb01nite practically-sized networks. This agreement is robust across different\narchitectures, optimization methods, and loss functions.\n\n1\n\nIntroduction\n\nMachine learning models based on deep neural networks have achieved unprecedented performance\nacross a wide range of tasks [1, 2, 3]. Typically, these models are regarded as complex systems for\nwhich many types of theoretical analyses are intractable. Moreover, characterizing the gradient-based\ntraining dynamics of these models is challenging owing to the typically high-dimensional non-convex\nloss surfaces governing the optimization. As is common in the physical sciences, investigating the\nextreme limits of such systems can often shed light on these hard problems. For neural networks,\none such limit is that of in\ufb01nite width, which refers either to the number of hidden units in a fully-\nconnected layer or to the number of channels in a convolutional layer. Under this limit, the output of\nthe network at initialization is a draw from a Gaussian process (GP); moreover, the network output\nremains governed by a GP after exact Bayesian training using squared loss [4, 5, 6, 7, 8]. Aside from\nits theoretical simplicity, the in\ufb01nite-width limit is also of practical interest as wider networks have\nbeen found to generalize better [5, 7, 9, 10, 11].\nIn this work, we explore the learning dynamics of wide neural networks under gradient descent and\n\ufb01nd that the weight-space description of the dynamics becomes surprisingly simple: as the width\nbecomes large, the neural network can be effectively replaced by its \ufb01rst-order Taylor expansion with\nrespect to its parameters at initialization. For this linear model, the dynamics of gradient descent\nbecome analytically tractable. While the linearization is only exact in the in\ufb01nite width limit, we\nnevertheless \ufb01nd excellent agreement between the predictions of the original network and those of\n\u21e4Both authors contributed equally to this work. Work done as a member of the Google AI Residency program\n\n(https://g.co/airesidency).\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fthe linearized version even for \ufb01nite width con\ufb01gurations. The agreement persists across different\narchitectures, optimization methods, and loss functions.\nFor squared loss, the exact learning dynamics admit a closed-form solution that allows us to charac-\nterize the evolution of the predictive distribution in terms of a GP. This result can be thought of as an\nextension of \u201csample-then-optimize\" posterior sampling [12] to the training of deep neural networks.\nOur empirical simulations con\ufb01rm that the result accurately models the variation in predictions across\nan ensemble of \ufb01nite-width models with different random initializations.\nHere we summarize our contributions:\n\u2022 Parameter space dynamics: We show that wide network training dynamics in parameter space\nare equivalent to the training dynamics of a model which is af\ufb01ne in the collection of all network\nparameters, the weights and biases. This result holds regardless of the choice of loss function. For\nsquared loss, the dynamics admit a closed-form solution as a function of time.\n\n\u2022 Suf\ufb01cient conditions for linearization: We formally prove that there exists a threshold learning\nrate \u2318critical (see Theorem 2.1), such that gradient descent training trajectories with learning rate\nsmaller than \u2318critical stay in an On1/2-neighborhood of the trajectory of the linearized network\nwhen n, the width of the hidden layers, is suf\ufb01ciently large.\n\u2022 Output distribution dynamics: We formally show that the predictions of a neural network\nthroughout gradient descent training are described by a GP as the width goes to in\ufb01nity (see\nTheorem 2.2), extending results from Jacot et al. [13]. We further derive explicit time-dependent\nexpressions for the evolution of this GP during training. Finally, we provide a novel interpretation\nof the result. In particular, it offers a quantitative understanding of the mechanism by which\ngradient descent differs from Bayesian posterior sampling of the parameters: while both methods\ngenerate draws from a GP, gradient descent does not generate samples from the posterior of any\nprobabilistic model.\n\n\u2022 Large scale experimental support: We empirically investigate the applicability of the theory in\nthe \ufb01nite-width setting and \ufb01nd that it gives an accurate characterization of both learning dynamics\nand posterior function distributions across a variety of conditions, including some practical network\narchitectures such as the wide residual network [14].\n\n\u2022 Parameterization independence: We note that linearization result holds both in standard and\nNTK parameterization (de\ufb01ned in \u00a72.1), while previous work assumed the latter, emphasizing that\nthe effect is due to increase in width rather than the particular parameterization.\n\ncorresponding to fully-connected networks with ReLU or erf nonlinearities.\n\n\u2022 Analytic ReLU and erf neural tangent kernels: We compute the analytic neural tangent kernel\n\u2022 Source code: Example code investigating both function space and parameter space linearized\nlearning dynamics described in this work is released as open source code within [15].2 We also\nprovide accompanying interactive Colab notebooks for both parameter space3 and function\nspace4 linearization.\n\n1.1 Related work\nWe build on recent work by Jacot et al. [13] that characterize the exact dynamics of network outputs\nthroughout gradient descent training in the in\ufb01nite width limit. Their results establish that full batch\ngradient descent in parameter space corresponds to kernel gradient descent in function space with\nrespect to a new kernel, the Neural Tangent Kernel (NTK). We examine what this implies about\ndynamics in parameter space, where training updates are actually made.\nDaniely et al. [16] study the relationship between neural networks and kernels at initialization. They\nbound the difference between the in\ufb01nite width kernel and the empirical kernel at \ufb01nite width n,\nwhich diminishes as O(1/pn). Daniely [17] uses the same kernel perspective to study stochastic\ngradient descent (SGD) training of neural networks.\nSaxe et al. [18] study the training dynamics of deep linear networks, in which the nonlinearities\nare treated as identity functions. Deep linear networks are linear in their inputs, but not in their\n\n2Note that the open source library has been expanded since initial submission of this work.\n3colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/weight_space_linearization.ipynb\n4colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/function_space_linearization.ipynb\n\n2\n\n\fparameters. In contrast, we show that the outputs of suf\ufb01ciently wide neural networks are linear in\nthe updates to their parameters during gradient descent, but not usually their inputs.\nDu et al. [19], Allen-Zhu et al. [20, 21], Zou et al. [22] study the convergence of gradient descent\nto global minima. They proved that for i.i.d. Gaussian initialization, the parameters of suf\ufb01ciently\nwide networks move little from their initial values during SGD. This small motion of the parameters\nis crucial to the effect we present, where wide neural networks behave linearly in terms of their\nparameters throughout training.\nMei et al. [23], Chizat and Bach [24], Rotskoff and Vanden-Eijnden [25], Sirignano and Spiliopoulos\n[26] analyze the mean \ufb01eld SGD dynamics of training neural networks in the large-width limit. Their\nmean \ufb01eld analysis describes distributional dynamics of network parameters via a PDE. However,\ntheir analysis is restricted to one hidden layer networks with a scaling limit (1/n) different from ours\n(1/pn), which is commonly used in modern networks [2, 27].\nChizat et al. [28]5 argued that in\ufb01nite width networks are in \u2018lazy training\u2019 regime and maybe too\nsimple to be applicable to realistic neural networks. Nonetheless, we empirically investigate the\napplicability of the theory in the \ufb01nite-width setting and \ufb01nd that it gives an accurate characterization\nof both the learning dynamics and posterior function distributions across a variety of conditions,\nincluding some practical network architectures such as the wide residual network [14].\n\n2 Theoretical results\n\n2.1 Notation and setup for architecture and training dynamics\nLet D\u2713 Rn0 \u21e5 Rk denote the training set and X = {x : (x, y) 2D} and Y = {y : (x, y) 2D}\ndenote the inputs and labels, respectively. Consider a fully-connected feed-forward network with L\nhidden layers with widths nl, for l = 1, ..., L and a readout layer with nL+1 = k. For each x 2 Rn0,\nwe use hl(x), xl(x) 2 Rnl to represent the pre- and post-activation functions at layer l with input x.\nThe recurrence relation for a feed-forward network is de\ufb01ned as\n\n\u21e2hl+1 = xlW l+1 + bl+1\nxl+1 = hl+1\n\nand (W l\n\nbl\nj\n\ni,j = !pnl\n= bl\nj\n\n!l\nij\n\n,\n\n(1)\n\nij, l\n\nij and bl\n\n! and 2\n\nj are the trainable variables, drawn i.i.d. from a standard Gaussian !l\n\nwhere is a point-wise activation function, W l+1 2 Rnl\u21e5nl+1 and bl+1 2 Rnl+1 are the weights and\nbiases, !l\nj \u21e0N (0, 1)\nat initialization, and 2\nb are weight and bias variances. Note that this parametrization is non-\nstandard, and we will refer to it as the NTK parameterization. It has already been adopted in several\nrecent works [29, 30, 13, 19, 31]. Unlike the standard parameterization that only normalizes the\nforward dynamics of the network, the NTK-parameterization also normalizes its backward dynamics.\nWe note that the predictions and training dynamics of NTK-parameterized networks are identical\nto those of standard networks, up to a width-dependent scaling factor in the learning rate for each\nparameter tensor. As we derive, and support experimentally, in Supplementary Material (SM) \u00a7F\nand \u00a7G, our results (linearity in weights, GP predictions) also hold for networks with a standard\nparameterization.\n\nWe de\ufb01ne \u2713l \u2318 vec{W l, bl}, the ((nl1 + 1)nl) \u21e5 1 vector of all parameters for layer l. \u2713 =\nvec[L+1\nl=1 \u2713l then indicates the vector of all network parameters, with similar de\ufb01nitions for \u2713\uf8ffl\nand \u2713>l. Denote by \u2713t the time-dependence of the parameters and by \u27130 their initial values. We\nuse ft(x) \u2318 hL+1(x) 2 Rk to denote the output (or logits) of the neural network at time t. Let\n`(\u02c6y, y) : Rk \u21e5 Rk ! R denote the loss function where the \ufb01rst argument is the prediction and the\nsecond argument the true label. In supervised learning, one is interested in learning a \u2713 that minimizes\nthe empirical loss6, L =P(x,y)2D `(ft(x, \u2713), y).\n\n5We note that this is a concurrent work and an expanded version of this note is presented in parallel at\n\nNeurIPS 2019.\n\nplots in \u00a73, we show the average loss.\n\n6To simplify the notation for later equations, we use the total loss here instead of the average loss, but for all\n\n3\n\n\fLet \u2318 be the learning rate7. Via continuous time gradient descent, the evolution of the parameters \u2713\nand the logits f can be written as\n\n\u02d9\u2713t = \u2318r\u2713ft(X )Trft(X )L\n\u02d9ft(X ) = r\u2713ft(X ) \u02d9\u2713t = \u2318 \u02c6\u21e5t(X ,X )rft(X )L\n\n(2)\n(3)\n\nwhere ft(X ) = vec[ft (x)]x2X, the k|D| \u21e5 1 vector of concatenated logits for all examples, and\nrft(X )L is the gradient of the loss with respect to the model\u2019s output, ft(X ). \u02c6\u21e5t \u2318 \u02c6\u21e5t(X ,X ) is the\ntangent kernel at time t, which is a k|D| \u21e5 k|D| matrix\nL+1Xl=1\n\n\u02c6\u21e5t = r\u2713ft(X )r\u2713ft(X )T =\n\nr\u2713lft(X )r\u2713lft(X )T .\n\n(4)\n\nOne can de\ufb01ne the tangent kernel for general arguments, e.g. \u02c6\u21e5t(x,X ) where x is test input. At\n\ufb01nite-width, \u02c6\u21e5 will depend on the speci\ufb01c random draw of the parameters and in this context we\nrefer to it as the empirical tangent kernel.\nThe dynamics of discrete gradient descent can be obtained by replacing \u02d9\u2713t and \u02d9ft(X ) with (\u2713i+1 \u2713i)\nand (fi+1(X ) fi(X )) above, and replacing e\u2318 \u02c6\u21e50t with (1 (1 \u2318 \u02c6\u21e50)i) below.\n2.2 Linearized networks have closed form training dynamics for parameters and outputs\nIn this section, we consider the training dynamics of the linearized network. Speci\ufb01cally, we replace\nthe outputs of the neural network by their \ufb01rst order Taylor expansion,\n\nf lin\nt (x) \u2318 f0(x) + r\u2713f0(x)|\u2713=\u27130\n\n!t ,\n\n(5)\n\nwhere !t \u2318 \u2713t \u27130 is the change in the parameters from their initial values. Note that f lin\nis the\nsum of two terms: the \ufb01rst term is the initial output of the network, which remains unchanged during\ntraining, and the second term captures the change to the initial value during training. The dynamics\nof gradient \ufb02ow using this linearized function are governed by,\n\u02d9!t = \u2318r\u2713f0(X )Trf lin\nt (X )L\n\u02d9f lin\nt (x) = \u2318 \u02c6\u21e50(x,X )rf lin\n\n(6)\n(7)\nAs r\u2713f0(x) remains constant throughout training, these dynamics are often quite simple. In the case\nof an MSE loss, i.e., `(\u02c6y, y) = 1\n\n2, the ODEs have closed form solutions\n\nt (X )L .\n\nt\n\n2k\u02c6y yk2\n\nFor an arbitrary point x, f lin\n\n0 \u21e3I e\u2318 \u02c6\u21e50t\u2318 (f0(X ) Y ) ,\n\n!t = r\u2713f0(X )T \u02c6\u21e51\nt (X ) = (I e\u2318 \u02c6\u21e50t)Y + e\u2318 \u02c6\u21e50tf0(X ) .\nf lin\nt (x) = \u00b5t(x) + t(x), where\n0 \u21e3I e\u2318 \u02c6\u21e50t\u2318Y\n\n\u00b5t(x) = \u02c6\u21e50(x,X ) \u02c6\u21e51\nt(x) = f0(x) \u02c6\u21e50 (x,X ) \u02c6\u21e51\n\n0 \u21e3Ie\u2318 \u02c6\u21e50t\u2318 f0(X ).\n\n(8)\n\n(9)\n\n(10)\n\n(11)\n\nTherefore, we can obtain the time evolution of the linearized neural network without running gradient\ndescent. We only need to compute the tangent kernel \u02c6\u21e50 and the outputs f0 at initialization and use\nEquations 8, 10, and 11 to compute the dynamics of the weights and the outputs.\n\nIn\ufb01nite width limit yields a Gaussian process\n\n2.3\nAs the width of the hidden layers approaches in\ufb01nity, the Central Limit Theorem (CLT) implies\nconverge to a multivariate Gaussian in distribution.\nthat the outputs at initialization {f0(x)}x2X\n\n7Note that compared to the conventional parameterization, \u2318 is larger by factor of width [31]. The NTK\n\nparameterization allows usage of a universal learning rate scale irrespective of network width.\n\n4\n\n\fKi,j(x, x0) =\n\nlim\n\nmin(n1,...,nL)!1\n\nEhf i\n0(x) \u00b7 f j\n\n0 (x0)i ,\n\n(12)\n\nInformally, this occurs because the pre-activations at each layer are a sum of Gaussian random\nvariables (the weights and bias), and thus become a Gaussian random variable themselves. See\n[32, 33, 5, 34, 35] for more details, and [36, 7] for a formal treatment.\nTherefore, randomly initialized neural networks are in correspondence with a certain class of GPs\n(hereinafter referred to as NNGPs), which facilitates a fully Bayesian treatment of neural networks\n[5, 6]. More precisely, let f i\nt denote the i-th output dimension and K denote the sample-to-sample\nkernel function (of the pre-activation) of the outputs in the in\ufb01nite width setting,\n\nthen f0(X ) \u21e0N (0,K(X ,X )), where Ki,j(x, x0) denotes the covariance between the i-th output\nof x and j-th output of x0, which can be computed recursively (see Lee et al. [5, \u00a72.3] and SM\n\u00a7E). For a test input x 2X T , the joint output distribution f ([x,X ]) is also multivariate Gaussian.\nConditioning on the training samples8, f (X ) = Y, the distribution of f (x)|X ,Y is also a Gaussian\nN (\u00b5(x), \u2303(x)),\n(13)\nand where K = K(X ,X ). This is the posterior predictive distribution resulting from exact Bayesian\ninference in an in\ufb01nitely wide neural network.\n\n\u00b5(x) = K(x,X )K1Y, \u2303(x) = K(x, x) K (x,X )K1K(x,X )T ,\n\n2.3.1 Gaussian processes from gradient descent training\nIf we freeze the variables \u2713\uf8ffL after initialization and only optimize \u2713L+1, the original network and its\nlinearization are identical. Letting the width approach in\ufb01nity, this particular tangent kernel \u02c6\u21e50 will\nconverge to K in probability and Equation 10 will converge to the posterior Equation 13 as t ! 1\n(for further details see SM \u00a7D). This is a realization of the \u201csample-then-optimize\" approach for\nevaluating the posterior of a Gaussian process proposed in Matthews et al. [12].\nIf none of the variables are frozen, in the in\ufb01nite width setting, \u02c6\u21e50 also converges in probability to\na deterministic kernel \u21e5 [13, 37], which we sometimes refer to as the analytic kernel, and which\ncan also be computed recursively (see SM \u00a7E). For ReLU and erf nonlinearity, \u21e5 can be exactly\ncomputed (SM \u00a7C) which we use in \u00a73. Letting the width go to in\ufb01nity, for any t, the output f lin\nt (x)\nof the linearized network is also Gaussian distributed because Equations 10 and 11 describe an af\ufb01ne\ntransform of the Gaussian [f0(x), f0(X )]. Therefore\nCorollary 1. For every test points in x 2X T , and t 0, f lin\ngoes to in\ufb01nity to a Gaussian with mean and covariance given by9\n\nt (x) converges in distribution as width\n\n\u00b5(XT ) =\u21e5( XT ,X )\u21e5 1\u21e3I e\u2318\u21e5t\u2318Y ,\n\u2303(XT ,XT ) = K (XT ,XT ) +\u21e5( XT ,X )\u21e51\u21e3I e\u2318\u21e5t\u2318K\u21e3I e\u2318\u21e5t\u2318 \u21e51\u21e5(X ,XT )\n\n\u21e3\u21e5(XT ,X )\u21e51\u21e3I e\u2318\u21e5t\u2318K (X ,XT ) + h.c.\u2318 .\n\n(14)\n\n(15)\n\nt (x) has distribution\n\nTherefore, over random initialization, limt!1 limn!1 f lin\n\nN\u21e5(XT ,X )\u21e5 1Y,\nK (XT ,XT ) +\u21e5( XT ,X )\u21e51K\u21e51\u21e5(X ,XT ) \u21e5(XT ,X )\u21e51K (X ,XT ) + h.c.. (16)\n\nUnlike the case when only \u2713L+1 is optimized, Equations 14 and 15 do not admit an interpretation\ncorresponding to the posterior sampling of a probabilistic model.10 We contrast the predictive\ndistributions from the NNGP, NTK-GP (i.e. Equations 14 and 15) and ensembles of NNs in Figure 2.\nIn\ufb01nitely-wide neural networks open up ways to study deep neural networks both under fully Bayesian\ntraining through the Gaussian process correspondence, and under GD training through the lineariza-\ntion perspective. The resulting distributions over functions are inconsistent (the distribution resulting\n8 This imposes that hL+1 directly corresponds to the network predictions. In the case of softmax readout,\n\nvariational or sampling methods are required to marginalize over hL+1.\n9Here \u201c+h.c.\u201d is an abbreviation for \u201cplus the Hermitian conjugate\u201d.\n10One possible exception is when the NNGP kernel and NTK are the same up to a scalar multiplication. This\n\nis the case when the activation function is the identity function and there is no bias term.\n\n5\n\n\fFigure 1: Relative Frobenius norm change during training. Three hidden layer ReLU net-\nworks trained with \u2318 = 1.0 on a subset of MNIST (|D| = 128). We measure changes of (in-\nput/output/intermediary) weights, empirical \u02c6\u21e5, and empirical \u02c6K after T = 217 steps of gradient\ndescent updates for varying width. We see that the relative change in input/output weights scales as\n1/pn while intermediate weights scales as 1/n, this is because the dimension of the input/output\ndoes not grow with n. The change in \u02c6\u21e5 and \u02c6K is upper bounded by O (1/pn) but is closer to\nO (1/n). See Figure S6 for the same experiment with 3-layer tanh and 1-layer ReLU networks. See\nFigures S9 and S10 for additional comparisons of \ufb01nite width empirical and analytic kernels.\nfrom GD training does not generally correspond to a Bayesian posterior). We believe understand-\ning the biases over learned functions induced by different training schemes and architectures is a\nfascinating avenue for future work.\n\nIn\ufb01nite width networks are linearized networks\n\n2.4\nEquation 2 and 3 of the original network are intractable in general, since \u02c6\u21e5t evolves with time.\nHowever, for the mean squared loss, we are able to prove formally that, as long as the learning rate\n\u2318<\u2318 critical := 2(min(\u21e5) + max(\u21e5))1, where min/max(\u21e5) is the min/max eigenvalue of \u21e5, the\ngradient descent dynamics of the original neural network falls into its linearized dynamics regime.\nTheorem 2.1 (Informal). Let n1 = \u00b7\u00b7\u00b7 = nL = n and assume min(\u21e5) > 0. Applying gradient\ndescent with learning rate \u2318<\u2318 critical (or gradient \ufb02ow), for every x 2 Rn0 with kxk2 \uf8ff 1, with\nprobability arbitrarily close to 1 over random initialization,\n\n= O(n 1\n\n2 ), as n ! 1 .\n\n(17)\n\nsup\n\nt0ft(x) f lin\n\nt (x)2 , sup\n\nt0\n\nk\u2713t \u27130k2pn\n\n, sup\n\nt0 \u02c6\u21e5t \u02c6\u21e50F\n\nTherefore, as n ! 1, the distributions of ft(x) and f lin\nCorollary 1, we have\nTheorem 2.2. If \u2318<\u2318 critical, then for every x 2 Rn0 with kxk2 \uf8ff 1, as n ! 1, ft(x) converges\nin distribution to the Gaussian with mean and variance given by Equation 14 and Equation 15.\n\nt (x) become the same. Coupling with\n\nWe refer the readers to Figure 2 for empirical veri\ufb01cation of this theorem. The proof of Theorem 2.1\nconsists of two steps. The \ufb01rst step is to prove the global convergence of overparameterized neural\nnetworks [19, 20, 21, 22] and stability of the NTK under gradient descent (and gradient \ufb02ow); see\nSM \u00a7G. This stability was \ufb01rst observed and proved in [13] in the gradient \ufb02ow and sequential limit\n(i.e. letting n1 ! 1, . . . , nL ! 1 sequentially) setting under certain assumptions about global\nconvergence of gradient \ufb02ow. In \u00a7G, we show how to use the NTK to provide a self-contained (and\ncleaner) proof of such global convergence and the stability of NTK simultaneously. The second step\nis to couple the stability of NTK with Gr\u00f6nwall\u2019s type arguments [38] to upper bound the discrepancy\nbetween ft and f lin\nt , i.e. the \ufb01rst norm in Equation 17. Intuitively, the ODE of the original network\n(Equation 3) can be considered as a k \u02c6\u21e5t \u02c6\u21e50kF -\ufb02uctuation from the linearized ODE (Equation 7).\nOne expects the difference between the solutions of these two ODEs to be upper bounded by some\nfunctional of k \u02c6\u21e5t \u02c6\u21e50kF ; see SM \u00a7H. Therefore, for a large width network, the training dynamics\ncan be well approximated by linearized dynamics.\nNote that the updates for individual weights in Equation 6 vanish in the in\ufb01nite width limit, which for\ninstance can be seen from the explicit width dependence of the gradients in the NTK parameterization.\nIndividual weights move by a vanishingly small amount for wide networks in this regime of dynamics,\nas do hidden layer activations, but they collectively conspire to provide a \ufb01nite change in the \ufb01nal\noutput of the network, as is necessary for training. An additional insight gained from linearization\n\n6\n\n\fof the network is that the individual instance dynamics derived in [13] can be viewed as a random\nfeatures method,11 where the features are the gradients of the model with respect to its weights.\n\n2.5 Extensions to other optimizers, architectures, and losses\n\nOur theoretical analysis thus far has focused on fully-connected single-output architectures trained\nby full batch gradient descent. In SM \u00a7B we derive corresponding results for: networks with multi-\ndimensional outputs, training against a cross entropy loss, and gradient descent with momentum.\nIn addition to these generalizations, there is good reason to suspect the results to extend to much\nbroader class of models and optimization procedures. In particular, a wealth of recent literature\nsuggests that the mean \ufb01eld theory governing the wide network limit of fully-connected models [32,\n33] extends naturally to residual networks [35], CNNs [34], RNNs [39], batch normalization [40], and\nto broad architectures [37]. We postpone the development of these additional theoretical extensions\nin favor of an empirical investigation of linearization for a variety of architectures.\n\nFigure 2: Dynamics of mean and variance of trained neural network outputs follow analytic\ndynamics from linearization. Black lines indicate the time evolution of the predictive output\ndistribution from an ensemble of 100 trained neural networks (NNs). The blue region indicates the\nanalytic prediction of the output distribution throughout training (Equations 14, 15). Finally, the red\nregion indicates the prediction that would result from training only the top layer, corresponding to an\nNNGP (Equations S22, S23). The trained network has 3 hidden layers of width 8192, tanh activation\nfunctions, 2\nw = 1.5, no bias, and \u2318 = 0.5. The output is computed for inputs interpolated between\ntwo training points (denoted with black dots) x(\u21b5) = \u21b5x(1) + (1 \u21b5)x(2). The shaded region and\ndotted lines denote 2 standard deviations (\u21e0 95% quantile) from the mean denoted in solid lines.\nTraining was performed with full-batch gradient descent with dataset size |D| = 128. For dynamics\nfor individual function initializations, see SM Figure S1.\n\n3 Experiments\n\nIn this section, we provide empirical support showing that the training dynamics of wide neural\nnetworks are well captured by linearized models. We consider fully-connected, convolutional, and\nwide ResNet architectures trained with full- and mini- batch gradient descent using learning rates\nsuf\ufb01ciently small so that the continuous time approximation holds well. We consider two-class\nclassi\ufb01cation on CIFAR-10 (horses and planes) as well as ten-class classi\ufb01cation on MNIST and\nCIFAR-10. When using MSE loss, we treat the binary classi\ufb01cation task as regression with one class\nregressing to +1 and the other to 1.\nExperiments in Figures 1, 4, S2, S3, S4, S5 and S6, were done in JAX [41]. The remaining experi-\nments used TensorFlow [42]. An open source implementation of this work providing tools to inves-\ntigate linearized learning dynamics is available at www.github.com/google/neural-tangents\n[15].\nPredictive output distribution: In the case of an MSE loss, the output distribution remains Gaussian\nthroughout training. In Figure 2, the predictive output distribution for input points interpolated\nbetween two training points is shown for an ensemble of neural networks and their corresponding\nGPs. The interpolation is given by x(\u21b5) = \u21b5x(1) + (1 \u21b5)x(2) where x(1,2) are two training inputs\n11We thank Alex Alemi for pointing out a subtlety on correspondence to a random features method.\n\n7\n\n\fFigure 3: Full batch gradient descent on a model behaves similarly to analytic dynamics on\nits linearization, both for network outputs, and also for individual weights. A binary CIFAR\nclassi\ufb01cation task with MSE loss and a ReLU fully-connected network with 5 hidden layers of width\nb = 0.1. Left two panes show dynamics for\nn = 2048, \u2318 = 0.01, |D| = 256, k = 1, 2\na randomly selected subset of datapoints or parameters. Third pane shows that the dynamics of loss\nfor training and test points agree well between the original and linearized model. The last pane shows\nthe dynamics of RMSE between the two models on test points. We observe that the empirical kernel\n\u02c6\u21e5 gives more accurate dynamics for \ufb01nite width networks.\n\nw = 2.0, and 2\n\nT \u02c6\u21e5(n)\n\n0\n\n0 \u21e5kF , where \u02c6\u21e5(n)\n\nwith different classes. We observe that the mean and variance dynamics of neural network outputs\nduring gradient descent training follow the analytic dynamics from linearization well (Equations\n14, 15). Moreover the NNGP predictive distribution which corresponds to exact Bayesian inference,\nwhile similar, is noticeably different from the predictive distribution at the end of gradient descent\ntraining. For dynamics for individual function draws see SM Figure S1.\nComparison of training dynamics of linearized network to original network: For a particular\nrealization of a \ufb01nite width network, one can analytically predict the dynamics of the weights and\noutputs over the course of training using the empirical tangent kernel at initialization. In Figures\n3, 4 (see also S2, S3), we compare these linearized dynamics (Equations 8, 9) with the result of\ntraining the actual network. In all cases we see remarkably good agreement. We also observe\nthat for \ufb01nite networks, dynamics predicted using the empirical kernel \u02c6\u21e5 better match the data\nthan those obtained using the in\ufb01nite-width, analytic, kernel \u21e5. To understand this we note that\n0 kF = O(1/n) \uf8ffO (1/pn) = k \u02c6\u21e5(n)\nk \u02c6\u21e5(n)\ndenotes the empirical tangent\nkernel of width n network, as plotted in Figure 1.\nOne can directly optimize parameters of f lin instead of solving the ODE induced by the tangent\nkernel \u02c6\u21e5. Standard neural network optimization techniques such as mini-batching, weight decay, and\ndata augmentation can be directly applied. In Figure 4 (S2, S3), we compared the training dynamics\nof the linearized and original network while directly training both networks.\nWith direct optimization of linearized model, we tested full (|D| = 50, 000) MNIST digit classi\ufb01ca-\ntion with cross-entropy loss, and trained with a momentum optimizer (Figure S3). For cross-entropy\nloss with softmax output, some logits at late times grow inde\ufb01nitely, in contrast to MSE loss where\nlogits converge to target value. The error between original and linearized model for cross entropy\nloss becomes much worse at late times if the two models deviate signi\ufb01cantly before the logits enter\ntheir late-time steady-growth regime (See Figure S4).\nLinearized dynamics successfully describes the training of networks beyond vanilla fully-connected\nmodels. To demonstrate the generality of this procedure we show we can predict the learning\ndynamics of subclass of Wide Residual Networks (WRNs) [14]. WRNs are a class of model that are\npopular in computer vision and leverage convolutions, batch normalization, skip connections, and\naverage pooling. In Figure 4, we show a comparison between the linearized dynamics and the true\ndynamics for a wide residual network trained with MSE loss and SGD with momentum, trained on\nthe full CIFAR-10 dataset. We slightly modi\ufb01ed the block structure described in Table S1 so that\neach layer has a constant number of channels (1024 in this case), and otherwise followed the original\nimplementation. As elsewhere, we see strong agreement between the predicted dynamics and the\nresult of training.\nEffects of dataset size: The training dynamics of a neural network match those of its linearization\nwhen the width is in\ufb01nite and the dataset is \ufb01nite. In previous experiments, we chose suf\ufb01ciently\nwide networks to achieve small error between neural networks and their linearization for smaller\n\n8\n\n\fFigure 4: A wide residual network and its linearization behave similarly when both are trained\nby SGD with momentum on MSE loss on CIFAR-10. We adopt the network architecture\nfrom Zagoruyko and Komodakis [14]. We use N = 1, channel size 1024, \u2318 = 1.0, = 0.9,\nk = 10, 2\nb = 0.0. See Table S1 for details of the architecture. Both the linearized\nand original model are trained directly on full CIFAR-10 (|D| = 50, 000), using SGD with batch size\n8. Output dynamics for a randomly selected subset of train and test points are shown in the \ufb01rst two\npanes. Last two panes show training and accuracy curves for the original and linearized networks.\n\nw = 1.0, and 2\n\ndatasets. Overall, we observe that as the width grows the error decreases (Figure S5). Additionally,\nwe see that the error grows in the size of the dataset. Thus, although error grows with dataset this can\nbe counterbalanced by a corresponding increase in the model size.\n\n4 Discussion\n\nWe showed theoretically that the learning dynamics in parameter space of deep nonlinear neural\nnetworks are exactly described by a linearized model in the in\ufb01nite width limit. Empirical investiga-\ntion revealed that this agrees well with actual training dynamics and predictive distributions across\nfully-connected, convolutional, and even wide residual network architectures, as well as with different\noptimizers (gradient descent, momentum, mini-batching) and loss functions (MSE, cross-entropy).\nOur results suggest that a surprising number of realistic neural networks may be operating in the\nregime we studied. This is further consistent with recent experimental work showing that neural\nnetworks are often robust to re-initialization but not re-randomization of layers (Zhang et al. [43]).\nIn the regime we study, since the learning dynamics are fully captured by the kernel \u02c6\u21e5 and the target\nsignal, studying the properties of \u02c6\u21e5 to determine trainability and generalization are interesting future\ndirections. Furthermore, the in\ufb01nite width limit gives us a simple characterization of both gradient\ndescent and Bayesian inference. By studying properties of the NNGP kernel K and the tangent kernel\n\u21e5, we may shed light on the inductive bias of gradient descent.\nSome layers of modern neural networks may be operating far from the linearized regime. Preliminary\nobservations in Lee et al. [5] showed that wide neural networks trained with SGD perform similarly\nto the corresponding GPs as width increase, while GPs still outperform trained neural networks for\nboth small and large dataset size. Furthermore, in Novak et al. [7], it is shown that the comparison\nof performance between \ufb01nite- and in\ufb01nite-width networks is highly architecture-dependent. In\nparticular, it was found that in\ufb01nite-width networks perform as well as or better than their \ufb01nite-width\ncounterparts for many fully-connected or locally-connected architectures. However, the opposite was\nfound in the case of convolutional networks without pooling. It is still an open research question to\ndetermine the main factors that determine these performance gaps. We believe that examining the\nbehavior of in\ufb01nitely wide networks provides a strong basis from which to build up a systematic\nunderstanding of \ufb01nite-width networks (and/or networks trained with large learning rates).\n\nAcknowledgements\n\nWe thank Greg Yang and Alex Alemi for useful discussions and feedback. We are grateful to\nDaniel Freeman, Alex Irpan and anonymous reviewers for providing valuable feedbacks on the\ndraft. We thank the JAX team for developing a language which makes model linearization and NTK\ncomputation straightforward. We would like to especially thank Matthew Johnson for support and\ndebugging help.\n\n9\n\n\fReferences\n[1] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classi\ufb01cation with deep\nconvolutional neural networks. In Advances in Neural Information Processing Systems. 2012.\n[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image\nrecognition. In Conference on Computer Vision and Pattern Recognition, pages 770\u2013778, 2016.\n[3] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of\ndeep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805,\n2018.\n\n[4] Radford M. Neal. Priors for in\ufb01nite networks (tech. rep. no. crg-tr-94-1). University of Toronto,\n\n1994.\n\n[5] Jaehoon Lee, Yasaman Bahri, Roman Novak, Sam Schoenholz, Jeffrey Pennington, and Jascha\nSohl-dickstein. Deep neural networks as gaussian processes. In International Conference on\nLearning Representations, 2018.\n\n[6] Alexander G. de G. Matthews, Jiri Hron, Mark Rowland, Richard E. Turner, and Zoubin\nGhahramani. Gaussian process behaviour in wide deep neural networks. In International\nConference on Learning Representations, 4 2018. URL https://openreview.net/forum?\nid=H1-nGgWC-.\n\n[7] Roman Novak, Lechao Xiao, Jaehoon Lee, Yasaman Bahri, Greg Yang, Jiri Hron, Daniel A.\nAbola\ufb01a, Jeffrey Pennington, and Jascha Sohl-Dickstein. Bayesian deep convolutional net-\nworks with many channels are gaussian processes. In International Conference on Learning\nRepresentations, 2019.\n\n[8] Adri\u00e0 Garriga-Alonso, Laurence Aitchison, and Carl Edward Rasmussen. Deep convolutional\nnetworks as shallow gaussian processes. In International Conference on Learning Representa-\ntions, 2019.\n\n[9] Behnam Neyshabur, Ryota Tomioka, and Nathan Srebro. In search of the real inductive bias:\nOn the role of implicit regularization in deep learning. In International Conference on Learning\nRepresentations workshop track, 2015.\n\n[10] Roman Novak, Yasaman Bahri, Daniel A. Abola\ufb01a, Jeffrey Pennington, and Jascha Sohl-\nDickstein. Sensitivity and generalization in neural networks: an empirical study. In International\nConference on Learning Representations, 2018.\n\n[11] Behnam Neyshabur, Zhiyuan Li, Srinadh Bhojanapalli, Yann LeCun, and Nathan Srebro. The\nrole of over-parametrization in generalization of neural networks. In International Conference\non Learning Representations, 2019.\n\n[12] Alexander G. de G. Matthews, Jiri Hron, Richard E. Turner, and Zoubin Ghahramani. Sample-\nthen-optimize posterior sampling for bayesian linear models. In NeurIPS Workshop on Advances\nin Approximate Bayesian Inference, 2017. URL http://approximateinference.org/\n2017/accepted/MatthewsEtAl2017.pdf.\n\n[13] Arthur Jacot, Franck Gabriel, and Clement Hongler. Neural tangent kernel: Convergence and\ngeneralization in neural networks. In Advances in Neural Information Processing Systems,\n2018.\n\n[14] Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. In British Machine Vision\n\nConference, 2016.\n\n[15] Roman Novak, Lechao Xiao, Jiri Hron, Jaehoon Lee, Jascha Sohl-Dickstein, and Samuel S.\nSchoenholz. Neural tangents: Fast and easy in\ufb01nite neural networks in python, 2019. URL\nhttp://github.com/google/neural-tangents.\n\n[16] Amit Daniely, Roy Frostig, and Yoram Singer. Toward deeper understanding of neural networks:\nThe power of initialization and a dual view on expressivity. In Advances In Neural Information\nProcessing Systems, 2016.\n\n10\n\n\f[17] Amit Daniely. SGD learns the conjugate kernel class of the network. In Advances in Neural\n\nInformation Processing Systems, 2017.\n\n[18] Andrew M Saxe, James L McClelland, and Surya Ganguli. Exact solutions to the nonlinear\ndynamics of learning in deep linear neural networks. In International Conference on Learning\nRepresentations, 2014.\n\n[19] Simon S Du, Jason D Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent \ufb01nds\nglobal minima of deep neural networks. In International Conference on Machine Learning,\n2019.\n\n[20] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via\n\nover-parameterization. In International Conference on Machine Learning, 2019.\n\n[21] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. On the convergence rate of training recurrent\n\nneural networks. arXiv preprint arXiv:1810.12065, 2018.\n\n[22] Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Stochastic gradient descent optimizes\n\nover-parameterized deep relu networks. Machine Learning, 2019.\n\n[23] Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean \ufb01eld view of the landscape\nof two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):\nE7665\u2013E7671, 2018.\n\n[24] Lenaic Chizat and Francis Bach. On the global convergence of gradient descent for over-\nparameterized models using optimal transport. In Advances in neural information processing\nsystems, 2018.\n\n[25] Grant M Rotskoff and Eric Vanden-Eijnden. Parameters as interacting particles: long time\nconvergence and asymptotic error scaling of neural networks. In Advances in neural information\nprocessing systems, 2018.\n\n[26] Justin Sirignano and Konstantinos Spiliopoulos. Mean \ufb01eld analysis of neural networks. arXiv\n\npreprint arXiv:1805.01053, 2018.\n\n[27] Xavier Glorot and Yoshua Bengio. Understanding the dif\ufb01culty of training deep feedforward\nneural networks. In International Conference on Arti\ufb01cial Intelligence and Statistics, pages\n249\u2013256, 2010.\n\n[28] Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable program-\n\nming. arXiv preprint arXiv:1812.07956, 2018.\n\n[29] Twan van Laarhoven. L2 regularization versus batch and weight normalization. arXiv preprint\n\narXiv:1706.05350, 2017.\n\n[30] Tero Karras, Timo Aila, Samuli Laine, and Jaakko Lehtinen. Progressive growing of GANs for\nimproved quality, stability, and variation. In International Conference on Learning Representa-\ntions, 2018.\n\n[31] Daniel S. Park, Jascha Sohl-Dickstein, Quoc V. Le, and Samuel L. Smith. The effect of network\nwidth on stochastic gradient descent and generalization: an empirical study. In International\nConference on Machine Learning, 2019.\n\n[32] Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya Ganguli.\nExponential expressivity in deep neural networks through transient chaos. In Advances In\nNeural Information Processing Systems, pages 3360\u20133368, 2016.\n\n[33] Samuel S Schoenholz, Justin Gilmer, Surya Ganguli, and Jascha Sohl-Dickstein. Deep informa-\n\ntion propagation. International Conference on Learning Representations, 2017.\n\n[34] Lechao Xiao, Yasaman Bahri, Jascha Sohl-Dickstein, Samuel Schoenholz, and Jeffrey Penning-\nton. Dynamical isometry and a mean \ufb01eld theory of CNNs: How to train 10,000-layer vanilla\nconvolutional neural networks. In International Conference on Machine Learning, 2018.\n\n11\n\n\f[35] Ge Yang and Samuel Schoenholz. Mean \ufb01eld residual networks: On the edge of chaos. In\n\nAdvances in Neural Information Processing Systems. 2017.\n\n[36] Alexander G de G Matthews, Mark Rowland, Jiri Hron, Richard E Turner, and Zoubin\nGhahramani. Gaussian process behaviour in wide deep neural networks. arXiv preprint\narXiv:1804.11271, 9 2018.\n\n[37] Greg Yang. Scaling limits of wide neural networks with weight sharing: Gaussian pro-\ncess behavior, gradient independence, and neural tangent kernel derivation. arXiv preprint\narXiv:1902.04760, 2019.\n\n[38] Sever Silvestru Dragomir. Some Gronwall type inequalities and applications. Nova Science\n\nPublishers New York, 2003.\n\n[39] Minmin Chen, Jeffrey Pennington, and Samuel Schoenholz. Dynamical isometry and a mean\n\ufb01eld theory of RNNs: Gating enables signal propagation in recurrent neural networks. In\nInternational Conference on Machine Learning, 2018.\n\n[40] Greg Yang, Jeffrey Pennington, Vinay Rao, Jascha Sohl-Dickstein, and Samuel S. Schoen-\nholz. A mean \ufb01eld theory of batch normalization. In International Conference on Learning\nRepresentations, 2019.\n\n[41] Roy Frostig, Peter Hawkins, Matthew Johnson, Chris Leary, and Dougal Maclaurin. JAX:\n\nAutograd and XLA. www.github.com/google/jax, 2018.\n\n[42] Mart\u00edn Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu\nDevin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensor\ufb02ow: A system for\nlarge-scale machine learning. In 12th USENIX Symposium on Operating Systems Design and\nImplementation (OSDI 16), 2016.\n\n[43] Chiyuan Zhang, Samy Bengio, and Yoram Singer. Are all layers created equal? arXiv preprint\n\narXiv:1902.01996, 2019.\n\n[44] Ning Qian. On the momentum term in gradient descent learning algorithms. Neural networks,\n\n12(1):145\u2013151, 1999.\n\n[45] Weijie Su, Stephen Boyd, and Emmanuel Candes. A differential equation for modeling nes-\nterov\u2019s accelerated gradient method: Theory and insights. In Advances in Neural Information\nProcessing Systems, pages 2510\u20132518, 2014.\n\n[46] Youngmin Cho and Lawrence K Saul. Kernel methods for deep learning. In Advances in neural\n\ninformation processing systems, 2009.\n\n[47] Christopher KI Williams. Computing with in\ufb01nite networks. In Advances in neural information\n\nprocessing systems, pages 295\u2013301, 1997.\n\n[48] Roman Vershynin. Introduction to the non-asymptotic analysis of random matrices. arXiv\n\npreprint arXiv:1011.3027, 2010.\n\n12\n\n\f", "award": [], "sourceid": 4627, "authors": [{"given_name": "Jaehoon", "family_name": "Lee", "institution": "Google Brain"}, {"given_name": "Lechao", "family_name": "Xiao", "institution": "Google Brain"}, {"given_name": "Samuel", "family_name": "Schoenholz", "institution": "Google Brain"}, {"given_name": "Yasaman", "family_name": "Bahri", "institution": "Google Brain"}, {"given_name": "Roman", "family_name": "Novak", "institution": "Google Brain"}, {"given_name": "Jascha", "family_name": "Sohl-Dickstein", "institution": "Google Brain"}, {"given_name": "Jeffrey", "family_name": "Pennington", "institution": "Google Brain"}]}