Technology

Batch Normalization, Explained

Abhay Abhay 5 min read
Batch Normalization, Explained
Photo by Anthony Roberts on Unsplash

Picture a relay race where every runner secretly changes the distance of the baton handoff between laps. The next runner trains for one distance, shows up, and finds the goalposts have quietly moved. That, roughly, is the headache that Batch Normalization was invented to cure — and since Sergey Ioffe and Christian Szegedy introduced it in their 2015 paper, it has become one of the most quietly indispensable tricks in deep learning.

The problem: layers keep moving the goalposts

A neural network is a stack of layers, each feeding the next. During training, every layer’s weights are being updated simultaneously. So the moment layer 3 learns something, the distribution of numbers it sends to layer 4 shifts — and layer 4, which had carefully tuned itself to yesterday’s inputs, now has to re-adjust. Multiply that across dozens of layers and millions of updates, and training becomes a game of whack-a-mole.

The original paper called this internal covariate shift: the inputs to each layer keep drifting as the layers below them learn. (Later research has poked holes in this exact explanation — more on that shortly — but the practical fix works regardless.) The symptom is familiar to anyone who has trained a deep net the hard way: you have to use a timid learning rate and obsess over weight initialization, or the whole thing diverges into NaN.

The fix: normalize, then let the network undo it

Batch Norm’s idea is disarmingly simple. For each mini-batch, take the inputs to a layer and standardize them: subtract the batch’s mean, divide by its standard deviation. Now every feature arrives roughly centered at zero with unit variance, no matter how chaotic the layer below has become. The goalposts stop wandering.

But there’s a clever twist. Forcing every layer’s inputs to be zero-mean and unit-variance might throw away useful information — maybe some layer wants its inputs scaled large or shifted off-center. So Batch Norm adds two learnable parameters per feature: a scale (gamma, γ) and a shift (beta, β). After normalizing, it computes γ · x̂ + β. The network can learn to stretch, squash, or even completely reverse the normalization if that’s what minimizes the loss. You give it a clean slate; it decides what to do with it.

Why it actually helps

The benefits show up fast and on multiple fronts:

  • Faster training. Networks converge in far fewer epochs. The original paper hit the same accuracy as a baseline in a fraction of the training steps.
  • Higher learning rates. With inputs kept in a sane range, you can crank the learning rate up without the gradients exploding — and a bigger learning rate is free speed.
  • Less fussing over initialization. Batch Norm is forgiving about how you initialize weights, which removes a whole category of dark-arts tuning.
  • Mild regularization. Because each example’s normalization depends on the random batch it landed in, there’s a little noise injected at every step — a gentle, dropout-like regularizing effect (sometimes enough to reduce how much dropout you need).

Worth noting: a well-known 2018 MIT study argued the real reason Batch Norm works isn’t covariate shift at all, but that it smooths the loss landscape, making the gradients more predictable. The mechanism is debated; the usefulness isn’t.

Train vs. inference: the part everyone trips on

Here’s the gotcha. During training, Batch Norm uses the statistics of the current mini-batch. But at inference time you might be predicting on a single example — and “the mean of one sample” is a nonsensical, unstable thing to normalize by. Worse, you want predictions to be deterministic: the same input should always give the same output, not depend on whatever else happened to be in the batch.

The solution: during training, Batch Norm keeps a running average of the mean and variance it sees across all batches. At inference, it quietly switches to using those frozen population statistics instead of the batch’s own. This is exactly why model.eval() in PyTorch (or training=False in Keras) matters so much — forget it, and your model will normalize test data with the wrong statistics and produce baffling results.

Where it goes, and a snippet

Batch Norm typically sits between a layer’s linear/convolutional transform and its activation function. In modern frameworks it’s one line:

import torch.nn as nn

model = nn.Sequential(
    nn.Linear(128, 256),
    nn.BatchNorm1d(256),   # normalize the 256 features across the batch
    nn.ReLU(),
    nn.Linear(256, 10),
)

model.train()   # uses batch statistics + updates running averages
model.eval()    # uses stored running averages — don't forget this!

A quick word on LayerNorm

If Batch Norm is so great, why don’t transformers use it? They use its cousin, Layer Normalization, instead. The difference is the axis: Batch Norm normalizes each feature across the batch, while Layer Norm normalizes all features within a single example. That independence from batch size is exactly what sequence models want — batches vary, sequences are processed token by token, and Layer Norm behaves identically in training and inference (no running-stats bookkeeping). Batch Norm rules computer vision; Layer Norm rules the transformer world.

The takeaway

Reach for Batch Norm in CNNs and feed-forward nets to train faster, push your learning rate higher, and stop fretting over initialization — drop it between your linear/conv layer and its activation. Just burn one rule into memory: always switch to eval() mode before inference so it uses running statistics, not the batch’s. For transformers and recurrent models, grab LayerNorm instead. Same goal — keep the numbers sane so the network can learn — different axis for a different job.

More posts