9.1 Supervised Fine-Tuning (SFT) Fundamentals
거대한 사전 학습(Pre-training) 단계를 거친 파운데이션 모델은 방대한 세계 지식과 복잡한 통사적 이해력을 갖춘 언어의 대가입니다. 하지만 본질적으로 이 모델은 여전히 다음 토큰을 예측하는 모델일 뿐입니다. 가공되지 않은 베이스 모델에 “어떻게 케이크를 굽나요?”라고 프롬프트를 입력하면, 질문에 답하는 대신 “어떻게 파이를 굽나요? 어떻게 빵을 굽나요?”와 같이 웹에서 본 목록의 구조를 모방하여 계속해서 질문을 나열할 수 있습니다.
가공되지 않은 텍스트 예측기를 유용하고 대화식인 어시스턴트로 전환하려면 Supervised Fine-Tuning (SFT) 을 거쳐야 합니다. 이는 사후 학습 파이프라인의 첫 번째 단계로, 모델에게 지시를 따르고 특정 페르소나를 채택하는 방법을 가르칩니다.
비유: 학자와 인턴
사전 학습된 베이스 모델을 세계의 모든 책을 읽은 똑똑한 학자로 생각해보세요. 그들은 모든 것에 대한 사실을 알고 있지만 사회적 기술이 없으며 질문에 직접 답하는 방법을 모릅니다. 그들에게 질문을 던지면 관련 시를 읊기 시작할 수도 있습니다.
Supervised Fine-Tuning은 이 학자를 데려다가 유용한 인턴이 되도록 교육 프로그램에 참여시키는 것과 같습니다. 우리는 그들에게 질문과 이상적인 답변의 수천 가지 예시(“사용자가 X를 물으면 Y라고 답해야 한다”)를 보여줍니다. 학자는 새로운 사실을 배우지 않지만(그들은 이미 모든 것을 알고 있습니다), 유용한 어시스턴트가 되는 형식을 배웁니다.
SFT의 메커니즘
수학적으로 SFT는 단순히 고도로 큐레이션된 데이터 분포에서 계속해서 사전 학습을 하는 것입니다. 그러나 손실을 계산하는 방식에는 결정적인 차이가 있습니다.
사전 학습 중에는 시퀀스의 모든 토큰에 대해 손실을 계산합니다. 반면 SFT에서는 일반적으로 모델이 프롬프트 토큰이 아닌 응답 토큰 에 대해서만 손실을 계산하도록 합니다. 우리는 모델이 사용자의 질문을 예측하는 것이 아니라 답변을 생성하는 법을 배우기를 원하기 때문입니다.
마스킹된 손실 함수 (Masked Loss Function)
입력 시퀀스를 프롬프트 와 응답 의 연결로 나타내겠습니다. 토큰 시퀀스는 (프롬프트)에 이어 (응답)이 됩니다.
SFT 손실은 표준 자기회귀 교차 엔트로피(Cross-Entropy) 손실이지만 응답 토큰에 대해서만 합산하도록 마스킹됩니다.
프롬프트 토큰 를 손실에서 제외함으로써 모델이 훈련 세트에 있는 질문의 특정 문구를 “정답처럼” 예측하도록 학습되는 것을 방지하고, 고품질 답변을 생성하는 데 학습 신호를 집중할 수 있습니다. 단, 표준 causal LM 학습에서는 프롬프트 토큰도 순전파와 역전파의 문맥으로 사용됩니다. labels=-100은 손실 항을 제거할 뿐, 그 토큰들의 활성화 메모리를 자동으로 없애 주지는 않습니다.
표면적 정렬 가설 (The Superficial Alignment Hypothesis)
SFT를 이해하는 데 있어 중요한 개념은 표면적 정렬 가설 (Superficial Alignment Hypothesis) 입니다 [1]. 이 가설은 모델의 지식과 능력은 거의 전적으로 사전 학습(Pre-training) 중에 습득되는 반면, 정렬(SFT)은 모델이 사용자와 상호작용할 때 어떤 형식의 *하위 분포(sub-distribution)*를 사용해야 하는지 가르칠 뿐이라고 가정합니다.
다시 말해, SFT는 모델을 더 똑똑하게 만드는 것이 아니라, 단지 유용한 어시스턴트처럼 행동하도록 가르칠 뿐입니다. 이는 SFT를 위해 수백만 개의 예시가 필요하지 않음을 시사합니다. LIMA 논문 [1]에서 보여주었듯이, 고품질의 다양한 인스트럭션-응답 쌍으로 구성된 소규모 데이터셋만으로도 모델을 정렬하기에 충분합니다.
챗 템플릿 및 대화 구조 (Chat Templates and Conversation Structure)
프로덕션 환경의 SFT에서는 단순히 원시 문자열을 이어 붙이지 않습니다. 사용자(User), 어시스턴트(Assistant), 그리고 시스템 프롬프트(System prompt)를 구분하기 위해 구조화된 형식을 사용합니다. 업계 표준은 ChatML 이나 Hugging Face의 특정 Jinja 템플릿과 같은 형식으로 이동하고 있습니다.
일반적인 학습 시퀀스는 다음과 같습니다:
<|im_start|>system
You are a helpful AI assistant.<|im_end|>
<|im_start|>user
How do I bake a cake?<|im_end|>
<|im_start|>assistant
To bake a cake, follow these steps...<|im_end|>
<|im_start|> 및 <|im_end|> (또는 Llama 3의 <|begin_of_text|>, <|start_header_id|> 등)와 같은 특수 토큰이 어휘 사전에 추가됩니다. SFT 동안에는 이러한 특수 토큰이 올바르게 처리되는지, 그리고 손실(Loss)이 종료 태그를 포함하여 오직 어시스턴트의 응답에 대해서만 계산되는지 확인해야 합니다.
SFT 데이터셋을 설계하는 법
SFT의 성패는 학습 알고리즘보다 데이터셋 설계에서 더 자주 갈립니다. 좋은 SFT 데이터는 “정답이 긴 데이터”가 아니라, 모델이 제품에서 보여야 할 행동을 압축해서 보여주는 데이터입니다.
실무에서는 다음 축을 분리해서 관리하는 것이 좋습니다.
| 데이터 축 | 왜 중요한가 | 흔한 실패 |
|---|---|---|
| instruction following | 사용자의 명령을 정확히 따르는 능력 | 질문 일부만 답하거나, 요구한 형식을 무시함 |
| domain style | 회사 문서, 고객지원, 법률, 의료, 코드 리뷰 등 도메인별 말투 | 일반 챗봇 말투로 답해 전문성이 떨어짐 |
| refusal / safety | 답하면 안 되는 요청을 짧고 명확하게 거부 | 과잉 거부 또는 장황한 설교체 |
| tool-call format | JSON schema, function name, argument type 준수 | 유효하지 않은 JSON, 없는 tool 호출 |
| multi-turn repair | 사용자가 정정했을 때 이전 답을 고치는 능력 | 이전 실수를 고집하거나 문맥을 잃음 |
| concision / verbosity | 제품에 맞는 답변 길이 | 모든 답을 블로그 글처럼 길게 씀 |
특히 enterprise assistant는 “친절한 범용 답변”보다 “정확한 형식과 경계”가 더 중요할 때가 많습니다. 예를 들어 비용 정산 봇은 재미있는 설명보다 금액, 근거 문서, 승인 상태를 정확한 JSON으로 반환하는 편이 훨씬 가치 있습니다.
Packing, EOS, Loss Masking의 함정
SFT를 돌릴 때 데이터 로더 최적화를 위해 여러 짧은 대화를 하나의 긴 시퀀스로 packing하는 경우가 많습니다. 이때 세 가지 실수가 자주 나옵니다.
- 대화 사이 attention 누수: 서로 다른 샘플을 한 시퀀스에 붙였는데 attention mask를 막지 않으면, 두 번째 대화가 첫 번째 대화 내용을 문맥으로 볼 수 있습니다.
- EOS 누락: assistant 응답 끝에 EOS나 turn boundary를 제대로 학습시키지 않으면 모델이 답변을 멈추지 못하거나 다음 사용자 발화까지 생성합니다.
- 마스킹 범위 오류: system/user 토큰까지 loss에 포함하면 모델이 사용자 질문을 흉내 내는 방향으로 학습됩니다. 반대로 assistant의 종료 태그까지 모두 마스킹하면 멈추는 법을 덜 배웁니다.
간단한 검증 방법은 학습 배치 하나를 사람이 읽을 수 있는 텍스트로 디코딩하고, loss가 걸린 토큰만 색칠해 보는 것입니다. SFT 버그는 loss curve보다 이 시각화에서 더 빨리 잡히는 경우가 많습니다.
SFT 이후 바로 봐야 할 평가
SFT가 끝났다고 바로 RLHF/DPO로 넘어가면 안 됩니다. 먼저 아래 지표를 확인해야 합니다.
- format exact match: JSON, Markdown table, function call처럼 구조가 있는 출력이 정확한가.
- instruction coverage: 여러 조건이 있는 요청에서 빠뜨린 조건이 없는가.
- refusal calibration: 위험한 요청은 거부하되, 정상적인 보안/의학/법률 설명까지 막지는 않는가.
- verbosity drift: 답변 길이가 제품 요구에 맞는가.
- base capability retention: 코딩, 수학, 번역, 요약 같은 기본 능력이 SFT 후 떨어지지 않았는가.
- tool-call dry run: 실제 tool 실행 전 schema validation과 mock execution을 통과하는가.
SFT는 모델에게 제품의 “말하는 방식”을 입히는 단계입니다. 이 단계에서 포맷과 경계를 제대로 잡아두면 이후 preference tuning이 훨씬 안정적입니다.
SFT 루프 엔지니어링 (PyTorch)
아래는 손실 계산 중에 프롬프트 토큰을 무시하기 위해 마스크를 적용하는 방법을 보여주는 현실적인 PyTorch 구현입니다. 이는 Hugging Face의 TRL (Transformer Reinforcement Learning)과 같은 프레임워크에서 표준적으로 사용되는 방식입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SFTTrainer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.criterion = nn.CrossEntropyLoss(ignore_index=-100) # -100은 표준 무시 인덱스입니다
def train_step(self, prompt, response, optimizer):
self.model.train()
optimizer.zero_grad()
# 1. 프롬프트와 응답 토큰화
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
response_ids = self.tokenizer.encode(response, add_special_tokens=False)
# 2. 연결 및 레이블 생성
# 우리는 응답을 예측하기를 원하므로 프롬프트에 대한 레이블은 무시됩니다 (-100)
input_ids = torch.tensor([prompt_ids + response_ids]).to(self.model.device)
# 레이블: 프롬프트 토큰 무시, 응답 토큰 유지
labels = torch.tensor([[-100] * len(prompt_ids) + response_ids]).to(self.model.device)
# 3. 순전파 (Forward pass)
outputs = self.model(input_ids)
logits = outputs.logits
# 4. 마스킹된 손실 계산
# 자기회귀 예측을 위해 로짓과 레이블을 시프트합니다
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = self.criterion(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# 5. 역전파 (Backward pass)
loss.backward()
optimizer.step()
return loss.item()
Quizzes
Quiz 1: 인과 관계 LLM 파인튜닝 환경에서 프롬프트 마스킹(Prompt-masking)을 적용하면 표준 구현에서 활성화 메모리가 얼마나 절약됩니까?
일반적인 causal LM SFT에서 labels=-100 프롬프트 마스킹은 손실 계산에서 프롬프트 토큰을 제외할 뿐입니다. 프롬프트 토큰은 여전히 응답 토큰을 예측하기 위한 문맥으로 순전파에 들어가고, attention 계산과 hidden state 저장에도 참여합니다. 따라서 표준 full fine-tuning 구현에서는 활성화 메모리가 거의 절약되지 않습니다. 메모리 절약을 얻으려면 prefix 부분을 별도 no-grad로 처리하거나, KV 재사용, prefix-LM 전용 최적화, sequence packing 최적화처럼 별도의 시스템 기법이 필요합니다.
Quiz 2: SFT 모델이 SFT 데이터셋에 없는 작업에 대해 제로샷(Zero-shot) 능력을 일부 잃을 수 있다는 관찰이 있습니다 (Alignment Tax). “Superficial Alignment Hypothesis”는 이를 어떻게 설명할까요?
Superficial Alignment Hypothesis는 SFT가 모델에게 새로운 능력을 가르치는 것이 아니라 새로운 형식의 하위 분포(예: 어시스턴트가 되는 방법)를 가르칠 뿐이라고 말합니다. SFT 데이터셋이 좁거나 잘못된 형식을 포함하고 있다면 모델은 출력을 너무 엄격하게 제한하도록 학습하여 사전 학습 중에 획득한 방대한 세계 지식을 효과적으로 숨기거나 억제할 수 있습니다.
Quiz 3: 시스템 관점에서 볼 때 데이터셋이 작더라도(예: 1,000개 예시) 전체 70B 파라미터 모델에 대한 SFT가 계산적으로 비싼 이유는 무엇일까요?
Full Fine-Tuning은 모든 파라미터를 업데이트해야 하기 때문입니다. 시스템은 모델 가중치뿐만 아니라 모든 70B 파라미터에 대한 옵티마이저 상태(예: Adam의 모멘트)와 그래디언트를 저장해야 하므로 방대한 VRAM(1TB 이상)이 필요합니다. 작은 데이터셋 크기는 훈련 시간(단계 수)을 줄여주지만 단계당 필요한 최대 메모리를 줄여주지는 않습니다.
References
- Zhou, C., et al. (2023). LIMA: Less Is More for Alignment. arXiv:2305.11206.
- Ouyang, L., et al. (2022). Training language models to follow instructions with human feedback. arXiv:2203.02155.