Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
A model too big for one GPU still trains if you split each layer's matrix multiplies across many.
Cut each weight matrix the right way and the GPUs exchange data at only two points per layer; everything else runs locally on each shard. That is enough to train an 8.3-billion-parameter transformer in plain PyTorch.
Explaining the paperMegatron-LM: Training Multi-Billion Parameter Language Models Using Model ParallelismBy late 2019 the recipe for a better language model was to "make it bigger," and the bottleneck was no longer ideas or data. It was that the model no longer fit on a single chip.
A 32 GB GPU was the largest memory you could buy, and a few billion parameters with their optimizer state already overflowed it. The two standard ways to use more GPUs both fail here. Scaling the model by adding data parallelism only gives each GPU its own copy of the entire model and a different slice of the batch, so every GPU still has to hold all of it. Splitting the model across GPUs instead had been done before only with custom compilers and frameworks that most people could not adopt.
Megatron-LM, from NVIDIA, makes the second option cheap. It splits the individual matrix multiplies inside each transformer layer across several GPUs, with communication inserted at exactly the points where it is unavoidable, and it does this by adding a handful of operations to an ordinary PyTorch transformer, with no new compiler and no rewrite. Those operations are all-reduces: a collective in which every GPU contributes its partial result and they all come away holding the identical sum. The paper calls the scheme intra-layer model parallelism; the field later settled on the name tensor parallelism, and it is now a permanent ingredient in training the largest models. The whole approach comes down to a few ideas: why one GPU is not enough, how to cut a single matrix multiply two different ways, why one of those ways needs no communication, and how to assemble a block, an embedding, and a deep stack out of those cuts.
The memory wall
Start with the arithmetic that forces everything else. Training with the Adam optimizer in mixed precision keeps several numbers per parameter alive at once: a half-precision weight and its gradient at two bytes each, a full-precision master copy of the weight at four bytes, and Adam's two running averages (the mean and variance of the gradient) at four bytes each. That is about 16 bytes per parameter, and it sits in memory for the entire run. (This accounting is standard mixed-precision practice rather than something Megatron introduced; the ZeRO paper later made it exact.)
Sixteen bytes per parameter is a hard floor that has nothing to do with how clever your code is. An 8.3-billion parameter model needs about 133 GB for those tensors alone, four times what a 32 GB V100 holds, and the optimizer and master copy alone (12 bytes per parameter) already require 100 GB. Activations, the intermediate values saved during the forward pass, pile on top of that, though those are the recoverable part: activation checkpointing throws most of them away and recomputes them in the backward pass, trading compute for memory. The 16 bytes of weights and optimizer state are not recoverable that way. They have to live somewhere.
This is exactly why data parallelism cannot help. Replicating the model across more GPUs means every replica still carries the full 133 GB, so you never get under the ceiling, you only train the same too-large model on more batches at once. The only way down is to make each GPU hold a fraction of the model. Drag the model size below and watch the single-GPU bar cross the 32 GB line, while the tensor-parallel bar, the same model split ways, stays under it:
There are two ways to split a model across devices, and they are orthogonal. Pipeline parallelism cuts the stack by depth, putting whole layers on different GPUs like stages on an assembly line; it works, but the line has bubbles where some stages sit idle waiting for others. Intra-layer parallelism instead reaches inside a single layer and splits its matrix multiplies. Mesh-TensorFlow had already shown how to express that, but it required a special compiler. Megatron's contribution is to get the same split with a few all-reduce operations in native PyTorch, and the two kinds of parallelism stack, so you can use both at once. The rest of this piece is about the intra-layer half.
Splitting one matrix multiply
Everything rests on one small question: when you compute , a matrix multiply followed by GeLU (the nonlinear activation a transformer applies to every entry of the result), and want two GPUs to share the work, which way do you cut the weight matrix ? There are two choices, and they are not symmetric. One needs the GPUs to talk before they can finish; the other does not. Get this right for both matrix multiplies in the block, and the rest of tensor parallelism follows.
Cut by rows. Then the input splits by columns to match, and each GPU computes a partial product:
Neither nor equals the output; you have to add them. And because GeLU is nonlinear, you cannot apply it to the pieces and add afterward, since GeLU does not distribute over a sum:
So the two GPUs must sum their partials, a communication step, before the nonlinearity can run. A sync point lands right in the middle of the block.
Cut by columns. Now is replicated and each GPU owns a different set of output columns:
Here each output column depends only on its own slice of , so the GeLU on GPU 0's columns never needs GPU 1's numbers. Each GPU applies the nonlinearity to its own shard, locally, and the output stays split as . No communication at all. This works because an elementwise nonlinearity commutes with a column split but not with a row split: a column split gives each GPU complete output columns, so it never hands the nonlinearity a partial sum the way a row split would. Toggle the two cuts and watch where the sync appears:
That handles a single matrix multiply. The MLP block has two of them in a row (it widens the hidden size from to and back, with a GeLU in the middle), and the cuts compose beautifully.
A whole block, one all-reduce
Make the first matrix multiply column-parallel, so its GeLU runs locally and the output is already split into column shards. Then feed those shards straight into the second matrix multiply, cut by rows. The row cut lines up exactly with the column shards coming in, so each GPU multiplies its piece with no rearranging, and only at the very end do the partial outputs need to be summed. Two matrix multiplies, one all-reduce. The middle of the block, where a careless split would force a sync, is free.
That single forward all-reduce gets a name, , and it comes with a conjugate twin, , that handles the backward pass. They are the only communication in the block, and they are the same operation seen from two directions:
- sits at the block's output: it is an all-reduce in the forward pass (sum the row-parallel partials) and does nothing in the backward pass.
- sits at the block's input: it does nothing in the forward pass and is an all-reduce in the backward pass (sum the gradient that fans back to the replicated input).
Why conjugates? Broadcasting one value to several GPUs in the forward direction means their gradients have to be summed back together in the backward direction, which is the chain rule applied to a copy. So and run the identical all-reduce on opposite passes, which is why each is a four-line wrapper and the second follows directly from the first:
# the two communication operators (paper, Code 1).
# f and g do the SAME all-reduce, on opposite passes.
class f(torch.autograd.Function): # column-parallel input
def forward(ctx, x): return x # identity
def backward(ctx, grad): all_reduce(grad); return grad
class g(torch.autograd.Function): # row-parallel output
def forward(ctx, x): all_reduce(x); return x
def backward(ctx, grad): return grad # identityThe self-attention block has the same shape, and for a reason worth seeing rather than asserting: attention heads are independent. Head 3 attends over the sequence using only its own query, key, and value projections; it never reads head 5's numbers. So you column-split the combined query/key/value projection, which hands each GPU a whole set of heads, and every head's attention (including its softmax, which runs over the un-split sequence) computes locally. The output projection that recombines the heads is then row-parallel, and like the MLP it needs a single all-reduce at the end. Per transformer layer, then, attention and the MLP contribute one forward all-reduce each, so the cost is two all-reduces in the forward pass and two in the backward pass, four in all, no matter how many GPUs you split across. Toggle the figure between the two blocks and drag :
It helps to see the actual shapes. Take the 8.3-billion-parameter model: 72 layers, hidden size 3072, 24 attention heads of dimension 128, split ways. Each GPU holds a thin vertical or horizontal slice of every weight matrix and three of the 24 heads, and the only data that crosses the wire is one summed activation per block:
# 8.3B config: hidden h=3072, MLP 4h=12288, 24 heads (dim 128),
# split t=8 ways. Each GPU holds a thin slice; one all-reduce per block.
MLP: A 3072x12288 -> column-split -> 3072x1536 per GPU (GeLU local)
B 12288x3072 -> row-split -> 1536x3072 per GPU (g: all-reduce)
Attn: QKV 3072x9216 -> column-split -> 3 of 24 heads per GPU (local)
out 3072x3072 -> row-split -> 384x3072 per GPU (g: all-reduce)Everything else in the layer, the layer norm, the dropout, the residual add, is cheap and identical on every GPU, so Megatron recomputes it on each one rather than communicating a result. Each GPU also keeps and updates only its own slice of the weights, so there is no parameter-synchronization traffic either. The block runs as if it were local, punctuated by two short all-reduces.
The embedding bookends
The layers are handled. Two pieces sit outside them: the input embedding that turns tokens into vectors, and the output layer that turns the final vectors back into a probability over the vocabulary. In a language model these share one weight matrix, of shape hidden-size by vocabulary, and the vocabulary is large (GPT-2 used 50,257 tokens). That matrix is worth splitting too, and Megatron splits it along the vocabulary dimension, giving each GPU a slice of the token columns.
The output side hides a trap. After the parallel multiply, each GPU holds the logits (one raw, pre-softmax score per vocabulary token) for its own slice of the vocabulary. The naive way to compute the loss is to gather all those logits onto one GPU, which moves numbers, batch times sequence-length times vocabulary. With a vocabulary in the tens of thousands, that is an enormous amount of data to ship every step. Megatron instead fuses the parallel multiply with the cross-entropy loss itself. Cross-entropy needs only a couple of sums over the vocabulary (the largest logit, and the sum of exponentials for the softmax denominator) plus the single target logit, and all of those can be reduced from per-GPU partials by communicating just a handful of scalars per token. So only the final per-token losses cross the wire, numbers, a reduction by the full vocabulary factor . And this is not an approximation that trades accuracy for bandwidth; the fused result is numerically identical to gathering the logits. Drag the vocabulary and watch the gap:
One small detail that pays off later: the vocabulary is padded from 50,257 up to 51,200, which is divisible by 1024 so that each GPU's share lands on a clean multiple of 128. That keeps the matrix multiplies at sizes the hardware runs efficiently, the same reason the head dimension was held fixed across the scaling study.
Making a deep model converge
Splitting the math across GPUs lets a giant model fit. It does not by itself make a giant model train. Two changes turned out to be load-bearing once the models got deep, and the more interesting one is about where you put the layer normalization.
BERT had a known wall: past its largest standard size, making it bigger made it worse, not better. Megatron traced this to the placement of the layer normalization relative to the residual connection. The original transformer puts the normalization after the residual add, so the value handed to the next layer is . That means a normalization sits directly on the residual stream and re-scales it at every single layer. The rearranged version moves the normalization inside the residual branch, computing , so the skip connection becomes a clean identity path from input to output and the normalization only ever feeds the sublayer. (These two orderings are now called post-norm and pre-norm; the names and the gradient analysis are from Xiong and collaborators, who studied them concurrently. The paper also adds one final normalization before the output head, so it is a small rearrangement, not a one-line move.)
That direct identity path explains why deep models become trainable: when the residual stream is an uninterrupted sum, a gradient can reach the early layers without being repeatedly rescaled by a normalization at every step. In a post-norm block, each layer's normalization multiplies the gradient on its way back, and across many layers those factors compound, so in a deep enough stack the early-layer gradients are badly scaled at initialization. The unscaled skip carries no such product. With the rearrangement, BERT's downstream accuracy goes back to improving monotonically as the model grows, all the way to 3.9 billion parameters. Toggle the two block orderings:
The second change is a quieter one about initialization. Every block adds its output onto the residual stream, so after many layers the stream is a sum of many contributions and its variance grows with depth. The size of the fix follows from a clean cancellation: the variances of the roughly independent contributions add, so the stream's variance is inflated by about , and scaling each contribution down by divides the variance back by and restores it. Here is the number of transformer layers and the 2 counts the two residual adds per layer (one from attention, one from the MLP). This is the same scale GPT-2 introduced as counting residual layers directly, written a different way. The base weights are drawn from a normal distribution with standard deviation 0.02; only the projections feeding the residual stream get the extra downscaling.
How far it scales
With the model split and converging, the question is how much of the extra hardware actually turns into useful work. Three numbers from the paper get conflated easily, so it is worth keeping them apart, because they measure three different things.
Megatron deliberately uses weak scaling on the model: as GPUs are added, the model grows with them rather than the batch, which is the question that actually matters here, since the goal is to train a model that did not fit before, not to run a fixed model faster. The single-GPU baseline is already strong, sustaining 39 teraflops, about 30% of the V100's tensor-core peak. That 30% sounds low only if you read the peak as reachable; it is a ceiling that real workloads, bottlenecked by memory traffic rather than raw multiply rate, essentially never touch, so 30% sustained on a real transformer is a demanding baseline to scale from. Splitting the 8.3-billion model 8 ways with tensor parallelism alone holds 77% of perfect linear scaling. Adding 64-way data parallelism on top to reach 512 GPUs holds 74%. And the whole 512-GPU run sustains 15.1 petaflops, which is 76% of what 512 independent copies of that strong single-GPU baseline would do. None of these is a peak-versus-sustained figure; they are three sustained efficiencies against three different baselines.
The figure's other two views show the texture behind the headline. Strong scaling, where the model is held fixed at 1.2 billion parameters and GPUs are added purely to go faster, gives speedups of 1.64, 2.34, and 2.98 on 2, 4, and 8 GPUs; real acceleration, but with the expected diminishing returns as each GPU's share of the work shrinks and communication starts to dominate. And the number of attention heads matters more than you might expect: at a fixed 8.3-billion model, going from 16 to 24 to 32 heads slips efficiency 82%, 80%, 77%, because more heads mean smaller per-head matrix multiplies and a larger softmax, both of which the hardware likes less.
Bigger models, better results
All of this engineering is in service of one bet: scale improves a model, and now you can build at scale. On the generative side, the GPT-2-style models improve monotonically with size and pass the prior state of the art. The 8.3-billion model reaches a WikiText103 test perplexity of 10.81, well under the previous best of 15.79 (lower perplexity is better; it means the model is less surprised by held-out text), and a LAMBADA cloze accuracy of 66.51%, up from 63.24%. The validation perplexity on their own corpus, a separate number, was 9.27.
On the understanding side, the rearranged BERT models climb the same way. The 3.9-billion model sets a new state of the art on the RACE reading-comprehension test, 89.5% as a single model and 90.9% as a five-model ensemble, past the prior 89.4%, with MNLI, QQP, and SQuAD all improving in step. Switch between the metrics:
What lasted is the method more than any benchmark number: the intra-layer split works and keeps working as you scale.GPT-3 a year later cites this paper and partitions its model across both width and depth. Turing-NLG, at 17 billion parameters, was trained with Megatron's tensor parallelism, and Megatron-Turing NLG, at 530 billion, combined it with pipeline and data parallelism into a three-dimensional split. The technique outlived its first paper. The direct follow-up went further inside the same layer, adding sequence parallelism and selective recomputation to cut the activation memory that this paper left on the table. What started as a way to fit one stubborn model on the hardware of 2019 became a standard way to train everything larger.
Questions you might still have
Is this the same as splitting a model across GPUs with pipeline parallelism?
No, and the two are complementary. Pipeline parallelism cuts the stack by depth (whole layers on different GPUs, like an assembly line with idle bubbles). Tensor parallelism, this paper, reaches inside one layer and splits its matrix multiplies. You can and do use both at once; later 530B-scale models combine tensor, pipeline, and data parallelism into a three-dimensional split.
Why not split the matrix the simple way and not worry about column versus row?
Because of the nonlinearity. If you split the first matrix by rows, each GPU has only a partial product, and since GeLU is nonlinear you cannot apply it to the pieces and add later. You would be forced to all-reduce in the middle of every block. Splitting by columns first keeps the GeLU local, and pairing it with a row split on the second matrix leaves just one all-reduce at the end.
Does fusing the loss into the output matrix multiply change the answer?
No. The fused cross-entropy is numerically identical to gathering all the logits and computing the loss centrally. It works because the softmax needs only a maximum and a sum of exponentials over the vocabulary, plus the one target logit, all of which can be reduced from per-GPU partials. It moves b times s scalar losses across the wire instead of b times s times v logits.
Why do more attention heads make scaling worse?
At a fixed model size, more heads means each head is smaller, so the per-head matrix multiplies shrink and the hardware runs them less efficiently, while the attention softmax (which grows with the number of heads) gets larger. The paper measures the slip directly: 82%, 80%, 77% efficiency at 16, 24, 32 heads on the 8.3B model.
Did Megatron invent tensor parallelism?
It invented the practical, transformer-specific version. Intra-layer tensor splitting existed in Mesh-TensorFlow, but that needed a custom compiler; the residual-init scaling came from GPT-2, activation checkpointing from Chen and collaborators, mixed precision from Micikevicius and collaborators. Megatron’s contribution is doing the intra-layer split with a few all-reduce operations in native PyTorch, and the name tensor parallelism is a retroactive one the field adopted later.
Footnotes & further reading
- The paper: Shoeybi, Patwary, Puri, LeGresley, Casper, Catanzaro, Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (NVIDIA, 2019). Code.
- The transformer the splits operate on: Vaswani et al., Attention Is All You Need (2017), for the two-GEMM MLP and multi-head attention.
- The intra-layer idea before a compiler-free implementation: Shazeer et al., Mesh-TensorFlow (2018); and the orthogonal pipeline approach, Huang et al., GPipe (2018).
- The memory levers: Chen et al., activation checkpointing (2016), and Micikevicius et al., mixed-precision training (2017). The 16-bytes-per-parameter accounting is made exact in Rajbhandari et al., ZeRO (2019).
- The layer-norm placement: Ba et al., Layer Normalization (2016), and the post-norm vs pre-norm analysis, Xiong et al., On Layer Normalization in the Transformer Architecture (2020).
- Where it went next: BERT and GPT-3 as the models scaled here and a year later; and the direct successor, Reducing Activation Recomputation in Large Transformer Models (2022), which adds sequence parallelism and selective recomputation inside the same layer.
How could this explainer be improved? Found an error, or something unclear? I read every message.