Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Re-center and re-scale every layer's inputs as the network trains, and it learns far faster.
Deep networks used to train slowly and need careful babysitting. Batch Normalization fixes the mean and variance of each layer's inputs on every mini-batch, which lets you use much higher learning rates and revives the saturating nonlinearities everyone had given up on.
Explaining the paperBatch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate ShiftNudge one early layer and you reshuffle the numbers every later layer sees. Stop the reshuffling and a network that used to crawl takes off.
Deep networks are trained by stochastic gradient descent: show the network a mini-batch of examples, measure how wrong it is, push every weight a little in the direction that lowers the loss, and repeat a few million times. It works, and in the years before 2015 it was also finicky. You set the learning rate low, you initialized the weights with care, and even then the deepest networks trained painfully slowly. Stacking many layers made the optimization fragile in a way nobody could fully pin down.
Ioffe and Szegedy put a name to one cause. Every layer's input is the previous layer's output, so the instant you update an early layer, the distribution of numbers arriving at every layer above it moves. The layers above are chasing a target that slides around underneath them, and they waste effort re-adapting to it. The paper calls this internal covariate shift: the change in the distribution of a layer's inputs caused by training the layers below. That name turned out to be the most disputed claim in the paper, while the operation it motivated stuck.
Batch Normalization makes normalization a step in the network itself: standardize each layer's inputs to mean zero and variance one over the current mini-batch, then let the layer learn to scale and shift them back however it likes. A few ideas carry the rest, and each is simple alone: why a sliding input distribution stalls training, why you cannot just normalize and walk away, the exact normalize-then-rescale operation, how gradients flow through it, what has to change at test time, and why faster training at learning rates that used to blow up follows.
Each layer's inputs keep shifting
Start with the concrete damage a sliding distribution does, using the nonlinearity the paper keeps returning to. A sigmoid, , squashes any input into . Its slope is , which peaks at when and falls toward zero as moves either way. Out at the slope is about . Since backpropagation multiplies by that local slope on the way down, an input sitting far out on the sigmoid's flat shoulder passes almost no gradient to the weights that produced it. Those weights stop learning. This is the saturation problem, and it is exactly what internal covariate shift triggers.
Walk the failure once. Early in training the pre-activations sit near zero, in the steep middle of the sigmoid, and gradients flow. A few thousand steps later the layers below have drifted the batch mean out to, say, , and most of the batch now lands on the flat part where the slope is a few thousandths. The gradient reaching those weights is a rounding error, and that layer quietly stalls. Deeper networks make it worse, because the drift compounds layer over layer. The usual escapes (ReLU instead of sigmoid, careful initialization, a smaller learning rate) treat the symptom. Batch Normalization goes after the cause by keeping the pre-activation distribution centered, so the batch stays where the slope lives.
Drag the drift in the figure below and watch the batch climb the curve. As its mean pushes into the flat tail, the points pile up where the sigmoid is level and the gradient-signal readout falls toward zero. Then flip Batch Norm on: the batch snaps back to the steep middle no matter how far the layers below have pushed it.
So a moving input distribution is not merely untidy. When it wanders into saturation it starves whole layers of gradient, and the network grinds to a halt. Hold the distribution still and the saturating sigmoid, long written off as too hard to train deep, becomes usable again, a claim the results section gets to put a number on.
Why you can't just normalize
The obvious move is to standardize each layer's inputs from the data, before each step, as a preprocessing pass. The authors tried a stripped version of exactly that and the model blew up, with one parameter racing off to infinity. Seeing why is what forces normalization to live inside the network.
Suppose a layer adds a learned bias , so its value is , and you normalize by subtracting the running mean: . Now take a gradient step on that ignores the fact that itself depends on . The step nudges by some (proportional to , the gradient with respect to the normalized value). Recompute the normalized output afterward:
The cancels itself. It shifts the mean by exactly the amount it shifts , so the mean-subtracted output, and therefore the loss, are unchanged. With no change in loss, the gradient on stays the same as before, so the next step nudges the same direction again, and grows without bound until the arithmetic overflows. Because the normalization subtracts the mean, and the mean travels with , any change to is erased by the normalization that follows it. The gradient step is pushing on a direction the normalization makes invisible.
Scrub the training steps in the figure below. In the naive mode climbs forever while the layer output and the loss sit dead flat, a useless drift that continues until the numbers explode. Switch the gradient to aware and stops moving, because once the gradient accounts for the normalization, changing provably cannot change the loss, so its gradient is exactly zero.
So normalization cannot be a bookkeeping pass bolted on between gradient updates. It has to be a layer the network differentiates through, so backprop accounts for how the mean and variance depend on the parameters. That dependence is the term the naive version dropped. Once it is included, normalizing makes the additive bias redundant (its gradient through the mean subtraction is zero, and the learned shift covers what it used to do), so batch-normalized layers simply drop the bias.
The batch-normalizing transform
Now the actual operation, applied to a single feature (one neuron's pre-activation). Across a mini-batch of examples you have values of that feature. Take their mean and variance, standardize, then apply a learned scale and shift:
Reading the four lines: and are the mean and variance of that feature over the batch. The variance divides by , the biased estimator, and that choice is deliberate; it returns in the inference section. The is a small constant added to the variance inside the square root so the division stays safe when a feature is nearly constant. The result has mean zero and variance one across the batch.
That last fact looks like it should hurt the network. Force a sigmoid's inputs to mean zero and variance one and you have trapped it in its near-linear middle, throwing away its ability to saturate when saturating is the right answer. So the transform does not stop at standardizing. It adds two learned parameters per feature, a scale and a shift , and the layer's real output is . The network can learn back any mean and spread it wants. It can even learn the identity: set
and you recover the original pre-normalization activation exactly (neglecting ). So normalization costs the layer nothing in representational power: at worst it learns to undo itself, and at best it starts from a better-conditioned place. Drag and in the figure below, or press "set to identity," and watch the output distribution detach from the fixed normalized one.
Two choices in (2) are worth a sentence each, because the textbook version of "normalize" is more ambitious. The classical move is to whiten: decorrelate the features and give them unit covariance, which needs the full covariance matrix and its inverse square root over the data. That is expensive (a matrix inverse-root every step), not always differentiable, and singular whenever the batch is smaller than the number of features. Batch Normalization makes two cheap simplifications instead: normalize each feature on its own (a mean and a variance, no cross-feature terms), and estimate those statistics from the current mini-batch rather than the entire dataset. The mini-batch estimate is precisely what lets the statistics ride along in backpropagation, the non-negotiable property the runaway-bias example forced.
In a convnet, and where to insert it
Batch Normalization goes immediately before the nonlinearity, normalizing the pre-activation , so becomes (the bias dropped, as above). It normalizes , not the raw layer input , on purpose: is the output of an earlier nonlinearity whose distribution shape keeps changing, whereas is a sum of many terms and tends to be smoother and more symmetric, the kind of distribution that fixing the first two moments actually helps.
For a convolutional layer you want the same normalization at every spatial position of a feature map (the convolutional property), so the statistics are pooled over the batch and all locations: a feature map of size in a batch of gives an effective batch of values, with one learned per feature map rather than per activation. That is the entire operating layer: one mean, one variance, a standardize, and a learned scale and shift, slotted in ahead of each nonlinearity.
Pushing gradients through the batch
For the transform to live inside the network, every step in (2) has to be differentiable, including the batch statistics. The unusual part is that the output for example depends on all the other examples in the batch, through the shared and , so the gradient for one example picks up paths routed through those shared quantities. The chain rule gives the full set (before any simplification):
Reading the lines: the incoming gradient is scaled by to reach . The and lines collect how every example's normalized value shifts when the shared variance and mean shift, and the gradient to the input sums three paths: the direct one through , plus the two indirect ones through the batch variance and mean. The learned and get the obvious sums. It is the chain rule applied through a couple of averages, and it gives backprop a path through the normalization itself, the property the runaway-bias argument said was required. Because these derivatives are cheap, a BN layer trains with whatever optimizer you already use, SGD, momentum, or Adagrad, and the batch statistics are part of what backprop differentiates.
Train on batches, test on the world
At test time a BN layer's output for an example still depends on the rest of its mini-batch, through the batch mean and variance. During training that coupling is fine, and even useful. At inference you want a deterministic answer that depends only on the input in front of you, not on whichever examples happen to be batched with it. So once training finishes, BN swaps its per-batch statistics for fixed population ones, estimated over the training data (in practice with moving averages kept during training) and then frozen:
Here is the unbiased population variance: average the per-batch (biased, ) variances over training batches, then multiply by to remove the bias. This is the deliberate asymmetry flagged back in (2). Training normalizes by the biased variance because that is the statistic the gradient was computed through, so consistency demands it; inference wants the best estimate of the true variance, which is the unbiased one. With the statistics frozen, BN is a fixed linear map per feature, and you can fold the standardize together with into a single affine transform:
so at deployment BN costs essentially nothing, merging into the neighboring linear layer. In code the only thing that changes between training and inference is which mean and variance you use:
# batch norm for one feature, over a mini-batch x[1..m]
def batch_norm(x, gamma, beta, eps, training, run_mean, run_var, mom):
if training:
mu = x.mean() # batch mean, divide by m
var = x.var(biased=True) # batch variance, divide by m
run_mean = (1-mom)*run_mean + mom*mu # track for test
run_var = (1-mom)*run_var + mom*unbiased_var(x) # m/(m-1)
else:
mu, var = run_mean, run_var # frozen population stats
x_hat = (x - mu) / sqrt(var + eps) # standardize to mean 0, var 1
return gamma * x_hat + beta # learned scale + shift(Most frameworks take a small shortcut: they keep an exponential moving average of the biased batch variance and skip the correction at test time. A gap between paper and practice, not a mistake in either.) BN behaving differently in training and inference is by design, and that gap is the source of both its mild regularizing effect and the small-batch failure mode that drives Group and Layer Norm.
Why it lets you train faster
The most useful thing BN buys is a much higher learning rate without divergence. In a plain network a too-large rate makes gradients explode or vanish and the weights blow up. BN damps that through a scale invariance: multiply a layer's weights by any constant and the normalized output is unchanged, because standardizing divides the scale right back out.
The gradient to the layer below is untouched by the weight scale, and the gradient to the weights themselves shrinks by . So if a step makes the weights larger, the next step's gradient on them is proportionally smaller, and weight growth becomes self-limiting. That is why a batch-normalized network tolerates a learning rate five, even thirty times higher than the un-normalized one without flying apart. This particular effect survives scrutiny: later work reads it as BN automatically tuning the effective learning rate as the weight norm grows (van Laarhoven 2017; Arora et al. 2018).
The paper offers a second, more theoretical reason, and it is the right place to be careful. It conjectures (the authors' word) that BN might push each layer's input-output Jacobian toward singular values near one, which would keep gradient magnitudes steady through depth. The paper itself flags this as unproven and "an area of further study," and later analysis went the other way: Yang et al. (2019) showed that at initialization BN drives those singular values away from one and can make gradients explode with depth. So treat the Jacobian story as an early guess that did not hold, not as an established mechanism.
The third reason is the one Figure 1 already made concrete: by keeping pre-activations in the steep region, BN makes saturating nonlinearities trainable again. A sigmoid network that was hopeless on its own becomes competitive once BN holds the inputs near the slope, with the ImageNet numbers below pinning down how much.
For intuition on the high-learning-rate claim, the two landscapes below are descended with one shared rate. The left, without BN, is jagged and high-curvature; the right, with BN, is a smooth bowl. Gradient descent on a quadratic stays stable only while the step is below , so the jagged side diverges at a rate the smooth side still rides down. Drag the rate up to see it, then push it to the very top, where both diverge, since BN raises the ceiling rather than removing it. That smooth-versus-jagged picture is the modern reading of why BN helps, picked up again once the ICS story is unwound.
14× fewer steps, and past human raters
The testbed is a variant of the Inception network on ImageNet, the 1000-class image-classification benchmark, with 13.6 million parameters, trained by momentum SGD at batch size 32. Batch Normalization goes before every nonlinearity, and the rest of the architecture is held fixed so the comparison is clean.
Adding BN alone (call it BN-Baseline) reaches Inception's accuracy in under half the steps and tops out a touch higher, against . Then the authors spent the headroom BN buys: they raised the learning rate , removed Dropout, cut the L2 weight penalty by , decayed the learning rate faster, and dropped local response normalization. That model, BN-x5, reaches in million steps against Inception's million, the abstract's 14× fewer steps, about 7% of the training. In wall-clock terms that is the difference between a run that takes two weeks and one that takes a day. Push the learning rate to (BN-x30) and it trains a little slower at first but climbs higher, to .
Drag the target-accuracy line in the figure below. At the horizontal gap between BN-x5 and Inception is that 14×; raise the target and BN-x30 pulls ahead as the others flatten against their ceilings.
The saturating-sigmoid claim gets its number here too: BN-x5 with a sigmoid in place of ReLU reaches , while Inception with a sigmoid and no BN never escapes chance, one correct guess in a thousand. That gap comes entirely from keeping the pre-activations off the flat shoulders.
Combine six BN networks and the ensemble reaches top-5 error on the ImageNet validation set ( on the held-out test set), past the previous best results (GoogLeNet's ensemble at , Deep Image at , and a contemporaneous ). The paper says this exceeds the accuracy of human raters, and that phrase deserves precision. The human number it beats is one carefully-trained annotator's estimate of about top-5 error (Russakovsky et al.); a second, less-practiced annotator scored around , and the annotator who set stressed that human accuracy sits on a speed-versus-effort tradeoff. So "past human raters" means edging past one trained human's best effort, not a robust superhuman result. With that caveat, the headline holds: the same architecture and the same parameter count, only the training changed, went from competitive to state of the art while training an order of magnitude faster.
What batch norm actually does
Batch Normalization plainly works. Why it works is a separate question, and the paper's own answer, internal covariate shift, is the part that did not survive. Three years on, Santurkar and colleagues tested the ICS story directly, and it failed three ways.
First, they added noise after each BN layer, deliberately re-introducing a large and constantly changing distributional shift, worse instability than a network with no normalization at all. If BN helped by stabilizing distributions, this should have wrecked it. The noisy-BN network trained almost as well as ordinary BN. Second, under a precise gradient-based definition of internal covariate shift, BN networks did not show less of it, and sometimes showed more, while still training far better; the link between ICS and optimization is, in their words, tenuous at best. Third, what BN does change is the optimization landscape: it makes the loss and its gradients smoother, so a gradient stays predictive over a longer step and larger learning rates remain stable. And the mean-zero, variance-one detail is not special, since other normalizations (L1, L2, L-infinity) that do not produce unit-Gaussian activations help comparably, which points at smoothing rather than distribution-fixing. (Scope it honestly: the smoothness theorems are proved for specific settings, and the result where L1 beats BN is for deep linear networks, not general ones.)
Be careful what that does and does not say. It does not say BN is useless or its gains are imaginary; they are large and real. It does not say internal covariate shift is fictional. It says the label on the box names the wrong mechanism. The Figure 4 picture, a jagged landscape made smooth, is closer to what is happening than the moving-distribution story the title leads with.
There is a second, more practical asterisk. BN ties every example's normalization to the rest of its batch. That coupling is what gives BN its mild regularizing noise: each example is seen alongside different neighbors every epoch, a little like Dropout, which is why the paper could remove Dropout and lose nothing. The same coupling is its weak spot. With small or skewed batches the per-batch statistics get noisy and the train-versus-inference gap widens, and accuracy falls; on ResNet-50 the error climbs from about at batch 32 to about at batch 2, where batch-independent successors like Group Normalization stay flat near . That single dependence on the batch is why later architectures, the Transformer above all, reach for Layer Normalization instead, normalizing across a layer's features within one example so there is no batch to depend on.
What survives is the operation, not the explanation. Standardize each feature over the batch, hand it back a learned scale and shift, differentiate through the batch statistics so backprop carries them, and freeze them at test time. That one layer let networks train at learning rates that used to be suicidal, revived the saturating nonlinearities everyone had abandoned, and turned a strong image model into a faster and better one without adding anything to the architecture beyond two numbers per feature. The mechanism named in the title is the part the field stopped believing. The layer it produced is in almost everything since.
Questions you might still have
If batch norm doesn’t reduce internal covariate shift, why does it work?
Later analysis (Santurkar et al. 2018) points at the optimization landscape: BN makes the loss and its gradients smoother, so a larger step still lands somewhere sensible. That lets you raise the learning rate, which is most of the speedup. The mean-0/variance-1 detail is not special; other normalizations help comparably.
Why does the layer need γ and β if normalization already standardizes everything?
Standardizing to mean 0, variance 1 would pin a sigmoid to its near-linear middle and throw away useful range. γ and β let the layer put any mean and spread back, and with γ = √Var[x], β = E[x] they recover the original activation exactly, so normalization never costs representational power.
Why use the biased variance in training but the unbiased one at test time?
During training you must normalize by the same statistic the gradient is computed through, which is the biased (÷m) batch variance. At inference you want the best estimate of the true population variance, so you use the unbiased (÷(m−1)) correction. Many frameworks instead keep a moving average of the biased variance and skip the correction; a practice gap, not an error.
Why does batch norm break with small batch sizes?
Each example’s output depends on the rest of its batch through the shared mean and variance. With few examples those statistics are noisy, and the gap between batch statistics (training) and frozen statistics (inference) widens, so accuracy drops. Batch-independent successors like Group and Layer Normalization avoid the batch entirely.
Footnotes & further reading
- The paper: Ioffe & Szegedy, Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Google, ICML 2015).
- The mechanism reckoning: Santurkar, Tsipras, Ilyas & Madry, How Does Batch Normalization Help Optimization? (NeurIPS 2018), the noise-injection experiment and the loss-smoothing account.
- The Section 3.3 conjecture, disproved at initialization: Yang, Pennington, Rao, Sohl-Dickstein & Schoenholz, A Mean Field Theory of Batch Normalization (ICLR 2019).
- The small-batch failure and batch-free fixes: Wu & He, Group Normalization; Ioffe, Batch Renormalization; and Ba, Kiros & Hinton, Layer Normalization.
- The scale-invariance descendant, effective-learning-rate tuning: van Laarhoven, L2 Regularization versus Batch and Weight Normalization, and Arora, Li & Lyu, Theoretical Analysis of Auto Rate-Tuning by Batch Normalization.
- Foundations the paper builds on: Shimodaira (2000) for the term "covariate shift," Glorot & Bengio (2010) for Xavier initialization, and the human-rater estimate from Russakovsky et al. (2014) (the ImageNet challenge paper).
How could this explainer be improved? Found an error, or something unclear? I read every message.