Foundation Model Engineering

20.6 Diffusion-based LLMs

For years, the standard approach to language modeling has been Autoregressive (AR)—predicting the next token strictly one after another from left to right. This has been extremely successful, but it has an obvious limitation: once the model commits to a token, it does not naturally revise the whole sequence. It is like writing a complex essay on a typewriter with no eraser.

Enter Diffusion-based LLMs. Inspired by the massive success of diffusion models in image generation (like Imagen and Stable Diffusion), researchers have successfully adapted diffusion processes to discrete token spaces. This represents a paradigm shift from step-by-step generation to iterative refinement, acting more like a sculptor gradually chiseling a block of marble into a detailed statue.


1. The Basics: What is Diffusion?

Before diving into how diffusion applies to text, we must understand the core mechanics of diffusion models, which were originally designed for continuous data like images.

Diffusion models operate via a two-stage process:

  1. The Forward Process (Noise Addition): We take a clean data point (e.g., an image) and gradually add Gaussian noise over a series of TT timesteps until the data becomes indistinguishable from pure random noise. This process is fixed and requires no learning.
  2. The Reverse Process (Denoising): This is where the neural network learns. The model is trained to predict the noise added at each step and reverse the process, gradually restoring the clean data from pure noise.

Mathematically, the reverse step often involves predicting the score function (the gradient of the log-density of the data) or predicting the noise vector ϵ\epsilon itself.


2. The Discrete Challenge: Adapting to Text

Adapting this continuous framework to text introduces a catastrophic challenge: text is discrete. You cannot have “half a word” or “noisy text” in a continuous sense. A word is either “cat” or it is not.

Engineers have developed two primary solutions to bridge this gap:

A. Continuous Relaxation (Embedding Diffusion)

The most popular approach is to map discrete tokens into a continuous vector space (embeddings) and apply standard continuous diffusion in that space [2].

  • Forward: Map text tokens to embeddings, then add continuous Gaussian noise to the embeddings.
  • Reverse: The model denoises the noisy embeddings. At the final step, a “rounding” operation maps the denoised continuous embeddings back to the nearest discrete tokens in the vocabulary.

B. Discrete/Categorical Diffusion

Instead of mapping to continuous space, these models operate directly on the discrete probability distribution of tokens. At each step, tokens have a small probability of transitioning to a random token or a special [MASK] token. The reverse network learns to predict the correct token from the corrupted sequence.


3. Scaling Diffusion-Style Models to Language

While small-scale diffusion language models (like Diffusion-LM [1]) demonstrated proof-of-concept, a major open question is whether diffusion-style architectures can scale to language settings where large autoregressive models currently dominate.

In this paradigm, the model does not generate text token-by-token. Instead, it generates the entire sequence length simultaneously.

Why do this? The Benefits of Global Planning

  1. Controllability: If you ask an AR model to “Write a poem where the 5th word is ‘blue’”, it must generate the first 4 words before it knows whether the whole sequence will fit that constraint naturally. A diffusion-style model can, in principle, refine the sequence around fixed constraints.
  2. Non-Autoregressive Speedups: For short-to-medium sequences, generating all tokens in parallel and refining them in 20–50 steps can be faster than generating 1000 tokens sequentially in 1000 steps.

Below is a realistic PyTorch implementation demonstrating the core concept of continuous relaxation diffusion: mapping tokens to embeddings, adding noise, and a mock denoising step.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContinuousDiffusionLM(nn.Module):
    def __init__(self, vocab_size=30000, embed_dim=768, seq_len=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.seq_len = seq_len
        
        # A simple Transformer backbone to predict the noise
        self.backbone = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, batch_first=True),
            num_layers=6
        )
        
        # Projection back to vocabulary for rounding
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def add_noise(self, embeddings, timestep, total_steps=1000):
        """
        Simulates the forward process: adds Gaussian noise based on the timestep.
        """
        # Simple linear schedule for variance
        beta = timestep / total_steps
        noise = torch.randn_like(embeddings)
        # x_t = sqrt(1-beta)*x_0 + sqrt(beta)*noise
        noisy_embeddings = torch.sqrt(1 - beta) * embeddings + torch.sqrt(beta) * noise
        return noisy_embeddings, noise

    def forward(self, input_ids, timestep):
        # 1. Map discrete tokens to continuous embeddings
        x = self.embed(input_ids) # Shape: [Batch, Seq_Len, Embed_Dim]
        
        # 2. Add noise (simulating a specific step in the forward process)
        noisy_x, true_noise = self.add_noise(x, timestep)
        
        # 3. Predict the noise using the backbone
        # In reality, we would also condition the backbone on the timestep
        predicted_noise = self.backbone(noisy_x)
        
        # 4. Calculate loss between true noise and predicted noise
        loss = F.mse_loss(predicted_noise, true_noise)
        
        return loss

    def generate(self, batch_size, device='cpu'):
        """
        Mock reverse process: starts from pure noise and denoises (simplified).
        """
        self.eval()
        with torch.no_grad():
            # Start with pure random noise in the embedding space
            x = torch.randn(batch_size, self.seq_len, 768).to(device)
            
            # In a real model, we would loop from T down to 0
            # and update x based on the predicted noise.
            # Here we simulate the final "rounding" step.
            
            # Rounding: project back to vocabulary logits
            logits = self.lm_head(x) # Shape: [Batch, Seq_Len, Vocab_Size]
            predicted_ids = torch.argmax(logits, dim=-1)
            
            return predicted_ids

