Как работает batch normalization и как она применяется в TensorFlow?
Batch normalization нормализует активации по батчу во время обучения и использует накопленные moving_mean/moving_variance при инференсе. В кастомном цикле нужно явно передавать training=True/False — иначе статистики не обновляются.
Что такое batch normalization
Batch normalization (BN) — слой, который нормализует активации внутри батча, чтобы ускорить обучение и снизить чувствительность к инициализации весов. Для каждого признака слой вычисляет среднее и дисперсию по батчу, нормализует значения, затем применяет обучаемые параметры gamma (масштаб) и beta (сдвиг):
x_norm = (x - mean(x)) / sqrt(var(x) + epsilon)
y = gamma * x_norm + beta
Во время инференса вместо статистик батча используются скользящие средние moving_mean и moving_variance, накопленные за обучение.
Режимы train и inference
Ключевой нюанс: BN ведёт себя по-разному в зависимости от аргумента training. В режиме training=True нормализация идёт по текущему батчу и обновляются скользящие средние. В training=False — используются накопленные moving_mean / moving_variance.
Пример в Keras
import tensorflow as tf
import numpy as np
# Создаём модель с BatchNormalization
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, use_bias=False), # bias не нужен — BN добавляет beta
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(64, use_bias=False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
# Синтетические данные
X = np.random.randn(1000, 20).astype('float32')
y = np.random.randint(0, 10, size=1000)
model.fit(X, y, epochs=5, batch_size=32, validation_split=0.2)
# Проверяем, что moving_mean накопились
bn_layer = model.layers[1]
print("moving_mean:", bn_layer.moving_mean.numpy()[:4])
print("moving_variance:", bn_layer.moving_variance.numpy()[:4])
Ручное управление режимом training
При использовании кастомного цикла обучения нужно явно передавать training=True/False:
@tf.function
def train_step(x_batch, y_batch):
with tf.GradientTape() as tape:
logits = model(x_batch, training=True) # <-- важно
loss = loss_fn(y_batch, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
@tf.function
def eval_step(x_batch, y_batch):
logits = model(x_batch, training=False) # <-- важно
return loss_fn(y_batch, logits)
Если передать training=False во время обучения, moving_mean не обновится и инференс будет некорректным. Это частая и трудно обнаруживаемая ошибка.
BatchNormalization и Functional API
inputs = tf.keras.Input(shape=(20,))
x = tf.keras.layers.Dense(64, use_bias=False)(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
func_model = tf.keras.Model(inputs, outputs)
func_model.summary()
# Обратить внимание: BN добавляет 4 переменных:
# gamma, beta (trainable) + moving_mean, moving_variance (non-trainable)
Подводные камни
- Маленький батч (batch size < 8): статистики батча нестабильны, BN ухудшает обучение. Альтернативы — Layer Normalization или Group Normalization.
- Забыть
training=Trueв кастомном цикле:moving_meanне обновляется, модель правильно обучается, но выдаёт плохие результаты на инференсе. - use_bias=True перед BN: BN вычитает среднее, поэтому bias Dense-слоя перед BatchNormalization бесполезен — он полностью компенсируется параметром
beta. - BN с RNN / последовательными данными: BN плохо работает с переменной длиной последовательностей и рекуррентными сетями — для них используется Layer Normalization.
- Сохранение модели: при
model.save()сохраняются иmoving_mean/moving_variance. Если загрузить только веса (load_weights), эти переменные могут остаться нулевыми. - Transfer learning — заморозка BN: при fine-tuning часто замораживают BN-слои базовой модели (
layer.trainable = False), чтобы не испортить накопленную статистику маленьким батчем домена. - Детерминированность GPU: на GPU порядок редукций нестрого детерминирован, поэтому статистики батча могут незначительно отличаться между запусками даже при фиксированном seed.
- Mixed precision: при
tf.keras.mixed_precisionгамма/бета хранятся в float32, а активации — в float16. Убедитесь, чтоcompute_dtypeиvariable_dtypeу BN-слоя правильные.
Common mistakes
- Объяснять
batch normalizationтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
batch normalizationс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.