Foundation Model Engineering

3.1 Self-Attention Mathematics

The Transformer architecture, introduced in the seminal paper “Attention Is All You Need” (Vaswani et al., 2017), revolutionized AI by replacing recurrent structures with a mechanism called Self-Attention. This allows the model to process all tokens in a sequence simultaneously and model dependencies regardless of their distance.

Behind-the-Scenes Story: The title of the paper, ‘Attention Is All You Need,’ was a bold and provocative claim at the time. It suggested that one could completely discard RNNs and LSTMs, which had dominated sequence modeling for years, and rely solely on attention mechanisms. The authors at Google were motivated by the frustration with the slow training speeds of RNNs in machine translation. Their provocative title turned out to be a prophecy that changed the course of AI history.

Let’s understand the core mechanism through a familiar metaphor.


The Metaphor: The Filing Cabinet System

Imagine you are doing research in a library with a highly efficient filing system.

  • Query (QQ): This is what you are looking for. You have a topic in mind (e.g., “How do birds fly?”).
  • Key (KK): These are the labels or tags on the filing folders. Each folder has a summary of what’s inside.
  • Value (VV): This is the actual content inside the folder.

To find the information you need:

  1. You compare your Query against all the Keys to see which folders are relevant.
  2. You calculate a relevance score (Attention weight) for each folder.
  3. You extract the Values from the folders, giving more weight to the values from folders with high relevance scores.

In Self-Attention, every word in a sentence acts as a Query, a Key, and a Value to interact with every other word.


From Embeddings to Q, K, V: Linear Projections

In practice, we don’t use the input word embeddings directly as Queries, Keys, and Values. Instead, we project them into different spaces using learnable weight matrices.

Given an input sequence represented as a matrix XRT×dmodelX \in \mathbb{R}^{T \times d_{model}}, where TT is the sequence length and dmodeld_{model} is the embedding dimension:

Q=XWQQ = XW_Q K=XWKK = XW_K V=XWVV = XW_V

Where the learnable weight matrices are:

  • WQRdmodel×dkW_Q \in \mathbb{R}^{d_{model} \times d_k}
  • WKRdmodel×dkW_K \in \mathbb{R}^{d_{model} \times d_k}
  • WVRdmodel×dvW_V \in \mathbb{R}^{d_{model} \times d_v}

Why do we do this? If we used XX directly for all three roles, the self-attention operation would be purely based on the static embeddings. By using linear projections, the model can learn to extract different aspects of the same word for different roles. For example, a word might have a certain representation when acting as a Query looking for other words, and a different representation when acting as a Key being looked up. This significantly increases the expressiveness of the model.


The Mathematics of Scaled Dot-Product Attention

The core operation of Self-Attention is the Scaled Dot-Product Attention. Given a set of Queries, Keys, and Values packed into matrices Q,K,Q, K, and VV, the operation is defined as:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Self-Attention Matrix Calculation

Source: The Illustrated Transformer by Jay Alammar

Where:

  • QRT×dkQ \in \mathbb{R}^{T \times d_k} (Matrix of Queries)
  • KRT×dkK \in \mathbb{R}^{T \times d_k} (Matrix of Keys)
  • VRT×dvV \in \mathbb{R}^{T \times d_v} (Matrix of Values)
  • TT is the sequence length.
  • dkd_k is the dimension of the keys (and queries).
  • dk\sqrt{d_k} is the scaling factor.

Step-by-Step Breakdown

  1. Dot Product (QKTQK^T): Measures the raw similarity between each query and all keys.
  2. Scaling (1dk\frac{1}{\sqrt{d_k}}): Prevents the dot products from growing too large in magnitude, which would push the softmax function into regions with extremely small gradients.
  3. Softmax: Converts the scaled scores into probabilities (attention weights) that sum to 1.
  4. Weighted Sum (VV): Multiplies the attention weights by the values to get the final output.

PyTorch Implementation

Let’s implement this operation from scratch in PyTorch. This is the core building block of the Transformer.

