VerifiedarXiv:2506.1420232 min
Training · Systems

Reducing Activation Recomputation in Large Transformer Models

Cut transformer training memory 5×, and skip the recompute tax.

Training a giant transformer means storing a mountain of intermediate tensors so the backward pass can use them. The usual fix throws them away and recomputes them, at a steep price. This paper shows most of that price was never necessary.

Explaining the paperReducing Activation Recomputation in Large Transformer ModelsKorthikanti, Casper, Lym, McAfee, Andersch, Shoeybi, Catanzaro · NVIDIA · MLSys 2023 · arXiv:2205.05198

What if most of the compute you spend recomputing activations during training, you never needed to spend at all?

Ask anyone who has tried to train a large model what stops them, and the answer is seldom "not enough arithmetic." It is "out of memory." And the thing that fills the GPU is rarely the weights. It is the activations: every intermediate tensor the forward pass produces that the backward pass needs in order to compute a gradient. When you run a layer forward, you hold on to its inputs, its attention scores, the output of its non-linearity, and so on, because backpropagation reaches back for every one of them to work out how to nudge the weights. That stored scratch work is the activation memory, and at scale it is enormous: it grows with the batch, the sequence length, the width, and the depth all at once, so for a model with hundreds of billions of parameters the activations dwarf the parameters themselves and blow past the 80GB an A100 gives you.

The standard escape is activation recomputation, also called gradient checkpointing: during the forward pass you deliberately forget most of those tensors, keeping only a few, and when the backward pass needs one you recompute it on the spot with an extra forward pass. It is a clean trade, memory for compute, and it has been the workhorse of large-model training for years. The trouble is the bill. Recomputing every layer's activations means running the whole forward pass twice, and in NVIDIA's training runs that costs 303040%40\% in wall-clock time. On a training run that takes GPU-months, a third of your budget vanishing into redundant arithmetic is not a rounding error.

