12.4 Speculative Decoding
Large Language Model (LLM) inference is notoriously slow and compute-intensive. The autoregressive nature of generation—where the model must predict tokens one by one, with each step depending on all previous tokens—makes it a memory-bandwidth bound problem.
Speculative Decoding, introduced independently by Leviathan et al. [1] and Chen et al. [2], is a powerful technique to accelerate inference without changing the model’s output distribution. It achieves this by trading compute for memory bandwidth.
The Bottleneck: Memory Bandwidth vs. Compute
During LLM inference, loading the massive model weights from GPU memory (VRAM) to the compute cores (ALUs) takes orders of magnitude more time than the actual matrix multiplication. This is called a memory-bandwidth bound problem.
When generating tokens one by one, we are forced to load the entire model for every single token. This results in low GPU utilization. However, if we could process a batch of tokens at once, the cost of loading the weights would be amortized across the batch, leading to much higher efficiency.
The Core Concept: Draft and Verify
Speculative Decoding breaks the autoregressive bottleneck by using two models:
- Draft Model (): A small, fast, and lightweight model (e.g., a few hundred million parameters).
- Target Model (): The large, high-quality model we actually want to use (e.g., tens or hundreds of billions of parameters).
The process works in cycles:
- Drafting: The small draft model generates candidate tokens autoregressively. This is very fast because the model is small.
- Verification: The large target model processes all candidate tokens in a single forward pass (parallel decoding). It computes the logits for all positions simultaneously.
- Correction: We compare the predictions of both models. We accept the draft tokens as long as they match the target model’s distribution. The first rejected token determines where we stop, and we use the target model’s prediction for that position. We then repeat the process.
Even if the draft model is only correct of the time, generating several tokens per large model forward pass yields significant speedups (often to ).
The Mathematics of Rejection Sampling
To ensure that the output distribution of Speculative Decoding is identical to sampling directly from the target model , we use a specialized rejection sampling scheme.
Let be the probability distribution of the next token predicted by the draft model , and be the distribution predicted by the target model .
Suppose the draft model proposes a token . We accept with probability:
- If , the target model agrees or is more confident, so we always accept ().
- If , we accept with probability .
If the token is rejected, we stop the verification for this cycle. To maintain the exact distribution of the target model, we must sample the replacement token from a modified distribution :
This ensures that the final sampled token follows the exact distribution of the large model, making Speculative Decoding mathematically lossless.
When It Works Well in Practice
The speedup from speculative decoding is controlled less by the elegance of the math than by the acceptance rate in real workloads.
- If the draft model is close to the target model, easy continuations are often accepted in long runs.
- If the prompt is highly ambiguous, the temperature is high, or the task is brittle (for example, exact code syntax), acceptance rates usually drop.
- If the target model is extremely large and memory-bandwidth bound, even a modest acceptance rate can still be valuable because each target-model pass is so expensive.
This is why production systems tune speculative decoding per workload instead of treating it as a universal switch. The best draft model is not simply the smallest one. It is the one that gives the most accepted tokens per unit of extra cost.
Serving Trade-offs and Failure Modes
Speculative decoding also changes the serving stack.
Choosing the Lookahead Length
Larger creates more room for acceleration, but only if the draft model remains accurate deep into the lookahead window. If most rejections happen after one or two positions, pushing higher only adds wasted draft computation and larger verification tensors.
KV Cache Coordination
The draft model and target model maintain different KV caches and may advance through the sequence at different speeds. A production implementation must keep these caches synchronized carefully, especially when some draft tokens are accepted and others are replaced. This makes speculative decoding more complex than ordinary batch decoding even though the core idea is simple.
Throughput vs. Tail Latency
Speculative decoding often improves average throughput, but it can complicate tail-latency behavior. Different requests may have very different acceptance rates depending on task type, prompt entropy, and temperature. In multi-tenant serving, operators therefore track not only tokens-per-second, but also acceptance-rate distributions and latency variance across request classes.
PyTorch Implementation of Verification
Here is a PyTorch implementation demonstrating the core verification and rejection sampling logic of Speculative Decoding.
import torch
import torch.nn.functional as F
def verify_draft_tokens(draft_logits, target_logits, draft_tokens, temperature=1.0):
"""
Verify draft tokens and apply rejection sampling.
Args:
draft_logits: Logits from draft model. Shape (batch_size, K, vocab_size)
target_logits: Logits from target model. Shape (batch_size, K + 1, vocab_size)
draft_tokens: Tokens proposed by draft model. Shape (batch_size, K)
temperature: Sampling temperature.
Returns:
accepted_tokens: List of accepted tokens.
num_accepted: Number of accepted tokens.
"""
batch_size, K = draft_tokens.shape
# Apply temperature
draft_probs = F.softmax(draft_logits / temperature, dim=-1)
target_probs = F.softmax(target_logits[:, :K, :] / temperature, dim=-1)
accepted_tokens = []
for i in range(K):
# Get probabilities for the specific draft token
q = draft_probs[:, i, draft_tokens[:, i]]
p = target_probs[:, i, draft_tokens[:, i]]
# Acceptance probability
p_accept = torch.min(torch.ones_like(p), p / q)
# Roll the dice
rand_val = torch.rand_like(p_accept)
accepted = rand_val < p_accept
if accepted.all():
accepted_tokens.append(draft_tokens[:, i])
else:
# Rejection! Sample from adjusted distribution
# Note: For simplicity in this multi-batch simulation, we assume batch_size=1
# In full implementation, you handle mask per batch item
diff = target_probs[:, i, :] - draft_probs[:, i, :]
adjusted_probs = torch.clamp(diff, min=0.0)
if adjusted_probs.sum() > 0:
adjusted_probs = adjusted_probs / adjusted_probs.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(adjusted_probs, 1)
else:
# Fallback to target model's original distribution if difference is zero
next_token = torch.multinomial(target_probs[:, i, :], 1)
accepted_tokens.append(next_token.squeeze(-1))
break
# If all K tokens are accepted, we can also take the target model's prediction
# for the (K+1)-th position which was computed for free!
if len(accepted_tokens) == K:
last_probs = F.softmax(target_logits[:, K, :] / temperature, dim=-1)
next_token = torch.multinomial(last_probs, 1)
accepted_tokens.append(next_token.squeeze(-1))
return torch.cat(accepted_tokens, dim=-1), len(accepted_tokens)
# Example Usage (Batch size 1)
K = 4 # Draft lookahead
vocab_size = 1000
draft_logits = torch.randn(1, K, vocab_size)
target_logits = torch.randn(1, K + 1, vocab_size)
# Simulate draft model choosing tokens
draft_tokens = torch.argmax(draft_logits, dim=-1)
accepted, count = verify_draft_tokens(draft_logits, target_logits, draft_tokens)
print(f"Accepted {count} tokens.")
print(f"Accepted token sequence: {accepted}")
Quizzes
Quiz 1: Why is the target model’s forward pass in Speculative Decoding faster than generating tokens autoregressively, even though it processes more tokens?
The target model processes all draft tokens in a single forward pass using parallel matrix multiplications. Because inference is memory-bandwidth bound, the time taken to load the model weights dominates the execution time. Loading the weights once to process tokens in parallel takes almost the same time as loading the weights to process a single token.
Quiz 2: What happens to the speedup of Speculative Decoding if the draft model is very small but has extremely low accuracy?
If the draft model has low accuracy, the target model will reject the draft tokens at early positions in most cycles. This means the model will only generate 1 or 2 tokens per cycle, similar to standard autoregressive decoding, but with the added overhead of running the draft model. The speedup will disappear, and inference may even become slower.
Quiz 3: Explain how the adjusted distribution guarantees that the final output follows the target model’s distribution exactly.
The adjusted distribution focuses on the probability mass that the target model assigns to tokens but the draft model missed. It is proportional to . By sampling from this residual distribution when a draft token is rejected, we mathematically compensate for the “under-sampling” by the draft model, ensuring the total probability of generating any token matches the target model exactly.
Quiz 4: Derive the formula for the expected number of generated tokens per iteration in Speculative Decoding. Let the number of draft tokens proposed be , and the independent probability of the target model accepting each draft token be .
For an iteration where draft tokens are proposed, the sequence generated is the sequence of accepted tokens plus one token generated by the target model (either a replacement for the first rejected token, or an additional token if all are accepted).
Let be the number of accepted draft tokens. can take values from to .
For , . In this case, tokens are generated.
For , . In this case, tokens are generated.
The expected number of tokens generated is:
.
Using algebraic simplification, this geometric sequence reduces precisely to:
.
This mathematical formula demonstrates that increasing the draft length has diminishing returns if the acceptance rate is low.
References
- Leviathan, Y., et al. (2023). Fast Inference from Transformers via Speculative Decoding. arXiv:2211.17192.
- Chen, C., et al. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv:2302.01318.