import torch
import torch.nn as nn
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute Scaled Dot-Product Attention.
    """
    d_k = query.size(-1)
    
    # Step 1 & 2: Dot product and scaling
    # query shape: (batch, heads, seq_len, d_k)
    # key.transpose(-2, -1) shape: (batch, heads, d_k, seq_len)
    # scores shape: (batch, heads, seq_len, seq_len)
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    
    # Optional: Apply mask (e.g., for causal autoregressive decoding)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 3: Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 4: Weighted sum of values
    # output shape: (batch, heads, seq_len, d_v)
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# Example usage
batch_size = 1
seq_len = 4
d_k = 8
d_v = 8

# Random Tensors
Q = torch.randn(batch_size, 1, seq_len, d_k)
K = torch.randn(batch_size, 1, seq_len, d_k)
V = torch.randn(batch_size, 1, seq_len, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)
print("Output Shape:", output.shape)
print("Attention Weights Shape:", weights.shape)

Example: Attention Weight Matrix

Visualize how words attend to each other. Click on a word in the sentence to see its attention weights to all other words. (Simulated attention weights).

Sentence: "The animal didn't cross the street because it was too tired."
Click on a word to see its attention weights.
Selected Word Attends To:
1%
The
45%
animal
2%
didn't
5%
cross
1%
the
15%
street
1%
because
20%
it
2%
was
3%
too
5%
tired

Quizzes

Quiz 1: Why is the scaling factor dk\sqrt{d_k} necessary in the attention formula? As the dimension dkd_k grows large, the magnitude of the dot products grows large. This pushes the softmax function into regions where the gradient is extremely small (vanishing gradients). Dividing by dk\sqrt{d_k} scales the variance of the dot products back to approximately 1, ensuring stable gradients.

Quiz 2: What is the difference between Self-Attention and standard Attention (like Bahdanau Attention)? Standard attention typically aligns a decoder state with encoder states (cross-attention). Self-attention relates different positions of a single sequence to compute a representation of the same sequence (e.g., words in a sentence attending to other words in the same sentence).

Quiz 3: In the interactive example, why does the word “it” attend strongly to “animal”? This demonstrates coreference resolution. The model learns that “it” refers to the “animal” in this context, because of the semantic relationship and the word “tired” appearing later. This is the power of Self-Attention—capturing long-range dependencies and context.

Quiz 4: Why do we use linear projections to create Q, K, V instead of using the input embeddings directly? Linear projections allow the model to learn different representations for the same word depending on its role (Query, Key, or Value). This increases the model’s capacity to capture complex relationships compared to using static embeddings directly.

Quiz 5: How does Self-Attention handle variable-length sequences? Self-attention handles variable-length sequences naturally because the attention operation (weighted sum) does not depend on a fixed sequence length. The output dimension is determined by the input sequence length, and the same parameters (projection matrices) are used regardless of length.

Quiz 6: Formally derive the number of parameters and the total floating-point operations (FLOPs) for a single self-attention layer with sequence length TT, embedding dimension dmodeld_{model}, and dk=dv=dmodeld_k = d_v = d_{model}. Parameters: The layer requires projection matrices WQ,WK,WV\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V of dimension dmodel×dmodeld_{model} \times d_{model}. The total number of parameters is 3dmodel23d_{model}^2. (Excluding bias for simplicity). FLOPs: 1. Linear Projections: XWQ,XWK,XWVX\mathbf{W}_Q, X\mathbf{W}_K, X\mathbf{W}_V each require 2×T×dmodel22 \times T \times d_{model}^2 FLOPs (multiply and add). Total: 6Tdmodel26Td_{model}^2. 2. Attention Matrix QKTQK^T: Multiplying a T×dmodelT \times d_{model} matrix by a dmodel×Td_{model} \times T matrix requires 2T2dmodel2T^2d_{model} FLOPs. 3. Softmax and Scaling: O(T2)O(T^2) operations. 4. Multiplying with VV: Multiplying a T×TT \times T attention matrix with a T×dmodelT \times d_{model} matrix requires 2T2dmodel2T^2d_{model} FLOPs. Total dominant FLOPs: 6Tdmodel2+4T2dmodel6Td_{model}^2 + 4T^2d_{model}.


References

  1. Vaswani, A., et al. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008). arXiv:1706.03762.