Foundation Model Engineering

3.4 Layer Normalization & Residuals

Building deep Transformers requires more than just self-attention. To ensure stable training and allow networks to grow to hundreds of layers, two critical components are used: Layer Normalization and Residual Connections.


Motivation: The Deep Network Dilemma

As we stack more layers in a neural network:

  • Covariate Shift: The distribution of inputs to deeper layers changes constantly during training, making it hard for them to learn.
  • Degradation Problem: Adding more layers can sometimes lead to higher training error, not because of overfitting, but because gradients struggle to flow back through so many layers.

Layer Norm and Residuals are the stabilization pillars that make deep learning actually work at scale.


The Metaphor: The Choir and The Shortcut

Imagine you are leading a large choir.

  • Layer Normalization is like a Volume Controller. If some singers are shouting and others are whispering, the harmony is ruined. Layer Norm ensures that at each step, everyone’s volume is adjusted to be at a standard, controlled level. It keeps the signals from blowing up or dying out.
  • Residual Connections are like building a Highway next to a winding side street. If you want to send a message from the beginning of the city to the end, going through every side street (layer) takes a long time and the message might get lost. A residual connection provides a highway (skip connection) where the message can travel directly without distortion.

Residual Connections: The Highway

Introduced in ResNet (He et al., 2015) [1], residual connections simply add the input of a sub-layer to its output.

Behind-the-Scenes Story: When Kaiming He and his team proposed ResNet in 2015, they were motivated by a counter-intuitive discovery: adding more layers to a deep network caused the training error to increase. This wasn’t due to overfitting, but rather a fundamental difficulty in optimization. They introduced the simple idea of adding the input directly to the output, creating a ‘residual’ for the layer to learn. This brilliant hack allowed them to train networks with over 100 layers for the first time.

Output=Sublayer(x)+x\text{Output} = \text{Sublayer}(x) + x

Transformer Residual and Layer Norm

Source: The Illustrated Transformer by Jay Alammar

This simple addition operation allows gradients to flow directly through the identity mapping (+x+ x), bypassing the complex non-linearities of the sub-layer during backpropagation. This effectively solves the vanishing gradient problem for extremely deep networks.


Layer Normalization: Keeping it Stable

Unlike Batch Normalization, which normalizes across the batch dimension, Layer Normalization (Ba et al., 2016) [2] normalizes the inputs across the features for each training case independently.

For a vector x\mathbf{x} of dimension dd: μ=1di=1dxi\mu = \frac{1}{d} \sum_{i=1}^d x_i σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 LN(x)=xμσ2+ϵγ+β\text{LN}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta

Where γ\gamma and β\beta are learnable parameters, and ϵ\epsilon is a small constant for numerical stability. Layer Norm is preferred in NLP because it works well with variable sequence lengths and small batch sizes.


PyTorch Implementation

Here is how these are combined in a standard Transformer block.

import torch
import torch.nn as nn

