Что такое torch.cuda.amp и как mixed-precision training улучшает производительность?
torch.cuda.amp предоставляет autocast и GradScaler для автоматического смешения float16/bfloat16 и float32. FP16-операции на Tensor Core в 2–4× быстрее и требуют вдвое меньше памяти, GradScaler предотвращает underflow градиентов.
Что такое torch.cuda.amp
torch.cuda.amp — модуль PyTorch для автоматического смешанного обучения (Automatic Mixed Precision). Он предоставляет два инструмента:
- autocast — контекстный менеджер, который автоматически понижает тип операций до
float16(илиbfloat16) там, где это безопасно, и оставляетfloat32для численно чувствительных вычислений. - GradScaler — масштабирует loss перед backward(), чтобы предотвратить underflow градиентов в FP16, затем автоматически обнуляет шаг оптимизатора при обнаружении inf/nan и снижает scale.
Почему FP16 быстрее
На GPU с архитектурой Volta и новее (A100, V100, H100) аппаратные блоки Tensor Core выполняют матричные умножения в FP16 в 2–8× быстрее, чем FP32. Кроме того, FP16-тензоры занимают вдвое меньше памяти, что позволяет увеличить batch size и сократить число итераций.
Полный рабочий пример
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
device = torch.device("cuda")
model = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True),
num_layers=6,
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler() # масштабировщик градиентов
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(device) # float32 на входе
targets = targets.to(device)
optimizer.zero_grad()
# autocast автоматически выбирает FP16 для линейных слоёв,
# attention, свёрток и оставляет FP32 для softmax, BatchNorm
with autocast(dtype=torch.float16):
logits = model(inputs) # большинство ops → float16
loss = criterion(logits, targets) # loss тоже float16
# Масштабируем loss, чтобы градиенты не стали нулями в FP16
scaler.scale(loss).backward()
# Unscale перед clip_grad_norm_, иначе порог неверный
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Шаг делается только если нет inf/nan
scaler.step(optimizer)
scaler.update() # корректирует scale на следующую итерацию
if batch_idx % 100 == 0:
print(f"step {batch_idx}, loss={loss.item():.4f}, "
f"scale={scaler.get_scale():.0f}")
autocast: какие операции понижаются
PyTorch разделяет операции на три категории:
- float16-safe: matmul, conv2d, linear, bmm, attention — работают быстрее и не теряют точность критически.
- float32-only: softmax, layer_norm, batch_norm, loss-функции, exp, log — остаются в FP32, чтобы не терять численную устойчивость.
- promote: если в операцию попадают оба типа, тензоры приводятся к более широкому.
bfloat16 vs float16
На A100/H100 предпочтительнее bfloat16: у него тот же диапазон экспоненты, что у float32 (меньше underflow/overflow), но меньше точность мантиссы. GradScaler при bfloat16 не нужен, но его наличие не вредит.
with autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(inputs)
Распределённое обучение (DDP + AMP)
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[local_rank])
scaler = GradScaler()
with autocast(dtype=torch.float16):
loss = criterion(model(inputs), targets)
scaler.scale(loss).backward() # DDP синхронизирует масштабированные градиенты
scaler.step(optimizer)
scaler.update()
GradScaler работает корректно с DDP: каждый процесс масштабирует самостоятельно, all-reduce происходит до unscale.
Подводные камни
- Забыть unscale_ перед clip_grad_norm_ — норма будет вычислена по масштабированным градиентам, порог окажется в тысячи раз выше реального, клиппинг не сработает.
- Использовать .half() вместо autocast — ручное преобразование не учитывает float32-only операции, возникают NaN в softmax или layer_norm.
- Отключить GradScaler при bfloat16 на Volta — Volta не поддерживает bfloat16 аппаратно, операции падают на FP32 и производительность не растёт.
- scale постоянно уменьшается — признак переполнения (inf/nan) в градиентах на каждом шаге; нужно снизить learning rate или проверить инициализацию весов.
- Смешивать autocast и ручные .to(float32) — если внутри autocast принудительно привести тензор к float32 и передать его обратно в FP16-слой, происходит нежелательный promote и замедление.
- Сохранять state_dict модели в FP16 — при загрузке на CPU или без autocast инференс будет в FP16 с меньшей точностью; рекомендуется сохранять в float32 и активировать AMP только при обучении.
- Использовать AMP с кастомными CUDA-ядрами без регистрации — autocast не знает о пользовательских операторах; нужно явно регистрировать их через
torch.amp.register_op_dtype. - Не проверять поддержку Tensor Core — на CPU или старых GPU (Pascal и ниже) autocast работает, но никакого ускорения нет; стоит логировать
torch.cuda.get_device_capability()и предупреждать, если < (7, 0).
Common mistakes
- Объяснять
amp mixed precisionтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
amp mixed precisionс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
- Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.