This paper, from the NVIDIA team behind Megatron-LM, makes a quietly aggressive claim: most of that recomputation is unnecessary, because you can cut the memory enough that you barely have to recompute at all. It does so with two techniques that are embarrassingly simple, sequence parallelism and selective activation recomputation, layered on top of the tensor parallelism Megatron already uses. Together they cut activation memory by about 5×5\times and recover over 90%90\% of the time recomputation used to cost, which on a 530B model across 2240 GPUs turns a 42.1%42.1\% model-FLOPs utilization (the fraction of the GPU's peak arithmetic spent on work the model actually needs) into 54.2%54.2\%, a 29%29\% speedup for free.

To see how, we build up the argument in the order the paper does: count exactly how many bytes one transformer layer costs and notice its two very different halves, watch what tensor parallelism can and cannot divide, split the stubborn leftover along a dimension nobody was using, and finally recompute only the half that is cheap to redo. None of the pieces is hard. Stacked up, they explain the whole result.

The activation memory wall

Be precise about the word, because the whole paper turns on it. An activation here is any tensor created in the forward pass and needed for the backward pass, and nothing else: not the model weights, not the optimizer state, but yes the humble dropout mask, because backprop needs to know which entries were zeroed. Naming it this way carves out exactly the memory this paper attacks. Weights and optimizer state are fixed by your model and your optimizer; activations are the part you have freedom over, because you can always recompute one instead of storing it.

Why is there nothing else to lean on? The two other forms of parallelism each fail to help with activations specifically. Tensor parallelism (splitting a layer across GPUs) reduces them somewhat, as we will see, but cannot be pushed past a handful of GPUs before its communication overwhelms the gains. Pipeline parallelism (splitting the layers into groups across GPUs) looks like it should divide activation memory by the number of pipeline stages pp, and it does not. To keep the pipeline busy and avoid idle bubbles, the first stage has to hold pp microbatches in flight at once, and since it owns L/pL/p layers, it stores p×L/p=Lp \times L/p = L layers' worth of activations no matter how large pp is. So depth parallelism buys you nothing on the activation budget. That is why the storage of activations, not arithmetic and not parameters, is what decides how large a model you can fit.

Counting the bytes

So count them. Take one transformer layer: a self-attention block and a two-layer MLP, each preceded by a layer norm, wired with residual connections. Write ss for the sequence length, bb for the microbatch size, hh for the hidden width, and aa for the number of attention heads. Activations are stored in 16-bit precision, so two bytes each, except dropout masks, which need only one. Walk the layer and add up what the backward pass will reach for: a matrix multiply has to keep its input, the softmax has to keep its output, the GeLU has to keep its input, each dropout keeps its one-byte mask. Doing this bookkeeping for the attention block, the MLP, and the two layer norms gives a single tidy formula:

activation bytes per layer=sbh(34+5ash)\text{activation bytes per layer} = sbh\left(34 + 5\,\frac{as}{h}\right)
(1)

What matters in this formula is the shape of its two terms. The 34sbh34sbh is the structural part: the inputs to the big matrix multiplies, the GeLU, the layer norms, the dropout masks. It grows linearly with the sequence length ss. The 5as/h5\,as/h term is written relative to that outer sbhsbh; multiply the sbhsbh back through and the hh cancels while a second ss appears, leaving 5as2b5as^2b bytes, the attention part, the score tensors that the QKQK^\top product, the softmax, and the attention dropout all produce. There is one s×ss\times s score matrix per head, so this part grows like as2a s^2, quadratically in the sequence length. Which of the two terms dominates depends on how long your sequences are.

Watch the crossover in the figure below. The amber attention line has twice the slope of the teal structural line, so past a sequence length of about 870 (for GPT-3's width) it overtakes and never looks back. By the trained length of 2048, attention is already about 70% of the per-layer budget, and out at long-context lengths it is nearly the whole thing. That 70% is exactly the slice selective recomputation will later throw away.

Figure 1 · the activation budget
2K
Per-layer activation memory (GPT-3 width, one sample, fp16) split into the structural part (34sbh, linear in s) and the attention part (5as²b, quadratic in s). On log axes the attention line is twice as steep, crosses the structural line near s≈870, and dominates from there. Drag the sequence length; the gap is why long context is so memory-hungry.

So a single layer of GPT-3 at sequence length 2048 costs nearly three gigabytes of activations for one sample, and a real model stacks dozens of such layers and runs a microbatch through all of them at once. That is the wall. The rest of the paper is four moves to climb it.

Tensor parallelism, and the floor it leaves

The first move is one Megatron already had: tensor parallelism. The idea is to split each layer's big matrix multiplies across tt GPUs, so each holds a slice of the weights and computes a slice of the work. Megatron does this cleverly: the MLP's first matrix is split by columns and its second by rows, so the non-linearity in between needs no communication, and attention is split across heads. Two collective communications per block stitch the slices back together. The win for us is that splitting the matrix also splits the activations inside the block: each GPU only stores its slice of the attention scores and the MLP's wide intermediate. With tt-way tensor parallelism the per-layer cost becomes:

activation bytes per layer=sbh(10+24t+5asht)\text{activation bytes per layer} = sbh\left(10 + \frac{24}{t} + 5\,\frac{as}{ht}\right)
(2)

Look at what divides and what does not. The attention term and most of the structural work now carry a 1/t1/t, so adding GPUs shrinks them. But a stubborn 10sbh10sbh sits out front with no tt underneath it, and that floor is the crux of this section. It is the activation memory that tensor parallelism structurally cannot touch, and it is worth seeing exactly which tensors make it up: the two layer-norm inputs (4sbh4sbh), the two tensors handed into the blocks that the communication operators need (the input to the query-key-value projection and the input to the MLP, another 4sbh4sbh), and the two dropout masks after attention and after the MLP (2sbh2sbh). Add them: 4 plus 4 plus 2 is 10.

Why can't tensor parallelism just split these too? Because they live at the block boundaries, on the full hidden vector, with no big matrix multiply to chop along the hidden dimension. A layer norm and a dropout touch every element on its own; there is nothing to distribute. You could force a split along the hidden dimension anyway (this is what some other systems do), but then the layer norm has to be computed across devices, which means extra communication on an operation that is supposed to be cheap. So tensor parallelism leaves these replicated: every one of the tt GPUs keeps a full copy of the same layer-norm and dropout activations. Drag tt in the figure and watch the two divisible terms melt away while the floor stays exactly where it is. At t=8t=8, the paper's setting, the floor is already most of what survives.

Figure 2 · what tensor parallelism can't split
t = 8
Per-layer activation memory under t-way tensor parallelism (Eq 2), in units of sbh, for GPT-3. The attention scores and the rest of the block shrink as 1/t, but the replicated floor of 10 (layer norms, dropout masks, block inputs) never divides. Drag t: the bar collapses toward the dashed floor, which then dominates.

That floor is the whole reason the next technique exists. Tensor parallelism alone is stuck: pour in more GPUs and the replicated layer norms and dropouts become the bottleneck. We need a way to split them too, without paying for a multi-device layer norm.

Splitting the leftover by sequence

The escape is a different axis. The replicated operations, the layer norms and the dropouts, share a quiet property: they act on each position in the sequence independently. A layer norm normalizes each token's vector on its own; dropout zeroes entries one at a time. Nothing in those operations couples one token to another. So instead of trying to split them across the hidden dimension, where there is nothing to split, sequence parallelismsplits them across the sequence: hand each of the tt GPUs a different stretch of tokens, and let it store and normalize only its stretch. The floor that would not divide along the hidden dimension divides cleanly along the sequence.

The catch is that the layer now has two kinds of regions with two different splits. The attention and MLP blocks are split along the hidden dimension (that is tensor parallelism). The layer norms and dropouts are now split along the sequence. Between them you need a converter that re-shards the activation from one layout to the other. Going into a block, you must gather the full sequence back together so the matrix multiply can see all of it: an all-gather. Coming out, you must sum the parallel partial results and scatter them back into sequence shards: a reduce-scatter. Call these two converters gg and gˉ\bar{g}.

A reader doing the bookkeeping would expect this to double the communication. It does not. Plain tensor parallelism used one all-reduce at each of those boundaries, the collective that sums a tensor across all tt GPUs and hands every GPU back the identical summed copy. A ring all-reduce, the standard fast implementation, is not a single primitive: it is a reduce-scatter followed by an all-gather. So when sequence parallelism replaces each all-reduce with a reduce-scatter and an all-gather, it is splitting the all-reduce into the two halves it was always made of, not adding a second collective on top of the first. The total number of bytes crossing the network is identical. Sequence parallelism shards the entire layer, floor included, for no extra communication volume.

Toggle the figure below to feel it. With tensor parallelism alone, the layer-norm and dropout regions show four full copies, one on every rank, the wasted replication. Turn sequence parallelism on and those regions split into four shards, the converters relabel from all-reduce to all-gather and reduce-scatter, and the per-layer memory drops with them.

Figure 3 · sequence parallelism
One transformer layer across t=4 ranks. The Attention region is split along the hidden dimension; the LayerNorm and Dropout regions are the leftover floor. Tensor parallelism alone keeps four full copies of them; add sequence parallelism and they split into four shards along the sequence. The converters (all-gather g, reduce-scatter ḡ) move the same bytes a single all-reduce did.

Sharding the floor is exact and clean: it turns the un-divided 10sbh10sbh into 10sbh/t10sbh/t, so now every term in Equation (1) carries a 1/t1/t, and the per-layer cost collapses to Equation (1) divided by the tensor parallel size:

activation bytes per layer=sbht(34+5ash)\text{activation bytes per layer} = \frac{sbh}{t}\left(34 + 5\,\frac{as}{h}\right)
(4)

Tensor parallelism on its own divided most of the layer by tt; sequence parallelism finishes the job and divides the rest. And it does so while the model's real bottleneck, the slow part, gets a touch faster too, because the layer norms and dropouts now run on 1/t1/t of the data on each GPU instead of the full tensor.

Recompute the cheap half

Even divided by tt, the activations of the largest models still do not fit, so some recomputation is unavoidable. The old way recomputed everything, paying the full 303040%40\% tax. The insight here is that the two halves of Equation (1) differ in how they scale, and they differ just as much in what they cost to rebuild.

Go back to the attention part, the 5as/h5as/h term. Those s×ss\times s score tensors are huge in memory, because there is one per head and they grow quadratically in sequence length. But the operations that produce them, the QKQK^\top product, the softmax, the dropout, and the multiply against the values, are cheap in arithmetic: a couple of matrix multiplies whose floating-point cost is small next to the rest of the layer. They are memory-heavy and compute-light. The structural 34sbh34sbh part is the opposite: it is anchored by the big weight matrices, which are expensive to recompute. So instead of recomputing the whole layer, selective activation recomputation recomputes exactly the attention scores and stores everything else. You keep the queries, keys, and values (those are part of the 3434), throw away the score tensors they produce, and rebuild only those in the backward pass.

The trade is lopsided in your favor, which the figure makes plain: drag the sequence length and watch the memory bar shed most of its length while the compute bar barely grows. Dropping the attention term cuts per-layer activation memory by (5as/h)/(34+5as/h)(5as/h)/(34 + 5as/h), which works out to 70%70\% for GPT-3 and 65%65\% for the 530B model. The compute it costs is one extra partial forward pass over just those two matrix multiplies, which raises total training FLOPs by a factor of roughly 1+s/(6h)1 + s/(6h), about 2.7%2.7\% for GPT-3 and 1.6%1.6\% for the 530B model. Two-thirds of the memory back for a couple of percent of compute. The edge of the slider is honest about the limit: at very long sequences the recompute cost climbs, because s/(6h)s/(6h)grows with ss, so the trade is best in the regime these models actually train in.

Figure 4 · the lopsided trade
2K
Selective recomputation keeps the structural 34 activations and recomputes the attention scores 5as/h (top, the dashed block is rebuilt, not stored). The compute cost (bottom) is one extra partial forward, a factor 1+s/6h. At s=2048 that is a 70% memory cut for 2.7% more compute. Drag s: memory savings climb toward 100% while compute stays modest.

Why target attention and not, say, the MLP's wide intermediate, which is also large? Because the MLP's intermediate is produced by the big h4hh\to 4h matrix multiply, which is exactly the kind of expensive arithmetic you do not want to redo. Selective recomputation is a rule about cost density, not size: recompute the activations that are large in bytes but small in FLOPs, store the ones that are cheap to keep but dear to rebuild. The attention scores are the clearest case of the former in the whole layer.

In code it is a one-line change to a normal transformer layer: wrap the attention-score core in a recompute checkpoint, which keeps its inputs and re-runs it in the backward pass, and leave everything else stored as usual.

# Selective activation recomputation. Wrap ONLY the attention-score
# core in a recompute checkpoint: its inputs (Q, K, V) are stored,
# the big a×s×s score tensors are dropped and rebuilt in the backward
# pass. That is the whole change from a normal transformer layer.
def attn_core(q, k, v):                  # the 5as²b memory hog
    scores = softmax(q @ k.transpose())  # [a, s, s] per microbatch
    scores = dropout(scores)
    return scores @ v                    # attention over values

def layer(x):                            # x: [s, b, h]
    y = layernorm(x)
    q, k, v = qkv(y)                     # kept (part of the 34 sbh)
    a = checkpoint(attn_core, q, k, v)   # recomputed, not stored
    x = x + dropout(proj(a))             # residual add
    z = layernorm(x)
    x = x + dropout(mlp(z))              # mlp kept in full
    return x

With this, the total activation memory across the first pipeline stage, the one that has to hold a full model's worth, drops to:

total activation memory=34sbhLt\text{total activation memory} = 34\,\frac{sbhL}{t}
(6)

The attention term is gone, which has a tidy consequence: activation memory now scales linearly with sequence length instead of quadratically, and it no longer depends on the number of attention heads at all. The single most explosive term in the budget has been spent down to a recompute that costs a couple of percent.

What it buys

Stack the four moves and measure. The cleanest way to see the memory story is as a ladder against the tensor-parallel baseline. Sequence parallelism and selective recomputation each cut roughly half on their own; together they bring a layer's activation memory down to under 20%20\% of the baseline, the headline 5×5\times reduction. Full recomputation reaches even lower, around 10%10\%, but that is the point of comparison the paper wants you to make: the combined techniques use only about twice the memory of full recompute while avoiding nearly all of its compute tax. Switch models in the figure and the ladder holds across every scale from 22B to a trillion parameters.

Figure 5 · the memory ladder
Per-layer activation memory as a percentage of the tensor-parallel baseline (t=8). Sequence parallelism and selective recomputation each cut about half; together (the present work) they reach ~18%, a ~5× cut, at roughly twice the memory of full recomputation but without its 30–40% compute tax. Switch models with the buttons.

The reason that two-times-the-memory comparison matters is that the memory was never the goal; it was the obstacle. Once you no longer have to recompute, the redundant forward pass disappears and the training runs faster. The paper measures the per-layer overhead of full recomputation at 39%39\% and the overhead of the combined techniques at 4%4\%, and on the full end-to-end iteration the throughput improves by 292932%32\% across all four model sizes. Click through the models below: the gap between the two bars is the recompute pass you stopped paying for, and the labels track the model FLOPs utilization climbing with scale to 56.3%56.3\% on the trillion-parameter model.

Figure 6 · the throughput payoff
Iteration time, full recomputation versus the present work, normalized within each model. The present work runs in about 77% of the time, a 29–32% throughput gain, while reaching higher model FLOPs utilization (up to 56.3% at 1T, against the A100's 312 TFLOP/s peak). Click a model for its absolute iteration times.

One number deserves a closer look, because it is what makes the whole effort worthwhile. Model FLOPs utilization rises with model size and reaches 56.3%56.3\% at a trillion parameters. Recomputation drags that number down because the recomputed FLOPs are real arithmetic the hardware performs but the model does not need; cutting them is most of why utilization climbs. Run the 530B model on its full 2240-GPU configuration and it sustains a 54.2%54.2\% utilization, up from 42.1%42.1\% with recomputation, which is the 29%29\% end-to-end speedup the abstract opens with.

Step back and the argument is four facts long. Activation memory is what bounds large-model training, and it splits into a structural part that is linear in sequence length and an attention part that is quadratic. Tensor parallelism divides most of it but leaves a replicated floor of layer norms and dropouts. Sequence parallelism shards that floor along the one axis those operations leave free, for no extra communication, dividing the whole layer by the tensor parallel size. And selective recomputation spends down the single largest remaining term, the attention scores, recomputing the half that is huge in memory and cheap in arithmetic. The recomputation everyone was paying for, it turns out, was mostly avoidable. They just had not counted the bytes carefully enough to see it.

Provenance Verified against primary literature
Megatron-LM (2019)Tensor parallelism: column/row-split attention and MLP, the f and g collectives.
Chen et al. (2016)Activation recomputation (gradient checkpointing): trade compute to save memory.
Megatron-LM (2021)1F1B / interleaved pipeline parallelism and the first-stage activation pressure.
PaLM (2022)Model- vs hardware-FLOPs utilization (MFU / HFU); A100 peak 312 TFLOP/s.
correction“No communication overhead” means the same bytes move, not the same wall-clock: a ring all-reduce already equals a reduce-scatter plus an all-gather, so the volume is unchanged, but two collectives carry slightly more latency than one. We say so rather than claim it is strictly free.

Questions you might still have

?

If activations are the problem, why not just recompute all of them?
You can, and people did. But recomputing every layer means a second full forward pass, which costs 30–40% in wall-clock time. The point of this paper is that sequence parallelism and selective recomputation cut the memory enough that you barely have to recompute at all, recovering over 90% of that tax.

?

Why can’t tensor parallelism split the layer norms too?
Layer norms and dropouts act on the full hidden vector with no big matrix to chop along the hidden dimension, and forcing a split there means computing the layer norm across devices (extra communication). They are independent across the sequence, so sequence parallelism splits them along that axis instead, for free.

?

Does sequence parallelism add communication?
No extra bytes. It replaces each all-reduce with a reduce-scatter plus an all-gather, which is exactly what a ring all-reduce already decomposes into, so the total data moved is identical. There is a little more latency from launching two collectives instead of one, but the bandwidth is the same.

?

Why recompute the attention scores and not the MLP?
The attention score tensors are huge in memory (one s×s matrix per head, growing quadratically with sequence length) but cheap to recompute (a couple of matrix multiplies). The MLP’s intermediate is the opposite: produced by an expensive matrix multiply. Recompute the memory-heavy, compute-light half.

Footnotes & further reading

  1. The paper: Korthikanti, Casper, Lym, McAfee, Andersch, Shoeybi, Catanzaro, Reducing Activation Recomputation in Large Transformer Models (NVIDIA, MLSys 2023). The implementation ships in Megatron-LM.
  2. Tensor (intra-layer) parallelism, the column/row matrix splits and the f and g operators: Shoeybi et al., Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.
  3. Activation recomputation / gradient checkpointing is prior art, not this paper: Chen, Xu, Zhang, Guestrin, Training Deep Nets with Sublinear Memory Cost (2016), building on Griewank & Walther. We cover it separately in the gradient checkpointing explainer.
  4. Pipeline parallelism (the 1F1B and interleaved schedules and their activation pressure): Narayanan et al., Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM, and Memory-Efficient Pipeline-Parallel DNN Training.
  5. The earlier, different "sequence parallelism" (which shards along the sequence everywhere but replicates the parameters, so it does not scale to large models): Li, Xue, Li, You, Sequence Parallelism: Long Sequence Training from System Perspective.
  6. Model- and hardware-FLOPs utilization: Chowdhery et al., PaLM: Scaling Language Modeling with Pathways. The models measured are GPT-3 and the 530B Megatron-Turing NLG.