20.3 Linear Attention
앞 절의 Mamba는 “어텐션 대신 상태를 업데이트하자”는 방향이었습니다. Linear Attention은 조금 다른 길을 택합니다. 어텐션이라는 인터페이스는 유지하되, 모든 토큰 쌍을 명시적으로 계산하는 softmax attention을 더 싼 형태로 바꾸려는 시도입니다. 즉, SSM이 순환 상태의 언어로 문제를 푼다면, Linear Attention은 어텐션을 연관적으로 다시 묶을 수 있는 커널 계산 으로 바꿉니다.
긴 컨텍스트 모델링은 계속 같은 벽에 부딪힙니다. 표준 어텐션은 우아하고 표현력이 높지만 비쌉니다. 시퀀스 길이가 충분히 길어지면, 모든 토큰 쌍의 상호작용을 정확하게 유지하는 비용은 아키텍처 문제이면서 동시에 시스템 문제이기도 합니다. Linear Attention은 바로 이 긴장 위에 놓여 있습니다. softmax attention을 완전히 대체하는 만능 해법이라기보다, 정확도의 일부를 더 나은 확장성과 바꾸려는 대표적인 시도에 가깝습니다.
1. 핵심 아이디어
표준 어텐션은 시퀀스 전체의 토큰 쌍 상호작용을 물질화합니다.
여기서 비싼 부분은 모든 토큰 쌍에 대한 조밀한 상호작용 패턴입니다. Linear Attention은 softmax를 feature map 로 바꿔 연산 순서를 다시 묶을 수 있게 만듭니다.
실용적인 매력은 분명합니다. 과거를 compact한 running state로 요약할 수 있다면, 전체 시퀀스를 처리하는 동안 메모리 증가를 full quadratic attention보다 훨씬 완만하게 만들 수 있습니다.
2. 왜 이 아이디어가 계속 매력적인가
프롬프트 길이가 수백 토큰에서 수십만 토큰으로 늘어나는 문서 처리 파이프라인을 떠올려 봅시다. 모델 품질이 괜찮더라도, 실제 병목은 serving cost와 KV-cache pressure일 수 있습니다. Linear Attention은 바로 이런 구간에서 매력적입니다.
- 더 낮은 메모리 압박
- 긴 컨텍스트에서 더 나은 확장성
- 시퀀스를 더 recurrent하게 보는 계산 관점
그래서 예전의 관심이 한 번 식었다가도, 긴 컨텍스트 문제가 커질 때마다 다시 돌아옵니다.
3. 초기 변형들이 왜 고전했는가
아이디어가 이렇게 매력적인데도 왜 softmax attention이 오랫동안 주류였을까요?
초기의 linear 계열 모델은 종종 품질 대가를 치렀기 때문입니다.
- 학습 불안정성
- 약한 exact recall
- 정밀한 token-level interaction이 필요한 작업에서의 성능 저하
즉, 과거를 compact state로 요약하는 것은 효율적이지만, 동시에 모델을 좋게 만드는 정보를 과도하게 압축할 수도 있습니다.
4. 최근 연구가 고치려는 것
최근 작업은 recurrent state를 단순 덧셈이 아니라 더 선택적으로 유지하도록 만드는 데 집중하고 있습니다. gating, delta-style update, 더 구조화된 state transition은 모두 같은 문제를 겨냥합니다. 중요한 과거는 남기고, state가 포화되지 않도록 하려는 것입니다.
이런 접근은 이미 확정된 승자 라기보다 떠오르는 설계 패턴 으로 보는 편이 더 정확합니다. 현재까지의 증거는 linear-attention 계열 레이어가 긴 컨텍스트 구간에서 유용할 수 있음을 보여 주지만, 그것만으로 전체 네트워크를 설명하기에는 아직 이른 편입니다.
떠오르는 Hybrid Pattern
현재 상황을 조심스럽게 요약하면 이렇습니다.
- pure linear attention은 효율성 측면에서 매력적이다
- full attention은 여전히 exact하고 고해상도 상호작용에 가치가 있다
- 그래서 hybrid stack은 자연스러운 엔지니어링 절충안이 된다
하지만 이것이 곧 하나의 보편적인 recipe로 정착됐다는 뜻은 아닙니다. 어느 위치에 비싼 full attention을 남기고, 어디에 더 효율적인 메커니즘을 둘지에 대해 여러 팀이 계속 탐색 중이라는 뜻에 가깝습니다.
Linear Attention, SSM, Retention의 공통 언어
최근 논문들을 읽다 보면 Linear Attention, SSM, retention, delta rule이 서로 다른 이름으로 비슷한 구조를 설명하는 경우가 많습니다. 핵심은 모두 다음 형태의 state update로 볼 수 있다는 점입니다.
여기서 는 과거 정보를 담은 compact state, 는 잊기/유지 gate, 는 새로 쓰는 정보입니다. 차이는 이 state를 어떤 feature map으로 만들고, 어떤 gate를 쓰고, 어떤 커널로 GPU에서 실행하느냐에 있습니다. 이 관점을 가지면 Mamba-2의 SSD가 왜 SSM과 Linear Attention을 연결하는지 더 자연스럽게 이해됩니다.
5. 교육용 PyTorch 예제
아래의 단순화된 블록은 핵심 계산 아이디어를 보여 줍니다. history를 compact state로 만든 뒤, 그것을 query에 다시 적용하는 방식입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedLinearAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.g_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def feature_map(self, x: torch.Tensor) -> torch.Tensor:
return F.elu(x) + 1.0
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq_len, dim = x.shape
q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.d_head).transpose(1, 2)
k = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.d_head).transpose(1, 2)
v = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.d_head).transpose(1, 2)
g = torch.sigmoid(self.g_proj(x))
q_phi = self.feature_map(q)
k_phi = self.feature_map(k)
kv_state = torch.matmul(k_phi.transpose(-1, -2), v)
z_state = k_phi.sum(dim=2, keepdim=True).transpose(-1, -2)
numerator = torch.matmul(q_phi, kv_state)
denominator = torch.matmul(q_phi, z_state) + 1e-6
out = numerator / denominator
out = out.transpose(1, 2).contiguous().view(bsz, seq_len, dim)
out = out * g
return self.out_proj(out)
이 코드는 교육용입니다. 실제 시스템에서는 더 정교한 normalization, gating, masking, hardware-specific optimization이 추가됩니다.
6. Interactive: Memory Complexity Visualizer
아래 시각화는 이 논쟁을 직관적으로 이해하는 데 도움이 됩니다. 시퀀스 길이가 커질수록 메모리 사용량이 어떻게 증가하는지를 보면, 왜 이런 아키텍처적 시도가 반복해서 등장하는지 훨씬 명확해집니다.
KV Cache Memory Complexity: $O(n)$ vs $O(1)$
Adjust the sequence length to see how the memory footprint of Standard Attention grows linearly with context, while Linear Attention maintains a constant state size.
7. Practical Takeaway
Linear Attention은 표준 어텐션을 이미 대체한 승자라기보다, 긴 컨텍스트 시대에 중요한 효율성 방향으로 이해하는 편이 적절합니다. 매력적인 이유는 긴 컨텍스트가 비싸기 때문이고, 어려운 이유는 compact recurrent state가 어떤 작업에는 꼭 필요한 exact한 정보를 잃게 만들 수 있기 때문입니다. 그래서 가까운 시기의 그럴듯한 시나리오는 “linear attention의 완승”보다는 “정밀도와 효율성이 필요한 위치에 따라 서로 다른 interaction mechanism을 섞는 모델”에 가깝습니다.
Quizzes
Quiz 1: Linear Attention이 근본적으로 해결하려는 문제는 무엇인가요?
긴 컨텍스트를 다룰 때 시퀀스 상호작용 비용을 줄여, 메모리와 연산 측면에서 더 다루기 쉬운 구조를 만드는 것입니다.
Quiz 2: 초기의 Linear Attention 계열이 full attention보다 종종 성능이 낮았던 이유는 무엇인가요?
과거를 compact state로 요약하는 과정에서 exact recall이나 세밀한 token-level interaction에 필요한 정보까지 과도하게 압축해 버리기 쉬웠기 때문입니다.
Quiz 3: hybrid stack이 자연스러운 엔지니어링 절충안인 이유는 무엇인가요?
linear 계열 메커니즘은 효율성에 도움이 되고, full attention은 여전히 정밀한 상호작용에 강점이 있기 때문입니다. 둘을 섞으면 각자의 장점을 필요한 곳에 배치할 수 있습니다.
Quiz 4: 특정한 hybrid 비율 하나를 오늘의 “production standard”라고 부르는 것이 왜 오해를 부를 수 있나요?
현재 분야는 여전히 여러 설계와 trade-off를 탐색 중이기 때문입니다. hybrid design 자체는 떠오르는 패턴이지만, 하나의 정확한 공식이 보편 표준으로 굳었다고 보기는 어렵습니다.
References
- Katharopoulos, A., et al. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. arXiv:2006.16236.
- Yang, S., et al. (2024). DeltaNet: Linear-Time Sequence Modeling with Gated Delta Rule. arXiv:2406.06484.
- Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. arXiv:2405.21060.