PyTorchSeniorSystem design

Что такое torch.cuda.amp и как mixed-precision training улучшает производительность?

torch.cuda.amp предоставляет autocast и GradScaler для автоматического смешения float16/bfloat16 и float32. FP16-операции на Tensor Core в 2–4× быстрее и требуют вдвое меньше памяти, GradScaler предотвращает underflow градиентов.

Что такое torch.cuda.amp

torch.cuda.amp — модуль PyTorch для автоматического смешанного обучения (Automatic Mixed Precision). Он предоставляет два инструмента:

  • autocast — контекстный менеджер, который автоматически понижает тип операций до float16 (или bfloat16) там, где это безопасно, и оставляет float32 для численно чувствительных вычислений.
  • GradScaler — масштабирует loss перед backward(), чтобы предотвратить underflow градиентов в FP16, затем автоматически обнуляет шаг оптимизатора при обнаружении inf/nan и снижает scale.

Почему FP16 быстрее

На GPU с архитектурой Volta и новее (A100, V100, H100) аппаратные блоки Tensor Core выполняют матричные умножения в FP16 в 2–8× быстрее, чем FP32. Кроме того, FP16-тензоры занимают вдвое меньше памяти, что позволяет увеличить batch size и сократить число итераций.

Полный рабочий пример

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

device = torch.device("cuda")

model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True),
    num_layers=6,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()  # масштабировщик градиентов

for batch_idx, (inputs, targets) in enumerate(dataloader):
    inputs = inputs.to(device)   # float32 на входе
    targets = targets.to(device)

    optimizer.zero_grad()

    # autocast автоматически выбирает FP16 для линейных слоёв,
    # attention, свёрток и оставляет FP32 для softmax, BatchNorm
    with autocast(dtype=torch.float16):
        logits = model(inputs)           # большинство ops → float16
        loss = criterion(logits, targets)  # loss тоже float16

    # Масштабируем loss, чтобы градиенты не стали нулями в FP16
    scaler.scale(loss).backward()

    # Unscale перед clip_grad_norm_, иначе порог неверный
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Шаг делается только если нет inf/nan
    scaler.step(optimizer)
    scaler.update()  # корректирует scale на следующую итерацию

    if batch_idx % 100 == 0:
        print(f"step {batch_idx}, loss={loss.item():.4f}, "
              f"scale={scaler.get_scale():.0f}")

autocast: какие операции понижаются

PyTorch разделяет операции на три категории:

  • float16-safe: matmul, conv2d, linear, bmm, attention — работают быстрее и не теряют точность критически.
  • float32-only: softmax, layer_norm, batch_norm, loss-функции, exp, log — остаются в FP32, чтобы не терять численную устойчивость.
  • promote: если в операцию попадают оба типа, тензоры приводятся к более широкому.

bfloat16 vs float16

На A100/H100 предпочтительнее bfloat16: у него тот же диапазон экспоненты, что у float32 (меньше underflow/overflow), но меньше точность мантиссы. GradScaler при bfloat16 не нужен, но его наличие не вредит.

with autocast(device_type="cuda", dtype=torch.bfloat16):
    output = model(inputs)

Распределённое обучение (DDP + AMP)

from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model, device_ids=[local_rank])
scaler = GradScaler()

with autocast(dtype=torch.float16):
    loss = criterion(model(inputs), targets)

scaler.scale(loss).backward()  # DDP синхронизирует масштабированные градиенты
scaler.step(optimizer)
scaler.update()

GradScaler работает корректно с DDP: каждый процесс масштабирует самостоятельно, all-reduce происходит до unscale.

Подводные камни

  • Забыть unscale_ перед clip_grad_norm_ — норма будет вычислена по масштабированным градиентам, порог окажется в тысячи раз выше реального, клиппинг не сработает.
  • Использовать .half() вместо autocast — ручное преобразование не учитывает float32-only операции, возникают NaN в softmax или layer_norm.
  • Отключить GradScaler при bfloat16 на Volta — Volta не поддерживает bfloat16 аппаратно, операции падают на FP32 и производительность не растёт.
  • scale постоянно уменьшается — признак переполнения (inf/nan) в градиентах на каждом шаге; нужно снизить learning rate или проверить инициализацию весов.
  • Смешивать autocast и ручные .to(float32) — если внутри autocast принудительно привести тензор к float32 и передать его обратно в FP16-слой, происходит нежелательный promote и замедление.
  • Сохранять state_dict модели в FP16 — при загрузке на CPU или без autocast инференс будет в FP16 с меньшей точностью; рекомендуется сохранять в float32 и активировать AMP только при обучении.
  • Использовать AMP с кастомными CUDA-ядрами без регистрации — autocast не знает о пользовательских операторах; нужно явно регистрировать их через torch.amp.register_op_dtype.
  • Не проверять поддержку Tensor Core — на CPU или старых GPU (Pascal и ниже) autocast работает, но никакого ускорения нет; стоит логировать torch.cuda.get_device_capability() и предупреждать, если < (7, 0).

Common mistakes

  • Объяснять amp mixed precision только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать amp mixed precision с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
  • Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.

Sources

Related topics