Diffusion Model의 손실함수는 단순한 MSE처럼 보인다. 그런데 그 MSE가 어디서 왔는지 추적하면, Jensen 부등식 → KL 분해 → 노이즈 예측 → 가중치 제거까지 일련의 설계 결정이 쌓여 있다. 이론과 실무가 이렇게 선명하게 갈리는 손실함수가 또 있을까?
VLB — 왜 하한인가
생성 모델의 목표는 logpθ(x0)를 최대화하는 것이다. 하지만 이 값은 직접 계산할 수 없다. 잠재 변수 x1:T에 대한 적분이 intractable하기 때문이다.
해결책은 보조 분포 q(x1:T∣x0)를 도입하는 것이다. 이 분포는 가우시안 잡음을 단계적으로 더하는 순방향 과정으로 고정된다.
결국 역방향 평균 μθ가 posterior 평균 μ~t를 얼마나 잘 근사하는가가 핵심이다.
노이즈 예측 — 표현 방식의 선택
μθ를 어떻게 parameterize할 것인가? 세 가지 선택지가 있다.
표현 방식
예측 대상
수치 안정성
Noise prediction
ϵθ∈Rd
높음 (unit variance)
X0 prediction
x0,θ 직접 예측
낮음 (큰 범위)
Velocity prediction
vθ=x˙t
중간
세 표현은 이론적으로 동등하다. xt=αˉtx0+1−αˉtϵ이라는 관계로 서로 일대일 변환이 가능하기 때문이다. 그러나 신경망 근사 관점에서는 noise prediction이 압도적으로 선호된다 — ϵ∼N(0,I)이므로 스케일이 항상 일정하고, batch normalization 없이도 학습이 안정적이다.
Noise prediction을 사용하면 Lt−1은 다음과 같이 단순화된다.
Lt−1=2σt2αt(1−αˉt)βt2E[∥ϵ−ϵθ(xt,t)∥2]
Lsimple — 이론을 버리고 얻은 것
위 식에서 가중치 wt=2σt2αt(1−αˉt)βt2는 역 SNR에 비례한다.
wt∝αˉt1−αˉt=SNR(t)1
즉, 잡음이 많고 신호가 약한 초기 시간 단계에 더 큰 가중치를 부여한다. 그런데 Ho et al. (2020)은 이 가중치를 전부 제거하는 것이 실제로 더 좋은 이미지를 생성한다고 보고했다.
Lsimple=Et,x0,ϵ[∥ϵ−ϵθ(xt,t)∥2]
✎ 트레이드오프
이론적으로 최적인 손실(가중 MSE)이 실무에서 최적이 아닌 이유: 초기 저SNR 단계에서는 신호가 너무 약해 정확한 예측이 어렵다. 가중치를 높여봤자 신경망은 학습하기 힘든 항에 용량을 낭비한다. 가중치를 제거하면 신경망이 고SNR 영역(시각적으로 의미있는 구조)에 자연스럽게 집중하게 된다. 이론적 타당성은 낮지만, 지각 품질(perceptual quality)은 높아진다.
Improved DDPM의 세 가지 결정
Nichol & Dhariwal (2021)은 Lsimple을 기반으로 세 가지를 추가한다.
코사인 분산 스케줄. 선형 스케줄은 초기 시간 단계에서 SNR이 급격히 떨어진다. 코사인 스케줄은 초기에 완만하게, 후기에 급격히 감소한다.
αˉt=f(0)f(t),f(t)=cos2(1+st/T+s⋅2π)
SNR 분포가 균등해져 학습이 안정된다. 현재 사실상 업계 표준이다.
학습 가능한 분산. 역방향 분산 Σθ(t)를 고정값 대신 신경망이 예측하도록 한다. βt와 posterior 분산 β~t 사이를 보간하는 스칼라 v∈[0,1]를 학습한다.
Σθ(t)=exp(vlogβt+(1−v)logβ~t)
샘플 품질(FID)은 소폭 향상되고, 로그 우도(NLL)는 유의미하게 향상된다.
하이브리드 손실. 두 목표를 동시에 달성하기 위해 Lsimple과 Lvlb를 섞는다.
Lhybrid=Lsimple+λLvlb,λ=0.001
λ가 작은 이유는 명확하다 — Lvlb를 지나치게 강조하면 FID가 악화된다. 0.001은 로그 우도를 개선하면서 샘플 품질을 유지하는 균형점이다.
정리
VLB는 Jensen 부등식으로 얻은 계산 가능한 하한이다. T=1이면 VAE ELBO와 같다.
ELBO 3항 분해에서 LT는 상수, 학습의 실질적 대상은 Lt−1과 L0이다.
Noise prediction은 이론적으로 x0/velocity prediction과 동등하지만, unit variance 덕분에 수치 안정성이 높다.
Lsimple은 이론 가중치를 제거했지만, 신경망을 고SNR 구조에 집중시켜 지각 품질을 높인다.
Improved DDPM의 코사인 스케줄 + 학습 분산 + 하이브리드 손실은 이후 모든 주요 구현의 기준이 됐다.
다음 글에서는 ϵθ가 실제로 무엇을 근사하는지 — score function ∇xtlogp(xt)와의 연결, 그리고 이것이 연속시간 SDE 관점에서 어떻게 통합되는지 추적한다.