PyTorchSeniorSystem design

Как профилировать модель PyTorch для выявления узких мест по памяти и скорости?

Используйте torch.profiler.profile() с ProfilerActivity.CPU/CUDA для замера времени и памяти; результаты смотрите через prof.key_averages() или TensorBoard-плагин.

Профилирование PyTorch: память и скорость

PyTorch предоставляет встроенный профилировщик torch.profiler, заменивший устаревший torch.autograd.profiler. Он охватывает CPU, CUDA, память и позволяет экспортировать трейсы для TensorBoard и Chrome Trace Viewer.

Базовый сценарий

import torch
from torch.profiler import profile, ProfilerActivity, record_function

model = MyModel().cuda()
x = torch.randn(32, 3, 224, 224, device="cuda")

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    with record_function("model_inference"):
        model(x)

print(prof.key_averages().table(
    sort_by="cuda_time_total", row_limit=20
))

# Экспорт для TensorBoard
prof.export_chrome_trace("/tmp/trace.json")

Планировщик (schedule) для реалистичного профилирования

Первые итерации содержат «прогрев» CUDA. Используйте torch.profiler.schedule, чтобы пропустить их:

from torch.profiler import schedule

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=2, warmup=2, active=5, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_logs"),
    record_shapes=True,
    profile_memory=True,
) as prof:
    for step, batch in enumerate(loader):
        train_step(batch)
        prof.step()  # обязательно вызывать каждый шаг

Анализ памяти

Для детального анализа выделений используйте torch.cuda.memory_snapshot() и визуализатор памяти:

torch.cuda.memory._record_memory_history(max_entries=100_000)

train_loop()

with open("mem_snapshot.pickle", "wb") as f:
    import pickle
    pickle.dump(torch.cuda.memory._snapshot(), f)

# Открыть на https://pytorch.org/memory_viz

Ключевые метрики в выводе key_averages()

  • cpu_time_total — суммарное CPU-время оператора
  • cuda_time_total — суммарное CUDA-время (реальная задержка GPU)
  • self_cpu_memory_usage — память, выделенная самим оператором без дочерних
  • cuda_memory_usage — выделения на GPU

Дополнительные инструменты

  • torch.cuda.memory_summary() — быстрый дамп статистики выделений
  • torch.cuda.max_memory_allocated() — пиковое выделение в байтах
  • NVIDIA Nsight Systems / nvprof — низкоуровневый CUDA-анализ ядер
  • TensorBoard Profiler Plugin — визуальный timeline оператора

Подводные камни

  • Без prof.step() в цикле schedule профилировщик не переходит между фазами и трейс остаётся пустым.
  • profile_memory=True заметно замедляет обучение (до 2×); не оставляйте его включённым в продакшене.
  • CUDA-операции асинхронны: cpu_timecuda_time; всегда смотрите на cuda_time_total для оценки реальной задержки.
  • Первые 1–3 итерации включают JIT-компиляцию CUDA-ядер — всегда используйте фазу warmup.
  • with_stack=True генерирует огромные трейсы; на больших моделях Chrome Trace Viewer может зависнуть.
  • Если модель использует torch.compile, профилировщик показывает скомпилированные фьюзированные ядра, а не исходные операции — сопоставление с кодом затруднено.
  • Memory snapshot API экспериментален: _record_memory_history и _snapshot начинаются с подчёркивания и могут измениться.
  • На multi-GPU нужно профилировать каждый ранк отдельно; суммарный трейс не создаётся автоматически.

Common mistakes

  • Объяснять profiler memory speed только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать profiler memory speed с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
  • Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.

Sources

Related topics