TensorFlowMiddleTechnical

Как работает batch normalization и как она применяется в TensorFlow?

Batch normalization нормализует активации по батчу во время обучения и использует накопленные moving_mean/moving_variance при инференсе. В кастомном цикле нужно явно передавать training=True/False — иначе статистики не обновляются.

Что такое batch normalization

Batch normalization (BN) — слой, который нормализует активации внутри батча, чтобы ускорить обучение и снизить чувствительность к инициализации весов. Для каждого признака слой вычисляет среднее и дисперсию по батчу, нормализует значения, затем применяет обучаемые параметры gamma (масштаб) и beta (сдвиг):

x_norm = (x - mean(x)) / sqrt(var(x) + epsilon)
y = gamma * x_norm + beta

Во время инференса вместо статистик батча используются скользящие средние moving_mean и moving_variance, накопленные за обучение.

Режимы train и inference

Ключевой нюанс: BN ведёт себя по-разному в зависимости от аргумента training. В режиме training=True нормализация идёт по текущему батчу и обновляются скользящие средние. В training=False — используются накопленные moving_mean / moving_variance.

Пример в Keras

import tensorflow as tf
import numpy as np

# Создаём модель с BatchNormalization
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, use_bias=False),  # bias не нужен — BN добавляет beta
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dense(64, use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dense(10, activation='softmax'),
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)

# Синтетические данные
X = np.random.randn(1000, 20).astype('float32')
y = np.random.randint(0, 10, size=1000)

model.fit(X, y, epochs=5, batch_size=32, validation_split=0.2)

# Проверяем, что moving_mean накопились
bn_layer = model.layers[1]
print("moving_mean:", bn_layer.moving_mean.numpy()[:4])
print("moving_variance:", bn_layer.moving_variance.numpy()[:4])

Ручное управление режимом training

При использовании кастомного цикла обучения нужно явно передавать training=True/False:

@tf.function
def train_step(x_batch, y_batch):
    with tf.GradientTape() as tape:
        logits = model(x_batch, training=True)  # <-- важно
        loss = loss_fn(y_batch, logits)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

@tf.function
def eval_step(x_batch, y_batch):
    logits = model(x_batch, training=False)  # <-- важно
    return loss_fn(y_batch, logits)

Если передать training=False во время обучения, moving_mean не обновится и инференс будет некорректным. Это частая и трудно обнаруживаемая ошибка.

BatchNormalization и Functional API

inputs = tf.keras.Input(shape=(20,))
x = tf.keras.layers.Dense(64, use_bias=False)(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

func_model = tf.keras.Model(inputs, outputs)
func_model.summary()
# Обратить внимание: BN добавляет 4 переменных:
# gamma, beta (trainable) + moving_mean, moving_variance (non-trainable)

Подводные камни

  • Маленький батч (batch size < 8): статистики батча нестабильны, BN ухудшает обучение. Альтернативы — Layer Normalization или Group Normalization.
  • Забыть training=True в кастомном цикле: moving_mean не обновляется, модель правильно обучается, но выдаёт плохие результаты на инференсе.
  • use_bias=True перед BN: BN вычитает среднее, поэтому bias Dense-слоя перед BatchNormalization бесполезен — он полностью компенсируется параметром beta.
  • BN с RNN / последовательными данными: BN плохо работает с переменной длиной последовательностей и рекуррентными сетями — для них используется Layer Normalization.
  • Сохранение модели: при model.save() сохраняются и moving_mean/moving_variance. Если загрузить только веса (load_weights), эти переменные могут остаться нулевыми.
  • Transfer learning — заморозка BN: при fine-tuning часто замораживают BN-слои базовой модели (layer.trainable = False), чтобы не испортить накопленную статистику маленьким батчем домена.
  • Детерминированность GPU: на GPU порядок редукций нестрого детерминирован, поэтому статистики батча могут незначительно отличаться между запусками даже при фиксированном seed.
  • Mixed precision: при tf.keras.mixed_precision гамма/бета хранятся в float32, а активации — в float16. Убедитесь, что compute_dtype и variable_dtype у BN-слоя правильные.

Common mistakes

  • Объяснять batch normalization только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать batch normalization с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics