Training Deep Nets with Sublinear Memory Cost
Trade one extra forward pass for square-root memory.
Most of what training keeps in GPU memory, it keeps only to use once, much later. Gradient checkpointing throws most of it away and recomputes it on demand, dropping the memory cost of an -layer network from to for the price of a single extra forward pass.
Explaining the paperTraining Deep Nets with Sublinear Memory CostTo train a network you have to remember its whole forward pass. What if you just recomputed the parts you threw away?
A fact that has quietly shaped how big a model you can train: backpropagation makes you remember everything. Every layer's output activation is held in memory from the moment the forward pass computes it until the backward pass comes back around, much later, to use it. Train a 100-layer model and you pay for 100 layers' worth of activations at once. The culprit is the chain rule. To get the gradient at a layer you need that layer's own forward values: a stored activation is the price of being able to compute its gradient later. So memory grows linearly with depth, and that wall, not arithmetic and not data, is usually what decides how deep a network fits on a GPU. The paper opens on exactly this pain: a 1,000-layer residual network that needs 48 GB it doesn't have.
Look closely at what those stored activations are, though: scratch work. The forward pass used each one to compute the next and then left it sitting idle for the entire rest of the run, alive in memory but untouched, waiting for one brief moment when the backward pass needs it. Keeping it around is convenient, not necessary. If an activation is cheap to rebuild, you could throw it away and recompute it the instant the gradient asks for it. That one trade, a little recomputation in exchange for a lot of memory, is the entire paper: keep a few checkpoints, drop the rest, and rebuild the dropped ones on the way back. Done right it takes the memory bill from down to , at the cost of running the forward pass roughly one extra time.
The technique now has a name, gradient checkpointing, and it is everywhere a large model is trained. Two ideas stack to get there. First, treat memory the way a compiler treats registers and reuse it aggressively, which is a constant-factor win. Second, drop activations and recompute them, which is the win that actually beats the linear trend. We'll take them in that order.
The memory wall
Start with why training is so much hungrier than inference. At inference you run the forward pass and discard each activation as soon as the next layer has consumed it, so a network of any depth can run in nearly constant memory. Training cannot do that, and the reason is the chain rule.
To update layer , backprop needs two things: the gradient flowing back from above, and the layer's own local derivative. That local derivative has to be evaluated at the actual values the forward pass produced. A ReLU's derivative is 1 where its input was positive and 0 where it was negative, so you cannot apply it without knowing which inputs were positive, and the output cannot tell you: a ReLU sends every negative input to the same 0, so it has thrown that information away. A matmul's weight gradient is the upstream gradient times the layer's input, so again you need that input. None of it is recoverable from the output alone, so the framework keeps every layer's forward activation alive. The consequence is the picture below: at the moment the backward pass begins, all activations are in memory at once. That simultaneous peak is the wall, and it scales linearly with depth.
That linearity is the enemy. It means the deepest network you can train is set by a division: GPU memory divided by the memory per layer. Buy a bigger GPU and you push the wall out by a constant; you do not move it.
Cheap wins: the computation graph
Before recomputing anything, the paper squeezes the easy savings out of how memory is allocated. The trick is to stop thinking of a network as code and start thinking of it as a computation graph: nodes are operations, edges are the values that flow between them, and the backward pass is just more nodes and edges in the same graph. Once it is a graph, allocating memory to it is the same problem a compiler solves when it assigns a finite set of registers to a program's variables.
The key notion borrowed from there is liveness: a value is live from the moment it is produced until its last use, and two values whose live ranges do not overlap can safely share one buffer. That unlocks two optimizations. An in-place operation writes its output directly over an input that is already dead, so an activation function costs no new memory. Sharing recycles a buffer the moment its value is no longer needed by anyone downstream. The textbook version of finding the best such assignment, graph coloring, is actually NP-complete, so the paper does not solve it exactly: it sweeps the graph once with a liveness counter, freeing each buffer when its count of pending uses hits zero, which is linear time and good enough in practice. (The same analysis lets MXNet, which the authors built on, allocate every buffer statically before a single op runs, so it can report the exact memory a plan will use.)
The saving is real but bounded: in-place and sharing cut the feature-map memory by a factor of two to three. That is the lower bar in Figure 1, and the catch is right there in the picture. It is a constant factor. Both meters still climb in lockstep with depth, so a deep enough network hits the wall regardless of how cleverly you pack the buffers. To bend the line itself, not just lower it, you need an idea that changes the exponent.
Trade compute for memory
That idea is recomputation, and it goes like this: cut the chain of layers into contiguous segments; on the forward pass, keep only the activation at each segment boundary and drop everything inside; on the backward pass, when you arrive at a segment, replay its short forward run from the saved boundary to rebuild its internal activations, run the gradient through them, then free them again before moving to the next segment. You never hold more than one segment's worth of internals at a time, plus the handful of saved boundaries. Figure 2 below scrubs through the whole sequence, forward pass keeping only the four checkpoints and the backward pass walking right-to-left lighting and freeing one segment at a time. In code it is anticlimactic:
# checkpointed backprop over an n-layer chain, k segments (Alg 1)
v = input
for seg in segments: # FORWARD: keep only the boundaries
checkpoint[seg] = v # save this segment's input activation
for layer in seg: # run the segment, drop its internals
v = layer.forward(v)
grad = loss_grad(v, target)
for seg in reversed(segments): # BACKWARD: one segment at a time
v = checkpoint[seg] # restore the saved boundary
cache = {}
for layer in seg: # RE-COMPUTE the dropped activations
cache[layer] = v
v = layer.forward(v)
for layer in reversed(seg): # backprop through the segment
grad = layer.backward(grad, cache[layer])
free(cache) # peak: one segment + k boundariesWatch it run below. The forward pass keeps only the amber checkpoints and discards the rest. Then the backward pass walks the segments from right to left: each one lights up as it is recomputed from its checkpoint (teal), the gradient sweeps through it, and it is freed. The memory trace underneath is the whole point. Instead of the plain-backprop ramp up to , it is a low sawtooth, each tooth one segment's recompute, with a peak nowhere near the top line.
Two things to take from the trace. The first is what it costs. Every activation that was dropped is recomputed exactly once, during its segment's backward step, so the total extra work is precisely one more forward pass over the network. That lands at about 30% more wall-clock time, not the 50% it sounds like. A backward pass already costs about twice a forward pass, because it does a forward's worth of arithmetic and then propagates the error, so an ordinary training step is roughly one forward plus two forwards' worth of backward, three units in all. Adding one more forward makes four, about a third more. The paper measures 30% on real hardware, a hair under the unit model's 33% because the recomputed forward runs a little cheaper than the original.
The second is what it does not cost. The recomputed activations are bit-for-bit the same arithmetic as the originals, so the gradients are identical to what plain backprop would have produced. The paper states that the method "gives equivalent weight gradient": this is a pure memory-for-compute trade with no effect on the model you get out. You are not approximating anything. You are just choosing to rebuild some numbers rather than store them.
The √n sweet spot
One dial is left: how many segments? Cut the network into of them and the peak memory has two parts, which the paper writes as
Read the two terms off the mechanism. The is the size of the single segment you have to hold and recompute at the worst moment, when one segment's internals are all live for its backward step. The is the pile of saved checkpoints, one per boundary, that you carry the whole time. They pull in opposite directions: few segments means each one is enormous to recompute ( large), many segments means the checkpoints themselves become the cost ( large). You want the that makes the total smallest, so set the derivative to zero:
The balance lands exactly where the two terms are equal, , which is . Split an -layer network into segments of layers each and the peak is . Drag the slider below and watch the two component curves cross right under the bottom of the U; both ends of the slider, one giant segment and a checkpoint on every layer, cost the full .
That is the headline result. Memory , compute one extra forward pass. The square root is doing a lot of work: a network whose activations would have needed a million units of memory now needs a thousand, and the bill for that is a third more time. One caution on what the covers. It is the cost of the intermediate feature maps, the activation tensors flowing through each layer, which dominate training memory for deep conv nets and unrolled RNNs. The parameters and the scratch space a convolution needs are a separate, unchanged line on the bill; checkpointing leaves them alone.
Pay even less: recursion
One more turn of the screw, and it is too pretty to leave out: a segment is itself just a chain of layers, so apply the very same trick inside it. Checkpoint within the segment, drop its sub-internals, recompute them recursively. Let be the memory to do a forward-and-backward over layers when you store results and recurse on the pieces between them. Each level costs checkpoints and hands a chain of length to the next level down:
Unrolling that recursion stacks up one per level, and the number of levels is how many times you can divide by before reaching a single layer, which is . So
Choosing where to place the checkpoints at every level of that recursion is exactly the revolve algorithm from the automatic-differentiation literature, the optimal schedule for trading recomputation against memory. Push it to the extreme. With you keep a single checkpoint at each level, halving the remaining chain every time, so (3) gives . Logarithmic memory. The price is one more forward pass per level, so forward passes in total instead of one. This instantiates a general theorem from that same literature (Griewank and Walther): for any , you can train in compute and memory. Plain backprop is ; the scheme is ; driving up to takes memory down to .
Drag the slider to spend forward passes and watch the memory collapse. The shape is the lesson: the curve has a sharp knee at two passes. That first extra forward pass, the move from storing everything to the scheme, buys the bulk of the saving. Everything past it, the slow walk from down to , costs many more passes for a sliver more memory.
Which is exactly why is the default and is a curiosity. The paper says as much: the logarithmic version "may not be used commonly" because running the forward pass times is a steep price for the last factor of memory. It is a theorem about how far the trade can be pushed, not a recommendation.
What it costs, what it buys
The experiments cash the whole thing out. On a 1,000-layer residual network, checkpointing takes the runtime memory from 48 GB down to 7 GB, a touch under a 7× reduction, while adding about 30% to the wall-clock time. The gap is decisive in practice: with even the best linear allocation plan, the largest ResNet the authors' GPU could hold was a couple hundred layers; with checkpointing, a thousand layers fit in 7 GB. The figure below makes that concrete. Drag depth and watch the linear plans cross the GPU's memory ceiling after a few hundred layers while the sublinear plan stays under it into the thousands.
The second toggle in that figure shows this is not a vision trick: the authors run it on a four-layer LSTM with 1,024 hidden units unrolled over a long sequence, the setting where memory grows with the number of timesteps rather than layers, and the sublinear plan gives more than a 4× reduction over the best plan that does not recompute. (On recurrent nets the in-place optimization pulls extra weight, because it lets the per-timestep weight gradients accumulate directly into one buffer instead of allocating fresh space at every step.)
A note on what the numbers are measuring, since two different quantities are in play. The clean curve is the feature-map estimate from MXNet's static allocator, which is why it is exact and smooth. The 48 GB to 7 GB headline is the total runtime memory read off nvidia-smi, parameters and workspace and all, which is why it is a roughly 7× drop rather than a literal square root. Both are correct; they count different things.
Step back and place the idea among its relatives, because it is one of three ways people now get memory-efficient training, and they are easy to conflate. Checkpointing recomputes dropped activations from saved checkpoints, and it is the general, drop-in one: it works on any network with no change to the architecture. Reversible networks instead reconstruct each layer's activations from the next layer's, reaching constant memory in depth, but only by constraining the architecture so the layers are invertible. Neural ODEs get constant memory a third way, by solving an adjoint equation backward instead of storing a trajectory. FlashAttention also recomputes in its backward pass, which is why it is often filed next to checkpointing, but its central idea is something else, IO-aware tiling to keep the attention computation in fast on-chip memory, and it only recomputes the small softmax statistics rather than full activations. Checkpointing is the one that asks nothing of your model.
It is also the one that won. The trade is now standard equipment, sitting behind a single flag in every major framework (PyTorch's torch.utils.checkpoint, the activation-checkpointing switches in DeepSpeed and Megatron), and it is a large part of how anyone fits a big Transformer onto the hardware they have. The title undersells one thing: the technique was not invented here. The core idea, and the optimal recursive scheduling that gives the result, come from the automatic-differentiation literature of the 1990s. This paper's contribution was to bring it to general deep networks, automate the planning, and measure that it pays, which is why it is the one everyone cites when they reach for it.
The argument fits in a sentence. An activation you can cheaply recompute is not worth the memory to store it. Keep of them, recompute the rest, pay about a third more compute, and depth stops being a memory problem.
Questions you might still have
If you recompute activations, isn’t that just doing the work twice?
Only one extra forward pass, not double everything. A backward pass already costs about twice a forward, so a normal step is ~3 forward-units; one extra forward makes it ~4, about a third more time. You pay ~30% compute to take memory from n down to √n.
Why √n, and not something smaller?
Peak memory is n/k + k: one live segment to recompute plus k stored checkpoints. The sum is smallest when the two are equal, at k=√n. You can go below √n only by recomputing recursively, which the frontier shows costs many more forward passes for very little more memory.
Does checkpointing change the trained model?
No. Recomputed activations are the same arithmetic as the originals, so the gradients are bit-for-bit what plain backprop would compute. The paper calls it "equivalent weight gradient." It is a pure memory-for-compute trade, with no effect on the result.
Is this the same thing FlashAttention does?
Related, not the same. Both recompute in the backward pass, but FlashAttention’s core idea is IO-aware tiling to keep attention in fast on-chip memory, and it only recomputes the small softmax statistics. Checkpointing is the general version that recomputes whole segments of any network.
Footnotes & further reading
- The paper: Chen, Xu, Zhang, Guestrin, Training Deep Nets with Sublinear Memory Cost (2016). The implementation was built on MXNet, whose static allocator lets the exact feature-map memory of a plan be read off before training runs.
- The origin of checkpointing and the logarithmic-memory result: Griewank & Walther, Algorithm 799: revolve (ACM TOMS, 2000), and the general compute/memory theorem in their book Evaluating Derivatives (2008).
- Register allocation and liveness analysis, the compiler analogy: Aho, Lam, Sethi & Ullman, Compilers: Principles, Techniques, and Tools (the "Dragon Book").
- The 1,000-layer benchmark was made trainable by He et al., Identity Mappings in Deep Residual Networks (2016).
- The recompute-vs-reconstruct-vs-adjoint cousins: FlashAttention's IO-aware recomputation (Dao et al., 2022), reversible residual networks (Gomez et al., 2017), and the adjoint method of Neural ODEs (Chen et al., 2018).
- The modern incarnation: PyTorch's torch.utils.checkpoint, and the activation-checkpointing options in DeepSpeed and Megatron-LM.
How could this explainer be improved? Found an error, or something unclear? I read every message.