torch.compile은 Python 코드를 어떻게 GPU 커널로 바꾸는가
Dynamo의 바이트코드 캡처부터 AOTAutograd의 심볼릭 역전파, Inductor의 커널 퓨전, 분산 학습과의 통합까지, PT 2.0 컴파일 파이프라인의 설계 철학을 추적한다.
- 01 PyTorch Tensor는 왜 Storage와 Metadata로 분리되어 있는가
- 02 PyTorch autograd는 어떻게 gradient를 계산하는가
- 03 PyTorch Dispatcher는 어떻게 동작하는가
- 04 GPU 커널 성능은 무엇이 결정하는가
- 05 PyTorch Custom Kernel의 핵심은 HBM을 피하는 것이다
- 06 Mixed Precision Training의 수학 — FP16은 왜 위험하고 BF16은 왜 안전한가
- 07 torch.compile은 Python 코드를 어떻게 GPU 커널로 바꾸는가
torch.compile(model) 한 줄을 추가하면 모델이 20~50% 빨라진다. 코드는 바뀌지 않는다. 어떻게 가능한가? 그리고 왜 첫 번째 호출은 수십 초가 걸리고, 그 이후는 거짓말처럼 빨라지는가?
파이프라인의 전체 구조
PT 2.0의 핵심 아이디어는 단순하다: Python 코드를 그대로 유지하면서 내부 실행 경로만 바꾼다.
Eager mode에서는 forward 한 번에 Python 인터프리터가 수십 번 개입하고, 커널 런치가 10개 이상 발생한다. torch.compile은 첫 호출에서 이 Python 코드를 정적 그래프로 변환하고, 이후 호출에서는 그 그래프를 단일 Triton 커널로 실행한다. 런치 횟수가 10+에서 1로 줄고, Python 인터프리터가 사라진다.
Dynamo — 바이트코드를 가로채다
TorchDynamo는 PEP 523의 frame evaluation API를 사용한다. CPython은 함수 호출 프레임마다 커스텀 evaluator를 등록할 수 있고, Dynamo가 이를 활용해 모든 Python 함수의 바이트코드 실행 시점에 훅을 삽입한다.
# LOAD_FAST 'x', BINARY_MATRIX_MULTIPLY, STORE_FAST 'y' 같은
# 바이트코드를 한 줄씩 읽으면서 FX graph node로 변환한다.
이 과정에서 두 가지 중요한 개념이 등장한다.
Graph break: Python control flow(if, for)나 추적 불가능한 연산을 만나면 그래프가 분할된다. 개의 break가 발생하면 compile time은 최소 배 증가한다. 각 segment가 독립적으로 Dynamo 파싱과 Inductor codegen을 거치기 때문이다.
Guard: 컴파일된 커널은 입력의 shape, dtype, contiguity 조건을 기억한다. 런타임에 이 조건이 바뀌면 recompilation이 트리거된다. batch size가 매번 바뀌는 환경이라면 torch.compile(model, dynamic=True)로 symbolic shape를 활성화해야 한다.
if use_relu: x = relu(x) 같은 Python 조건문은 graph break를 만든다. torch.where나 F.dropout 같은 PyTorch native ops로 교체하면 break 없이 단일 그래프로 컴파일된다.
AOTAutograd — forward와 backward를 미리 합치다
Eager autograd는 forward 실행 중에 intermediate를 tape에 저장하고, backward에서 tape를 역순 재생한다. 두 phase가 분리되어 있으므로 fusion이 불가능하다.
AOTAutograd는 컴파일 시점에 forward 그래프로부터 symbolic chain rule을 적용해 backward 그래프를 미리 생성한다.
이 merged graph 하나가 Inductor로 전달되고, forward와 backward가 하나의 커널로 융합된다. 여기서 세 가지 변환이 핵심이다.
Functionalization: x.add_(1) 같은 in-place 연산을 x = x.add(1)로 변환한다. in-place op은 storage aliasing을 만들어 fusion heuristic을 복잡하게 한다. functional form으로 바꾸면 data flow가 명확해지고 최적화 공간이 넓어진다.
Decomposition: log_softmax 같은 고수준 op을 max, sub, exp, sum, log 같은 primitive로 분해한다. Inductor가 더 자유롭게 fuse할 수 있다.
Min-cut partition: backward에서 어떤 intermediate를 저장하고 어떤 것을 재계산할지 결정한다. 저장 비용과 재계산 비용의 trade-off를 DAG min-cut으로 최적화한다.
AOTAutograd로 계산한 gradient는 eager autograd의 gradient와 numerically 동일하다.
AOTAutograd는 forward 그래프의 각 node에 chain rule을 symbolic하게 적용해 backward node를 생성한다. 최종 gradient는 backward 그래프의 topological sort로 계산되며, chain rule의 결합법칙에 의해 eager과 동일한 값이 나온다.
Inductor — FX 그래프를 Triton 커널로
Inductor는 merged FX 그래프를 실제 실행 가능한 코드로 변환한다.
┌──────────────────────────────────────────┐
│ EAGER: 3개 커널, HBM 왕복 3회 │
│ matmul → [HBM] → add_bias → [HBM] → relu │
└──────────────────────────────────────────┘
┌──────────────────────────────────────────┐
│ FUSED: 1개 커널, HBM 왕복 2회 │
│ x, w 로드 → matmul + add + relu (SRAM) → z 저장 │
└──────────────────────────────────────────┘
Pointwise + reduction + pointwise 패턴이 단일 커널로 융합되면 intermediate HBM 왕복이 사라진다. Roofline model 관점에서 bandwidth 사용량이 약 절반으로 줄어든다.
matmul 처리에서 Inductor는 상황에 따라 backend를 선택한다. 순수 matmul은 cuBLAS, matmul + activation 융합이 필요한 경우는 Triton을 사용한다. 생성된 Triton 소스 코드는 TORCH_LOGS="output_code" 환경변수로 확인할 수 있다.
분산 학습과의 통합
DDP에서 eager mode는 backward 완료 후 all-reduce를 blocking으로 호출한다. GPU는 all-reduce가 끝날 때까지 대기한다.
Compiled DDP는 backward 그래프에 all-reduce를 FX node로 포함시킨다. Inductor가 두 연산을 다른 CUDA stream에 배치해 computation과 communication을 overlap한다.
일 때 최대 에 근접하지만, 실제로는 5~10% 수준이다. FSDP의 경우 PT 2.1+부터 all-gather와 reduce-scatter가 그래프에 표현되어 parameter sharding 스케줄링도 Inductor가 최적화한다.
torch.compile이 만능은 아니다. CPU 학습: 2~5% 향상에 그친다. 동적 shape: recompilation overhead가 누적된다. graph break: 많을수록 compile time이 선형으로 증가한다. custom op: torch.autograd.Function은 PT 2.1+ 이후 부분 지원된다. 모델이 고정 shape의 GPU 학습이라면 reduce-overhead 모드가 가장 안정적이다.
정리
torch.compile의 파이프라인은 Dynamo(캡처) → AOTAutograd(그래프 생성) → Inductor(코드 생성)의 3단계다.- 첫 호출의 10~50초 지연은 Triton JIT 컴파일 비용이다. 이후 호출은 캐시된 커널을 재사용한다.
- AOTAutograd는 forward와 backward를 컴파일 시점에 하나의 그래프로 합쳐 Inductor가 함께 최적화하게 한다.
- graph break를 줄이고 고정 shape를 유지하는 것이 실무에서 가장 중요한 최적화 전략이다.
Python 코드를 건드리지 않고 GPU 커널 수준의 최적화를 달성한다는 약속 — 그 안에는 바이트코드 인터셉션부터 symbolic chain rule까지, 컴파일러 이론의 여러 층이 숨어 있다.