class TransformerSublayer(nn.Module):
    def __init__(self, d_model):
        super(TransformerSublayer, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        # Simulated sublayer (e.g., self-attention or feed-forward)
        self.sublayer = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # Post-LN architecture (as in original paper)
        # 1. Sublayer operation
        out = self.sublayer(x)
        # 2. Dropout and Residual addition
        out = x + self.dropout(out)
        # 3. Layer Normalization
        out = self.norm(out)
        return out

# Example usage
d_model = 64
layer = TransformerSublayer(d_model)
x = torch.randn(2, 10, d_model)
output = layer(x)
print("Output Shape:", output.shape)

[!NOTE] Modern Transformers often use Pre-LN, where normalization happens before the sublayer: x + sublayer(norm(x)). This is found to be even more stable.


Pre-LN vs. Post-LN: The Architectural Shift

The original Transformer paper placed the Layer Normalization after the residual addition (Post-LN). However, most modern Transformers (like GPT-2, GPT-3, Llama) use Pre-LN. Let’s compare them:

Post-LN (Original)

xl+1=LayerNorm(xl+Sublayer(xl))\mathbf{x}_{l+1} = \text{LayerNorm}(\mathbf{x}_l + \text{Sublayer}(\mathbf{x}_l))

  • Pros: The output of each layer is well-normalized, which can lead to better performance if trained successfully.
  • Cons: Gradients can become unstable at initialization, requiring a careful learning rate warm-up schedule. The variance of the hidden states increases as we go deeper.

Pre-LN (Modern Standard)

xl+1=xl+Sublayer(LayerNorm(xl))\mathbf{x}_{l+1} = \mathbf{x}_l + \text{Sublayer}(\text{LayerNorm}(\mathbf{x}_l))

  • Pros: Training is much more stable. Gradients can flow directly through the residual connection without being altered by normalization layers. This allows for training much deeper networks without complex warm-up schedules.
  • Cons: The identity branch is directly added to, which can lead to representation collapse if not careful, though in practice it performs very well.

Researchers found that in Post-LN, the expected gradient norm at the layers near the input is much smaller than at the layers near the output, leading to vanishing gradients. Pre-LN solves this by ensuring that the gradient norm is well-behaved across all layers.


Example: Normalization Effect

See how Layer Normalization transforms a vector with wild values into a normalized distribution with mean 0 and variance 1.

Raw Vector
Normalized Vector

Quizzes

Quiz 1: Why is Layer Normalization preferred over Batch Normalization in NLP? Batch Normalization computes statistics across the batch dimension. In NLP, sequences often have variable lengths, making batch statistics unstable. Also, small batch sizes (often used due to memory constraints of large models) make Batch Norm inaccurate. Layer Norm computes statistics across features for each token independently, making it invariant to batch size and sequence length.

Quiz 2: How do Residual Connections solve the vanishing gradient problem? In a residual block Output=F(x)+x\text{Output} = F(x) + x, the derivative with respect to the input contains a term +1+ 1 from the identity mapping (i.e., Outputx=F(x)x+1\frac{\partial \text{Output}}{\partial x} = \frac{\partial F(x)}{\partial x} + 1). This +1+1 term ensures that gradients can flow back directly even if the derivative of F(x)F(x) is very small, preventing the gradient from vanishing.

Quiz 3: What is the difference between Pre-LN and Post-LN architectures? In Post-LN (the original Transformer), normalization is applied after the residual addition: LayerNorm(x+F(x))\text{LayerNorm}(x + F(x)). In Pre-LN, normalization is applied to the input before the sublayer, and the output is added to the original input: x+F(LayerNorm(x))x + F(\text{LayerNorm}(x)). Pre-LN is found to be more stable for training very deep networks without warm-up.

Quiz 4: Why do we need the learnable parameters γ\gamma and β\beta in Layer Normalization? Without γ\gamma and β\beta, Layer Normalization would force the activations to always have mean 0 and variance 1. This might limit the expressive power of the network. The learnable parameters allow the network to scale and shift the normalized values to whatever distribution is optimal for learning, effectively allowing the network to “undo” the normalization if that is beneficial.

Quiz 5: In the Post-LN formula xl+1=LayerNorm(xl+Sublayer(xl))\mathbf{x}_{l+1} = \text{LayerNorm}(\mathbf{x}_l + \text{Sublayer}(\mathbf{x}_l)), why does the variance grow with depth? Since the output of the sublayer is added to the input before normalization, the variance of the sum is the sum of the variances (assuming independence). As we go deeper, these additions accumulate, causing the scale of the hidden states to grow before being squashed by the next LayerNorm. This makes optimization harder.

Quiz 6: Mathematically formulate the variance of the hidden state xL\mathbf{x}_L at layer LL for both Post-LN and Pre-LN architectures, assuming the sublayer output F(x)F(\mathbf{x}) has a variance of σ2\sigma^2. In a standard residual branch without LayerNorm, the variance propagates as Var(xl+1)=Var(xl)+Var(F(xl))\text{Var}(\mathbf{x}_{l+1}) = \text{Var}(\mathbf{x}_l) + \text{Var}(F(\mathbf{x}_l)) assuming independence. For Post-LN: xl+1=LN(xl+F(xl))\mathbf{x}_{l+1} = \text{LN}(\mathbf{x}_l + F(\mathbf{x}_l)). Since LayerNorm restores the variance to 1 at each step, the variance entering the next layer is always 1, but the hidden state scale before norm grows linearly: Var(xl+F(xl))=1+σ2\text{Var}(\mathbf{x}_l + F(\mathbf{x}_l)) = 1 + \sigma^2. For Pre-LN: xL=x0+l=1L1F(LN(xl))\mathbf{x}_{L} = \mathbf{x}_0 + \sum_{l=1}^{L-1} F(\text{LN}(\mathbf{x}_l)). The variance grows linearly with depth: Var(xL)=Var(x0)+Lσ2\text{Var}(\mathbf{x}_L) = \text{Var}(\mathbf{x}_0) + L\sigma^2. Because Pre-LN maintains an unaltered identity branch, the scale of gradients flowing back is independent of LL, preventing vanishing gradients without warm-up.


References

  1. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
  2. Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization. arXiv:1607.06450.