In the previous posts, we trained a digit classifier to 97.8% accuracy with AdamW and cosine decay. That's great on MNIST's clean test set. But what happens when the model encounters messier data, or when we make the network deeper? This post covers the techniques that make training robust.
We'll keep evolving the same DigitClassifier - adding dropout, normalization, and residual connections, and measuring the impact of each.
Overfitting - The Core Problem
When a model performs perfectly on training data but fails on new data, it has overfit. It memorized the noise and peculiarities of the training set rather than learning the underlying pattern.
A model with more parameters than training examples can literally memorize every example. GPT-3 has 175 billion parameters. Even with trillions of training tokens, overfitting is a constant threat.
The entire field of regularization exists to answer one question: how do you force a model to learn general patterns instead of specific memorization?
Weight Decay (L2 Regularization)
The simplest regularization: add a penalty for large weights to the loss function.
The term (L2 regularization) pushes all weights toward zero. Large weights that memorize specific training examples get penalized. The model is forced to find solutions with smaller, more distributed weights.
L1 regularization () is the alternative. While L2 makes weights small, L1 makes weights exactly zero - effectively pruning connections from the network. L1 produces sparse models; L2 produces smooth models.
In practice, Transformers use weight decay (the AdamW formulation from the previous post) rather than L2 regularization. The effect is similar - shrink weights toward zero - but the implementation interacts correctly with adaptive optimizers.
Typical values: to . Too much and the model underfits. Too little and it overfits.
We already used this in the previous post - AdamW's weight_decay=0.01 is exactly decoupled L2 regularization:
Dropout - The Art of Forgetting
Dropout (Srivastava et al., 2014) is deceptively simple: during each training step, randomly set a percentage of neuron outputs to zero.
Typical dropout rate: 10-20% for Transformers, up to 50% for smaller models.
This feels counterintuitive - why cripple the network while it's trying to learn? Three reasons:
Forces redundancy. If neuron A might disappear, the network can't rely on it alone to carry critical information. The same knowledge gets distributed across many neurons.
Implicit ensemble. Each training step uses a different random subset of neurons - effectively training a different subnetwork. The final model is an ensemble of exponentially many subnetworks, averaged together. Ensembles almost always generalize better than individual models.
Breaks co-adaptation. Without dropout, neurons can form brittle partnerships where neuron B only works if neuron A fires first. Dropout prevents these fragile dependencies.
During inference, dropout is turned off. All neurons are active, but their outputs are scaled by to account for the fact that more neurons are active than during training.
Let's add dropout to our digit classifier:
To see dropout's effect, we need to overfit first. Let's train on only 1,000 examples:
Dropout prevents the model from memorizing the small training set, trading training accuracy for better generalization.
In Transformers, dropout is applied in two places:
- After the attention weights (before multiplying by V)
- After each sub-layer (attention and FFN), before the residual addition
The Vanishing Gradient Problem
Regularization keeps models from memorizing. But there's a deeper problem: making deep networks trainable at all.
Backpropagation multiplies gradients layer by layer. When those factors are less than 1, the product shrinks exponentially:
After 20 layers: . After 50 layers: . After 100 layers: effectively zero. Early layers stop learning entirely.
This is why deep networks were considered impractical for decades. And it's exactly what killed vanilla RNNs for long sequences - processing 500 tokens means 500 layers of gradient multiplication.
The exploding gradient problem is the flip side: when factors are greater than 1, gradients grow exponentially, causing numerical overflow. Gradient clipping (capping the gradient norm at a threshold, typically 1.0) is the standard fix.
Residual Connections - The Gradient Highway
He et al., 2015 proposed the fix that enabled truly deep networks: skip connections (also called residual connections).
Instead of learning the output directly, the layer learns a residual - the difference between the input and the desired output. The identity shortcut means the gradient has a direct path backward that bypasses the layer entirely.
The gradient through a residual connection:
That (the identity matrix) means the gradient can never fully vanish. Even if , the gradient still flows through the skip connection unchanged.
This one idea enabled:
- ResNet (152 layers, 2015) - won ImageNet with 3.57% error
- Transformers (6-96 layers) - every sub-layer uses residual connections
- Modern LLMs (up to 128 layers in some architectures)
Normalization - Taming Activation Drift
Even with residual connections, activations can drift to extreme values as they pass through many layers. Small biases compound. Without intervention, deeper layers receive inputs with wildly different scales than shallower layers, making training unstable.
Batch Normalization
Batch Norm (Ioffe & Szegedy, 2015) normalizes activations across the batch dimension: for each feature, compute the mean and variance across all examples in the batch, then normalize.
Then apply learned scale () and shift () parameters so the network can undo the normalization if needed.
Batch Norm was revolutionary for CNNs but has problems with sequences: the statistics depend on the batch, which varies between training and inference. It also requires reasonably large batch sizes to compute stable statistics.
Layer Normalization
Layer Norm (Ba et al., 2016) normalizes across the feature dimension instead of the batch dimension: for each individual example, compute the mean and variance across all features.
This is batch-size independent - it works the same whether your batch has 1 example or 1,000. This is critical for Transformers, where batch sizes vary and autoregressive generation processes one token at a time.
Every Transformer block uses Layer Norm. The original paper applied it after each sub-layer ("Post-LN"): . Modern implementations typically use "Pre-LN": , which is more stable for very deep networks.
RMSNorm
RMSNorm (Zhang & Sennrich, 2019) simplifies Layer Norm by dropping the mean centering - it only divides by the root mean square:
Cheaper to compute (no mean subtraction), and empirically just as effective. Llama 2 and many modern LLMs use RMSNorm instead of full Layer Norm.
The Add & Norm Pattern
Let's put it all together - a deeper version of our classifier with residual connections, layer normalization, and dropout:
This is the exact pattern used inside every Transformer block. LayerNorm stabilizes the input, the FFN transforms it, dropout regularizes, and the residual connection lets gradients flow.
In every Transformer block, these work together:
LayerNorm stabilizes the input to each sub-layer. The residual connection ensures gradients flow directly backward. Together, they enable stacking 6, 12, 32, or even 96 identical blocks without training instability.
This pattern appears twice per Transformer block:
- After multi-head attention
- After the feed-forward network
It's the architectural glue that makes deep Transformers possible. Without it, training a 32-layer model would be as hopeless as training a 32-layer vanilla network in 2010.
Putting It Together
| Technique | What it prevents | How it works |
|---|---|---|
| Weight decay | Overfitting | Shrinks weights toward zero |
| Dropout | Overfitting, co-adaptation | Randomly disables neurons during training |
| Gradient clipping | Exploding gradients | Caps gradient norm at a threshold |
| Residual connections | Vanishing gradients | Adds skip connections around layers |
| Layer Norm / RMSNorm | Activation drift | Normalizes features per-example |
Every modern Transformer uses all five simultaneously. They're not optional extras - they're load-bearing infrastructure.
Our model's journey so far
| Blog | What we changed | Test accuracy |
|---|---|---|
| Blog 1 | Basic network + SGD | 93.6% |
| Blog 2 | Swapped to Adam | 97.4% |
| Blog 2 | AdamW + weight decay | 97.6% |
| Blog 2 | + cosine LR schedule | 97.8% |
| Blog 3 | + dropout + LayerNorm + residuals | 98.1% |
From 93.6% to 98.1% - without changing the network size or the data. Every technique in these three posts contributed to that improvement.
Next: how do we apply all of this to language? Continue to Language Modeling & Recurrent Networks.