PyTorch Dispatcher는 어떻게 동작하는가
aten::add 한 호출이 CPU·CUDA·Autograd kernel 중 어느 것으로 실행될지 결정하는 Dispatcher의 설계 철학부터 functorch의 함수형 변환까지, PyTorch 내부 구조를 추적한다.
torch.add(x, y) 한 줄을 실행하면 PyTorch는 조용히 50개가 넘는 DispatchKey 중에서 올바른 kernel을 골라 실행한다. CPU인가 CUDA인가, autograd graph를 만들어야 하는가, mixed precision으로 변환해야 하는가 — 이 모든 결정이 Dispatcher라는 routing system을 통해 자동으로 내려진다. 이 메커니즘을 모르면 custom op를 만들거나 torch.compile의 동작을 이해하는 것이 불가능에 가깝다. 왜 PyTorch는 이 구조를 선택했는가?
세 계층: c10, ATen, torch
PyTorch의 내부는 세 계층으로 명확하게 나뉜다.
┌──────────────────────────────────────┐
│ torch (Python binding) │ ← 사용자 호출
└─────────────┬────────────────────────┘
│ pybind11
┌─────────────▼────────────────────────┐
│ ATen (operations + kernels) │ ← add, matmul, conv2d ...
└─────────────┬────────────────────────┘
│
┌─────────────▼────────────────────────┐
│ c10 (core types) │ ← Tensor, Storage, Device
└──────────────────────────────────────┘
c10은 Tensor의 metadata(shape, stride, device, dtype)와 DispatchKey enum만 정의한다. ATen에 대한 의존이 전혀 없다. 이 독립성 덕분에 새로운 backend(TPU, Neuron)를 추가할 때 c10은 손댈 필요가 없다 — ATen의 RegisterTPU.cpp만 추가하면 된다.
ATen은 native_functions.yaml에서 모든 op의 schema를 선언적으로 정의하고, 이로부터 CPU·CUDA kernel 등록 코드를 자동 생성한다.
# native_functions.yaml
- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
dispatch:
CPU: native::add_cpu
CUDA: native::add_cuda
CompositeImplicitAutograd: native::add
이 한 줄의 schema로부터 C++ signature, Python binding, 각 backend의 kernel 등록 scaffold가 자동으로 만들어진다. schema-driven 설계의 핵심은 한 번 선언으로 전 경로의 type safety를 보장한다는 것이다.
Dispatcher: DispatchKeySet와 Fallthrough Chain
Dispatcher의 핵심 자료구조는 두 가지다.
DispatchKey는 50개 이상의 enum 값(CPU, CUDA, Autograd, AutocastCUDA, Functionalize, BackendSelect …)으로 구성된다. DispatchKeySet은 이를 uint64 bit mask로 표현한다.
Dispatcher는 DispatchKeySet의 최상위 set bit부터 순차적으로 kernel을 탐색한다(fallthrough chain). 이 탐색은 __builtin_clzll()(count leading zeros)로 구현되어 O(1)이다.
DispatchKeySet의 최상위 bit부터 kernel을 탐색할 때, 등록된 kernel이 있으면 실행하고 없으면 다음 key로 fallthrough한다. 결과적으로 단 하나의 execution path가 결정된다.
DispatchKeySet 에서 (priority order). Dispatcher는 에 대한 kernel을 먼저 탐색한다. 등록되어 있으면 실행 후 반환, 없으면 로 fallthrough, 반복한다. 모든 key가 미등록이면 error. 따라서 priority order에 따라 정확히 하나의 path가 결정된다.
torch.add(x_cuda, y_cuda, requires_grad=True) 호출 시 DispatchKeySet은 {Autograd[CUDA], CUDA, FP32, ...}가 된다. 가장 높은 priority인 Autograd[CUDA]가 먼저 선택되어 backward metadata를 설정한 뒤, Autograd key를 제거하고 CUDA kernel로 redispatch한다.
Autograd Key: Forward와 Backward의 분리
Autograd[CUDA] kernel이 하는 일은 실제 연산이 아니라 metadata layer다.
torch.no_grad() context는 DispatchKeySet에서 Autograd key를 제거함으로써 동작한다. C++ 내부에서 GradMode::set_enabled(false) 호출 시 compute_dispatch_keys()가 Autograd key를 strip한다. x.detach()는 다르다 — Tensor의 requires_grad 플래그와 grad_fn을 제거하는 Tensor 수준 연산이다.
x.add_(y) 같은 inplace op는 Autograd key를 통해 backward node를 생성하지만, x의 storage 자체가 변한다. backward 시 저장된 original value와 현재 값이 달라 gradient가 부정확해질 수 있다. requires_grad=True인 tensor에는 inplace 연산을 피하라.
Custom Op: TORCH_LIBRARY의 통합
Built-in op과 동일한 dispatcher 메커니즘으로 custom op를 등록하는 것이 TORCH_LIBRARY + TORCH_LIBRARY_IMPL이다.
// Step 1: schema 정의
TORCH_LIBRARY(mylib, m) {
m.def("my_op(Tensor x) -> Tensor");
}
// Step 2: CUDA kernel 등록
TORCH_LIBRARY_IMPL(mylib, CUDA, m) {
m.impl("my_op", &my_op_cuda);
}
// Step 3: Autograd 등록 (optional)
TORCH_LIBRARY_IMPL(mylib, Autograd, m) {
m.impl("my_op", torch::autograd::Function<MyOpAutograd>::apply);
}
이후 Python에서 torch.ops.mylib.my_op(x)로 호출하면, Dispatcher가 자동으로 device를 감지하고 올바른 kernel을 선택한다.
torch.compile 호환을 위해서는 Meta kernel도 등록해야 한다. Meta kernel은 실제 연산 없이 shape/dtype 정보만 전파하는 “dummy” kernel로, TorchDynamo의 symbolic execution이 이를 사용한다.
TORCH_LIBRARY_IMPL(mylib, Meta, m) {
m.impl("my_op", [](Tensor x) { return torch::empty_like(x); });
}
functorch: 함수형 변환의 합성
functorch(torch.func)는 JAX의 철학을 PyTorch에 도입한 것이다 — pure function에 grad, vmap, jacrev, jacfwd 같은 변환을 composable하게 적용한다.
vmap(grad(loss_fn)) 조합이 DP-SGD에서 핵심이다.
import torch.func as F
# Per-sample gradient
per_sample_grads = F.vmap(
F.grad(loss_fn, argnums=0),
in_dims=(None, 0, 0) # params broadcast, xs/ys batched
)(params, xs, ys)
# shape: [batch_size, d_out, d_in]
Jacobian 계산에서는 방향 선택이 복잡도를 결정한다.
functorch는 pure function을 요구한다. in-place 연산, global state mutation, torch.nn 모듈의 stateful forward는 그대로 쓸 수 없다. torch.nn.utils.stateless (또는 torch.func.functional_call)로 파라미터를 외부에서 주입하는 패턴이 필요하다. 유연성과 함수형 순수성 사이의 트레이드오프는 여전히 존재한다.
정리
- PyTorch의 모든 op 호출은 Dispatcher를 거친다. DispatchKeySet의 최상위 bit부터 kernel을 탐색하는 fallthrough chain이 device·dtype·backward kernel을 O(1)로 결정한다.
- c10 → ATen → torch의 계층 분리와
native_functions.yaml의 schema-driven 설계가 새 backend 추가 비용을 극소화한다. Autograd key는 실제 연산이 아닌 metadata layer다. forward에서 backward node를 붙이고 CUDA kernel로 redispatch한다.TORCH_LIBRARY+ Meta kernel 등록으로 custom op가torch.compile과 완전히 통합된다.vmap(grad(...))합성은 DP-SGD의 per-sample gradient를 loop 없이 vectorize한다.
Dispatcher를 이해하면 PyTorch의 “magic”이 사라진다 — 그 자리에 명확한 routing table이 남는다.