PyTorchSeniorSystem design
Что такое torch.compile() (представлен в PyTorch 2.0) и какие преимущества по производительности он даёт?
torch.compile() в PyTorch 2.0 JIT-компилирует модель через TorchDynamo и Inductor, генерируя оптимизированные Triton/C++ ядра — типичное ускорение 1.5–4× на inference и обучении без изменения кода модели.
torch.compile() в PyTorch 2.0+
Архитектура
torch.compile() состоит из трёх компонентов:
- TorchDynamo — перехватывает Python bytecode, строит граф из PyTorch-операций, обрабатывая Python control flow через «graph breaks».
- AOTAutograd — разворачивает autograd граф для joint forward+backward трассировки.
- Inductor — backend по умолчанию; генерирует Triton-ядра для GPU и C++/OpenMP для CPU с fusion, tiling и другими оптимизациями.
Базовое использование
import torch
import torch.nn as nn
model = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True),
num_layers=6,
).cuda()
# Однострочная компиляция
compiled_model = torch.compile(model)
# Первый вызов запускает компиляцию (медленно)
x = torch.randn(32, 128, 512, device="cuda")
out = compiled_model(x)
# Последующие вызовы — скомпилированный код
for _ in range(100):
out = compiled_model(x)
Режимы компиляции
# default — баланс скорости компиляции и производительности
compiled = torch.compile(model, mode="default")
# reduce-overhead — минимизация Python overhead (лучший выбор для inference)
compiled = torch.compile(model, mode="reduce-overhead")
# max-autotune — перебирает tile конфигурации (долго, но максимальная скорость)
compiled = torch.compile(model, mode="max-autotune")
# Отключить fallback на eager (строго только граф)
compiled = torch.compile(model, fullgraph=True)
Диагностика graph breaks
import torch._dynamo
torch._dynamo.config.verbose = True # логи graph breaks
# Или через explain
explanation = torch._dynamo.explain(model)(x)
print(explanation.graphs) # список подграфов
print(explanation.break_reasons) # почему произошёл разрыв
Динамические формы
# Без dynamic=True каждый новый shape перекомпилируется
compiled = torch.compile(model, dynamic=True)
# Или через torch.export с явными dynamic dims
from torch.export import export, Dim
batch = Dim("batch", min=1, max=64)
seq = Dim("seq", min=1, max=512)
exported = export(model, (x,), dynamic_shapes={"x": {0: batch, 1: seq}})
Реальные ускорения
- ResNet-50, batch 64, A100: 1.8× faster training, 1.5× faster inference.
- GPT-2 medium, A100: до 2.5× faster training step.
- SDXL UNet, 4090: до 3.5× faster (с
mode="max-autotune").
Подводные камни
- Первый forward pass медленнее в 5–100 раз из-за трассировки и компиляции Triton-ядер; необходим прогрев перед бенчмарком или продакшен-трафиком.
- Graph breaks из-за Python-специфичных конструкций (динамические условия на тензорах, некоторые сторонние операции) дробят модель на маленькие подграфы и снижают эффект.
- Скомпилированную модель нельзя сохранить через
torch.save(compiled_model); нужно сохранять исходные веса и перекомпилировать при загрузке. - С
fullgraph=Trueмодель падает при любом graph break вместо fallback на eager; полезно для диагностики, но не для продакшена с неизвестными входами. - На Windows поддержка Inductor ограничена; рекомендуется Linux для production.
- Несовместимость с некоторыми операциями: custom CUDA extensions без torch.library, torch.autograd.Function со сложным backward.
- В режиме
max-autotuneкомпиляция может занять 10–30 минут на большой модели — неприемлемо для динамического деплоя.
Common mistakes
- Объяснять
torch compileтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
torch compileс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
- Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.