← all posts
AI 2026.05.03 · 11 min read Advanced

PyTorch autograd는 어떻게 gradient를 계산하는가

Forward-mode JVP와 reverse-mode VJP의 비용 분석부터 computation graph의 동적 생성, custom Function 구현, double backward까지 — autograd의 설계 철학을 추적한다.


loss.backward() 한 줄이 수백만 개의 파라미터 gradient를 동시에 계산한다. 이것이 가능한 이유는 PyTorch가 **reverse-mode automatic differentiation(AD)**를 선택했기 때문이다. 그렇다면 왜 하필 reverse-mode인가? 그리고 이 선택이 computation graph, custom Function, double backward라는 세 가지 구조를 어떻게 결정하는가?

JVP vs VJP — 같은 chain rule의 두 구현

자동 미분에는 두 가지 모드가 있다. 둘 다 동일한 chain rule을 구현하지만, 행렬 곱셈의 결합 순서가 다르다.

Forward-mode (JVP): tangent vector x˙Rn\dot{x} \in \mathbb{R}^n를 입력 방향으로 전파한다.

y˙=Jf(x)x˙\dot{y} = J_f(x)\,\dot{x}

Reverse-mode (VJP): cotangent vector yˉRm\bar{y} \in \mathbb{R}^m를 출력에서 입력 방향으로 역전파한다.

xˉ=Jf(x)Tyˉ\bar{x} = J_f(x)^T\,\bar{y}

두 식은 같은 Jacobian JfJ_f의 서로 다른 곱이다. 전체 gradient를 구하는 비용은 모드에 따라 달라진다.

Costforward=nO(N),Costreverse=mO(N)\text{Cost}_{\text{forward}} = n \cdot O(N), \quad \text{Cost}_{\text{reverse}} = m \cdot O(N)

여기서 nn은 입력 차원, mm은 출력 차원, NN은 연산 수다. Neural network training에서 m=1m = 1(scalar loss), n106n \approx 10^6이므로 reverse-mode가 nn배 효율적이다. loss.backward()가 reverse-mode인 이유는 이 비용 분석의 직접적 귀결이다 (Baydin et al. 2018).

Forward-mode가 유리한 경우

출력 차원이 입력보다 훨씬 많은 경우 — 예를 들어 generator zR512imageR3×1024×1024z \in \mathbb{R}^{512} \to \text{image} \in \mathbb{R}^{3 \times 1024 \times 1024} — 에서는 forward-mode(jax.jacfwd)가 더 효율적이다.

Computation Graph — 동적으로 만들어지는 DAG

PyTorch의 computation graph는 define-by-run 방식이다. Forward pass를 실행하는 순간 그래프 노드가 생성된다.

x = torch.tensor(2.0, requires_grad=True)  # leaf, grad_fn=None
y = x ** 2                                  # intermediate, grad_fn=PowBackward0
z = y + 1                                   # intermediate, grad_fn=AddBackward0

그래프는 항상 DAG(Directed Acyclic Graph)다. 만약 cycle이 존재하면 xx를 정의하는 데 xx가 필요한 모순이 생기므로 PyTorch가 이를 보장한다.

명제 1 · Leaf gradient accumulation

scalar loss LL에 대해 동일한 leaf tensor xx가 여러 경로로 도달할 때, backward는 각 경로의 gradient를 누적한다.

x.grad=pathLxpathx.\text{grad} = \sum_{\text{path}} \frac{\partial L}{\partial x}\bigg|_{\text{path}}
▷ 증명

Backward는 reverse topological order로 각 노드의 local VJP를 호출한다. 같은 leaf에 도달하는 여러 경로의 gradient는 leaf 노드에서 +=로 합산된다. 미분의 선형성으로부터 이 합이 전체 편미분과 동일함을 알 수 있다. \square

Intermediate tensor의 .grad는 backward 후 None이다. PyTorch는 메모리 절약을 위해 leaf tensor의 gradient만 보존한다. y.retain_grad()를 명시적으로 호출해야 중간값 gradient를 얻을 수 있다.

Backward — Reverse Topological Sort

Backward 알고리즘의 본질은 단순하다.

1. topo_order = ReverseTopologicalSort(G, from=L)
2. gradient[L] = 1
3. For each node n in topo_order:
4.   For each input i of n:
5.     gradient[i] += local_vjp(n, gradient[n])
Backward=ReverseTopologicalSort(G)+LocalVJP accumulation\boxed{\text{Backward} = \text{ReverseTopologicalSort}(\mathcal{G}) + \text{LocalVJP accumulation}}

