PyTorchSeniorSystem design

Что такое torch.compile() (представлен в PyTorch 2.0) и какие преимущества по производительности он даёт?

torch.compile() в PyTorch 2.0 JIT-компилирует модель через TorchDynamo и Inductor, генерируя оптимизированные Triton/C++ ядра — типичное ускорение 1.5–4× на inference и обучении без изменения кода модели.

torch.compile() в PyTorch 2.0+

Архитектура

torch.compile() состоит из трёх компонентов:

  • TorchDynamo — перехватывает Python bytecode, строит граф из PyTorch-операций, обрабатывая Python control flow через «graph breaks».
  • AOTAutograd — разворачивает autograd граф для joint forward+backward трассировки.
  • Inductor — backend по умолчанию; генерирует Triton-ядра для GPU и C++/OpenMP для CPU с fusion, tiling и другими оптимизациями.

Базовое использование

import torch
import torch.nn as nn

model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True),
    num_layers=6,
).cuda()

# Однострочная компиляция
compiled_model = torch.compile(model)

# Первый вызов запускает компиляцию (медленно)
x = torch.randn(32, 128, 512, device="cuda")
out = compiled_model(x)

# Последующие вызовы — скомпилированный код
for _ in range(100):
    out = compiled_model(x)

Режимы компиляции

# default — баланс скорости компиляции и производительности
compiled = torch.compile(model, mode="default")

# reduce-overhead — минимизация Python overhead (лучший выбор для inference)
compiled = torch.compile(model, mode="reduce-overhead")

# max-autotune — перебирает tile конфигурации (долго, но максимальная скорость)
compiled = torch.compile(model, mode="max-autotune")

# Отключить fallback на eager (строго только граф)
compiled = torch.compile(model, fullgraph=True)

Диагностика graph breaks

import torch._dynamo
torch._dynamo.config.verbose = True  # логи graph breaks

# Или через explain
explanation = torch._dynamo.explain(model)(x)
print(explanation.graphs)         # список подграфов
print(explanation.break_reasons)  # почему произошёл разрыв

Динамические формы

# Без dynamic=True каждый новый shape перекомпилируется
compiled = torch.compile(model, dynamic=True)

# Или через torch.export с явными dynamic dims
from torch.export import export, Dim
batch = Dim("batch", min=1, max=64)
seq = Dim("seq", min=1, max=512)
exported = export(model, (x,), dynamic_shapes={"x": {0: batch, 1: seq}})

Реальные ускорения

  • ResNet-50, batch 64, A100: 1.8× faster training, 1.5× faster inference.
  • GPT-2 medium, A100: до 2.5× faster training step.
  • SDXL UNet, 4090: до 3.5× faster (с mode="max-autotune").

Подводные камни

  • Первый forward pass медленнее в 5–100 раз из-за трассировки и компиляции Triton-ядер; необходим прогрев перед бенчмарком или продакшен-трафиком.
  • Graph breaks из-за Python-специфичных конструкций (динамические условия на тензорах, некоторые сторонние операции) дробят модель на маленькие подграфы и снижают эффект.
  • Скомпилированную модель нельзя сохранить через torch.save(compiled_model); нужно сохранять исходные веса и перекомпилировать при загрузке.
  • С fullgraph=True модель падает при любом graph break вместо fallback на eager; полезно для диагностики, но не для продакшена с неизвестными входами.
  • На Windows поддержка Inductor ограничена; рекомендуется Linux для production.
  • Несовместимость с некоторыми операциями: custom CUDA extensions без torch.library, torch.autograd.Function со сложным backward.
  • В режиме max-autotune компиляция может занять 10–30 минут на большой модели — неприемлемо для динамического деплоя.

Common mistakes

  • Объяснять torch compile только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать torch compile с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
  • Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.

Sources

Related topics