Layer Normalization
Normalize each example across its own features, not across the batch.
Batch normalization normalizes each feature using statistics computed across the whole batch, so it cannot run without a batch, and it behaves differently at training and test time. Layer normalization normalizes each example using only that example's own features, so it needs no batch, runs at any batch size, and steadies recurrent networks.
Explaining the paperLayer NormalizationBatch normalization made deep networks train faster, then asked for something not every model can spare: a batch. Layer normalization asks only for the layer in front of it.
A year before the Transformer existed, three researchers in Toronto wrote down a one-line change to a popular training trick, aimed at the one place that trick worked badly: recurrent networks. The trick was batch normalization, which standardizes the summed input to each neuron using the mean and variance of that neuron across the cases in a mini-batch. Feedforward networks trained with it converge much faster. But the statistics come from the batch, and that single design decision drags three problems behind it.
The batch cannot be too small, because a handful of cases gives a noisy mean and variance, and at a batch of one there is nothing to average at all. Training and test do different arithmetic, because a single test example has no batch, so at test time you freeze a running average of the statistics collected during training and use that instead. And a recurrent network, which reuses the same weights at every step of a sequence whose length changes from example to example, has no clean batch to pool over: the obvious port keeps separate statistics for each time-step, which falls apart the moment a test sequence runs longer than any training one.
Layer normalization keeps the standardization and changes one thing: the axis the statistics come from. The paper's own phrase is that it transposes batch normalization. A few ideas carry it from there: why that axis swap buys training at batch size one, the gain and bias that keep normalization from costing the layer anything, the handful of transformations the result is provably blind to, and why all of it pays off most inside a recurrent network.
Batch norm needs a batch
Start with what a layer does without any normalization. The -th neuron forms a summed input from the layer below, and the layer's output is for some elementwise nonlinearity . Everything normalization touches happens to those summed inputs , before the nonlinearity.
Batch normalization rescales each by the mean and standard deviation of that one neuron, measured over the data:
The expectations run over the whole training distribution, which you cannot compute each step, so in practice and are estimated from the current mini-batch. That estimate ties batch normalization to the batch. Each neuron carries its own pair of statistics, pooled down the column of cases, and the gain is a learned scale applied afterward. The method earns its speedup, and on a feedforward convolutional network it remains hard to beat. The friction shows up only when the batch is small, when training and test must match, or when the model is recurrent, which is exactly the setting this paper set out to fix.
Normalize the other axis
Picture the summed inputs of a layer as a grid: one row per neuron, one column per training case in the batch. Batch normalization takes a row and averages it across the columns. Layer normalization takes a column and averages it across the rows, computing one mean and one standard deviation from all units of the layer, on a single case:
Every unit in the layer now shares the same and , and a different training case gets a different and . Nothing in this calculation looks at any other example, so the batch axis drops out of the math entirely. Two consequences follow at once. Layer normalization runs at a batch of one, which makes it usable for online learning and for models too large to fit more than a single case in memory. And it does identical arithmetic at training and at test, because there are no batch statistics to estimate during training and freeze for later, so every forward pass computes the same thing.
The flip is easy to picture as grading on a curve. Batch normalization curves each exam question against the whole class, one adjustment per question, pooled over students. Layer normalization curves each student against the spread of their own subject scores, one adjustment per student, pooled over subjects. The curving is the same operation applied along the other axis, which is also the literal sense in which one is the transpose of the other: batch norm reduces along the batch dimension, layer norm along the feature dimension. In code the only difference is the axis you reduce over:
# x: a mini-batch of summed inputs, shape [N cases, H units]
# batch norm: average each feature DOWN the batch (axis 0)
bn = (x - x.mean(0)) / x.std(0)
# layer norm: average each case ACROSS its units (axis 1)
ln = (x - x.mean(1, keepdims=True)) / x.std(1, keepdims=True)In the grid below, switch between the two modes and drag the batch size. Watch what happens to each method's outlined group as the batch shrinks. Layer norm's group is a column of features, so it never changes with . Batch norm's group is a row across the batch, and at that row holds a single number, so its variance is zero and there is nothing left to normalize.
What the layer computes
Standardizing pins every layer's summed inputs to mean zero and unit variance, which would throw away any scale or offset the layer actually wanted. So the same per-unit gain and bias batch normalization uses are put back here, after the normalization and before the nonlinearity:
If the operation stopped at it would force every layer's pre-activations into the same fixed shape and remove two degrees of freedom the network might need. The gain and bias hand exactly those two back, per unit, so normalization costs the layer nothing it cannot recover. (Weight normalization, a third method this paper compares against, slots into the same template: set and let , the norm of the incoming weights.)
By default the gain starts at and the bias at , so layer normalization begins life as a pure standardizer and the network learns from there how much scale and shift to restore. (The paper's appendix states this initialization backwards in one sentence, in a way that would zero every activation. The version that works, and the one everyone uses, is gain one, bias zero. See the note in Provenance below.)
A concrete pass makes the shapes unambiguous. Take a four-unit layer whose summed inputs on one case are . The mean is and the standard deviation is , so the standardized vector is , which has mean zero and unit variance by construction. At the default gain and bias that vector goes straight into . The forward pass is six lines:
# layer norm on one case's summed inputs a (shape [H])
mu = a.mean() # average over the H units
var = ((a - mu) ** 2).mean() # biased: divide by H, not H-1
a_hat = (a - mu) / sqrt(var + eps) # standardize: mean 0, var 1
h = f(g * a_hat + b) # learned gain g, bias b, then f
# init g=1, b=0 -> h = f(a_hat)Two details in that listing matter. The variance divides by , the biased estimator, not by , and the printed equations in the paper carry no inside the square root, while every real implementation adds a small one to keep the division safe when a layer's units happen to agree. Drag the gain and bias in the figure: at the default they leave the standardized vector untouched, and anywhere else they reshape the output's spread and center while the middle row stays fixed.
What it leaves unchanged
The paper's real contribution is a catalogue of changes the formula is blind to. If you can perturb the weights or the inputs in some way and the output does not move, then learning cannot be thrown off by drift of that kind. The three methods land differently:
| Invariant to | Weight matrix re-scale | Weight matrix re-center | Weight vector re-scale | Dataset re-scale | Dataset re-center | Single case re-scale |
|---|---|---|---|---|---|---|
| Batch norm | yes | no | yes | yes | yes | no |
| Weight norm | yes | no | yes | no | no | no |
| Layer norm | yes | yes | no | yes | no | yes |
Two of layer norm's entries are its own, and both repay a closer look. Take the first, invariance to re-scaling a single training case. Multiply one example by , and every summed input scales by too. So do the mean and the standard deviation, and the factor cancels in the ratio:
The model becomes blind to a per-example change of scale, the way a camera auto-exposing each shot makes a dim photo and a bright copy of the same scene come out matched. Batch normalization has no such property, because its mean and variance are pooled across the batch, not read off the single case in front of it. Drag below: the raw row runs off the panel or collapses toward zero, and the layer-norm row never moves.
Now the re-centering invariance. Shift every row of by one common vector (and scale the matrix by while you are at it), so . That added term lands the same scalar on every neuron's summed input. Centering subtracts exactly that shared amount, and the cancels against as before, so the layer's output is unchanged:
Layer norm shares one subtlety that is easy to get backwards. It is invariant to re-scaling the whole matrix but not to re-scaling a single weight vector, the incoming weights of one neuron, where both batch norm and weight norm are invariant. Scaling one row changes that neuron's relative to its neighbors, which moves the and the whole layer shares, so the output does shift. (The paper's prose at one point says all three methods are invariant to re-scaling the dataset, which contradicts its own table for weight norm. The table is the version to trust.)
Scale and the geometry
Invariance says the output does not change. The paper goes one step further and asks how the in the denominator changes learning, even when two parameterizations express the same function. The idea goes back to Amari: measure a parameter step by how much it moves the model's output distribution, the KL divergence (how far apart two distributions are) between before and after, rather than by its raw length in parameter space. To second order that distance is a quadratic form in the Fisher information matrix:
The Fisher matrix is the curvature of that KL distance, the local curvature of the function the network computes in the coordinates of its own parameters. What matters for normalization: the dividing each neuron also scales the curvature along that neuron's weight vector. Let a weight vector's norm grow and grows with it, so the same gradient step moves the function less. The network gets an automatic, per-direction learning-rate decay: a direction that has already grown large takes smaller and smaller effective steps, an implicit early stopping that the paper argues damps runaway weights and stabilizes convergence.
Two limits sit on this argument. The paper's clean "curvature halves when the norm doubles" is illustrative, not a theorem: it assumes grows exactly in step with the weight norm and holds the rest of the layer fixed, neither of which is quite true. And the tidy conclusion that a neuron's gain can be learned from the prediction error alone, decoupled from the scale of the input, holds for the generalized linear model the analysis assumes, a single linear layer feeding a fixed nonlinearity, not as an exact statement about a deep nonlinear stack. As geometry under simplifying assumptions it explains a real effect; as a precise law it is hedged on purpose.
Why RNNs gain most
A recurrent network reuses the same weights at every step, so whatever the recurrence does to the scale of the hidden state, it does again the next step, and the next. If the typical magnitude of the summed inputs is multiplied by a little more than one each step, it grows without bound over a long sequence; a little less, and it decays toward zero. The hidden state's scale drifts, and the gradients riding along it tend to explode or vanish with it. (The link to the gradient is a heuristic one, not an identity: what actually governs exploding and vanishing gradients is the singular values of the recurrent matrix, and controlling the forward scale only steadies the forward pass directly.)
Layer norm's invariance to re-scaling all of a layer's summed inputs pins that scale at every step. It re-standardizes the hidden vector to a fixed size set by the gain, whatever magnitude it arrived with, keeping the direction and resetting the length. It works like the automatic gain control on a microphone, which lifts each frame to a target level so the downstream chain always sees a steady signal. One gain and one bias are shared across all time-steps, so nothing per-step has to be stored and the parameter count does not grow with sequence length, which is also why a longer test sequence is no problem. Recurrent batch normalization, by contrast, keeps separate statistics per step and breaks past the longest sequence it trained on.
Drag the per-step factor across one. Below one the un-normalized magnitude sinks to the floor; above one it runs off the top; the layer-normalized trace holds flat through both, because every step resets the scale.
Wiring it into an LSTM takes a little care, and the appendix is specific about it. The two projections that feed the gates are normalized separately, each with its own gain and bias:
Keeping a separate layer norm on each projection, rather than one on their sum, lets each input stream carry its own scale into the gates, and a third layer norm wraps the cell state on the way out. The gated recurrent unit gets the same treatment. Long sequences and small batches, the regime where these models are usually trained, are exactly where batch norm struggles and layer norm has the most to offer.
What it buys
The experiments lean on recurrent models, and the headline is consistent across them: faster convergence, and usually a better endpoint too, not a trade between the two. On an image-and-sentence ranking model (order embeddings on MSCOCO), layer norm reaches its best validation score in of the time the baseline takes, and lifts the final numbers rather than just arriving sooner (caption recall@1 from to , image recall@1 from to ). On a question-answering attentive reader it beats both the baseline and recurrent batch normalization, and where recurrent BN needed its gain carefully initialized to to work at all, layer norm was insensitive to that choice, with the natural (the pass-through default) performing best.
The pattern repeats on a generative model (the DRAW attention model on binarized MNIST converges almost twice as fast and reaches against the baseline's nats of test negative log-likelihood, where lower is better), on unsupervised sentence representations, and on handwriting generation, where the sequences run around 700 steps and a batch of 8 makes stable dynamics matter most. A feedforward check on permutation-invariant MNIST shows the batch-size story directly: layer norm held its convergence as the batch dropped from 128 to 4, where batch norm degrades. One design detail there points back at the invariances: layer norm is applied to the hidden layers but deliberately not to the output, because the scale of the output logits encodes the model's confidence, and layer norm's per-case rescale invariance would erase exactly that.
On convolutional networks the result flips: layer norm beats no normalization but loses to batch norm. The authors give the reason: in a fully connected or recurrent layer the units contribute similarly, so one shared mean and variance fits them, but in a convolutional layer the units whose receptive fields sit near the image boundary fire rarely and have very different statistics, which a single shared statistic models poorly. "Further research is needed," they wrote, and that remains roughly the state of things.
One caveat on the motivation, not the results. The paper, following the batch normalization paper, frames normalization as reducing "covariate shift," the correlated drift in a layer's inputs as the layers below it learn. That explanation was later contested: Santurkar and colleagues argued in 2018 that batch normalization helps less by reducing covariate shift than by smoothing the loss landscape, and showed you can inject covariate shift back in and still train fast. The dispute is about the why, and it was settled for batch norm, not formally for layer norm. Either way it leaves layer norm's concrete wins, the invariances and the steadier recurrent dynamics, standing on their own.
What came after
Layer normalization was written for recurrent networks, and a year later the Transformer made it inescapable, since it is the normalization in every Transformer block. Two things about that modern usage are not in this paper. The placement: the original Transformer put the norm after the residual addition (Post-LN), while putting it inside the residual branch instead (Pre-LN) trains more stably and is now the default, a refinement that came years later. And the simplification: RMSNorm drops the mean subtraction, and usually the bias, dividing only by the root-mean-square of the vector. LLaMA and Mistral use it, on the wager that the re-centering was not earning its cost. The figure below isolates what the mean subtraction was doing: add an offset to every component of a vector and watch which normalizer removes it.
One more difference trips up readers who learned layer norm from a Transformer. The 2016 paper normalizes the pre-activation summed inputs and then applies the nonlinearity, . The modern version (PyTorch's nn.LayerNorm, the Transformer residual stream) normalizes the activation vector itself with no nonlinearity following it. Same statistics, computed at a slightly different spot in the layer.
It all reduces to a choice of axis. Average a neuron down the batch and you get batch normalization, and everything the batch brings with it: a size floor, a train-test split, a bad fit to recurrence. Average a case across its own units and the batch drops out of the math, training and test agree, and a recurrent network stops drifting. The gain, the bias, the invariances, and the geometry all follow from that one decision, to read the statistics off the layer in front of you instead of the batch around you.
Questions you might still have
Why not just keep using batch normalization?
Batch norm reads its mean and variance from the batch, so it needs a reasonably large one, does different arithmetic at training and test (frozen running averages), and fits recurrent networks badly. Layer norm reads the statistics from one case’s own units, which removes all three constraints.
Does it really work at batch size 1?
Yes. The mean and standard deviation come from the H units of the layer on a single case, so a batch of one is enough. That also makes the train and test computations identical, since there is nothing batch-dependent to freeze.
Where do the gain and bias go, and how are they set?
After standardizing and before the nonlinearity: h = f(g·â + b), per unit. They start at gain 1, bias 0, so layer norm begins as a pure standardizer. (The paper’s appendix states this init backwards in one sentence; the correct version is gain 1, bias 0.)
Is this the same LayerNorm as in Transformers?
Same statistics, but this paper is from 2016 and is about RNNs. The Transformer (2017) used Post-LN; Pre-LN came later and is now standard; RMSNorm (2019, used in LLaMA and Mistral) drops the mean. None of those variants is in the original paper.
Does it reduce “internal covariate shift”?
That was the inherited motivation, and it is contested. Santurkar et al. (2018) argued batch norm helps mainly by smoothing the loss landscape, not by reducing covariate shift. Layer norm’s solid story is its invariances and its steadier recurrent dynamics.
Footnotes & further reading
- The paper: Ba, Kiros, Hinton, Layer Normalization (University of Toronto, 2016).
- The method it transposes: Ioffe & Szegedy, Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (2015), explained here.
- The third method compared: Salimans & Kingma, Weight Normalization (2016); and recurrent batch norm: Cooijmans et al., Recurrent Batch Normalization (2016).
- The geometry: Amari, Natural Gradient Works Efficiently in Learning (1998). The covariate-shift dispute: Santurkar et al., How Does Batch Normalization Help Optimization? (2018), and the layer-norm-specific follow-up, Xu et al., Understanding and Improving Layer Normalization (2019).
- Later developments: Vaswani et al., Attention Is All You Need (2017), explained here; the Pre-LN analysis, Xiong et al., On Layer Normalization in the Transformer Architecture (2020); and RMSNorm, Zhang & Sennrich, Root Mean Square Layer Normalization (2019), as used in LLaMA.
How could this explainer be improved? Found an error, or something unclear? I read every message.