PyTorch autograd는 어떻게 gradient를 계산하는가
Forward-mode JVP와 reverse-mode VJP의 비용 분석부터 computation graph의 동적 생성, custom Function 구현, double backward까지 — autograd의 설계 철학을 추적한다.
- 01 PyTorch Tensor는 왜 Storage와 Metadata로 분리되어 있는가
- 02 PyTorch autograd는 어떻게 gradient를 계산하는가
- 03 PyTorch Dispatcher는 어떻게 동작하는가
- 04 GPU 커널 성능은 무엇이 결정하는가
- 05 PyTorch Custom Kernel의 핵심은 HBM을 피하는 것이다
- 06 Mixed Precision Training의 수학 — FP16은 왜 위험하고 BF16은 왜 안전한가
- 07 torch.compile은 Python 코드를 어떻게 GPU 커널로 바꾸는가
loss.backward() 한 줄이 수백만 개의 파라미터 gradient를 동시에 계산한다. 이것이 가능한 이유는 PyTorch가 **reverse-mode automatic differentiation(AD)**를 선택했기 때문이다. 그렇다면 왜 하필 reverse-mode인가? 그리고 이 선택이 computation graph, custom Function, double backward라는 세 가지 구조를 어떻게 결정하는가?
JVP vs VJP — 같은 chain rule의 두 구현
자동 미분에는 두 가지 모드가 있다. 둘 다 동일한 chain rule을 구현하지만, 행렬 곱셈의 결합 순서가 다르다.
Forward-mode (JVP): tangent vector 를 입력 방향으로 전파한다.
Reverse-mode (VJP): cotangent vector 를 출력에서 입력 방향으로 역전파한다.
두 식은 같은 Jacobian 의 서로 다른 곱이다. 전체 gradient를 구하는 비용은 모드에 따라 달라진다.
여기서 은 입력 차원, 은 출력 차원, 은 연산 수다. Neural network training에서 (scalar loss), 이므로 reverse-mode가 배 효율적이다. loss.backward()가 reverse-mode인 이유는 이 비용 분석의 직접적 귀결이다 (Baydin et al. 2018).
출력 차원이 입력보다 훨씬 많은 경우 — 예를 들어 generator — 에서는 forward-mode(jax.jacfwd)가 더 효율적이다.
Computation Graph — 동적으로 만들어지는 DAG
PyTorch의 computation graph는 define-by-run 방식이다. Forward pass를 실행하는 순간 그래프 노드가 생성된다.
x = torch.tensor(2.0, requires_grad=True) # leaf, grad_fn=None
y = x ** 2 # intermediate, grad_fn=PowBackward0
z = y + 1 # intermediate, grad_fn=AddBackward0
그래프는 항상 DAG(Directed Acyclic Graph)다. 만약 cycle이 존재하면 를 정의하는 데 가 필요한 모순이 생기므로 PyTorch가 이를 보장한다.
scalar loss 에 대해 동일한 leaf tensor 가 여러 경로로 도달할 때, backward는 각 경로의 gradient를 누적한다.
Backward는 reverse topological order로 각 노드의 local VJP를 호출한다. 같은 leaf에 도달하는 여러 경로의 gradient는 leaf 노드에서 +=로 합산된다. 미분의 선형성으로부터 이 합이 전체 편미분과 동일함을 알 수 있다.
Intermediate tensor의 .grad는 backward 후 None이다. PyTorch는 메모리 절약을 위해 leaf tensor의 gradient만 보존한다. y.retain_grad()를 명시적으로 호출해야 중간값 gradient를 얻을 수 있다.
Backward — Reverse Topological Sort
Backward 알고리즘의 본질은 단순하다.
1. topo_order = ReverseTopologicalSort(G, from=L)
2. gradient[L] = 1
3. For each node n in topo_order:
4. For each input i of n:
5. gradient[i] += local_vjp(n, gradient[n])
실전에서 중요한 두 가지:
# BAD: 매 iteration마다 gradient가 누적됨
for epoch in range(100):
loss = criterion(model(x), y)
loss.backward() # gradient 누적!
optimizer.step()
# GOOD
for epoch in range(100):
optimizer.zero_grad(set_to_none=True) # None 할당 → 더 빠름
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
set_to_none=True는 zero_()(메모리 재사용)와 달리 gradient tensor 자체를 해제하므로 약 2배 빠르다.
Custom Function — VJP Rule의 명시적 구현
PyTorch가 제공하지 않는 연산을 만들 때는 torch.autograd.Function으로 local VJP rule을 직접 정의한다.
class MySoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
sigma = torch.softmax(x, dim=-1)
ctx.save_for_backward(sigma) # ctx가 유일한 전달 통로
return sigma
@staticmethod
def backward(ctx, grad_output): # grad_output = upstream ∂L/∂y
sigma, = ctx.saved_tensors
# VJP: bar_x = sigma * (grad - (sigma * grad).sum())
return sigma * (grad_output - (sigma * grad_output).sum(dim=-1, keepdim=True))
핵심 패턴 두 가지:
ctx.save_for_backward()에 저장하지 않고 closure에서 참조하면 graph가 유지되어 메모리 누수가 발생한다.- Backward의 return tuple 길이는 반드시 forward의 input 개수와 일치해야 한다.
구현 후에는 반드시 gradcheck로 numerical gradient와 비교 검증한다.
x = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
torch.autograd.gradcheck(MySoftmax.apply, (x,), eps=1e-6, atol=1e-4)
Custom Function의 비용: forward/backward를 직접 구현하므로 PyTorch의 자동 최적화(operation fusion 등)를 받지 못한다. 단, CUDA kernel과 연동하거나 수치 안정성이 중요한 연산(log-softmax 등)에서는 명시적 구현이 필수다.
Double Backward — Hessian을 로 계산하기
create_graph=True를 사용하면 backward 자체가 differentiable operation으로 graph에 기록된다. 이를 통해 2차 미분을 효율적으로 계산할 수 있다.
가장 실용적인 패턴은 **Hessian-Vector Product(HVP)**다.
Explicit Hessian 는 메모리를 요구하지만, HVP는 단 1회의 추가 backward로 메모리에 계산된다. 파라미터가 개라면 배 효율적이다 (Pearlmutter 1994).
def hvp(f, x, v):
fx = f(x)
grad_fx = torch.autograd.grad(fx, x, create_graph=True)[0]
return torch.autograd.grad((grad_fx * v).sum(), x)[0]
Physics-Informed NN에서 PDE 항 를 계산하는 것도 같은 원리다.
u_x = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
u_xx = torch.autograd.grad(u_x.sum(), x)[0]
residual = ((u_t - nu * u_xx) ** 2).mean()
각 단계에서 create_graph=True가 graph를 유지해야 다음 단계의 backward가 가능하다.
정리
- Reverse-mode(VJP)는 (scalar loss)일 때 forward-mode보다 배 효율적이다. ML에서 이 선택은 비용 분석의 직접적 귀결이다.
- Computation graph는 forward 실행 시 동적으로 생성되는 DAG다. Backward는 이 DAG의 reverse topological order로 local VJP를 누적한다.
- Custom Function은
ctx를 통해 forward/backward를 연결하며,gradcheck로 수치 검증이 가능하다. - Double backward(
create_graph=True)는 Hessian 행렬 없이 HVP를 비용으로 계산한다.
loss.backward() 뒤에는 수십 줄의 알고리즘이 있다. 그 알고리즘을 알면, 왜 gradient가 누적되는지, 왜 inference_mode()가 더 빠른지, 왜 PINN이 동작하는지 모두 같은 프레임으로 설명된다.