Как работает mixed-precision training в TensorFlow (tf.keras.mixed_precision)?
Mixed precision ускоряет обучение, вычисляя в float16 и сохраняя веса в float32. Включается через mixed_precision.set_global_policy("mixed_float16"); loss scaling предотвращает underflow градиентов.
Mixed-precision training в TensorFlow
Mixed precision — техника обучения, при которой вычисления выполняются в float16 (или bfloat16), а веса хранятся в float32. Это ускоряет обучение на GPU (Tensor Cores) и TPU, снижает потребление памяти примерно вдвое, при этом точность модели не страдает благодаря сохранению мастер-копии весов в float32.
Включение mixed precision
import tensorflow as tf
from tensorflow.keras import mixed_precision
# Устанавливаем политику: вычисления в float16, переменные в float32
mixed_precision.set_global_policy("mixed_float16")
# Для TPU и A100 bfloat16 предпочтительнее:
# mixed_precision.set_global_policy("mixed_bfloat16")
print("Compute dtype:", mixed_precision.global_policy().compute_dtype) # float16
print("Variable dtype:", mixed_precision.global_policy().variable_dtype) # float32
Построение модели и loss scaling
После установки политики все слои автоматически работают в float16. Выходной слой классификатора нужно явно перевести в float32, чтобы избежать численной нестабильности при вычислении softmax/sigmoid.
inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.layers.Conv2D(32, 3, activation="relu")(inputs) # вычисления в fp16
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# Явно приводим к float32 перед финальной активацией
outputs = tf.keras.layers.Dense(1000, activation="softmax", dtype="float32")(x)
model = tf.keras.Model(inputs, outputs)
Loss scaling
float16 имеет диапазон ~5.96e-8 до 65504. Градиенты могут уйти в underflow (стать нулями) или overflow (Inf/NaN). Loss scaling умножает loss на большой коэффициент перед backward pass, масштабирует градиенты обратно перед обновлением весов.
# При использовании model.compile() loss scaling включается автоматически
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Явная обёртка LossScaleOptimizer (если нужен ручной контроль)
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit(x_train, y_train, epochs=10, batch_size=256)
Кастомный тренировочный цикл с mixed precision
optimizer = mixed_precision.LossScaleOptimizer(
tf.keras.optimizers.Adam(1e-3)
)
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
predictions = model(x, training=True)
loss = loss_fn(y, predictions)
# Масштабируем loss
scaled_loss = optimizer.get_scaled_loss(loss)
scaled_grads = tape.gradient(scaled_loss, model.trainable_variables)
# Возвращаем градиенты к исходному масштабу
grads = optimizer.get_unscaled_gradients(scaled_grads)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
Проверка ускорения
import time
# Замер throughput
start = time.time()
model.fit(x_train, y_train, epochs=3, batch_size=256, verbose=0)
elapsed = time.time() - start
print(f"Training time: {elapsed:.2f}s")
# Проверка текущего коэффициента масштабирования
print("Loss scale:", optimizer.loss_scale)
Подводные камни
- Финальный слой (Dense с softmax/sigmoid) должен явно иметь
dtype="float32"— иначе softmax в float16 даёт Inf при больших логитах. mixed_float16эффективно работает только на GPU с Tensor Cores (Volta, Turing, Ampere и новее). На CPU и старых GPU прирост отсутствует, а замедление возможно из-за конверсий типов.- bfloat16 предпочтительнее float16 на TPU и A100/H100, так как имеет тот же диапазон, что и float32, и не требует dynamic loss scaling.
- BatchNormalization накапливает скользящие средние в float32 независимо от политики — это ожидаемое поведение, но потребление памяти для BN-слоёв не уменьшится.
- Dynamic loss scaling может временно привести к пропуску шагов обновления, если обнаружены Inf/NaN в градиентах — это нормально, но если скиппинг постоянный, значит модель нестабильна.
- Некоторые кастомные слои или операции могут не поддерживать float16 — возникнет ошибка типа. Явно добавляйте
tf.cast(x, self.compute_dtype)вcall(). - При сохранении модели в формате SavedModel веса сохраняются в float32 (мастер-копия), что корректно, но файл занимает столько же места, как при float32-обучении.
- Установка глобальной политики через
set_global_policyвлияет на все слои в текущем процессе — при запуске нескольких экспериментов в одном процессе необходимо сбрасывать политику явно.
Common mistakes
- Объяснять
mixed precisionтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
mixed precisionс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.
- Предлагает ли observability, rollback, ограничения стоимости и стратегию incident replay.