← all posts
AI 2026.05.03 · 12 min read Advanced

LSTM은 어떻게 vanishing gradient를 피하는가

Hochreiter 1997의 CEC 비전부터 forget gate 초기화, GRU의 단순화, variants의 ablation 결과까지 — LSTM 설계 철학의 핵심을 추적한다.


Plain RNN의 gradient는 시간 축을 거슬러 올라갈수록 행렬 곱의 거듭제곱으로 소멸한다. LSTM은 이 구조를 부수는 대신, 곱셈적 누적을 덧셈적 누적으로 바꾼다. 어떻게? 그리고 그 선택이 만든 trade-off는 무엇인가?

문제의 뿌리: 행렬 곱의 지수적 붕괴

Plain RNN의 BPTT gradient는 다음 형태를 가진다.

hTh0=j=1TWhhdiag(σ(zj))\frac{\partial h_T}{\partial h_0} = \prod_{j=1}^{T} W_{hh}^\top \mathrm{diag}(\sigma'(z_j))

spectral radius ρ<1\rho < 1이면 TT가 커질수록 이 곱은 지수적으로 0에 수렴한다. ρ=0.9\rho = 0.9, σ0.5\sigma' \approx 0.5라면 T=100T = 100 후 gradient norm은 (0.9×0.5)1001030(0.9 \times 0.5)^{100} \approx 10^{-30}이다. float32의 정밀도 한계(10810^{-8})를 훨씬 밑돈다.

핵심 통찰은 단순하다. 곱셈적 누적이 문제라면, 덧셈적 누적으로 바꾸면 된다.

ct=ct1+(new info)    ctct1=1c_t = c_{t-1} + (\text{new info}) \implies \frac{\partial c_t}{\partial c_{t-1}} = 1

이것이 Hochreiter & Schmidhuber 1997의 Constant Error Carousel(CEC) 비전이다.

CEC: 덧셈이 gradient를 보존한다

LSTM의 cell state update는 다음과 같다.

ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t

여기서 ft,it,c~tf_t, i_t, \tilde{c}_t는 모두 ht1h_{t-1}xtx_t의 함수이며, ct1c_{t-1}직접 의존하지 않는다. 따라서 direct partial derivative는 다음과 같이 단순화된다.

정리 1 · Cell-to-Cell Direct Partial
ctct1direct=ft(element-wise)\frac{\partial c_t}{\partial c_{t-1}}\bigg|_{\text{direct}} = f_t \quad \text{(element-wise)}
▷ 증명

ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t에서 ft,it,c~tf_t, i_t, \tilde{c}_tct1c_{t-1}에 직접 의존하지 않는다. 따라서 (itc~t)/ct1direct=0\partial(i_t \odot \tilde{c}_t)/\partial c_{t-1}\big|_{\text{direct}} = 0이고, 첫 항의 direct partial만 남는다.

(ftct1)ct1=ft\frac{\partial(f_t \odot c_{t-1})}{\partial c_{t-1}} = f_t \qquad \square

chain rule을 TT step 전체에 적용하면 cell path의 gradient product는 행렬 곱이 아닌 element-wise scalar 곱이 된다.

cTc0cell path=t=1Tft\frac{\partial c_T}{\partial c_0}\bigg|_{\text{cell path}} = \prod_{t=1}^{T} f_t

ft1f_t \approx 1이면 이 곱은 1T=11^T = 1이다. 어떤 TT에서도 gradient가 보존된다. Plain RNN의 103010^{-30}과 LSTM(f=0.99f = 0.99)의 0.991000.370.99^{100} \approx 0.37의 차이는 102910^{29}배다.

4개 gate의 역할 분담

단순한 additive update만으로는 부족하다. “선택성”이 없으면 모든 입력이 무한 누적되고, 불필요한 정보를 지울 수도 없다. LSTM은 여기에 gate를 도입한다.

ft=σ(Wf[ht1;xt]+bf)(forget: 이전 cell 보존?)it=σ(Wi[ht1;xt]+bi)(input: 새 정보 수용?)c~t=tanh(Wc[ht1;xt]+bc)(candidate: 어떤 새 정보?)ot=σ(Wo[ht1;xt]+bo)(output: cell을 얼마나 노출?)ct=ftct1+itc~tht=ottanh(ct)\begin{aligned} f_t &= \sigma(W_f [h_{t-1}; x_t] + b_f) & \text{(forget: 이전 cell 보존?)} \\ i_t &= \sigma(W_i [h_{t-1}; x_t] + b_i) & \text{(input: 새 정보 수용?)} \\ \tilde{c}_t &= \tanh(W_c [h_{t-1}; x_t] + b_c) & \text{(candidate: 어떤 새 정보?)} \\ o_t &= \sigma(W_o [h_{t-1}; x_t] + b_o) & \text{(output: cell을 얼마나 노출?)} \\ c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\ h_t &= o_t \odot \tanh(c_t) \end{aligned}

sigmoid gate는 “얼마나”(degree)를, tanh는 “어떤 값”(signed magnitude)을 담당한다. 두 역할이 분리되어야 gate가 정보의 양과 내용을 독립적으로 제어할 수 있다.

파라미터 수는 4H(D+H+1)4H(D + H + 1)로 plain RNN의 4배다. 학습 비용이 늘지만, 이 비용이 long-range dependency 해결 능력을 산다.

Forget Bias 초기화: 학습 시작의 결정적 차이

CEC는 수학적으로 gradient 보존을 보장하지만, 학습 초기에도 그 보장이 작동해야 한다. 여기서 Jozefowicz et al. 2015의 핵심 발견이 등장한다.

학습 시작 시 weight가 small random값이라면 Wf[h,x]0W_f \cdot [h, x] \approx 0이므로 forget gate는 bias에 의해 결정된다.

bfb_fσ(bf)\sigma(b_f)T=100T=100 후 cell gradient
0 (PyTorch default)0.5001030\approx 10^{-30}
1 (Jozefowicz 권장)0.7311013\approx 10^{-13}
20.881105.5\approx 10^{-5.5}

bf=0b_f = 0이면 초기 forget gate가 0.5 — plain RNN과 거의 같은 vanishing이 발생한다. bf=1b_f = 1은 이를 101610^{16}배 개선하여 cold start를 회피한다. 설정 방법은 간단하다.

def set_forget_bias(lstm_module, value=1.0):
    for i in range(lstm_module.num_layers):
        H = lstm_module.hidden_size
        bias_ih = getattr(lstm_module, f'bias_ih_l{i}')
        bias_hh = getattr(lstm_module, f'bias_hh_l{i}')
        with torch.no_grad():
            # PyTorch order: i, f, g, o
            bias_ih[H:2*H].fill_(value)
            bias_hh[H:2*H].fill_(0.0)
PyTorch default의 함정

nn.LSTM의 forget bias default는 0이다. 이 값은 long-range task에서 학습 실패의 원인이 된다. LSTM을 쓰는 모든 코드에 set_forget_bias(lstm, 1.0)을 추가하는 것이 best practice다.

GRU: 75%의 파라미터로 같은 역할

Cho et al. 2014의 GRU는 질문 하나에서 출발한다. “4개 gate가 모두 필요한가?”

GRU는 cell state와 hidden state를 통합하고, forget/input gate를 coupled update gate ztz_t로 단순화한다.

zt=σ(Wz[ht1;xt]+bz)rt=σ(Wr[ht1;xt]+br)h~t=tanh(W[rtht1;xt]+b)ht=(1zt)ht1+zth~t\begin{aligned} z_t &= \sigma(W_z [h_{t-1}; x_t] + b_z) \\ r_t &= \sigma(W_r [h_{t-1}; x_t] + b_r) \\ \tilde{h}_t &= \tanh(W [r_t \odot h_{t-1}; x_t] + b) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{aligned}

zt=0z_t = 0이면 ht=ht1h_t = h_{t-1} (완전 보존), zt=1z_t = 1이면 ht=h~th_t = \tilde{h}_t (완전 교체). GRU의 1zt1 - z_t는 LSTM의 forget gate에, ztz_t는 input gate에 대응한다. 단지 두 gate가 합이 1이 되도록 coupled되어 있을 뿐이다.

reset gate rtr_t는 LSTM에 없는 메커니즘이다. candidate를 계산할 때 이전 hidden의 일부를 무시하는 “fresh start” 능력을 제공한다. 이는 topic shift나 문장 경계처럼 context를 끊어야 하는 상황에서 유리하다.

결과는 파라미터 수 3H(D+H+1)3H(D+H+1) — LSTM의 75%다. Chung et al. 2014의 empirical 비교는 GRU와 LSTM이 대부분의 task에서 comparable 성능을 보임을 보였다.

트레이드오프

트레이드오프 요약
Plain RNNLSTMGRU
Gradient행렬 곱, exponential 소멸element-wise 곱, 보존 가능element-wise, LSTM과 유사
파라미터H(D+H+1)H(D+H+1)4×4\times RNN3×3\times RNN
Long-range사실상 불가가능 (f1f \approx 1 유지 시)가능 (comparable)
속도빠름느림LSTM보다 ~30% 빠름
해석성낮음cell/hidden 분리로 높음통합 state로 낮음

Greff et al. 2017의 LSTM: A Search Space Odyssey는 8가지 LSTM variant를 ablation한 결과, 놀랍게도 vanilla LSTM이 가장 robust했다고 보고한다. input gate를 제거하면 PTB perplexity가 84.7에서 92.3으로 오른다. forget gate 제거는 88.5, output gate 제거는 89.1. 반면 peephole이나 coupled gate(CIFG)는 marginal 차이만 낸다.

이 결과는 역설처럼 보인다 — vanilla LSTM이 내부적으로 optimal이라면 왜 Transformer가 이를 대체했는가? 답은 scale이 다른 문제다. Greff 2017은 RNN family 내부의 비교다. Transformer는 RNN framework 자체를 버리고 sequential dependency를 attention으로 대체했다. O(T)O(T) sequential 계산을 O(logT)O(\log T) parallel 계산으로 바꾼 것이 2017 이후의 패러다임 전환이다.

정리

  • Plain RNN의 gradient는 행렬 곱의 지수적 소멸로 무너진다. LSTM의 cell state는 이를 element-wise 곱으로 바꿔 보존한다.
  • CEC의 보장은 수학적이지만, 실제 학습에서는 forget bias bf=1b_f = 1 초기화가 cold start를 막는 결정적 조건이다.
  • GRU는 LSTM의 75% 파라미터로 comparable 성능을 낸다. resource 제약이 있거나 데이터가 적을 때 GRU가 합리적 선택이다.
  • vanilla LSTM의 4-gate 구조는 RNN family 내에서 remarkably robust하다. Transformer의 우위는 다른 차원의 문제다.

LSTM과 Transformer의 관계는 ‘더 나은 같은 것’이 아니라 ‘다른 것’이다. 그리고 2023년 Mamba의 등장은 linear recurrence라는 이름으로 LSTM의 핵심 통찰이 다시 살아났음을 보여준다.

REF
Hochreiter, S. and Schmidhuber, J. · 1997 · Long Short-Term Memory · Neural Computation
REF
Chung, J., Gulcehre, C., Cho, K., and Bengio, Y. · 2014 · An Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling · NIPS 2014 Workshop