LSTM은 어떻게 vanishing gradient를 피하는가
Hochreiter 1997의 CEC 비전부터 forget gate 초기화, GRU의 단순화, variants의 ablation 결과까지 — LSTM 설계 철학의 핵심을 추적한다.
- 01 RNN은 왜 sequence를 기억하는가
- 02 RNN 학습은 왜 이렇게 설계됐는가
- 03 RNN Gradient 소멸은 왜 피할 수 없었나
- 04 LSTM은 어떻게 vanishing gradient를 피하는가
- 05 RNN 변형들이 공유하는 하나의 질문
- 06 Attention은 어떻게 Seq2Seq의 병목을 뚫었는가
- 07 RNN이 Transformer에 밀린 이유, 그리고 Mamba가 돌아온 이유
Plain RNN의 gradient는 시간 축을 거슬러 올라갈수록 행렬 곱의 거듭제곱으로 소멸한다. LSTM은 이 구조를 부수는 대신, 곱셈적 누적을 덧셈적 누적으로 바꾼다. 어떻게? 그리고 그 선택이 만든 trade-off는 무엇인가?
문제의 뿌리: 행렬 곱의 지수적 붕괴
Plain RNN의 BPTT gradient는 다음 형태를 가진다.
spectral radius 이면 가 커질수록 이 곱은 지수적으로 0에 수렴한다. , 라면 후 gradient norm은 이다. float32의 정밀도 한계()를 훨씬 밑돈다.
핵심 통찰은 단순하다. 곱셈적 누적이 문제라면, 덧셈적 누적으로 바꾸면 된다.
이것이 Hochreiter & Schmidhuber 1997의 Constant Error Carousel(CEC) 비전이다.
CEC: 덧셈이 gradient를 보존한다
LSTM의 cell state update는 다음과 같다.
여기서 는 모두 과 의 함수이며, 에 직접 의존하지 않는다. 따라서 direct partial derivative는 다음과 같이 단순화된다.
에서 는 에 직접 의존하지 않는다. 따라서 이고, 첫 항의 direct partial만 남는다.
chain rule을 step 전체에 적용하면 cell path의 gradient product는 행렬 곱이 아닌 element-wise scalar 곱이 된다.
이면 이 곱은 이다. 어떤 에서도 gradient가 보존된다. Plain RNN의 과 LSTM()의 의 차이는 배다.
4개 gate의 역할 분담
단순한 additive update만으로는 부족하다. “선택성”이 없으면 모든 입력이 무한 누적되고, 불필요한 정보를 지울 수도 없다. LSTM은 여기에 gate를 도입한다.
sigmoid gate는 “얼마나”(degree)를, tanh는 “어떤 값”(signed magnitude)을 담당한다. 두 역할이 분리되어야 gate가 정보의 양과 내용을 독립적으로 제어할 수 있다.
파라미터 수는 로 plain RNN의 4배다. 학습 비용이 늘지만, 이 비용이 long-range dependency 해결 능력을 산다.
Forget Bias 초기화: 학습 시작의 결정적 차이
CEC는 수학적으로 gradient 보존을 보장하지만, 학습 초기에도 그 보장이 작동해야 한다. 여기서 Jozefowicz et al. 2015의 핵심 발견이 등장한다.
학습 시작 시 weight가 small random값이라면 이므로 forget gate는 bias에 의해 결정된다.
| 후 cell gradient | ||
|---|---|---|
| 0 (PyTorch default) | 0.500 | |
| 1 (Jozefowicz 권장) | 0.731 | |
| 2 | 0.881 |
이면 초기 forget gate가 0.5 — plain RNN과 거의 같은 vanishing이 발생한다. 은 이를 배 개선하여 cold start를 회피한다. 설정 방법은 간단하다.
def set_forget_bias(lstm_module, value=1.0):
for i in range(lstm_module.num_layers):
H = lstm_module.hidden_size
bias_ih = getattr(lstm_module, f'bias_ih_l{i}')
bias_hh = getattr(lstm_module, f'bias_hh_l{i}')
with torch.no_grad():
# PyTorch order: i, f, g, o
bias_ih[H:2*H].fill_(value)
bias_hh[H:2*H].fill_(0.0)
nn.LSTM의 forget bias default는 0이다. 이 값은 long-range task에서 학습 실패의 원인이 된다. LSTM을 쓰는 모든 코드에 set_forget_bias(lstm, 1.0)을 추가하는 것이 best practice다.
GRU: 75%의 파라미터로 같은 역할
Cho et al. 2014의 GRU는 질문 하나에서 출발한다. “4개 gate가 모두 필요한가?”
GRU는 cell state와 hidden state를 통합하고, forget/input gate를 coupled update gate 로 단순화한다.
이면 (완전 보존), 이면 (완전 교체). GRU의 는 LSTM의 forget gate에, 는 input gate에 대응한다. 단지 두 gate가 합이 1이 되도록 coupled되어 있을 뿐이다.
reset gate 는 LSTM에 없는 메커니즘이다. candidate를 계산할 때 이전 hidden의 일부를 무시하는 “fresh start” 능력을 제공한다. 이는 topic shift나 문장 경계처럼 context를 끊어야 하는 상황에서 유리하다.
결과는 파라미터 수 — LSTM의 75%다. Chung et al. 2014의 empirical 비교는 GRU와 LSTM이 대부분의 task에서 comparable 성능을 보임을 보였다.
트레이드오프
| Plain RNN | LSTM | GRU | |
|---|---|---|---|
| Gradient | 행렬 곱, exponential 소멸 | element-wise 곱, 보존 가능 | element-wise, LSTM과 유사 |
| 파라미터 | RNN | RNN | |
| Long-range | 사실상 불가 | 가능 ( 유지 시) | 가능 (comparable) |
| 속도 | 빠름 | 느림 | LSTM보다 ~30% 빠름 |
| 해석성 | 낮음 | cell/hidden 분리로 높음 | 통합 state로 낮음 |
Greff et al. 2017의 LSTM: A Search Space Odyssey는 8가지 LSTM variant를 ablation한 결과, 놀랍게도 vanilla LSTM이 가장 robust했다고 보고한다. input gate를 제거하면 PTB perplexity가 84.7에서 92.3으로 오른다. forget gate 제거는 88.5, output gate 제거는 89.1. 반면 peephole이나 coupled gate(CIFG)는 marginal 차이만 낸다.
이 결과는 역설처럼 보인다 — vanilla LSTM이 내부적으로 optimal이라면 왜 Transformer가 이를 대체했는가? 답은 scale이 다른 문제다. Greff 2017은 RNN family 내부의 비교다. Transformer는 RNN framework 자체를 버리고 sequential dependency를 attention으로 대체했다. sequential 계산을 parallel 계산으로 바꾼 것이 2017 이후의 패러다임 전환이다.
정리
- Plain RNN의 gradient는 행렬 곱의 지수적 소멸로 무너진다. LSTM의 cell state는 이를 element-wise 곱으로 바꿔 보존한다.
- CEC의 보장은 수학적이지만, 실제 학습에서는 forget bias 초기화가 cold start를 막는 결정적 조건이다.
- GRU는 LSTM의 75% 파라미터로 comparable 성능을 낸다. resource 제약이 있거나 데이터가 적을 때 GRU가 합리적 선택이다.
- vanilla LSTM의 4-gate 구조는 RNN family 내에서 remarkably robust하다. Transformer의 우위는 다른 차원의 문제다.
LSTM과 Transformer의 관계는 ‘더 나은 같은 것’이 아니라 ‘다른 것’이다. 그리고 2023년 Mamba의 등장은 linear recurrence라는 이름으로 LSTM의 핵심 통찰이 다시 살아났음을 보여준다.