Foundation Model Engineering

9.1 Supervised Fine-Tuning (SFT) Fundamentals

After the massive pre-training phase, a Foundation Model is a master of language, possessing vast world knowledge and complex syntactic understanding. However, it is fundamentally still just a next-token predictor. If you prompt a raw base model with “How do I bake a cake?”, it might not answer you; instead, it might continue with “How do I bake a pie? How do I bake bread?”, mimicking the structure of a list it saw on the web.

To transform a raw text predictor into a helpful, conversational assistant, we must undergo Supervised Fine-Tuning (SFT). This is the first step of the post-training pipeline, popularized by OpenAI’s InstructGPT [2], where we teach the model to follow instructions and adopt a specific persona.


The Analogy: The Scholar vs. The Intern

Think of a pre-trained base model as a brilliant scholar who has read every book in the world library. They know facts about everything, but they have no social skills and don’t know how to answer questions directly. If you ask them a question, they might just start reciting a related poem.

Supervised Fine-Tuning is like taking this scholar and putting them through a training program to become a helpful intern. We show them thousands of examples of questions and ideal answers (“When a user asks X, you should reply with Y”). The scholar doesn’t learn new facts (they already know everything), but they learn the format of how to be a helpful assistant.


The Mechanics of SFT

Mathematically, SFT is just continued pre-training, but on a highly curated distribution of data. However, there is a crucial difference in how we compute the loss.

During pre-training, we compute the loss on every token in the sequence. In SFT, we typically only compute the loss on the response tokens, not the instruction tokens. We want the model to learn to generate the answer, not to predict the user’s question.

The Masked Loss Function

Let the input sequence be a concatenation of the prompt XX and the response YY. The sequence of tokens is x1,x2,...,xtx_1, x_2, ..., x_t (prompt) followed by y1,y2,...,ymy_1, y_2, ..., y_m (response).

The SFT loss is the standard autoregressive cross-entropy loss, but masked so that we only sum over the response tokens:

LSFT=i=1mlogPθ(yiy<i,X)\mathcal{L}_{SFT} = - \sum_{i=1}^{m} \log P_\theta(y_i \mid y_{<i}, X)

By ignoring the gradients for the prompt tokens XX, we prevent the model from overfitting to the specific phrasing of the questions in our training set, focusing all its capacity on generating high-quality answers.


The Superficial Alignment Hypothesis

A critical concept in understanding SFT is the Superficial Alignment Hypothesis [1]. It posits that a model’s knowledge and capabilities are almost entirely learned during pre-training, while alignment (SFT) teaches the model which sub-distribution of formats to use when interacting with users.

In other words, SFT doesn’t make the model smarter; it just teaches it to act like a helpful assistant. This implies that we do not need millions of examples for SFT. A small set of high-quality, diverse instruction-response pairs is sufficient to align the model, as demonstrated by the LIMA paper [1].


Chat Templates and Conversation Structure

In production SFT, we don’t just concatenate raw strings. We use structured formats to distinguish between the user, the assistant, and the system prompt. The industry standard is moving towards formats like ChatML or specific Jinja templates in Hugging Face.

A typical training sequence looks like this:

<|im_start|>system
You are a helpful AI assistant.<|im_end|>
<|im_start|>user
How do I bake a cake?<|im_end|>
<|im_start|>assistant
To bake a cake, follow these steps...<|im_end|>

Special tokens like <|im_start|> and <|im_end|> (or <|begin_of_text|>, <|start_header_id|>, etc. in Llama 3) are added to the vocabulary. During SFT, we must ensure these special tokens are handled correctly and that loss is only computed on the assistant’s response, including its closing tag.


Engineering the SFT Loop (PyTorch)

Below is a realistic PyTorch implementation demonstrating how to apply a mask to ignore prompt tokens during loss calculation, a standard practice in frameworks like Hugging Face’s TRL (Transformer Reinforcement Learning).

import torch
import torch.nn as nn
import torch.nn.functional as F

class SFTTrainer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100) # -100 is standard ignore index

    def train_step(self, prompt, response, optimizer):
        self.model.train()
        optimizer.zero_grad()

        # 1. Tokenize prompt and response
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        response_ids = self.tokenizer.encode(response, add_special_tokens=False)

        # 2. Concatenate and create labels
        # We want to predict the response, so labels for prompt are ignored (-100)
        input_ids = torch.tensor([prompt_ids + response_ids]).to(self.model.device)
        
        # Labels: ignore prompt tokens, keep response tokens
        labels = torch.tensor([[-100] * len(prompt_ids) + response_ids]).to(self.model.device)

        # 3. Forward pass
        outputs = self.model(input_ids)
        logits = outputs.logits

        # 4. Compute masked loss
        # Shift logits and labels for autoregressive prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        loss = self.criterion(
            shift_logits.view(-1, shift_logits.size(-1)), 
            shift_labels.view(-1)
        )

        # 5. Backward pass
        loss.backward()
        optimizer.step()

        return loss.item()

Quizzes

Quiz 1: What is the theoretical activation gradient memory savings when applying prompt-masking in a causal LLM fine-tuning setup? If the input prompt has TpT_p tokens and the response has TrT_r tokens, standard activation memory per layer is Mact×(Tp+Tr)M_{act} \times (T_p + T_r). With prompt masking, gradients for the prompt tokens are masked out and not computed, saving Tp×MactT_p \times M_{act} memory per layer. The ratio of memory saved is TpTp+Tr\frac{T_p}{T_p + T_r}. For an instruction with Tp=3000T_p = 3000 and Tr=1000T_r = 1000, prompt-masking saves 75%75\% of the activation gradient memory per layer.

Quiz 2: A common observation is that SFT models can lose some of their zero-shot capabilities on tasks not represented in the SFT dataset (Alignment Tax). How does the “Superficial Alignment Hypothesis” explain this? The Superficial Alignment Hypothesis states that SFT does not teach the model new capabilities, but merely teaches it a new sub-distribution of formats (e.g., how to be an assistant). If the SFT dataset is narrow or contains bad formatting, the model might learn to restrict its outputs too strictly, effectively hiding or suppressing the vast world knowledge it acquired during pre-training.

Quiz 3: From a systems perspective, why is SFT on a full 70B parameter model computationally expensive even if the dataset is small (e.g., only 1,000 examples)? Because Full Fine-Tuning requires updating all parameters. The system must store not only the model weights but also the optimizer states (e.g., Adam’s moments) and gradients for all 70B parameters, which requires massive VRAM (over 1TB). The small dataset size reduces the training time (fewer steps), but it does not reduce the peak memory required per step.


References

  1. Zhou, C., et al. (2023). LIMA: Less Is More for Alignment. arXiv:2305.11206.
  2. Ouyang, L., et al. (2022). Training language models to follow instructions with human feedback. arXiv:2203.02155.