실전에서 중요한 두 가지:

# BAD: 매 iteration마다 gradient가 누적됨
for epoch in range(100):
    loss = criterion(model(x), y)
    loss.backward()          # gradient 누적!
    optimizer.step()

# GOOD
for epoch in range(100):
    optimizer.zero_grad(set_to_none=True)   # None 할당 → 더 빠름
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()

set_to_none=Truezero_()(메모리 재사용)와 달리 gradient tensor 자체를 해제하므로 약 2배 빠르다.

Custom Function — VJP Rule의 명시적 구현

PyTorch가 제공하지 않는 연산을 만들 때는 torch.autograd.Function으로 local VJP rule을 직접 정의한다.

class MySoftmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        sigma = torch.softmax(x, dim=-1)
        ctx.save_for_backward(sigma)      # ctx가 유일한 전달 통로
        return sigma

    @staticmethod
    def backward(ctx, grad_output):       # grad_output = upstream ∂L/∂y
        sigma, = ctx.saved_tensors
        # VJP: bar_x = sigma * (grad - (sigma * grad).sum())
        return sigma * (grad_output - (sigma * grad_output).sum(dim=-1, keepdim=True))

핵심 패턴 두 가지:

  • ctx.save_for_backward()에 저장하지 않고 closure에서 참조하면 graph가 유지되어 메모리 누수가 발생한다.
  • Backward의 return tuple 길이는 반드시 forward의 input 개수와 일치해야 한다.

구현 후에는 반드시 gradcheck로 numerical gradient와 비교 검증한다.

x = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
torch.autograd.gradcheck(MySoftmax.apply, (x,), eps=1e-6, atol=1e-4)
트레이드오프

Custom Function의 비용: forward/backward를 직접 구현하므로 PyTorch의 자동 최적화(operation fusion 등)를 받지 못한다. 단, CUDA kernel과 연동하거나 수치 안정성이 중요한 연산(log-softmax 등)에서는 명시적 구현이 필수다.

Double Backward — Hessian을 O(p)O(p)로 계산하기

create_graph=True를 사용하면 backward 자체가 differentiable operation으로 graph에 기록된다. 이를 통해 2차 미분을 효율적으로 계산할 수 있다.

가장 실용적인 패턴은 **Hessian-Vector Product(HVP)**다.

Hv=(Lv)\boxed{Hv = \nabla(\nabla L \cdot v)}

Explicit Hessian HRp×pH \in \mathbb{R}^{p \times p}O(p2)O(p^2) 메모리를 요구하지만, HVP는 단 1회의 추가 backward로 O(p)O(p) 메모리에 계산된다. 파라미터가 10610^6개라면 10610^6배 효율적이다 (Pearlmutter 1994).

def hvp(f, x, v):
    fx = f(x)
    grad_fx = torch.autograd.grad(fx, x, create_graph=True)[0]
    return torch.autograd.grad((grad_fx * v).sum(), x)[0]

Physics-Informed NN에서 PDE 항 2ux2\frac{\partial^2 u}{\partial x^2}를 계산하는 것도 같은 원리다.

u_x  = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
u_xx = torch.autograd.grad(u_x.sum(), x)[0]
residual = ((u_t - nu * u_xx) ** 2).mean()

각 단계에서 create_graph=True가 graph를 유지해야 다음 단계의 backward가 가능하다.

정리

  • Reverse-mode(VJP)는 m=1m=1(scalar loss)일 때 forward-mode보다 nn배 효율적이다. ML에서 이 선택은 비용 분석의 직접적 귀결이다.
  • Computation graph는 forward 실행 시 동적으로 생성되는 DAG다. Backward는 이 DAG의 reverse topological order로 local VJP를 누적한다.
  • Custom Function은 ctx를 통해 forward/backward를 연결하며, gradcheck로 수치 검증이 가능하다.
  • Double backward(create_graph=True)는 Hessian 행렬 없이 HVP를 O(p)O(p) 비용으로 계산한다.

loss.backward() 뒤에는 수십 줄의 알고리즘이 있다. 그 알고리즘을 알면, 왜 gradient가 누적되는지, 왜 inference_mode()가 더 빠른지, 왜 PINN이 동작하는지 모두 같은 프레임으로 설명된다.