RNN 학습은 왜 이렇게 설계됐는가
Cyclic 구조를 DAG로 펼치는 unrolling부터 BPTT 유도, truncation의 bias-memory 트레이드오프, 그리고 RTRL이 왜 다시 주목받는지까지, RNN 학습 알고리즘의 설계 결정을 추적한다.
- 01 RNN은 왜 sequence를 기억하는가
- 02 RNN 학습은 왜 이렇게 설계됐는가
- 03 RNN Gradient 소멸은 왜 피할 수 없었나
- 04 LSTM은 어떻게 vanishing gradient를 피하는가
- 05 RNN 변형들이 공유하는 하나의 질문
- 06 Attention은 어떻게 Seq2Seq의 병목을 뚫었는가
- 07 RNN이 Transformer에 밀린 이유, 그리고 Mamba가 돌아온 이유
RNN은 본질적으로 cyclic한 구조다 — $h_t = f(h_{t-1}, x_t)$가 자기참조한다. 그런데 학습은 acyclic 그래프 위의 chain rule을 요구한다. 이 모순을 어떻게 해결하는가? 그리고 그 해결책이 메모리, 속도, 온라인 학습 능력에 어떤 트레이드오프를 만들어내는가?
Unrolling — cyclic을 DAG로
Unrolling은 RNN의 cyclic 구조를 시간축으로 펼쳐 DAG로 만드는 변환이다. $h_t = f(h_{t-1}, x_t)$에서 각 $h_t$를 distinct 노드로 만들면, 순환 참조가 사라지고 topological order $h_0, x_1, h_1, \ldots, h_T$가 생긴다.
Cyclic: Unrolled (DAG):
┌──────────┐ x₁→[cell]→h₁→[cell]→h₂→[cell]→h₃
x → [cell] →h │ │ │
└──┘ y₁ y₂ y₃
self-loop
이 변환이 중요한 이유는 세 가지다. 첫째, chain rule이 DAG에서 잘 정의되므로 BPTT가 standard backprop으로 환원된다. 둘째, 모든 cell이 동일한 $W_{hh}, W_{xh}$를 공유하므로 gradient는 자동으로 합산된다.
셋째, forward에 $O(TH)$의 activation을 모두 보존해야 backward가 가능하다 — 이것이 메모리 병목의 출발점이다.
BPTT — 시간을 거슬러 흐르는 delta
BPTT는 두 가지 핵심 수식으로 요약된다.
Delta의 역방향 전파:
Weight gradient의 합산:
$L_t$의 $W_{hh}$에 대한 gradient는 모든 과거 경로의 합이다.
$L_t \leftarrow h_t \leftarrow h_{t-1} \leftarrow \cdots \leftarrow h_k \leftarrow W_{hh}의 모든 경로에 chain rule을 적용한다. 각 경로의 contribution이
이고, 공유 weight의 gradient는 모든 에 대해 합산된다.
이 Jacobian 곱 $\prod J_j$의 spectral radius가 1보다 작으면 gradient가 exponentially 감소(vanishing)하고, 크면 폭발(exploding)한다. LSTM의 CEC(Constant Error Carousel)는 이 곱을 $f_t$로 대체해 문제를 완화한다 — 이는 다음 챕터의 주제다.
Truncated BPTT — bias와 memory의 교환
Full BPTT는 $O(TH)$ 메모리를 요구한다. $T = 10000$, $H = 1000$, $B = 32$, float32라면 forward activation만 1.28 GB다. Truncated BPTT(TBPTT)는 이 제약을 detach()로 해결한다.
for s in range(0, T, k):
h = h.detach() # gradient flow 차단
logits, h = model(x[s:s+k], h)
loss = criterion(logits, target[s:s+k])
loss.backward()
optimizer.step()
detach()는 이전 chunk로의 gradient flow를 끊는다. 메모리가 $O(TH)$에서 $O(kH)$로 줄고, 매 $k$ step마다 update가 일어난다.
TBPTT는 길이 $> k$인 long-range dependency의 gradient를 무시한다. 그러나 spectral radius $\rho < 1$이면 bias가 $O(\rho^k)$로 exponentially 감소한다. $\rho = 0.9$, $k = 25$에서 bias는 약 7% — 실용적으로 무시 가능하다. Karpathy의 char-RNN이 $k = 25$를 택한 이유가 여기 있다.
복잡도와 병렬성의 한계
BPTT의 총 비용은 forward와 backward 각각 $O(TH^2)$, 메모리 $O(TH)$다. Gradient checkpointing(Chen 2016)은 $\sqrt{T}$개 checkpoint만 보존하고 segment마다 forward를 재실행해 메모리를 $O(\sqrt{T}H)$로 줄인다 — 최적 segment 길이 $s^* = \sqrt{T}$는 $M(s) = TH/s + sH$를 미분해 구한다.
더 근본적인 한계는 sequence 내부 병렬성 부재다. $h_t는 $h_{t-1}$이 끝나야 계산할 수 있다. GPU의 수천 개 코어는 batch 차원과 hidden 차원의 matmul에는 활용되지만, sequence 차원 자체는 sequential이다. Transformer가 $O(T^2H)$의 비용을 감수하면서 attention matrix를 한 번에 계산하는 이유가 바로 이것이다.
RTRL — 뒤돌아보지 않는 학습
RTRL(Williams & Zipser 1989)은 BPTT의 거울상이다. Sensitivity matrix $S_t = \partial h_t / \partial \theta$를 forward와 함께 propagate한다.
이 수식이 BPTT와 수학적으로 동일한 gradient를 생성한다는 것은 chain rule의 결과다. 차이는 계산 순서다 — forward-mode AD는 weight 차원($|\theta| = O(H^2)$)을 propagate하므로 per-step 비용이 $O(H^4)$다. BPTT의 $O(H^2)$에 비해 $H^2$배 느리다.
그러나 RTRL에는 BPTT가 갖지 못한 것이 있다: 매 step 즉시 update가 가능하다. BPTT는 episode가 끝나야 backward를 실행한다. RTRL은 $L_t$를 관찰하는 즉시 $\theta$를 갱신한다. UORO(Tallec & Ollivier 2017)는 Rademacher random projection으로 sensitivity를 rank-1로 근사해 비용을 $O(H^2)$로 낮추면서 unbiasedness를 유지한다.
BPTT는 supervised learning의 효율성 챔피언이다 — $O(H^2)$ per step, autograd 생태계 완전 지원. RTRL은 online learning과 biological plausibility의 자연스러운 framework다 — streaming RL, edge AI, neuromorphic chip의 e-prop이 이 계열이다. 두 패러다임은 경쟁이 아니라 다른 use case에서 공존한다.
정리
- Unrolling은 cyclic RNN을 DAG로 변환한다. 이것이 BPTT가 standard backprop으로 환원되는 이유다.
- BPTT의 핵심은 Jacobian 곱
$\prod J_j$의 누적이다. Spectral radius가 vanishing/exploding을 결정한다. - TBPTT는
detach()로 gradient flow를 자르고 메모리를$O(kH)$로 줄인다. Bias는$O(\rho^k)$로 exponentially 감소한다. - RNN의 본질적 한계는 sequence-internal sequential dependency다. Transformer와 Mamba는 각각 다른 방식으로 이를 극복한다.
- RTRL은 BPTT와 동일한 gradient를
$O(H^4)$로 계산하지만, online update가 가능하다는 근본적 차이가 있다.
다음 글에서는 Jacobian 곱의 spectral radius 분석(Pascanu 2013)을 통해 vanishing/exploding gradient의 정확한 조건을 추적하고, LSTM의 CEC가 이를 어떻게 우회하는지 살펴본다.