PyTorchMiddleCoding

Как реализовать gradient accumulation в PyTorch и зачем это нужно?

Gradient accumulation накапливает градиенты за N итераций перед optimizer.step(), эмулируя большой batch size. Главное: делить loss на accumulation_steps и обнулять градиенты только после step(), а не после каждого мини-батча.

Gradient Accumulation в PyTorch: реализация и назначение

Gradient accumulation — техника, при которой градиенты не обнуляются и не применяются после каждого батча, а накапливаются в течение нескольких шагов. Шаг оптимизатора делается только раз в N итераций. Это позволяет эффективно обучать с большим логическим batch size, когда весь батч не помещается в GPU-память.

Зачем это нужно

  • Ограниченная GPU-память: современные большие модели (LLM, ViT) не помещаются в память при batch size 32+ — накопление позволяет эмулировать большие батчи.
  • Стабильность обучения: SGD и Adam стабильнее при больших батчах — оценка градиента более точна.
  • Multi-GPU эмуляция без DDP: на одной GPU можно воспроизвести поведение multi-GPU обучения по градиентам.

Базовая реализация

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

model = nn.Linear(128, 10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

dataset = TensorDataset(
    torch.randn(1024, 128),
    torch.randint(0, 10, (1024,))
)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

accumulation_steps = 4   # логический batch = 16 * 4 = 64

model.train()
optimizer.zero_grad()    # обнуляем один раз перед первым шагом

for step, (x, y) in enumerate(loader):
    x, y = x.cuda(), y.cuda()

    output = model(x)
    # Делим loss на число шагов накопления для правильного масштаба
    loss = criterion(output, y) / accumulation_steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

# Обрабатываем оставшиеся градиенты в конце эпохи
remainder = len(loader) % accumulation_steps
if remainder != 0:
    optimizer.step()
    optimizer.zero_grad()

Реализация с Mixed Precision

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
optimizer.zero_grad()

for step, (x, y) in enumerate(loader):
    x, y = x.cuda(), y.cuda()

    with autocast(dtype=torch.float16):
        output = model(x)
        loss = criterion(output, y) / accumulation_steps

    scaler.scale(loss).backward()

    if (step + 1) % accumulation_steps == 0:
        # Unscale перед gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Использование через Hugging Face Accelerate

Библиотека Accelerate автоматизирует gradient accumulation и устраняет типичные ошибки ручной реализации.

from accelerate import Accelerator

accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

for x, y in loader:
    with accelerator.accumulate(model):
        output = model(x)
        loss = criterion(output, y)
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

Проверка корректности

Убедитесь, что gradient accumulation даёт те же градиенты, что и полный батч:

# Метод 1: один большой батч
big_x = torch.randn(64, 128).cuda()
big_y = torch.randint(0, 10, (64,)).cuda()
optimizer.zero_grad()
loss_full = criterion(model(big_x), big_y)
loss_full.backward()
full_grad = model.weight.grad.clone()

# Метод 2: 4 малых батча с накоплением
optimizer.zero_grad()
for i in range(4):
    mini_x = big_x[i*16:(i+1)*16]
    mini_y = big_y[i*16:(i+1)*16]
    loss_mini = criterion(model(mini_x), mini_y) / 4
    loss_mini.backward()
accum_grad = model.weight.grad.clone()

print(torch.allclose(full_grad, accum_grad, atol=1e-5))  # True

Подводные камни

  • Отсутствие деления loss на accumulation_steps: без него каждый мини-батч вносит полный вклад, и суммарный градиент будет в N раз больше, чем должен — эффективный LR возрастает в N раз.
  • zero_grad() после каждого мини-батча: аннулирует всё накопленное. Обнулять нужно только после optimizer.step().
  • BatchNorm и маленькие мини-батчи: BatchNorm вычисляет статистики по текущему мини-батчу, а не по логическому. При batch_size=4 статистики нестабильны — лучше использовать LayerNorm или GroupNorm с accumulation.
  • Gradient clipping до unscale при AMP: при использовании GradScaler нужно вызвать scaler.unscale_(optimizer) перед clip_grad_norm_, иначе clipping применяется к масштабированным градиентам.
  • Остаток батчей в конце эпохи: если длина датасета не делится на accumulation_steps, последние градиенты не применяются без явной обработки остатка.
  • Scheduler шаги: LR scheduler нужно шагать только при реальном шаге оптимизатора, а не каждую итерацию.
  • Несинхронизированный Dropout: при использовании gradient accumulation с DDP каждый мини-батч получает свою маску Dropout — поведение отличается от полного батча.

Common mistakes

  • Объяснять gradient accumulation только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать gradient accumulation с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics