{"title": "Efficient Rematerialization for Deep Networks", "book": "Advances in Neural Information Processing Systems", "page_first": 15172, "page_last": 15181, "abstract": "When training complex neural networks, memory usage can be an important bottleneck. The question of when to rematerialize, i.e., to recompute intermediate values rather than retaining them in memory, becomes critical to achieving the best time and space efficiency. In this work we consider the rematerialization problem and devise efficient algorithms that use structural characterizations of computation graphs---treewidth and pathwidth---to obtain provably efficient rematerialization schedules. Our experiments demonstrate the performance of these algorithms on many common deep learning models.", "full_text": "Ef\ufb01cient Rematerialization for Deep Networks\n\nRavi Kumar\n\nGoogle Research\n\nMountain View, CA 94043\n\nravi.k53@gmail.com\n\nManish Purohit\nGoogle Research\n\nMountain View, CA 94043\nmpurohit@google.com\n\nZoya Svitkina\nGoogle Research\n\nMountain View, CA 94043\n\nzoya@google.com\n\nErik Vee\n\nGoogle Research\n\nMountain View, CA 94043\n\nerikvee@google.com\n\nJoshua R. Wang\nGoogle Research\n\nMountain View, CA 94043\njoshuawang@google.com\n\nAbstract\n\nWhen training complex neural networks, memory usage can be an important\nbottleneck. The question of when to rematerialize, i.e., to recompute intermediate\nvalues rather than retaining them in memory, becomes critical to achieving the best\ntime and space ef\ufb01ciency. In this work we consider the rematerialization problem\nand devise ef\ufb01cient algorithms that use structural characterizations of computation\ngraphs\u2014treewidth and pathwidth\u2014to obtain provably ef\ufb01cient rematerialization\nschedules. Our experiments demonstrate the performance of these algorithms on\nmany common deep learning models.\n\n1\n\nIntroduction\n\nThe world of deep learning is moving toward bigger model architectures. The recent successes in\nspeech, language understanding, vision, and others have repeatedly demonstrated that bigger and\ndeeper models yield the best results for a task, thereby advancing the state of the art. In addition to\nthe size, the models themselves and the methods to train them are becoming increasingly complex\nand intricate in terms of data dependencies, gradient propagation, optimization steps, etc. Specialized\nhardware such as GPUs and AI accelerators have been vastly in\ufb02uential in training these complex\nmodels. They are particularly helpful from a computational point of view, but are limited by memory\ncapacity that falls short of the peak demands of training these large models. Since memory turns out\nto be a bottleneck, it becomes an issue of feasibility\u2014can a given model be trained at all?\nWhile the growing model complexity is the root cause\nof severe demands on memory, the actual schedule in\nwhich the computation is carried out also plays a crit-\nical role in determining peak memory requirements.\nTo see why, it is helpful to view the computational\nsteps in training these models as a directed acyclic\ngraph (e.g., Figure 1) whose nodes represent opera-\ntions and directed edges represent data dependencies.\n(In TensorFlow parlance, this is a data\ufb02ow graph.)\nEach node consumes a set of inputs from its incom-\ning edges, does some computation, and outputs the\nresult of this computation on its outgoing edges; it is\nassumed that both inputs and outputs of this computa-\ntion are to be held in memory. The order in which the\nnodes are computed, i.e., the schedule, will determine\n\nFigure 1: The schedule (cid:104)A, B, C, D, E(cid:105) needs\nfour units of memory\u2014while computing D, two\nunits are needed for inputs to D, one for output\nfrom D, and one unit to keep the output of A as\nan input to E. The schedule (cid:104)A, B, C, D, A, E(cid:105)\nneeds three units of memory\u2014at node D the output\nof A need not be retained in memory since it will\nbe recomputed right after computing D.\n\nD\n\nB\n\nA\n\nE\n\nC\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fthe peak memory usage. Indeed, consider Figure 1, where the output of each node occupies one unit\nof memory. Computing the nodes in the order (cid:104)A, B, C, D, E(cid:105) would need four units of memory,\nwhereas computing them in the order (cid:104)A, B, C, D, A, E(cid:105) would only need three units of memory\n(see caption of Figure 1). This latter order involves rematerializing the output of node A instead\nof keeping it in memory. As this example illustrates, there can be a time-memory trade-off in our\nchoice of schedule, where recalculating intermediate results can reduce what we store in memory.\nJudiciously choosing an appropriate schedule may make larger models feasible. In this paper we\nconsider this rematerialization problem: given a computation graph as an input, construct a schedule,\npossibly rematerializing some nodes, that uses as little peak memory as possible1.\nWhen studied on computation graphs derived from training neural networks (i.e., graphs with\nforward computation and backward computations), rematerialization is often referred to as gradient\ncheckpointing [10, 6, 8, 1, 15]. Of course, there are many other techniques to try to reduce memory\nusage, such as reusing memory regions [19] and trying to use both GPU and CPU memory [17, 16].\nRematerialization is a particularly nice approach because it only changes how the computation is\ndone, but has no risk of changing the \ufb01nal result.\nCompared to the gradient checkpointing line of work, we do not assume we have a forward/backward\ncomputation, but rather show how certain structural properties of the graph can be used to obtain\na good solution. In particular, we identify treewidth of this graph as a key quantity that can be\nalgorithmically exploited to yield a schedule with provable bounds on its length and peak memory\nusage. Informally, our main result is that there is a polynomial time algorithm that, given an n-node\ncomputation graph with treewidth k and unit memory output at each node, constructs a schedule of\nlength O(knlog k) and peak memory usage of O(k log n). This algorithm uses a tree decomposition\nof the computation graph, which yields balanced separators and offers a natural way to partition the\ncomputations into independent sub-computations while allowing us to bound the memory use through\na charging argument. Note that while \ufb01nding the optimal tree decomposition is computationally hard,\nthere are ef\ufb01cient approximation algorithms and heuristics, which makes our algorithm ef\ufb01cient,\npractical, and easy to realize. We demonstrate its ef\ufb01cacy by applying it to training large networks\nincluding feedforward, residual, and transformer networks. In all these cases, our schedule yields\nsigni\ufb01cant savings in peak memory over baseline schedules, both with and without rematerialization.\nWe also design a different algorithm that produces schedules that are asymptotically more ef\ufb01cient.\nThis algorithm relies on the path decomposition of the computation graph and is more intricate, with an\ninvolved analysis; but currently less practical. This result, however, hints at the intriguing possibility\nof another structural property of the graph that better captures its rematerialization potential.\n\n2 Preliminaries\n\n2.1 Computation Graphs and Schedules\n\nThe input to our algorithms is a computation graph. Each node in this graph represents an operation\nthat takes as input zero or more tensors and produces a single tensor as an output (this assumption is\nfor simplicity). Let G = (V, E) be a directed acyclic computation graph. For u, v \u2208 V , a directed\nedge (u, v) \u2208 E represents data dependency, meaning that the output of node u is an input to node\nv. We are also given a \ufb01nal node f \u2208 V whose output tensor is required to be held in memory at\nthe end of the computation. We assume, without loss of generality, that f has out-degree zero (i.e.\nno other operations use the tensor produced by f) and all nodes in G are needed to compute f. For\nany node u \u2208 V , let in(u) denote the immediate predecessors of u, i.e., in(u) = {u(cid:48)\n| (u(cid:48), u) \u2208 E}.\nLet n = |V |, m = |E|, and [n] = {1, . . . , n}. Throughout, log(\u00b7) means log2(\u00b7).\nA schedule for a computational graph G = (V, E) is a sequence \u03c3 = \u03c3(G) = (cid:104)u1, . . . , ut(cid:105) of nodes\nin V with the following properties: (i) the \ufb01nal node f is represented in the schedule, and (ii) each\nnode in the schedule occurs only after all of its predecessors, i.e., for each j \u2208 [t] and for each\nu(cid:48) \u2208 in(uj), there is some j(cid:48) < j such that uj(cid:48) = u(cid:48). Let prev(u, j) = max{j(cid:48) < j | uj(cid:48) = u} be\nthe most recent time u occurs in the schedule before step j. Note that a node in G can occur more\nthan once in \u03c3(G) and this is precisely what enables the schedule length\u2013memory usage trade-off.\n\n1 Of course, the real goal is to keep peak memory under the memory available and while minimizing the time\nto compute the schedule. Our results are easier to understand when viewed from a purely memory-minimization\nstandpoint, but it is possible to stop our recursion early to obtain other trade-off points.\n\n2\n\n\fa schedule is given by L(\u03c3) =(cid:80)t\n\nA schedule naturally implies time and memory bounds for computing G. Let L(u) be the length\nof node u, representing the time required to execute the corresponding operation. The length of\nu\u2208V L(u) be the time required to\n\ni=1 L(ui). Let Tonepass =(cid:80)\n\nexecute every operation of the graph once. It lower bounds the length of any valid schedule.\nThe peak memory usage of the schedule, M (\u03c3), though intuitive, is a bit cumbersome to formalize.\nFor i \u2208 [t], \ufb01rst de\ufb01ne the set of tensors that need to be held in memory at step i as\n\nUi = {ui} \u222a in(ui) \u222a(cid:91)\n\nj>i\n\n{u(cid:48) \u2208 in(uj) | prev(u(cid:48), j) \u2264 i} .\n\nule at step i is M (\u03c3, i) = (cid:80)\nMin = maxu\u2208V {(cid:80)\n\nLet s(u) denote the size of the tensor output by node u. Now, the memory of the sched-\ni=1 M (\u03c3, i). The goal of\nan algorithm alg is to produce a schedule alg(G) of G that minimizes the peak memory. Let\nu(cid:48)\u2208in(u) s(u(cid:48))} be the maximum input size needed to compute any tensor.\nLet Mmax = maxu\u2208V s(u) be the maximum size of any tensor. Clearly, for any schedule \u03c3,\nM (\u03c3) \u2265 max{Min, Mmax}.\n\ns(u(cid:48)). Finally, M (\u03c3) = maxt\n\nu(cid:48)\u2208Ui\n\n2.2 Treewidth and Tree Decompositions\n\nTreewidth is a well-studied graph parameter expressing how close an undirected graph G = (V, E) is\nto a tree. Intuitively, if a problem is easy on trees, then one might hope that it remains easy on graphs\nof small treewidth. Formally, the treewidth of a graph is de\ufb01ned via the notion of tree decompositions.\nA tree decomposition of an undirected graph G = (V, E) is a pair (X , T ), where X \u2286 2V is a set of\nbags, with each bag a subset of the nodes, and T is a tree on the bags X . The bags and tree must\nsatisfy the following three properties: (i) each node in V is in some bag of X , (ii) for each edge\n(u, v) \u2208 E, both endpoints are together in some bag of X , and (iii) for each node v \u2208 V , the bags\ncontaining it (i.e., {X \u2208 X | v \u2208 X}) form a connected subgraph of T .\nNaturally, there are many tree decompositions of a particular graph G, including the trivial one that\nplaces all nodes into a single giant bag (X = {V }, T = {}). We measure a tree decomposition by\nits width, which is the maximum bag size minus one: maxX\u2208X |X| \u2212 1. The treewidth tw(G) of G\nis the minimum width of any tree decomposition. We refer to |X| as the size of the decomposition.\nNote that tw(G) can range from 1 (a tree) to n \u2212 1 (a clique).\nWe will use treewidth and tree decompositions of our directed computation graphs. When doing\nso, we are actually referring to the undirected graph obtained by forgetting the direction of every\nedge. It is known that series-parallel graphs have a treewidth of two and control-\ufb02ow graphs of all\nprograms written in C (without goto statements) have a treewidth of at most six [20]. We postulate\nthat computation graphs of neural networks in the inference mode also have similarly low treewidth\nand that, given a computation graph G for a neural network in the inference mode, the computation\ngraph for training the network via backpropagation has treewidth at most twice as that of the original\ngraph. Experimentally, we observe that computation graphs for training many common deep network\narchitectures (ResNet, Transformer, and feedforward networks) have small treewidth (see Table 1).\nOur results fall under the purview of \ufb01xed-parameter tractability, which studies the complexity\nof problems under particular parameters. Typically, we would hope to \ufb01nd an exact algorithm\n(computing the absolute memory-minimizing schedule) when treewidth is small. Unfortunately, this\nseems unlikely; such results typically come from Courcelle\u2019s theorem [7], which states that if a graph\nproperty can be expressed in second-order monadic logic, then it can be checked for in \ufb01xed-parameter\ntractable time relative to treewidth. Rematerialization is known to be PSPACE-complete [9]. If it\nwere expressible in second-order monadic logic, then it would lie in the polynomial hierarchy (PH)\nand then PSPACE would collapse to PH. Hence, we must settle for approximation algorithms.\n\n3 Ef\ufb01cient Rematerialization via Tree Decomposition\n\nOur main algorithm uses a tree decomposition of the computation graph for a divide-and-conquer\napproach. The tree decomposition of a graph allows us to \ufb01nd balanced separators of small size.\nAdditionally, the connectivity property of the tree decomposition guarantees that the nodes in the\ndifferent components can be computed independently of each other except for interactions via the\nseparator. Using these ideas, we recursively compute a memory-ef\ufb01cient schedule.\n\n3\n\n\fFirst, we consider the size of a tree decomposition and argue that it can be bounded.\nLemma 1. Given an undirected graph G = (V, E) and its tree decomposition (X , T ) of width k, we\ncan \ufb01nd another tree decomposition (X (cid:48), T (cid:48)) of width k and size at most n in O(|X| \u00b7 (k + \u03b1(|X|)))\ntime, where \u03b1(\u00b7) is the inverse Ackermann function.\nProof. The idea is to post-process (X , T ) by repeatedly merging every adjacent pair of bags for\nwhich one bag is a subset of the other. This can be done with a single pass over all edges in T , since\nany adjacent pair of bags which cannot be merged at any time can never be merged in the future.\nFor the sake of contradiction, imagine that bags X1 and X2 could not be merged due to some node\nv \u2208 X1, v (cid:54)\u2208 X2. This problematic node will always be in X1 since merging two bags only results in\nthe addition of nodes to a bag. At the same time, it can never get added to X2 because all bags that\ncontain v are connected and hence X1 is the only bag in the neighborhood of X2 that contains v.\nWe can keep track of these merges using a standard Union-Find data structure on the |X| bags,\nwhich costs O(\u03b1(|X|)) time per operation. We perform at most |X| merges, which cost a total of\nO(|X| \u00b7 \u03b1(|X|)) time. To check whether one bag is a subset of another, we can put the larger bag in a\nhash set and perform k + 1 membership checks. Hence we can perform all these checks in O(|X| \u00b7 k)\ntime. Hence the overall time is the claimed O(|X| \u00b7 (k + \u03b1(|X|))).\nWe can see why this post-processing procedure works by taking the resulting tree decomposition\n(X (cid:48), T (cid:48)) and rooting it at an arbitrary bag. Each non-root bag must contain a node not found in its\nparent bag because otherwise the bag should have been merged with its parent bag. Since the set\nof bags containing a node is connected, this assigns a unique node v \u2208 V to every non-root bag.\nFurthermore, the root cannot be empty since then it would have been merged, and its nodes cannot be\nassigned to any other bag due to the same property. Hence we can assign it one of these nodes. Since\nwe have assigned each bag a unique node v \u2208 V , there can be at most n bags.\n\nA classic result shows that a tree always has a balanced node separator.\nTheorem 2 (Jordan [14]). Any tree on n nodes has a node whose removal disconnects the tree into\ncomponents of size at most n/2.\nApplying Jordan\u2019s theorem on the tree decomposition (X , T ) directly yields the following lemma.\nLemma 3 (Balanced Separator). Given a tree decomposition (X , T ), we can \ufb01nd, in time O(|X|), a\nbag X (cid:63) \u2208 X such that each connected component of (X , T ) \\ {X (cid:63)} contains at most |X|/2 bags.\nOur divide-and-conquer approach chooses a balanced separator X (cid:63) of the tree decomposition so\nthat removing it results in subtrees with at most |X|/2 bags each. Combining with Lemma 1, this\nguarantees that there are at most log n levels of recursion. Finding such a bag is a standard technique.\nWith these two ideas, we present Algorithm 1, which is a recursive function that schedules a subset\nV (cid:48) of nodes with a requirement that the schedule contains all nodes in a speci\ufb01ed subset S. It breaks\nthe graph using the balanced separator, and schedules the predecessors of a node v in each of the\nresulting components before scheduling v itself. The produced schedule includes annotations about\nwhich tensors to keep in memory or to remove, which is just for ease of analysis, as in practice\nmemory usage can be inferred from a schedule of operations. Initially, the function is called with\narguments (G, V, (X , T ),{f}), where f \u2208 V is the \ufb01nal node.\nLemma 4. Algorithm 1 produces a valid rematerialization schedule.\n\nProof. The base case of the recursion is when there is a single bag in the tree decomposition, in\nwhich case we make no recursive calls and simply compute the desired outputs in some topological\norder. Inductively, we assume that the algorithm works correctly on tree decompositions with less\nthan b bags, and show that it also works when there are b bags.\nThe reasoning centers around what happens when we remove the balanced separator X (cid:63) from the\ntree decomposition. Since the bags containing any node v \u2208 V form a connected component, if v is\nin two or more components of C, it must also be in the separator X (cid:63). Hence this separator partitions\nand we know that these V (cid:48) together with X (cid:63) form a partition of V . Furthermore, by the de\ufb01nition of\ntree decomposition, we know that each edge must be present in some bag, so the only edges involving\nsome V (cid:48) go to other nodes in the same V (cid:48) or to X (cid:63).\n\nour graph: for each subgraph (X (cid:48), T (cid:48)) \u2208 C, we can de\ufb01ne the nodes in it to be V (cid:48) :=(cid:0)(cid:83)\n\nX\u2208X (cid:48) X(cid:1)\n\n4\n\n\fAlgorithm 1: Ef\ufb01cient Rematerialization via Tree Decomposition.\n\nFunction: TWRemat(G, V (cid:48), (X , T ), S):\n\nData: G = (V, E) a computation graph, V (cid:48) \u2286 V a subset of nodes to restrict to, (X , T ) a\ntree decomposition of G restricted to V (cid:48), S \u2286 V (cid:48) a subset of nodes to compute.\nResult: An annotated schedule consisting of nodes in V (cid:48) that contains all nodes in S.\nif this is the top level recursive call then\nFind a balanced separator (bag) X (cid:63) \u2208 X using Lemma 3;\nMake a copy of (X , T ), removing bag X (cid:63) and removing nodes of X (cid:63) from every other bag.\nLet C be the set of connected components that result (each a tree decomposition (X (cid:48), T (cid:48)));\nInitialize schedule = (cid:104)(cid:105);\nfor node v \u2208 X (cid:63) in any topological order (according to G) do\n\nShrink the size of the tree decomposition to at most n bags using Lemma 1;\n\nfor connected component (X (cid:48), T (cid:48)) \u2208 C do\n\nLet S(cid:48) = in(v) \u2229(cid:0)(cid:83)\n\nX\u2208X (cid:48) X(cid:1) and V (cid:48)(cid:48) = V (cid:48) \u2229(cid:0)(cid:83)\n\nX\u2208X (cid:48) X(cid:1);\n\nExtend schedule with TWRemat (G, V (cid:48)(cid:48), (X (cid:48), T (cid:48)), S(cid:48)) to compute the inputs of v in\nthis component;\nAdd annotation to schedule to keep S(cid:48) in memory;\n\nAdd v to schedule, keeping it in memory, and freeing all of its inputs not in X (cid:63);\n\nfor connected component (X (cid:48), T (cid:48)) \u2208 C do\n\nLet S(cid:48) = (S \\ X (cid:63)) \u2229(cid:0)(cid:83)\n\nX\u2208X (cid:48) X(cid:1) and V (cid:48)(cid:48) = V (cid:48) \u2229(cid:0)(cid:83)\n\nX\u2208X (cid:48) X(cid:1);\n\nExtend schedule with TWRemat (G, V (cid:48)(cid:48), (X (cid:48), T (cid:48)), S(cid:48)) to compute the remaining outputs\nin this subgraph;\nAdd annotation to schedule to keep S(cid:48) in memory;\n\nAdd annotation to schedule to free the unneeded balanced separator nodes X (cid:63) \\ S;\nreturn schedule;\n\nWe claim that whenever a recursive call to TWRemat is made (with arguments V (cid:48)(cid:48) and S(cid:48)), all\npredecessors of S(cid:48) which are not in V (cid:48)(cid:48) are already in memory of the caller\u2019s schedule. Consider\nsome node u \u2208 S(cid:48) and its predecessor u(cid:48) /\u2208 V (cid:48)(cid:48). It must be that u(cid:48) \u2208 X (cid:63) by the preceding discussion\nthat an edge involving u \u2208 V (cid:48)(cid:48) can only go to V (cid:48)(cid:48) or to X (cid:63). Suppose that the recursive call is made\nfrom the nested for loops in which the outer loop is processing a node v \u2208 X (cid:63). Since u(cid:48) is a\npredecessor of u and u is a predecessor of v (which we know from u \u2208 S(cid:48)), u(cid:48) must come before v\nin a topological order of G. Thus, it has already been scheduled in a previous iteration of the outer\nfor loop. If the recursive call is made from the other for loop, then all nodes of X (cid:63) are scheduled\nand in memory by that time.\nWe conclude that the precedence constraints are respected by the schedule\u2014with respect to nodes in\nV (cid:48)(cid:48) by induction, and with respect to nodes in X (cid:63) by the above discussion. Furthermore, all nodes of\nS are scheduled in the later loop.\nTheorem 5. Given a computation graph G = (V, E), its tree decomposition (X , T ) of width at most\nk, and S \u2286 V a subset of nodes to compute, Algorithm 1 runs in time O(|X|\u00b7(k+\u03b1(|X|))+kn log n+\nkn1+log(k+2)) and computes a rematerialization schedule of length O(Tonepass \u00b7 knlog(k+2)) that\nrequires O((Min + kMmax) log n) memory.\nProof. We begin with the running time. We pay an upfront cost of O(|X| \u00b7 (k + \u03b1(|X|))) to invoke\nLemma 1. We pay a total time of O(n log n) to invoke Lemma 3, since (i) each invocation requires\nlinear time and (ii) we recurse into subcalls that partition the tree decomposition into pieces that\nare at most half the current size. Note that we will need to memoize these balanced separators to\navoid recomputing them over and over. As a result we have O(log n) levels of recursion and over\nall subcalls in a level we do O(n) work. The processing of tree decompositions (removing a bag,\nremoving the nodes of a bag from other bags) can be done in O(kn) time and follows the same\nrecursion as \ufb01nding balanced separators (i.e. subcalls partition the tree decomposition and have at\nmost O(log n) depth), for a total of O(kn log n) work. Finally, the output is O(kn1+log(k+2)) in size\n(see the schedule length analysis), and we spend linear time to compute it.\n\n5\n\n\fNext, we check the schedule length. At each level, we make a recursive call to a particular subgraph\n(|X \u2217| + 1) \u2264 (k + 2) times, so we wind up amplifying the total work by a factor of at most k + 2 at\neach recursive level (except for the \ufb01nal recursive level, where we make no recursive calls). Carefully\ncounting, we need at most (cid:100)log n(cid:101) + 1 levels of recursion so we have ampli\ufb01ed the computation time\nby O((k + 2)(cid:100)log n(cid:101)). Since alog b = blog a, this is an ampli\ufb01cation of O(knlog(k+2)). In other words,\nwe make at most O(knlog(k+2)) copies of any operation, so this takes at most O(Tonepassknlog(k+2))\ntime as claimed.\nFinally, we check the memory needed by the schedule. Consider a particular segment of the schedule\nand the TWRemat function call that added it. The content of memory at this place in the schedule can\nbe charged to the active function calls at that point of execution as follows: we charge to a recursive\nlevel everything that it annotated to keep in memory except its outputs, which are charged to its caller.\nThe balanced separator requires O(kMmax) memory, one set of inputs to a balanced separator node\nrequires O(Min) memory (but since we free these we only need to hold one set of inputs). Since there\nare O(log n) levels of recursion, this results in a total memory of O((Min + kMmax) log n).\n\nWhat remains is to compute a tree decomposition ef\ufb01ciently. Our corollary utilizes an approximation\nalgorithm that runs in n\u00b72O(tw(G)) time and computes a decomposition of width at most (5tw(G)+4)\nand size O(n) [2]. Our actual implementation uses a minimum \ufb01ll-in heuristic [3], which yields good\ntree decompositions.\nCorollary 6. Given a computation graph G = (V, E), there is an algorithm that runs in 2O(tw(G))n+\nO(n \u00b7 (tw(G) + \u03b1(n)) + tw(G)n log n) time and computes a rematerialization that requires compu-\ntation time O(Tonepasstw(G)nlog2(5tw(G)+6)) and memory O((Min + tw(G)Mmax) log n).\n\n4 Experiments\n\nWe experimentally evaluate the performance of our rematerialization algorithm on computational\ngraphs for training commonly used deep neural networks. We remark that the memory optimizations\nproposed in this paper ensure that the computational graph is faithfully executed; this ensures that the\ngradients obtained at each train step are exactly equivalent to those obtained without any optimization,\nand hence do not affect convergence. We measure the theoretical peak memory usage of a schedule\nvia an optimal static memory allocation plan. Since the primary purpose of these experiments is to\nevaluate the effect of rematerialization on memory usage, we do not consider other heuristic memory\noptimizations such as in-place operations, operation fusion, and buffer aliasing. Finally, we also\nmeasure the length of the schedule obtained by the different algorithms. For simplicity, in these\nexperiments, we assume that each operation takes unit cost.\nAlgorithms. We compare the performance of the following three algorithms.\n(i) NoRemat: Schedules all operations in a topological sort without any rematerialization.\n(ii) GreedyRemat: This is an implementation of a greedy heuristic for rematerialization used by\nXLA2 that works as follows. Starting with a topological sort of all operations, it processes each\noperation sequentially. At each stage, if the current memory usage is over a speci\ufb01ed memory limit,\nthe algorithm attempts to rematerialize an already scheduled operation. In particular, the operation\nwhose rematerialization maximizes the amount of reduction in memory usage is chosen greedily at\neach step. If the memory usage cannot be reduced, the algorithm moves on to the next operation.\n(iii) TWRemat: This is an implementation of Algorithm 1 that uses a tree decomposition; we use the\nminimum \ufb01ll-in heuristic [3] to \ufb01nd the tree decomposition.\nModels and Setup. We evaluate all algorithms on different families of widely used deep networks.\n(i) Deep Residual Networks (ResNet): We \ufb01rst consider deep residual networks (ResNet) [13] as an\nexample of convolutional networks for image classi\ufb01cation. We use the of\ufb01cial implementation of\nthe ResNet model for the ImageNet task in TensorFlow3. We use different con\ufb01gurations to measure\nthe effect of network depth (number of convolutional layers) on memory requirements of schedules\nobtained by the algorithms.\n\n2www.tensorflow.org/xla\n3github.com/tensorflow/models/blob/master/official/resnet/imagenet_main.py\n\n6\n\n\fn\n\n17,705\n3,217\n15,842\n\ntw\n11\n6\n18\n\nModel\nResNet200\nFFN (100 layers)\nTransformer Base\n\n(ii) Feed forward networks (FFN): We consider a\nsimple feed-forward neural network to illustrate the\ntrends in peak memory usage of the schedules ob-\ntained by the different algorithms as a function of\nthe network depth. For this experiment, we setup a\nsimple feed-forward network with ReLU activations\n(number of hidden layers is varied) and randomly\ngenerated inputs and outputs. We use mean squared\nerror loss and train using standard gradient descent.\n(iii) Transformer: We also evaluate the memory savings obtained by our rematerialization algorithms\nfor training the transformer [21] network. Again, we use the of\ufb01cial implementation of Transformer\nin TensorFlow4 with all hyperparameters set to recommended defaults.\nTable 1 gives summary statistics for representative models from each family. Crucially, we observe\nthat even the largest graphs have tree decompositions with small width.\n\nTable 1: Computation graph statistics.\n\nm\n\n27,312\n4,447\n21,771\n\n4.1 Effect on Peak Memory Usage\n\nWe \ufb01rst demonstrate the effect of the depth of the network on the peak memory usage required for\ntraining the network. Figure 2 compares the performance of the three algorithms on the ResNet and\nFeed-forward models described above. As expected, we observe that the peak memory usage of\nNoRemat that does not perform any rematerialization increases linearly with the number of layers on\nboth model families. The GreedyRemat algorithm yields modest improvements (\u2248 2x) in memory\nusage for the ResNet models but still shows a linear growth with number of layers. We observe that\nGreedyRemat yields very little memory savings on the feed forward network. On the other hand, the\nTWRemat algorithm consistently gives memory savings on both the model families (up to 10x) and\nthe growth in peak memory usage is distinctly sublinear.\n\n(a) ResNet\n\n(b) Feed Forward Network\n\nFigure 2: Peak memory usage vs. model depth.\n\nTable 2 shows the memory usage and relative lengths of the schedules obtained by the three algorithms\non two con\ufb01gurations of the transformer network. The TWRemat algorithm yields a 3.48x and 4.59x\nreduction in peak memory usage respectively, albeit at a cost of up to 10.6x in the schedule length.\n\nNoRemat\nMem. (GiB)\n3.97\n13.25\n\nGreedyRemat\n\nMem. (GiB) Rel. Len.\n\n2.92\n10.12\n\n1.21\n1.27\n\nTWRemat\n\nMem. (GiB) Rel. Len.\n1.14\n2.89\n\n10.61\n10.64\n\nTransformer Base\nTransformer Big\n\nTable 2: Transformer: Peak memory usage and relative schedule lengths.\n\n4.2 Effect on Schedule Length\n\nOur algorithms are speci\ufb01cally designed to minimize peak memory consumption at the expense of\nadditional computation. Figure 3 illustrates the increase in the schedule length relative to NoRemat.\n\n4github.com/tensorflow/models/blob/master/official/transformer/transformer_main.py\n\n7\n\nResnet SizePeak Memory Usage (GiB)0246810183450101152200NoRematGreedyRematTwRematResnet: Peak memory usage vs. Size# Hidden LayersPeak Memory Usage (MiB)025050075010001250102030405060708090100NoRematGreedyRematTwRematFeed Forward Network: Peak memory usage vs. # hidden layers\fWe observe that GreedyRemat consistently yields schedules that are only marginally longer than the\ncorresponding schedules of NoRemat. On the other hand, the schedules obtained via TWRemat are\naround 3x-4x longer. Despite the longer schedules, we expect the schedules produced by TWRemat\nto be bene\ufb01cial in practice as the reduced memory usage allows the use of specialized hardware\naccelerators.\n\n(a) ResNet\n\n(b) Feed Forward Network\n\nFigure 3: Schedule length vs. model depth.\n\n4.3 Trading-off Memory Usage for Schedule Length\n\nAlgorithm 1 (TWRemat) uses the tree decomposition to\n\ufb01nd a balanced separator that breaks up the tree decompo-\nsition into smaller subtrees, and then recursively computes\nschedules to compute the required nodes in these subtrees.\nWe observe that we can obtain a trade-off between mem-\nory usage and schedule length by preemptively stopping\nthe recursion when the tree decomposition has few bags\nremaining. For any integer k, let TWRemat (k) be a vari-\nant of Algorithm 1 that stops the recursion when the tree\ndecomposition has fewer than k bags. In the base case,\nwe schedule the required nodes in an arbitrary topological\norder. In this notation, our TWRemat algorithm can be writ-\nten as TWRemat (1). Indeed, by varying the recursion limit\nfrom k = 1 to k = n, we can interpolate between the TWRemat and NoRemat algorithms. Figure 4\nshows the memory usage vs. schedule length trade-off obtained for the ResNet200 model.\n\nFigure 4: Resnet200: Mem. vs length.\n\n5 Stronger Guarantees via Path Decomposition\n\nRelated to treewidth of a graph is the notion of pathwidth, which is de\ufb01ned as the minimum width\nof any path decomposition, where a path decomposition is a tree decomposition (X , T ) under the\nadditional constraint that T must be a path. We can order the bags according to the path and instead\nuse the tuple X = (X1, X2, ..., X|X|) to represent the path decomposition, where each Xi \u2286 V is\na bag and (path decomposition) edges run between Xi and Xi+1. We denote the pathwidth of a\ngraph G by pw(G). Assuming that a computation graph has a small constant pathwidth allows us to\ndesign an algorithm for rematerialization that leverages the path decompositions to yield stronger\ntheoretical guarantees than in Theorem 5. In this section, we sketch the primary ideas, deferring the\nfull algorithm and analysis to the Supplementary Material.\nWe \ufb01rst show one can add a directed Hamiltonian path (i.e., a spine) to any graph G so that the\npathwidth of G only increases by a factor of \u223c 2. This allows us to prove certain structural properties\nof the path decomposition. Suppose the vertices of G are ordered according to the spine, let ui \u2208 V\nbe the ith node, and let last(X) denote the index of the last node in bag X. We show that if Xc \u2208 X\nis a bag in the path decomposition that contains un, then for all (cid:96) < (cid:96)(cid:48) < c, we have last(X(cid:96)) \u2264\nlast(X(cid:96)(cid:48)) \u2264 last(Xc) and for all r > r(cid:48) > c, we have last(Xr) \u2264 last(Xr(cid:48)) \u2264 last(Xc).\nSuch a structural characterization allows a divide-and-conquer strategy that recurses on the right and\nleft sides of the path decomposition. Unlike the tree decomposition algorithm where we argue that\n\n8\n\nResnet SizeRelative Schedule Length012345183450101152200NoRematGreedyRematTwRematResnet: Schedule length vs. Size# Hidden LayersRelative Schedule Length01234102030405060708090100NoRematGreedyRematTwRematFeed Forward Network: Schedule length vs. SizeSchedule LengthPeak Memory Usage (GiB)0246810200003000040000500006000070000ResNet: Length vs. Memory Tradeoffs\fthe size of the tree decomposition reduces at each recursive call, the additional properties of the path\ndecomposition allow us to argue that both the size and width of the decomposition decreases. The\nresulting algorithm yields a schedule that incurs a polylogarithmic increase in length (vs. polynomial\nblow up for the tree decomposition), but at the cost of polylogarithmic memory usage.\n\n6 Related Work\n\n\u221a\n\n\u221a\n\nRematerialization has been considered in very limited settings for training deep networks. The work\nmost relevant to ours is that of Chen et al. [6] and Gruslys et al. [11]. The former shows how to trade\noff memory and computation cost for simple chain-like networks. Their algorithm at a high level\nworks by dividing a computation of length n into\nn many sub-computations, storing the internal\nstates for each sub-computation and at the\nn check points; a second pass is needed to complete the\ncomputations. By recursing on this idea, one can get an O(n log n)-pass algorithm using memory\nO(log n) for chain-like computations. Gruslys et al. [11] consider backpropagation through time\nand propose a dynamic-programming based approach for achieving the best time-memory trade off;\ntheir algorithm is tailored to work on RNNs. It is unclear how to extend either of these algorithms\nto work for general computation graphs, which is the focus of our work. There are some practical\nheuristics for rematerialization used in open-source efforts such as XLA; in fact, we used it as one of\nour baselines (GreedyRemat). Other heuristics including in-place operations and register sharing\nmemory optimizations have been used in practice [5]. We, on the other hand, offer a principled\napproach to these problems.\nTree decomposition has been suggested as a tool to achieve time-memory trade off in register alloca-\ntion problems in compilers [18, 4]. A recent blog post5 informally suggests using tree decomposition\nfor memory saving in deep networks in the context of gradient checkpointing,6 which implements [6].\nAs noted control \ufb02ow graphs of structured programs have treewidth \u223c 6 [20]. Here, we work with the\ndata \ufb02ow graph to obtain a memory-ef\ufb01cient schedule, which may have larger treewidth in general.\nView materialization in databases is also somewhat related to rematerialization [12]. The goal there\nis to pre-compute materialized views in order to ef\ufb01ciently answer future queries. While this is also a\ncomputation-memory trade-off, the end goals are clearly different from our setting.\n\n7 Conclusions\n\nWe consider the rematerialization problem in the context of memory-ef\ufb01cient training of deep net-\nworks and obtain ef\ufb01cient algorithms based on tree decomposition for \ufb01nding a provably good\nschedule with rematerialization. Although our path decomposition based algorithm yields asymp-\ntotically better schedules, the schedule length and memory depend exponentially on the pathwidth.\nIt will be very interesting to make this algorithm more practical. Identifying the precise structural\nparameter that characterizes rematerialization of a given graph is a tantalizing research question.\n\nReferences\n[1] Olivier Beaumont, Julien Herrmann, Guillaume Pallez, and Alena Shilova. Optimal memory-\n\naware backpropagation of deep join networks. Research Report RR-9273, Inria, 2019.\n\n[2] Hans L. Bodlaender, P\u00e5l Gr\u00f8n\u00e5s Drange, Markus S. Dregi, Fedor V. Fomin, Daniel Lokshtanov,\nand Michal Pilipczuk. A (ckn) 5-approximation algorithm for treewidth. SICOMP, 45(2):317\u2013\n378, 2016.\n\n[3] Hans L Bodlaender and Arie MCA Koster. Treewidth computations I. Upper bounds. Informa-\n\ntion and Computation, 208(3):259\u2013275, 2010.\n\n[4] Preston Briggs, Keith D. Cooper, and Linda Torczon. Rematerialization. In PLDI, pages\n\n311\u2013321, 1992.\n\n5medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9\n6github.com/openai/gradient-checkpointing\n\n9\n\n\f[5] Tianqi Chen, Mu Li, Yutian Li, Min Lin, Naiyan Wang, Minjie Wang, Tianjun Xiao, Bing Xu,\nChiyuan Zhang, and Zheng Zhang. MXNet: A \ufb02exible and ef\ufb01cient machine learning library\nfor heterogeneous distributed systems. Technical Report 1512.01274, arXiv, 2015.\n\n[6] Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear\n\nmemory cost. Technical Report 1604.06174, arXiv, 2016.\n\n[7] Bruno Courcelle. The monadic second-order logic of graphs. I. Recognizable sets of \ufb01nite\n\ngraphs. Information and Computation, 85(1):12\u201375, 1990.\n\n[8] Jianwei Feng and Dong Huang. Cutting down training memory by re-fowarding. Technical\n\nReport 1808.00079, arXiv, 2018.\n\n[9] John R. Gilbert, Thomas Lengauer, and Robert Endre Tarjan. The pebbling problem is complete\n\nin polynomial space. In STOC, pages 237\u2013248, 1979.\n\n[10] Andreas Griewank and Andrea Walther. Algorithm 799: Revolve: An implementation of\ncheckpointing for the reverse or adjoint mode of computational differentiation. TOMS, 26(1):19\u2013\n45, 2000.\n\n[11] Audr\u00afunas Gruslys, R\u00e9mi Munos, Ivo Danihelka, Marc Lanctot, and Alex Graves. Memory-\n\nef\ufb01cient backpropagation through time. In NIPS, pages 4132\u20134140, 2016.\n\n[12] Alon Y. Halevy. Answering queries using views: A survey. The VLDB Journal, 10(4):270\u2013294,\n\n2001.\n\n[13] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image\n\nrecognition. In CVPR, pages 770\u2013778, 2016.\n\n[14] Camille Jordan. Sur les assemblages de lignes. J. Reine Angew. Math, 70(81):185\u2013190, 1869.\n\n[15] Mitsuru Kusumoto, Takuya Inoue, Gentaro Watanabe, Takuya Akiba, and Masanori Koyama. A\ngraph theoretic framework of recomputation algorithms for memory-ef\ufb01cient backpropagation.\nIn NeurIPS, 2019.\n\n[16] Chen Meng, Minmin Sun, Jun Yang, Minghui Qiu, and Yang Gu. Training deeper models by\nGPU memory optimization on TensorFlow. In Proceedings of the ML Systems Workshop in\nNeurIPS, 2017.\n\n[17] Minsoo Rhu, Natalia Gimelshein, Jason Clemons, Arslan Zul\ufb01qar, and Stephen W Keckler.\nvdnn: Virtualized deep neural networks for scalable, memory-ef\ufb01cient neural network design.\nIn IEEE/ACM International Symposium on Microarchitecture, page 18, 2016.\n\n[18] Ravi Sethi. Complete register allocation problems. SICOMP, 4(3):226\u2013248, 1975.\n\n[19] Koichi Shirahata, Yasumoto Tomita, and Atsushi Ike. Memory reduction method for deep\nneural network training. In Workshop on Machine Learning for Signal Processing (MLSP),\npages 1\u20136, 2016.\n\n[20] Mikkel Thorup. All structured programs have small tree-width and good register allocation.\n\nInformation and Computation, 142(2):159\u2013181, 1998.\n\n[21] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez,\n\u0141ukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NeurIPS, pages 5998\u20136008,\n2017.\n\n10\n\n\f", "award": [], "sourceid": 8693, "authors": [{"given_name": "Ravi", "family_name": "Kumar", "institution": "Google"}, {"given_name": "Manish", "family_name": "Purohit", "institution": "Google"}, {"given_name": "Zoya", "family_name": "Svitkina", "institution": "Google"}, {"given_name": "Erik", "family_name": "Vee", "institution": "Google"}, {"given_name": "Joshua", "family_name": "Wang", "institution": "Google"}]}