LLM 사전학습이 불안정한 이유는 하나다
Loss spike의 4가지 근인부터 Embedding LR 분리, QK-norm, z-loss, RMSNorm, AdamW ε까지 — LLM 훈련 안정화 기법들이 공유하는 하나의 진단 프레임을 추적한다.
- 01 LLM 학습 규모는 어떻게 결정되는가
- 02 LLM 사전학습의 설계 결정들은 어디서 오는가
- 03 LLM 사전학습이 불안정한 이유는 하나다
- 04 LLM 사전학습 데이터는 어떻게 설계되는가
- 05 토큰화는 왜 모델의 성능을 결정하는가
- 06 LLM 아키텍처 설계의 다섯 가지 선택
LLM 사전학습 도중 loss가 갑자기 치솟으면 대부분의 조언은 같다 — “LR을 낮춰라”, “gradient clipping을 켜라”. 하지만 같은 증상이라도 근인이 다르면 처방도 달라진다. 왜 실제로는 spike가 사라지지 않는가?
Loss Spike는 4가지 다른 병이다
Wortsman et al. (2024)은 대규모 LLM 훈련 로그를 분석해 loss spike를 4가지로 분류했다.
┌──────────────────┬────────────────────┬──────────────────────┐
│ 유형 │ 근인 │ 해결책 │
├──────────────────┼────────────────────┼──────────────────────┤
│ Catastrophic │ Embedding 희귀 토큰│ Embedding LR 분리 │
│ Transient │ Attention logit │ QK-norm │
│ │ overflow │ │
│ NaN Explosion │ Output logit 폭주 │ z-loss │
│ Slow Divergence │ Weight norm drift │ RMSNorm + weight │
│ │ │ decay │
└──────────────────┴────────────────────┴──────────────────────┘
각각의 수학적 정의는 엄밀하다. Catastrophic spike는 loss가 이전 K 스텝 평균의 5배를 넘고 이후에도 계속 오르는 상황이다. Transient spike는 2배 이상 뛰었다가 수십 스텝 내에 원래 수준으로 내려온다. NaN Explosion은 부동소수 overflow다. Slow Divergence는 loss가 급격하지 않지만 지속적으로 증가한다.
“LR을 낮춰라”는 Catastrophic spike에는 즉효지만 Transient나 NaN에는 근인을 건드리지 않는다. 진단 없이 처방하면 숨어 있던 문제가 더 늦게, 더 크게 터진다.
폭주의 4가지 출처
각 spike 유형은 모델의 서로 다른 부위에서 시작된다.
Embedding: 희귀 토큰의 gradient 분산 폭발. 자연언어 토큰은 Zipfian 분포를 따른다. 상위 1,000개 토큰이 corpus의 약 64%를 차지하고, 나머지 49,000개는 매우 드물게 등장한다. 배치 크기 32, 시퀀스 길이 2,048 기준으로 희귀 토큰은 대부분의 배치에서 0회 등장한다. 그러다 등장하면 gradient가 엄청나게 크다. 고정 LR에서 이 업데이트는 제어 불가능하다.
Attention: logit이 BF16 안전 범위를 초과. Q, K의 내적을 로 나눠도, sequence length 가 커지면 max logit은 에 비례해 증가한다. , 이면 이 값은 수천을 넘어 BF16 안전 범위 를 훨씬 벗어난다. Softmax 입력이 30을 넘으면 exp 계산에서 overflow가 발생한다.
Output logit: partition function의 발산. 훈련 중 출력 weight 와 hidden state norm이 함께 커지면 가 폭발한다. 이면 softmax의 분모 가 overflow되어 NaN이 된다.
Weight norm drift: 느린 누적. AdamW는 weight decay 항 를 포함하지만 이면 decay rate는 스텝당 1%에 불과하다. 훈련 후기에 gradient가 0에 가까워지면 decay가 주도해야 하지만, 속도가 너무 느려서 가 지속적으로 증가한다. Activation scale이 따라 오르고 gradient가 다시 커진다.
4가지 처방의 수학
각 근인에 대응하는 해결책은 원인을 직접 제거한다.
Embedding LR 분리. μP(Maximal Update Parameterization)에서 embedding layer는 output layer처럼 width-independent LR을 사용한다. 실험적으로 확립된 값은 다음과 같다.
Pythia(114M20B), Gemma(2B7B) 모두 이 비율을 사용한다. 희귀 토큰의 gradient variance가 100~1000배 크더라도, LR을 10분의 1로 낮추면 catastrophic update의 진폭이 같은 비율로 줄어든다.
QK-norm. Dehghani et al. (2023, ViT-22B)에서 도입한 방법이다. Q, K를 unit norm으로 정규화하고 학습 가능한 scale 를 곱한다.
Cauchy-Schwarz에 의해 정규화 후 내적은 안에 속하므로 logit 범위가 로 명시적으로 제한된다. 의 초기값을 로 설정하면 표준 attention의 실제 scale을 학습하면서도 overflow를 막을 수 있다.
z-loss. Chowdhery et al. (2022, PaLM)에서 도입한 보조 손실이다.
z-loss의 gradient는 다음과 같다.
gradient가 에 비례하므로 모든 토큰의 logit을 균등하게 감쇠한다. softmax 분포의 shape는 유지하면서 절대적인 logit scale만 낮춘다. PaLM, Gemini, Grok 모두 이 기법을 사용한다.
RMSNorm. Zhang et al. (2019)이 제안한 정규화 방식이다.
LayerNorm과의 결정적 차이는 centering(평균 빼기)의 제거다. 입력 를 배 스케일하면 다음이 성립한다.
정확히 불변이다. 가 훈련 중 증가해도 RMSNorm이 activation scale을 흡수하므로 slow divergence가 완화된다. LLaMA, Gemma, Mistral, Qwen 등 최신 LLM이 모두 RMSNorm을 선택한 이유다.
양수 상수 에 대해 이 성립한다.
AdamW . 분모 에서 이 너무 작으면 희귀 토큰 embedding처럼 인 파라미터에서 effective LR이 까지 폭발한다. 반대로 너무 크면 적응형 LR의 이점이 사라진다. BF16 훈련에서 실전 권장값은 다.
트레이드오프
- Embedding LR 분리: parameter group 설정 복잡도 증가. tied embedding(output projection = embedding) 사용 시 조정 필요.
- QK-norm: head별 추가 파라미터 , normalization 연산 오버헤드. cross-head scale 차이는 여전히 무시.
- z-loss: 훈련 loss에 noise 추가(이면 약 0.25%). inference 시에는 제거.
- RMSNorm: centering 제거로 일부 zero-mean 가정 기반 분석이 성립하지 않음. 실제로는 대부분의 경우 문제 없음.
- 큰 : 희귀 파라미터의 adaptive rate 감소 → 수렴 다소 느림.
정리
- Loss spike는 하나의 원인이 아니다. 근인(Embedding, Attention, Output, Weight)이 다르면 처방도 달라진다.
- Embedding LR 분리(), QK-norm, z-loss, RMSNorm, 적절한 은 각각 다른 폭주 경로를 직접 차단한다.
- 7B 모델에서 spike 발생 빈도가 약 3%인 데 비해 70B에서는 약 12%다. 규모가 커질수록 이 기법들은 선택이 아니라 필수가 된다.
- 진단 없이 LR만 낮추는 것은 근인을 건드리지 않는 증상 치료다.
다음 글에서는 사전학습 데이터의 품질 선별(data curation)이 loss curve에 어떤 흔적을 남기는지 추적한다.