# Example usage
model = ContinuousDiffusionLM()
dummy_tokens = torch.randint(0, 30000, (2, 128)) # Batch size 2, seq len 128
timestep = torch.tensor(500.0) # Middle of the diffusion process

loss = model(dummy_tokens, timestep)
print(f"Training Loss: {loss.item():.4f}")

generated = model.generate(batch_size=1)
print(f"Generated token IDs shape: {generated.shape}")

Summary & Next Steps

Diffusion-based language modeling represents an important alternative to left-to-right generation. It is especially attractive when global constraints and iterative refinement matter more than token-by-token commitment. At the same time, it remains an emerging direction rather than a settled replacement for autoregressive LLMs.

In the next section, we will explore how these advanced modeling techniques contribute to the ultimate goal: the path to AGI and the creation of comprehensive World Models.


Quizzes

Quiz 1: What is the primary advantage of Continuous Relaxation (Embedding Diffusion) over Autoregressive generation for text? Continuous relaxation allows for global planning. Because the model operates on the entire sequence simultaneously and refines it iteratively, it can naturally handle complex global constraints and bidirectional context that are difficult for left-to-right autoregressive models.

Quiz 2: What is the “rounding” operation in Continuous Relaxation diffusion, and why is it necessary? The rounding operation maps the continuous, denoised embeddings back to the nearest discrete tokens in the vocabulary. It is necessary because diffusion operates in a continuous vector space, but text output must be composed of discrete words or tokens.

Quiz 3: Why is inference typically slower in Diffusion-based LLMs compared to Autoregressive models for long sequences? Autoregressive models generate one token per forward pass. Diffusion models generate the entire sequence length but require multiple iterative forward passes (denoising steps, e.g., 20-50 steps) to refine the sequence from noise to coherent text, making the total computation higher for the whole sequence.

Quiz 4: In discrete/categorical diffusion for text, the forward process is modeled using a categorical transition matrix. Mathematically formulate the transition matrix QtQ_t for a uniform diffusion schedule over a vocabulary of size VV, and derive the probability of a token xtx_t given the original token x0x_0. In categorical diffusion, the transition from xt1x_{t-1} to xtx_t is defined by a matrix QtRV×VQ_t \in \mathbb{R}^{V \times V}, where Qt[i,j]=q(xt=jxt1=i)Q_t[i, j] = q(x_t = j | x_{t-1} = i). For a uniform diffusion schedule where a token remains unchanged with probability 1βt1 - \beta_t and transitions uniformly to any token in the vocabulary with probability βt\beta_t, the transition matrix is formulated as: Qt=(1βt)I+βt1V11TQ_t = (1 - \beta_t) I + \beta_t \frac{1}{V} \mathbf{1}\mathbf{1}^T To find the multi-step transition probability q(xtx0)q(x_t | x_0), we can multiply the matrices sequentially. Using the cumulative parameter αˉt=s=1t(1βs)\bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s), the cumulative transition matrix Qˉt\bar{Q}_t becomes: Qˉt=αˉtI+(1αˉt)1V11T\bar{Q}_t = \bar{\alpha}_t I + (1 - \bar{\alpha}_t) \frac{1}{V} \mathbf{1}\mathbf{1}^T Thus, the transition probability directly from x0x_0 is: q(xtx0)=αˉtx0+(1αˉt)1V1q(x_t | x_0) = \bar{\alpha}_t x_0 + (1 - \bar{\alpha}_t) \frac{1}{V} \mathbf{1} where x0x_0 is represented as a one-hot vector.


References

  1. Li, X., et al. (2022). Diffusion-LM Improves Controllable Text Generation. arXiv:2205.14217.
  2. Gulrajani, I., & Tatsuno, K. (2023). Continuous Diffusion for Categorical Data. arXiv:2306.06546.