Объясните batch normalization — что она делает и как реализована в PyTorch?
Batch Normalization нормализует активации по батчу: вычитает среднее, делит на std, затем масштабирует через обучаемые γ и β. В PyTorch: nn.BatchNorm1d/2d. Критично: разное поведение в train и eval режимах.
Что делает Batch Normalization
BatchNorm нормализует входные активации по размерности батча, что ускоряет обучение и снижает зависимость от инициализации весов. Формула для признака j:
μⱼ = (1/m) Σ xᵢⱼ # среднее по батчу
σ²ⱼ = (1/m) Σ (xᵢⱼ - μⱼ)² # дисперсия по батчу
x̂ᵢⱼ = (xᵢⱼ - μⱼ) / √(σ²ⱼ + ε) # нормализация
yᵢⱼ = γⱼ · x̂ᵢⱼ + βⱼ # масштаб и сдвиг (обучаемые)
γ и β — обучаемые параметры, позволяющие сети «отменить» нормализацию если это нужно. ε (~1e-5) предотвращает деление на ноль.
Реализация в PyTorch
import torch
import torch.nn as nn
# BatchNorm1d — для FC-слоёв, вход (batch, features) или (batch, features, seq)
batch_size, num_features = 16, 64
bn1d = nn.BatchNorm1d(num_features, eps=1e-5, momentum=0.1, affine=True)
x = torch.randn(batch_size, num_features)
out = bn1d(x)
print(out.shape) # (16, 64)
print(bn1d.weight.shape) # gamma: (64,)
print(bn1d.bias.shape) # beta: (64,)
# BatchNorm2d — для свёрточных сетей, вход (batch, channels, H, W)
# нормализует по batch+H+W, отдельно для каждого канала
batch_size, channels, H, W = 8, 32, 28, 28
bn2d = nn.BatchNorm2d(channels)
x_conv = torch.randn(batch_size, channels, H, W)
out_conv = bn2d(x_conv)
print(out_conv.shape) # (8, 32, 28, 28)
# Разница train vs eval
bn1d.train() # использует статистику текущего батча + обновляет running_mean/var
out_train = bn1d(x)
bn1d.eval() # использует running_mean/running_var (накопленную статистику)
out_eval = bn1d(x)
print(torch.allclose(out_train, out_eval)) # False! Разные значения
# Накопленная статистика (exponential moving average)
print(bn1d.running_mean.shape) # (64,) — обновляется при train()
print(bn1d.running_var.shape) # (64,)
# running_mean = (1 - momentum) * running_mean + momentum * batch_mean
# Встраивание в архитектуру: BN обычно ДО активации
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False), # bias=False!
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
# Пример: влияние размера батча на статистику
small_batch = torch.randn(2, 64) # батч из 2 — нестабильная статистика
large_batch = torch.randn(64, 64) # батч из 64 — стабильнее
bn_test = nn.BatchNorm1d(64)
bn_test.train()
out_small = bn_test(small_batch)
out_large = bn_test(large_batch)
# При batch_size=1 BatchNorm выбросит RuntimeError: Expected more than 1 value per channel
# Инспекция параметров
for name, param in bn1d.named_parameters():
print(name, param.shape) # weight (gamma), bias (beta)
for name, buf in bn1d.named_buffers():
print(name, buf.shape) # running_mean, running_var, num_batches_tracked
BatchNorm vs LayerNorm vs GroupNorm
BatchNorm нормализует по батчу — эффективен для CNN с большими батчами. LayerNorm нормализует по признакам одного примера — используется в Transformer (не зависит от размера батча). GroupNorm нормализует по группам каналов — работает при batch_size=1 (сегментация, detection).
Подводные камни
- Забыть model.eval() при инференсе: в режиме train() BatchNorm использует статистику текущего батча. При инференсе на одном примере это даёт случайные результаты. Всегда вызывайте
model.eval()перед inference. - bias=True в Conv перед BatchNorm: BatchNorm имеет собственное смещение β, поэтому bias у Conv-слоя перед BN бесполезен и тратит память. Стандарт:
nn.Conv2d(..., bias=False). - Малый батч: при batch_size=1 в режиме train() PyTorch бросает RuntimeError. При batch_size=2–4 статистика нестабильна и обучение расходится — используйте GroupNorm или LayerNorm.
- Несоответствие train/eval статистики: если модель обучается на данных с одним распределением, а на инференсе получает другое, running_mean/var будет некорректным — BN деградирует до бесполезного слоя. Следите за distribution shift.
- Перенос весов без running stats: при загрузке
state_dictсstrict=Falseможно случайно не загрузитьrunning_meanиrunning_var. Модель будет работать в eval() с нулевыми буферами — все предсказания неверны. - Fine-tuning замороженных BN: при замораживании слоёв BatchNorm через
requires_grad=Falseвеса γ/β не обновляются, но running stats продолжают обновляться если слой в режиме train(). Для полной заморозки нужно явно вызватьbn_layer.eval(). - Gradient через BN во время backward: BN делает нормализацию дифференцируемой, но при маленьком батче градиент по дисперсии нестабилен и может привести к взрывному росту gradients на первых эпохах.
- DataParallel и синхронизация BN: при multi-GPU обучении через
nn.DataParallelкаждый GPU считает статистику только по своему sub-batch. При маленьких батчах используйтеnn.SyncBatchNorm.convert_sync_batchnorm(model)с DistributedDataParallel.
Common mistakes
- Объяснять
batch normalizationтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
batch normalizationс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.