TensorFlowSeniorSystem design

Как работает mixed-precision training в TensorFlow (tf.keras.mixed_precision)?

Mixed precision ускоряет обучение, вычисляя в float16 и сохраняя веса в float32. Включается через mixed_precision.set_global_policy("mixed_float16"); loss scaling предотвращает underflow градиентов.

Mixed-precision training в TensorFlow

Mixed precision — техника обучения, при которой вычисления выполняются в float16 (или bfloat16), а веса хранятся в float32. Это ускоряет обучение на GPU (Tensor Cores) и TPU, снижает потребление памяти примерно вдвое, при этом точность модели не страдает благодаря сохранению мастер-копии весов в float32.

Включение mixed precision

import tensorflow as tf
from tensorflow.keras import mixed_precision

# Устанавливаем политику: вычисления в float16, переменные в float32
mixed_precision.set_global_policy("mixed_float16")

# Для TPU и A100 bfloat16 предпочтительнее:
# mixed_precision.set_global_policy("mixed_bfloat16")

print("Compute dtype:", mixed_precision.global_policy().compute_dtype)   # float16
print("Variable dtype:", mixed_precision.global_policy().variable_dtype)  # float32

Построение модели и loss scaling

После установки политики все слои автоматически работают в float16. Выходной слой классификатора нужно явно перевести в float32, чтобы избежать численной нестабильности при вычислении softmax/sigmoid.

inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.layers.Conv2D(32, 3, activation="relu")(inputs)  # вычисления в fp16
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# Явно приводим к float32 перед финальной активацией
outputs = tf.keras.layers.Dense(1000, activation="softmax", dtype="float32")(x)

model = tf.keras.Model(inputs, outputs)

Loss scaling

float16 имеет диапазон ~5.96e-8 до 65504. Градиенты могут уйти в underflow (стать нулями) или overflow (Inf/NaN). Loss scaling умножает loss на большой коэффициент перед backward pass, масштабирует градиенты обратно перед обновлением весов.

# При использовании model.compile() loss scaling включается автоматически
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Явная обёртка LossScaleOptimizer (если нужен ручной контроль)
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.fit(x_train, y_train, epochs=10, batch_size=256)

Кастомный тренировочный цикл с mixed precision

optimizer = mixed_precision.LossScaleOptimizer(
    tf.keras.optimizers.Adam(1e-3)
)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        predictions = model(x, training=True)
        loss = loss_fn(y, predictions)
        # Масштабируем loss
        scaled_loss = optimizer.get_scaled_loss(loss)

    scaled_grads = tape.gradient(scaled_loss, model.trainable_variables)
    # Возвращаем градиенты к исходному масштабу
    grads = optimizer.get_unscaled_gradients(scaled_grads)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

Проверка ускорения

import time

# Замер throughput
start = time.time()
model.fit(x_train, y_train, epochs=3, batch_size=256, verbose=0)
elapsed = time.time() - start
print(f"Training time: {elapsed:.2f}s")

# Проверка текущего коэффициента масштабирования
print("Loss scale:", optimizer.loss_scale)

Подводные камни

  • Финальный слой (Dense с softmax/sigmoid) должен явно иметь dtype="float32" — иначе softmax в float16 даёт Inf при больших логитах.
  • mixed_float16 эффективно работает только на GPU с Tensor Cores (Volta, Turing, Ampere и новее). На CPU и старых GPU прирост отсутствует, а замедление возможно из-за конверсий типов.
  • bfloat16 предпочтительнее float16 на TPU и A100/H100, так как имеет тот же диапазон, что и float32, и не требует dynamic loss scaling.
  • BatchNormalization накапливает скользящие средние в float32 независимо от политики — это ожидаемое поведение, но потребление памяти для BN-слоёв не уменьшится.
  • Dynamic loss scaling может временно привести к пропуску шагов обновления, если обнаружены Inf/NaN в градиентах — это нормально, но если скиппинг постоянный, значит модель нестабильна.
  • Некоторые кастомные слои или операции могут не поддерживать float16 — возникнет ошибка типа. Явно добавляйте tf.cast(x, self.compute_dtype) в call().
  • При сохранении модели в формате SavedModel веса сохраняются в float32 (мастер-копия), что корректно, но файл занимает столько же места, как при float32-обучении.
  • Установка глобальной политики через set_global_policy влияет на все слои в текущем процессе — при запуске нескольких экспериментов в одном процессе необходимо сбрасывать политику явно.

Common mistakes

  • Объяснять mixed precision только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать mixed precision с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
  • Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.

Sources

Related topics