PyTorchMiddleCoding
Как реализовать gradient accumulation в PyTorch и зачем это нужно?
Gradient accumulation накапливает градиенты за N итераций перед optimizer.step(), эмулируя большой batch size. Главное: делить loss на accumulation_steps и обнулять градиенты только после step(), а не после каждого мини-батча.
Gradient Accumulation в PyTorch: реализация и назначение
Gradient accumulation — техника, при которой градиенты не обнуляются и не применяются после каждого батча, а накапливаются в течение нескольких шагов. Шаг оптимизатора делается только раз в N итераций. Это позволяет эффективно обучать с большим логическим batch size, когда весь батч не помещается в GPU-память.
Зачем это нужно
- Ограниченная GPU-память: современные большие модели (LLM, ViT) не помещаются в память при batch size 32+ — накопление позволяет эмулировать большие батчи.
- Стабильность обучения: SGD и Adam стабильнее при больших батчах — оценка градиента более точна.
- Multi-GPU эмуляция без DDP: на одной GPU можно воспроизвести поведение multi-GPU обучения по градиентам.
Базовая реализация
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
model = nn.Linear(128, 10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
dataset = TensorDataset(
torch.randn(1024, 128),
torch.randint(0, 10, (1024,))
)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
accumulation_steps = 4 # логический batch = 16 * 4 = 64
model.train()
optimizer.zero_grad() # обнуляем один раз перед первым шагом
for step, (x, y) in enumerate(loader):
x, y = x.cuda(), y.cuda()
output = model(x)
# Делим loss на число шагов накопления для правильного масштаба
loss = criterion(output, y) / accumulation_steps
loss.backward()
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Обрабатываем оставшиеся градиенты в конце эпохи
remainder = len(loader) % accumulation_steps
if remainder != 0:
optimizer.step()
optimizer.zero_grad()
Реализация с Mixed Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
optimizer.zero_grad()
for step, (x, y) in enumerate(loader):
x, y = x.cuda(), y.cuda()
with autocast(dtype=torch.float16):
output = model(x)
loss = criterion(output, y) / accumulation_steps
scaler.scale(loss).backward()
if (step + 1) % accumulation_steps == 0:
# Unscale перед gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Использование через Hugging Face Accelerate
Библиотека Accelerate автоматизирует gradient accumulation и устраняет типичные ошибки ручной реализации.
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
for x, y in loader:
with accelerator.accumulate(model):
output = model(x)
loss = criterion(output, y)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
Проверка корректности
Убедитесь, что gradient accumulation даёт те же градиенты, что и полный батч:
# Метод 1: один большой батч
big_x = torch.randn(64, 128).cuda()
big_y = torch.randint(0, 10, (64,)).cuda()
optimizer.zero_grad()
loss_full = criterion(model(big_x), big_y)
loss_full.backward()
full_grad = model.weight.grad.clone()
# Метод 2: 4 малых батча с накоплением
optimizer.zero_grad()
for i in range(4):
mini_x = big_x[i*16:(i+1)*16]
mini_y = big_y[i*16:(i+1)*16]
loss_mini = criterion(model(mini_x), mini_y) / 4
loss_mini.backward()
accum_grad = model.weight.grad.clone()
print(torch.allclose(full_grad, accum_grad, atol=1e-5)) # True
Подводные камни
- Отсутствие деления loss на accumulation_steps: без него каждый мини-батч вносит полный вклад, и суммарный градиент будет в N раз больше, чем должен — эффективный LR возрастает в N раз.
- zero_grad() после каждого мини-батча: аннулирует всё накопленное. Обнулять нужно только после optimizer.step().
- BatchNorm и маленькие мини-батчи: BatchNorm вычисляет статистики по текущему мини-батчу, а не по логическому. При batch_size=4 статистики нестабильны — лучше использовать LayerNorm или GroupNorm с accumulation.
- Gradient clipping до unscale при AMP: при использовании GradScaler нужно вызвать
scaler.unscale_(optimizer)передclip_grad_norm_, иначе clipping применяется к масштабированным градиентам. - Остаток батчей в конце эпохи: если длина датасета не делится на accumulation_steps, последние градиенты не применяются без явной обработки остатка.
- Scheduler шаги: LR scheduler нужно шагать только при реальном шаге оптимизатора, а не каждую итерацию.
- Несинхронизированный Dropout: при использовании gradient accumulation с DDP каждый мини-батч получает свою маску Dropout — поведение отличается от полного батча.
Common mistakes
- Объяснять
gradient accumulationтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
gradient accumulationс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.