{"title": "A Graph Theoretic Framework of Recomputation Algorithms for Memory-Efficient Backpropagation", "book": "Advances in Neural Information Processing Systems", "page_first": 1163, "page_last": 1172, "abstract": "Recomputation algorithms collectively refer to a family of methods that aims to reduce the memory consumption of the backpropagation by selectively discarding the intermediate results of the forward propagation and recomputing the discarded results as needed. \nIn this paper, we will propose a novel and efficient recomputation method that can be applied to a wider range of neural nets than previous methods.\nWe use the language of graph theory to formalize the general recomputation problem of minimizing the computational overhead under a fixed memory budget constraint, and provide a dynamic programming solution to the problem.\nOur method can reduce the peak memory consumption on various benchmark networks by $36\\%\\sim81\\%$, which outperforms the reduction achieved by other methods.", "full_text": "A Graph Theoretic Framework of Recomputation\nAlgorithms for Memory-Ef\ufb01cient Backpropagation\n\nMitsuru Kusumoto \u2217\nPreferred Networks, Inc.\n\nmkusumoto@preferred.jp\n\nTakuya Inoue \u2217\n\nThe University of Tokyo\n\ninoue-takuya57@g.ecc.u-tokyo.ac.jp\n\nGentaro Watanabe\n\nPreferred Networks, Inc.\ng.wtnb@preferred.jp\n\nTakuya Akiba\n\nPreferred Networks, Inc.\nakiba@preferred.jp\n\nMasanori Koyama\n\nPreferred Networks, Inc.\n\nmasomatics@preferred.jp\n\nAbstract\n\nRecomputation algorithms collectively refer to a family of methods that aims to\nreduce the memory consumption of the backpropagation by selectively discarding\nthe intermediate results of the forward propagation and recomputing the discarded\nresults as needed. In this paper, we will propose a novel and ef\ufb01cient recomputation\nmethod that can be applied to a wider range of neural nets than previous methods.\nWe use the language of graph theory to formalize the general recomputation\nproblem of minimizing the computational overhead under a \ufb01xed memory budget\nconstraint, and provide a dynamic programming solution to the problem. Our\nmethod can reduce the peak memory consumption on various benchmark networks\nby 36% \u223c 81%, which outperforms the reduction achieved by other methods.\n\n1\n\nIntroduction\n\nThe ef\ufb01ciency of memory usage is always one of the most important issues in the application of deep\nneural nets. Modern deep neural networks used in commercial applications tend to be large, and they\nrequire massive memory for the forward and backward computations. The inputs to the networks\ncan be large as well; this is particularly true for the tasks related to computer vision such as object\ndetection and semantic segmentation, where higher resolution images are generally more useful to\ndetect small objects accurately. Ample free memory is also important for the training process itself;\nwhen the memory is insuf\ufb01cient, the user has no choice but to choose a small batch size. This is\nproblematic, especially when using batch normalization [7] as the quality of batch statistics degrades.\nIndeed, its impact on the resulting model quality is crucial. Recent studies enabled themselves to use\nlarge batch sizes by reducing the memory consumption or introducing distributed computation, and\nsucceeded in achieving the state-of-the-art results for computer vision tasks [12, 20].\nRecomputation algorithms collectively refer to a family of methods of smart memory manipulation\nthat can reduce the memory consumption without altering the outcome of the computation or\ncompromising the accuracy of the output. In theory, backward computation requires the result of\nthe forward computation, and naive approach requires the user to cache all the forward computation\nresults for the backward computation. However, we can reduce the peak memory consumption\nby deliberately discarding some parts of the intermediate results in the forward computation, and\nrecompute these intermediate results gradually on a need basis during the backpropagation. Naturally,\nthe ef\ufb01cacy of any recomputation method depends on its rules of what to forget and what to cache in\nwhat order.\n\n\u2217Equal contribution.\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fFigure 1: (Left) The computation graph of a three-layer perceptron. A part induced by intermediate\nnodes is denoted by G = (V, E). (Right) Neighborhoods, lower set, and its boundary.\n\nIndeed, the idea of recomputation is not entirely new on its own. For example, Chen et al. [2] have\nproposed an ef\ufb01cient recomputation framework for a speci\ufb01c family of networks that can be divided\ninto segments, and Gruslys et al. [4] have proposed a method specializing in RNNs. These methods,\nhowever, do not investigate neural networks with complex structures in complete generality, and their\napplications are limited either to a speci\ufb01c family of networks or to a group of networks to which\ntheir ad-hoc heuristics can be applied.\nIn this paper, we will propose a novel and ef\ufb01cient recomputation method that can be applied to\ntheoretically all types of neural nets. To the authors\u2019 knowledge, there has been no study to date that\ntackled this problem in a completely general setting, and a new formulation of the recomputation\nproblem is warranted. We will, therefore, de\ufb01ne the general recomputation problem with appropriate\nformality as minimization of computational overhead under a \ufb01xed memory budget (Section 2 and 3),\nand provide our dynamic programming (DP) solution to the problem (Section 4.1 to 4.3). We will\nalso show that, by solving the opposite problem of maximizing the computational overhead, we can\nreduce the peak memory consumption to the level that cannot be achieved by contemporary methods\n(Section 4.4). We will demonstrate the ef\ufb01cacy of our DP solution on the various benchmark networks\n(ResNet, DenseNet, U-Net, PSPNet, etc.) and compare its performance against contemporary methods\nin terms of the computational overhead and the size of the achieved memory reduction (Section 5).\nOur method can reduce the peak memory consumption by 36% \u223c 81%, which is much greater than\nthe reduction that can be achieved by other methods.\n\n2 Preliminaries\n\nIn this \ufb01rst section, we present basic concepts and de\ufb01nitions that will be required for understanding\nour algorithms. Throughout, we will use a directed graph G = (V, E) to represent the architecture\nof the network, where V is the set of variables and there is an edge (v, w) \u2208 E if v \u2208 V is directly\nrequired for the computation of w \u2208 V . In our study, we will exclude the input nodes from the\nde\ufb01nition of V , because the computations on the intermediate nodes tend to play a more pivotal\nrole than the input nodes for the memory consumption of neural networks. We also exclude from\nour considerations the memory required for the model parameters as well. In computation theory,\nit is customary to use computational graph to represent the computational dependency among the\nvariables. Figure 1-Left is the computational graph for the forward and backward computation on a\nneural network. Although we do not explicitly use the computational graph in our formulation of the\nproblem, we will develop our algorithm by building a theory on the computational graph.\nFor an arbitrary node set S \u2286 V , we will use \u03b4+(S) to denote the set of nodes to which there is a\ndirected edge from some node in S. We de\ufb01ne \u03b4\u2212(S) analogously: \u03b4+(S) := {v \u2208 V | (s, v) \u2208\nE for some s \u2208 S} and \u03b4\u2212(S) := {v \u2208 V | (v, s) \u2208 E for some s \u2208 S}.\nAlso, following the de\ufb01nitions in order theory [3], we say L \u2286 V is a lower set of V if there is no\nedge from V \\ L to L, and write L \u227a V . By de\ufb01nition, L is a lower set if and only if \u03b4\u2212(L) \u2286 L.\nThe boundary of L is de\ufb01ned by \u2202(L) := \u03b4\u2212(V \\ L) \u2229 L. In our theory, the concept of the lower set\nwe introduced above plays a pivotal role. We denote the set of all lower sets by LG. By a simple\nargument, we can deduce that #V \u2264 #LG \u2264 2#V holds for any graph G. See Figure 1-Right for\nthe visual renditions of these de\ufb01nitions.\n\n2\n\nxW1h1a1h2a2yW2W3gygxgh1ga1gh2ga2gW1gW2gW3Forward PartBackward PartG=(V, E)S\u012f+(S)\u012f\u2212(S)L\u0098(L)\fFigure 2: A network that cannot be divided into multiple segments.\n\nMv > 0 to each node v \u2208 V . Likewise, we de\ufb01ne T (S) :=(cid:80)\n\nFinally, we also de\ufb01ne the forward computation cost Tv > 0 and the memory consumption cost\nv\u2208S Mv for\nS \u2286 V . We do not need to consider the backward computation cost in this study, because we will not\nrecompute the backward computation.\n\nv\u2208S Tv and M (S) :=(cid:80)\n\n\u221a\n\n\u221a\n\nn segments of length\n\nChen\u2019s Recomputation Algorithm Before explaining our formulation, we would like to introduce\nthe work of Chen et al. [2] in order to provide some intuition for our approach. Chen et al. proposed\na method of recomputation for a family of neural nets that can be decomposed of segments. For an n\nlayer network, Chen\u2019s algorithm divides the graph into\nn. In the forward\ncomputation, the algorithm caches the values for the nodes at the boundary of each segment and\ndiscard all other nodes from the memory. When the time comes for backpropagation of the segment i,\nit recomputes the required forward propagation results using the cache of the segment i\u2212 1. This way,\nthe algorithm can keep the peak memory consumption within O(\nn) at an additional computational\ncost that amounts to one round of the forward computation. As mentioned in the beginning, however,\ntheir algorithm can be applied only to a speci\ufb01c type of graph.\nFor example, a neural net with a skip connection that connects the input layer with the output layer\ncannot (Figure 2) be divided into more than one segment, and an extra set of heuristical rules have\nto be applied in order to deal with such a situation. In general, it is dif\ufb01cult to determine if a given\narbitrary graph can be handled by their heuristics.\n\n\u221a\n\n3 General Recomputation Problem\n\nof caching and its order of computation. Consider a partition V =(cid:83)k\n\nIn this study, we will extend Chen\u2019s framework to a generic graph by reformulating the recomputation\nproblem so that we can develop algorithms that can be applied to all types of graphs. We \ufb01rst need an\nadditional set of de\ufb01nitions and notations.\nAs mentioned in the introduction, the ef\ufb01cacy of any recomputation method is determined by its rules\ni=1 Vi with the intention of\ncomputing Vi after Vi\u22121 in the forward computation. In order to make computation in the intended\norder, we need to require that if (v, w) \u2208 E, then v \u2208 Vi and w \u2208 Vj must hold for some i \u2264 j. If\nthis requirement is satis\ufb01ed for the partition, we can construct an increasing sequence of lower sets\n{L1 \u227a L2 \u227a . . . \u227a Lk = V } by putting Li := V1\u222a V2\u222a . . .\u222a Vi. Indeed, we can do the construction\nbackward by starting from an arbitrary increasing sequence {L1 \u227a L2 \u227a . . . \u227a Lk = V } and\nde\ufb01ning Vi = Li \\ Li\u22121 (Figure 3-(a)). These building blocks will be the basis of our formulation.\nNow, given an arbitrary sequence of lower sets {L1 \u227a . . . \u227a Lk = V }, we de\ufb01ne the canonical\nstrategy of recomputation as follows:\n\nForward computation After making the evaluations for all nodes in Vi, cache the values associated\nwith \u2202(Li) and discard all the values for Vi \\ \u2202(Li). Using the cache of \u2202(Li), compute the\nvalues for Vi+1.\nSee Figure 3-(b) for a visualization of the forward computation. At the time of completion\nof the computations for Vi, the nodes with solid blue color have been cached, and the nodes\nwith opaque blue colors have been discarded.\n\nBackward computation Backward computations are to be conducted in the reverse order. After\nmaking the evaluations for all nodes in Vi+1 (Figure 3-(c)), we will execute the following\ncommands in order.\n\n1. recover the required forward values for the nodes Vi based on the cache of \u2202(Li\u22121)\n\nand conduct the backpropagation for the nodes in Vi. (Figure 3-(d))\n2. cache the nodes in the backward part of the computational graph that corresponds\nto \u03b4+(Li\u22121) \u2229 Vi, and discard the nodes in both the forward and backward part that\ncorrespond to the nodes of Vi that will not be needed in the computation in future.\n\n3\n\n\fFigure 3: The visualization of the canonical strategy for a lower set sequence.\n\nThis means that if there is a skip connection into v \u2208 Vi, the cache on v will not be\ndiscarded. (Figure 3-(e))\n\nIn the real implementation of this canonical strategy, the gradient of the model with re-\nspect to the parameters are to be reported to the user in real time during the process of\nbackpropagation.\n\nIn principle, our formulation is based on this canonical strategy. Chen\u2019s algorithm can be seen as an\ninstance of the canonical strategy that specializes in a speci\ufb01c form of lower sets.\nNow, the ef\ufb01cacy of this canonical strategy will depend solely on the choice of the sequence of the\nlower sets. Two performance measures are of our interest: the amount of computational overhead and\nthe size of peak memory consumption. First, at the end of the forward computation for Vi, the values\n\nfor the following nodes Ui have been cached by the algorithm: Ui :=(cid:83)i\n\nj=1 \u2202(Lj).\n\nThus, the set of nodes subject to the recomputation is given by V \\ Uk. The total computational\noverhead of the recomputation is given by\n\nk(cid:88)\n\n(cid:16)\n\n(cid:17)\n\nT (V \\ Uk) =\n\nVi \\ \u2202(Li)\n\nT\n\n.\n\n(1)\n\ni=1\n\nWe would like to minimize this value under the prescribed memory budget. Let us also denote the\ncomputational overhead for the strategy based on the same sequence by T ({L1 \u227a . . . \u227a Lk}).\nBy the design of the canonical strategy, the memory consumption shall reach its peak during the\nbackward computation. The breakdown of the memory consumption of the backward computation\nfor Vi is as follows (Figure 3-(f)):\n\n(i) Cache for the forward computations up to Vi requires M (Ui\u22121).\n(ii) Altogether, the forward part and the backward part for Vi requires 2M (Vi).\n(iii) In backpropagation, we need to take into account the set of nodes in Vj with j > i that are\n\ndependent on the nodes in Vi. Such nodes require M (\u03b4+(Li) \\ Li).\nto v1 and v2 for the gradient at v3. Such case will require M (\u03b4\u2212(\u03b4+(Li)) \\ Li).\n\n(iv) When there are edges from v1, v2, v3 to h, one might have to use the forward cache dedicated\n\n4\n\nL1L2L3 = VV2V3V1i= 1i= 2(a)(b)(c)(d)(e)(f)(i)(ii)(iii)\fIn total, the backpropagation for the nodes in Vi requires\n\nM(i) := M (Ui\u22121) + 2M (Vi) + M (\u03b4+(Li) \\ Li) + M (\u03b4\u2212(\u03b4+(Li)) \\ Li).\n\n(2)\nThe peak memory consumption of the canonical strategy is therefore given by maxi=1,...,k M(i).\nAgain, this is a value that is solely determined by the sequence of the lower sets used in its evaluation.\nLet us, therefore, use the notation M ({L1 \u227a . . . \u227a Lk}) to denote the peak memory consumption of\nthe canonical strategy based on {L1 \u227a . . . \u227a Lk = V }. We can formalize our problem as follows:\nDe\ufb01nition (General Recomputation Problem). Given a neural network with graph representation\nG = (V, E) and a prescribed memory budget B, \ufb01nd the increasing sequence {L1 \u227a . . . \u227a Lk = V }\nof lower sets that minimizes T ({L1 \u227a . . . \u227a Lk}) while satisfying M ({L1 \u227a . . . \u227a Lk}) \u2264 B. If\nB is too small, there can be no solution to this problem.\n\nAfter solving the general recomputation problem, one can aim to further improve the ef\ufb01ciency by\nexecuting the obtained strategy with popular heuristics like liveness analysis [1].\nIn order to solve the general recompuation problem in practice, we need a relative value of Tv for\neach node v in the network. We can either directly measure Tv with some time scale or use some\nform of approximation. In general, convolutional node tends to be heavier than other types of node in\nterms of computational cost. In our experiment, we therefore set Tv = 10 for convolutional node,\nand Tv = 1 for all other types of node.\n\n4 Methods\n\nIn this section, we will provide three ways to solve the general recomputation problem: (1) a naive\nexhaustive search, (2) an exact dynamic programming (DP) algorithm, and (3) an approximate DP\nalgorithm that can be used to obtain a near-optimal canocnical strategy fast. We also present a\nmemotry-centric strategy, which prioritizes the reduction of peak memory consumption over the\ncomputational overhead.\n\n4.1 Naive Approach: An Exhaustive Search\n\nLet us \ufb01rst consider a rudimentary exhaustive search algorithm based on depth-\ufb01rst search (DFS).\nBecause the number of all lower sets #LG is \ufb01nite, one can \ufb01nd a solution to the general recom-\nputation problem in time of order O(#L#V\nG ) by simply computing M ({L1 \u227a . . . \u227a Lk}) and\nT ({L1 \u227a . . . Lk}) for every sequence {L1 \u227a . . . \u227a Lk = V }. For further discussion, let us\nextend the de\ufb01nition of M ({L1 \u227a . . . \u227a Lk}) for the pre\ufb01x of the lower set sequence using Equa-\ntion (2). For i \u2264 k, let M ({L1 \u227a . . . \u227a Li}) := maxj=1,...,i M(j). We can similarly de\ufb01ne the\n\ncomputational overhead using Equation (1). Let T ({L1 \u227a . . . \u227a Li}) :=(cid:80)i\ndeduce T ({L1 \u227a . . . \u227a Li}) = T ({L1 \u227a . . . \u227a Li\u22121}) + T(cid:0)Vi \\ \u2202(Li)(cid:1). Thus, we can compute\n\nStarting from an arbitrary lower set L1, the DFS proceeds by recursively visiting {L1 \u227a . . . \u227a Li}\nfrom {L1 \u227a . . . \u227a Li\u22121} such that M ({L1 \u227a . . . \u227a Li}) \u2264 B. From the de\ufb01nition above, we can\n\nj=1 T (Vj \\ \u2202(Lj)).\n\nthe computational overhead incrementally. The memory consumption can be computed incrementally\nas well. In Equation (2), the only term that depends on the entire sequence {L1 \u227a . . . \u227a Li} is\nM (Ui\u22121), and all other terms can be computed directly from Li and Vi = Li \\ Li\u22121.\nA keen reader might have noticed at this point that there is no need to retain the information of the\nentire traversed path {L1 \u227a . . . \u227a Li} in this DFS algorithm. Instead, it suf\ufb01ces to keep track of the\ntriplets (L, t, m) where L = Li, t = T ({L1 \u227a . . . \u227a Li}), and m = M (Ui).\n\n4.2 Exact DP Algorithm\n\nWe can use the triplet representation (L, t, m) used in the previous subsection to solve the general\nrecomputation problem with DP. Algorithm 1 in the Appendix summarizes our DP procedure.\nLet S be the set of all triplet states that can be visited during the DFS previously mentioned. Our DP\nsolution takes advantage of the fact that, if there exist (L, t, m) and (L, t, m(cid:48)) \u2208 S such that m < m(cid:48),\nwe do not need to consider (L, t, m(cid:48)) for further exploration.\n\n5\n\n\f:= min{m \u2208 N | (L, t, m) \u2208 S}, where\nLet us de\ufb01ne an array opt with entries opt[L, t]\nopt[L, t] := \u221e whenever there is no m such that (L, t, m) \u2208 S. This will serve as our DP ta-\nble. Starting with the initial condition opt[\u2205, 0] = 0, we \ufb01ll the DP table by visiting the states\nthat satisfy the prescribed memory budget constraint. More precisely, if opt[L, t] is the current\nentry, the algorithm loops over the set of all L(cid:48) such that L (cid:40) L(cid:48) and update opt[L(cid:48), t(cid:48)] with\nt(cid:48) = t + T (V (cid:48) \\ \u2202(L(cid:48))).\nIf opt[L, t] < \u221e after its evaluation, it means that there exists {L1 \u227a . . . \u227a Li} with Li = L\nsuch that M ({L1 \u227a . . . \u227a Li}) \u2264 B and T ({L1 \u227a . . . \u227a Li}) = t.\nIf t\u2217 := min{t0 |\nopt[V, t0] < \u221e} < \u221e, we can report t\u2217 as the minimum possible computational overhead in the\ngeneral recomputation problem. Conversely, if t\u2217 = \u221e, it means that there is no solution.\nFor the sake of argument, let us assume for now that t\u2217 < \u221e. In order to obtain the choice of\n{L1 \u227a . . . \u227a Lk = V } that achieves t\u2217, we need modify the algorithm slightly by making a room\nfor another variable optarg[L(cid:48), t(cid:48)] that stores the pointer to the L that precedes L(cid:48) in the sequence of\nthe recursive computation that leads to opt[L(cid:48), t(cid:48)]. The optimal sequence of the lower sets can thus\nbe obtained by starting from opt[V, t\u2217] and tracing back the pointers of optarg.\nSince the bottleneck of our DP algorithm is its iteration process, the computational time of our\nalgorithm is O(T (V ) \u00b7 #LG\n2). From a practical perspective, it is worth noting that the use of a\nsparse table for opt reduces the computation time by a large constant factor. Further, when t < t(cid:48)\nand opt[L, t] < opt[L, t(cid:48)], we can skip the iteration for the entry opt[L, t(cid:48)].\n\n4.3 Approximate DP Algorithm\n\nThe DP algorithm we described in the previous section can be used to obtain the optimal canonical\nstrategy. However, the computation time required for the DP algorithm grows with #LG, which can\nbe very large for models with complex network structures. We also present an algorithm that can be\nused to obtain a near-optimal canonical strategy fast. We shall emphasize here that any canonical\nstrategy is a legitimate recomputation strategy in the sense that it never alters the network output.\nThe modi\ufb01cation we would make to the DP algorithm is simple. Instead of making keys for all\nmembers of LG in the DP table, we use a good small subset of LG at every cell.\nLet us say that v is reachable from w if v = w or if there exists a path from w to v. Now, let us\n:= {Lv | v \u2208 V } where Lv := {w \u2208 V | v is reachable from w}. By de\ufb01nition,\nde\ufb01ne LPruned\nLPruned\n= #V . Our approximate DP algorithm makes keys for the members\nof LPruned\nIndeed, this modi\ufb01cation excludes some possibilities from the search pool, and we can no longer\nguarantee that we will be able to \ufb01nd the best canonical strategy. As we will show in our experiments,\nhowever, near-optimal solutions obtained from our approximate DP are often \u201cgood enough\u201d in\npractice at reducing the peak memory consumption.\n\nonly. This way, we can keep the computation time under O(T (V ) \u00b7 #V 2).\n\nG\n\nG\n\nG\n\n\u2286 LG and #LPruned\n\nG\n\n4.4 Memory-Centric Strategy\n\nLiveness analysis [1] is a heuristic technique that has an effect on the reduction of peak memory\nconsumption, and much of a given strategy\u2019s performance in terms of memory consumption depends\non how well this technique works in the corresponding sequence of lower sets. Experimentally,\nliveness analysis tends to work well when the node-set is partitioned coarsely; that is, when each Vi\nis large. Through trial and error, we discovered that we can realize this coarse partition intentionally\nby using a strategy with long computational overhead. In fact, a canonical strategy with maximal\ncomputational overhead tends to have exceptionally low peak memory consumption. Given a \ufb01xed\nbudget constraint B, the canonical strategy with maximal computational overhead can again be found\nusing DP.\nIn general, we can \ufb01nd a more economical strategy by setting the budget constraint B to a smaller\nvalue. If the reduction of the memory consumption is the \ufb01rst priority, one may set B to the lowest\nvalue for which the set of canonical strategies is non-empty. We call the strategy found this way\na memory-centric strategy, because it prioritizes the positive effect of liveness analysis over the\ncomputational overhead. The computational overhead of memory-centric strategy is bounded by the\n\n6\n\n\fcomputation time for one round of the forward computation. We discriminate this strategy from the\noptimal canonical strategy in the previous section by calling the latter a time-centric strategy.\nWhen applying our methods, we recommend the user to \ufb01rst try the time-centric strategy and prioritize\nthe computational overhead. We suggest the user to try the memory-centric strategy and pursue\nmemory reduction only if the solution of the time-centric strategy fails to satisfy the memory budget\nconstraint even with the application of liveness analysis.\n\n5 Experiments\n\nWe applied our algorithm to various network structures and investigated their performance in terms\nof computational overhead and peak memory consumption. All networks were implemented in\nChainer [18], and the experiments were conducted on NVIDIA Tesla K40c with GPU DRAM of size\n11.4 GB. The following are the list of networks on which applied our method: ResNet [5], VGG [16],\nDenseNet [6], GoogLeNet [17], U-Net [15], and PSPNet [20]. Input dimensions were set to 572 \u00d7\n572 for U-Net, 713\u00d7713 for PSPNet, and 224\u00d7224 for all other networks.\nWe compare our method against two methods: (1) vanilla implementation of the forward and backward\npropagation without any recomputation methods, and (2) Chen\u2019s algorithm implemented with the\nuse of their heuristic techniques. We tested both our method and Chen\u2019s method with the liveness\nanalysis. In the Appendix, we present an ablation study for the effect of liveness analysis.\nOur code is publicly available at https://github.com/pfnet-research/recompute.\n\n5.1 Memory Reduction\n\nTable 1 summarizes the performance of various methods evaluated in terms of the size of the achieved\nmemory consumption. ExactDP and ApproxDP are algorithms in Section 4.2 and 4.3, respectively.\nMC stands for memory-centric strategy in Section 4.4 and TC stands for time-centric strategy. The\npeak memory consumption enlisted in this table includes the memory used by the model parameters\nitself. Each value inside the parenthesis designates the proportional size of the achieved memory\nreduction relative to the peak memory consumption of the vanilla run. For each experiment, we\nselected a batch size so that the memory consumption with vanilla run roughly falls in the range of\n7 \u223c 9 GB. For the memory budget B to be used for our approach, we chose the minimal value B for\nwhich the solution of the general recomputation problem exists. This value was determined using\nbinary search.\nAs we can con\ufb01rm on the table, our method outperforms the previous method in terms of the\npeak memory consumption. In DenseNet and PSPNet, we are succeeding in reducing the memory\nconsumption respectively by 81% and 71%. Our method is performing better than Chen\u2019s algorithm\nparticularly for complex networks like PSPNet, U-Net, and GoogLeNet. The approximate DP was\nsigini\ufb01cantly faster to complete than the exact DP algorithm. The exact DP algorithm required more\nthan 80 secs to complete for GoogLeNet and PSPNet, while the approximate DP completed within\n1 sec for all networks. We would like to emphasize that, for all cases we considered, the exact\nsolution and approximate solution did not differ much in terms of performance. This is to say that the\n\u201cnear-optimal\u201d canonical strategy obtained from the Approximate DP is literally \u201cnear\u201d optimal for\ncomplex networks commonly used in applications, and that it can be used reliably in practice. 2\n\n5.2 Computational Time\n\nWe investigated the memory-runtime tradeoff for our algorithm by running the experiments with\nvarious batch sizes and compared the results against other methods. For all the experiments of\nrecomputation methods, we used batch sizes that are so large that the naive vanilla computation is\nimpossible. For each choice of batch size, we repeated each experiment four times and reported their\naverage runtime.\n\n2 For some networks, the approximate DP yielded slightly better results than the exact DP because the effect\nof the liveness analysis is not taken into the account in the de\ufb01nition of optimality in the general recomputation\nproblem.\n\n7\n\n\fTable 1: The comparison of peak memory consumption. Each value inside the parenthesis is the\nachieved memory reduction from the vanilla computation.\n\nNetwork\n\nApproxDP + MC\n\nApproxDP + TC\n\nExactDP + MC\n\nExactDP + TC\n\nChen\u2019s [2]\n\nVanilla #V\n\nBatch\n\nPSPNet\nU-Net\n\nDenseNet161\nGoogLeNet\n\nResNet50\nResNet152\n\nVGG19\n\n2.7 GB (-71%)\n5.0 GB (-45%)\n3.4 GB (-62%)\n2.3 GB (-75%)\n4.5 GB (-36%)\n1.6 GB (-81%)\n5.2 GB (-39%)\n\n3.1 GB (-67%)\n6.7 GB (-26%)\n4.4 GB (-51%)\n2.5 GB (-73%)\n5.5 GB (-22%)\n1.9 GB (-78%)\n5.5 GB (-36%)\n\n2.8 GB (-70%)\n4.7 GB (-48%)\n3.4 GB (-62%)\n2.3 GB (-75%)\n4.5 GB (-36%)\n1.7 GB (-80%)\n5.2 GB (-39%)\n\n3.2 GB (-66%)\n5.3 GB (-42%)\n4.3 GB (-51%)\n2.5 GB (-73%)\n5.5 GB (-22%)\n1.8 GB (-78%)\n5.9 GB (-31%)\n\n4.0 GB (-58%)\n7.4 GB (-18%)\n3.7 GB (-59%)\n2.4 GB (-74%)\n4.7 GB (-34%)\n1.8 GB (-79%)\n6.5 GB (-24%)\n\n9.4 GB\n9.1 GB\n8.9 GB\n9.2 GB\n7.0 GB\n8.5 GB\n8.5 GB\n\n385\n60\n176\n516\n46\n568\n134\n\n2\n8\n96\n48\n64\n32\n256\n\nFigure 4: The tradeoff between batch size and total runtime (forward and backward propagation).\n\nFigure 4 illustrates the results of this set of experiments. Blue curves are the results obtained from the\nvanilla computation. Dotted blue lines are the linear extrapolation of the vanilla translation results\nthat one might have been able to obtain if there was more memory in the device. Red curves and\norange curves respectively designate the results of \u201cApproxDP + time-centric\u201d and \u201cApproxDP +\nmemory-centric\u201d settings. Green curves represent the method obtained by Chen\u2019s algorithm. As\nwe can see in the plot, our method outperforms Chen\u2019s method by a large margin in terms of the\nruntime-memory tradeoff. In terms of the runtime that will be required to conduct an experiment\non ResNet152 with the batch size that is double the maximum batch size that can be used in vanilla\ncomputation, our method was 1.16 times faster than Chen\u2019s algorithm. This results also verify that\nour algorithm indeed seeks a strategy with small computation overhead in presence of ample memory.\nFor PSPNet, our method was able to increase the maximum possible batch size from 2 to 8.\n\n6 Related Work\n\nThere are various other methods of reducing the memory consumption that do not belong to the\nfamily of recomputation methods. Precision reduction [11] reduces the memory consumption by\nregulating the precision level of the computations. Network pruning [10], on the other hand, takes\nthe approach of compressing the network itself. These methods are fundamentally different from\nrecomputation methods in that they compromise the accuracy of the output in favor of ef\ufb01cient\nmemory usage, because the name recomputation refers to a family of methods that does not alter the\noutput of the network in anyway whatsoever. At the same time, however, the methods we mentioned\nabove can be used in combination with recomputation methods if the reduction of the peak memory\nconsumption is the \ufb01rst priority. Another method is to selectively transfer some part of the memory\nbetween CPU and GPU during the training process. Superneurons [19], vDNN [14], and LayRub [8]\nare the variants of this approach. This method, too, can be used in combination with recomputation\nmethods.\n\n8\n\n\fSome other methods specialize in the memory consumption reduction for a speci\ufb01c family of networks.\nPleiss et al. [13] developed a method that specializes in DenseNet by storing multiple feature maps at\na shared memory location. Gruslys et al. [4] developed a method that specializes in RNNs.\nChen et al. [2] developed a recomputation framework for a family of graphs that can be divided\ninto segments and provided a set of heuristic rules that can be used to extend the framework to\nselect networks like LSTM and ResNet. Our framework is more general and powerful than Chen\u2019s\nframework. Chen\u2019s framework does not take the computational overhead into account. In contrast, our\nmethod is based on a formalized tradeoff relationship between memory usage and the computational\noverhead and makes a search on a wider space of recomputation strategies than Chen\u2019s method.\nWe shall also mention that, at the same time as our publication, Kumar et al. [9] proposed a recom-\nputation method based on pathwidth and treewidth. Lastly, while quite distant from our method as\nan algorithm, Bulatov recently posted on his blog3 an informal idea to use tree decomposition to\ngenerate the recomputation strategies.\n\n7 Conclusion\n\nIn this study, we proposed a new framework of recomputation method that can be applied to neural\nnetworks of any type and formulated the goal of the recomputation in the language of graph theory.\nOur framework enables much simpler treatment of the recomputation problem and possibly opens\nthe door to complex methods with more sophisticated machinery of discrete mathematics. Also, in\nthis study, we only considered a set of strategies that allows at most one recomputation per node.\nOne possible future studies include the extension of our current formulation to strategies that allows\nmultiple recomputations per node. While even more challenging, this future study may lead to even\nmore ef\ufb01cient recomputation methods for neural networks.\nWe shall also note that we only considered static graphs in this study. One naive way to extend\nour algorithm to the dynamic setting is, for example, to conduct our algorithm in advance to the\nset of computation graphs that might become necessary in the course of training. In the case that\nthe dimension of the input varies over the dataset, we may prepare a maximal graph to develop a\ncomputation strategy. Extension of our method to dynamic setting may open venues for new ways to\noptimize the training for heavy tasks like those in NLP and time series analysis.\n\nAcknowledgement\n\nWe thank Shinichiro Hamaji and Hiroto Imachi for technical suggestions.\n\nReferences\n[1] Andrew W. Appel and Jens Palsberg. Modern Compiler Implementation in Java. Cambridge\n\nUniversity Press, 2002.\n\n[2] Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear\n\nmemory cost. arXiv preprint, arXiv:1604.06174, 2016.\n\n[3] Brian A Davey and Hilary A Priestley. Introduction to lattices and order. Cambridge University\n\nPress, 2002.\n\n[4] Audrunas Gruslys, R\u00e9mi Munos, Ivo Danihelka, Marc Lanctot, and Alex Graves. Memory-\nef\ufb01cient backpropagation through time. In Advances in Neural Information Processing Systems\n(NIPS), pages 4125\u20134133, 2016.\n\n[5] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image\nrecognition. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages\n770\u2013778, 2016.\n\n[6] Gao Huang, Zhuang Liu, Laurens van der Maaten, and Kilian Q. Weinberger. Densely connected\nconvolutional networks. In IEEE Conference on Computer Vision and Pattern Recognition\n(CVPR), pages 2261\u20132269, 2017.\n\n3https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9\n\n9\n\n\f[7] Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating deep network training\nby reducing internal covariate shift. In International Conference on Machine Learning (ICML),\npages 448\u2013456, 2015.\n\n[8] Hai Jin, Bo Liu, Wenbin Jiang, Yang Ma, Xuanhua Shi, Bingsheng He, and Shaofeng Zhao.\nLayer-centric memory reuse and data migration for extreme-scale deep learning on many-core\narchitectures. ACM Transactions on Architecture and Code Optimization, 15(3):37, 2018.\n\n[9] Ravi Kumar, Manish Purohit, Zoya Svitkina, Erik Vee, and Joshua R. Wang. Ef\ufb01cient remateri-\nalization for deep networks. In Advances in Neural Information Processing Systems (NeurIPS),\n2019.\n\n[10] Jian-Hao Luo, Jianxin Wu, and Weiyao Lin. ThiNet: a \ufb01lter level pruning method for deep\nneural network compression. In IEEE International Conference on Computer Vision (ICCV),\npages 5068\u20135076, 2017.\n\n[11] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory F. Diamos, Erich Elsen, David\nGarc\u00eda, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, and Hao Wu.\nMixed precision training. In International Conference on Learning Representations (ICLR),\n2018.\n\n[12] Chao Peng, Tete Xiao, Zeming Li, Yuning Jiang, Xiangyu Zhang, Kai Jia, Gang Yu, and Jian\nSun. MegDet: A large mini-batch object detector. In IEEE Conference on Computer Vision and\nPattern Recognition (CVPR), pages 6181\u20136189, 2018.\n\n[13] Geoff Pleiss, Danlu Chen, Gao Huang, Tongcheng Li, Laurens van der Maaten, and Kilian Q\nWeinberger. Memory-ef\ufb01cient implementation of densenets. arXiv preprint, arXiv:1707.06990,\n2017.\n\n[14] Minsoo Rhu, Natalia Gimelshein, Jason Clemons, Arslan Zul\ufb01qar, and Stephen W Keckler.\nvDNN: Virtualized deep neural networks for scalable, memory-ef\ufb01cient neural network design.\nIn IEEE/ACM International Symposium on Microarchitecture, page 18, 2016.\n\n[15] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-Net: Convolutional networks for\nbiomedical image segmentation. In Medical Image Computing and Computer-Assisted Inter-\nvention, pages 234\u2013241, 2015.\n\n[16] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale\n\nimage recognition. In International Conference on Learning Representations (ICLR), 2015.\n\n[17] Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott E. Reed, Dragomir Anguelov,\nDumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions.\nIn IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 1\u20139, 2015.\n[18] Seiya Tokui, Ryosuke Okuta, Takuya Akiba, Yusuke Niitani, Toru Ogawa, Shunta Saito, Shuji\nSuzuki, Kota Uenishi, Brian Vogel, and Hiroyuki Yamazaki Vincent. Chainer: A deep learning\nframework for accelerating the research cycle. In Proceedings of the 25th ACM SIGKDD\nInternational Conference on Knowledge Discovery & Data Mining, pages 2002\u20132011. ACM,\n2019.\n\n[19] Linnan Wang, Jinmian Ye, Yiyang Zhao, Wei Wu, Ang Li, Shuaiwen Leon Song, Zenglin Xu,\nand Tim Kraska. Superneurons: Dynamic GPU memory management for training deep neural\nnetworks. In ACM SIGPLAN Symposium on Principles and Practice of Parallel Programming,\npages 41\u201353, 2018.\n\n[20] Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. Pyramid scene\nparsing network. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR),\npages 6230\u20136239, 2017.\n\n10\n\n\f", "award": [], "sourceid": 692, "authors": [{"given_name": "Mitsuru", "family_name": "Kusumoto", "institution": "Preferred Networks, Inc."}, {"given_name": "Takuya", "family_name": "Inoue", "institution": "University of Tokyo"}, {"given_name": "Gentaro", "family_name": "Watanabe", "institution": "Preferred Networks, Inc."}, {"given_name": "Takuya", "family_name": "Akiba", "institution": "Preferred Networks, Inc."}, {"given_name": "Masanori", "family_name": "Koyama", "institution": "Preferred Networks Inc."}]}