PyTorchMiddleTechnical

Объясните batch normalization — что она делает и как реализована в PyTorch?

Batch Normalization нормализует активации по батчу: вычитает среднее, делит на std, затем масштабирует через обучаемые γ и β. В PyTorch: nn.BatchNorm1d/2d. Критично: разное поведение в train и eval режимах.

Что делает Batch Normalization

BatchNorm нормализует входные активации по размерности батча, что ускоряет обучение и снижает зависимость от инициализации весов. Формула для признака j:

μⱼ = (1/m) Σ xᵢⱼ          # среднее по батчу
σ²ⱼ = (1/m) Σ (xᵢⱼ - μⱼ)²  # дисперсия по батчу
x̂ᵢⱼ = (xᵢⱼ - μⱼ) / √(σ²ⱼ + ε)   # нормализация
yᵢⱼ = γⱼ · x̂ᵢⱼ + βⱼ        # масштаб и сдвиг (обучаемые)

γ и β — обучаемые параметры, позволяющие сети «отменить» нормализацию если это нужно. ε (~1e-5) предотвращает деление на ноль.

Реализация в PyTorch

import torch
import torch.nn as nn

# BatchNorm1d — для FC-слоёв, вход (batch, features) или (batch, features, seq)
batch_size, num_features = 16, 64
bn1d = nn.BatchNorm1d(num_features, eps=1e-5, momentum=0.1, affine=True)

x = torch.randn(batch_size, num_features)
out = bn1d(x)
print(out.shape)  # (16, 64)
print(bn1d.weight.shape)  # gamma: (64,)
print(bn1d.bias.shape)    # beta:  (64,)

# BatchNorm2d — для свёрточных сетей, вход (batch, channels, H, W)
# нормализует по batch+H+W, отдельно для каждого канала
batch_size, channels, H, W = 8, 32, 28, 28
bn2d = nn.BatchNorm2d(channels)
x_conv = torch.randn(batch_size, channels, H, W)
out_conv = bn2d(x_conv)
print(out_conv.shape)  # (8, 32, 28, 28)

# Разница train vs eval
bn1d.train()  # использует статистику текущего батча + обновляет running_mean/var
out_train = bn1d(x)

bn1d.eval()   # использует running_mean/running_var (накопленную статистику)
out_eval = bn1d(x)
print(torch.allclose(out_train, out_eval))  # False! Разные значения

# Накопленная статистика (exponential moving average)
print(bn1d.running_mean.shape)  # (64,) — обновляется при train()
print(bn1d.running_var.shape)   # (64,)
# running_mean = (1 - momentum) * running_mean + momentum * batch_mean

# Встраивание в архитектуру: BN обычно ДО активации
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),  # bias=False!
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

# Пример: влияние размера батча на статистику
small_batch = torch.randn(2, 64)   # батч из 2 — нестабильная статистика
large_batch = torch.randn(64, 64)  # батч из 64 — стабильнее

bn_test = nn.BatchNorm1d(64)
bn_test.train()
out_small = bn_test(small_batch)
out_large = bn_test(large_batch)
# При batch_size=1 BatchNorm выбросит RuntimeError: Expected more than 1 value per channel

# Инспекция параметров
for name, param in bn1d.named_parameters():
    print(name, param.shape)  # weight (gamma), bias (beta)
for name, buf in bn1d.named_buffers():
    print(name, buf.shape)    # running_mean, running_var, num_batches_tracked

BatchNorm vs LayerNorm vs GroupNorm

BatchNorm нормализует по батчу — эффективен для CNN с большими батчами. LayerNorm нормализует по признакам одного примера — используется в Transformer (не зависит от размера батча). GroupNorm нормализует по группам каналов — работает при batch_size=1 (сегментация, detection).

Подводные камни

  • Забыть model.eval() при инференсе: в режиме train() BatchNorm использует статистику текущего батча. При инференсе на одном примере это даёт случайные результаты. Всегда вызывайте model.eval() перед inference.
  • bias=True в Conv перед BatchNorm: BatchNorm имеет собственное смещение β, поэтому bias у Conv-слоя перед BN бесполезен и тратит память. Стандарт: nn.Conv2d(..., bias=False).
  • Малый батч: при batch_size=1 в режиме train() PyTorch бросает RuntimeError. При batch_size=2–4 статистика нестабильна и обучение расходится — используйте GroupNorm или LayerNorm.
  • Несоответствие train/eval статистики: если модель обучается на данных с одним распределением, а на инференсе получает другое, running_mean/var будет некорректным — BN деградирует до бесполезного слоя. Следите за distribution shift.
  • Перенос весов без running stats: при загрузке state_dict с strict=False можно случайно не загрузить running_mean и running_var. Модель будет работать в eval() с нулевыми буферами — все предсказания неверны.
  • Fine-tuning замороженных BN: при замораживании слоёв BatchNorm через requires_grad=False веса γ/β не обновляются, но running stats продолжают обновляться если слой в режиме train(). Для полной заморозки нужно явно вызвать bn_layer.eval().
  • Gradient через BN во время backward: BN делает нормализацию дифференцируемой, но при маленьком батче градиент по дисперсии нестабилен и может привести к взрывному росту gradients на первых эпохах.
  • DataParallel и синхронизация BN: при multi-GPU обучении через nn.DataParallel каждый GPU считает статистику только по своему sub-batch. При маленьких батчах используйте nn.SyncBatchNorm.convert_sync_batchnorm(model) с DistributedDataParallel.

Common mistakes

  • Объяснять batch normalization только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать batch normalization с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics