TensorFlowMiddleSystem design
Как отлаживать TensorFlow-модели — какие инструменты и техники доступны?
Для отладки TF-моделей: tf.print внутри @tf.function (не print), tf.debugging.assert_* для runtime-проверок shape/NaN, TensorBoard с histogram_freq=1 для мониторинга весов, tf.config.run_functions_eagerly(True) для использования обычного debugger, TerminateOnNaN callback для автостопа.
Инструменты отладки TensorFlow-моделей
Отладка TF-моделей делится на несколько уровней: eager execution для быстрой инспекции, TensorBoard для визуализации обучения, tf.debugging для runtime-проверок, и профайлер для производительности.
1. Eager execution и tf.print
По умолчанию TensorFlow 2.x работает в eager-режиме — тензоры можно вычислять и инспектировать как NumPy-массивы. Это главное отличие от TF 1.x.
import tensorflow as tf
import numpy as np
# Eager: операции выполняются немедленно
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
print(x.numpy()) # Доступно напрямую
# Внутри @tf.function: используйте tf.print, а не print()
@tf.function
def forward(x):
y = x @ tf.transpose(x)
tf.print("shape:", tf.shape(y), "min:", tf.reduce_min(y), "max:", tf.reduce_max(y))
return y
forward(x)
2. tf.debugging — runtime assertions
@tf.function
def safe_log(x):
tf.debugging.assert_positive(x, message="Входные значения должны быть > 0")
tf.debugging.assert_all_finite(x, message="NaN или Inf во входе")
return tf.math.log(x)
# Проверка shape
def model_call(inputs):
tf.debugging.assert_rank(inputs, 2, message="Ожидается 2D тензор")
tf.debugging.assert_shapes([(inputs, ('B', 10))],
message="Второе измерение должно быть 10")
return inputs
3. TensorBoard — визуализация метрик и графа
import datetime
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1, # Гистограммы весов каждую эпоху
write_graph=True, # Граф модели
write_images=True, # Визуализация весов как изображения
profile_batch='500,520' # Профилирование батчей 500-520
)
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
X = np.random.randn(1000, 20).astype('float32')
y = np.random.randint(0, 10, size=1000)
model.fit(X, y, epochs=5, callbacks=[tensorboard_callback])
# Запустить: tensorboard --logdir logs/fit
4. Кастомный callback для отладки NaN
class NaNDetector(tf.keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
loss = logs.get('loss')
if loss is not None and (np.isnan(loss) or np.isinf(loss)):
print(f"\nNaN/Inf loss на батче {batch}: {loss}. Останавливаем обучение.")
self.model.stop_training = True
# Встроенный аналог
model.fit(X, y, epochs=10, callbacks=[
tf.keras.callbacks.TerminateOnNaN()
])
5. Профайлер TensorFlow
# Профилирование конкретного участка кода
with tf.profiler.experimental.Profile('logdir'):
model(tf.random.normal([32, 20]), training=False)
# Через TensorBoard профайлер (более детально):
# tf.profiler.experimental.start('logdir')
# ... код ...
# tf.profiler.experimental.stop()
6. Отключение @tf.function для пошаговой отладки
# Временно отключить граф для обычного Python-дебаггера
tf.config.run_functions_eagerly(True)
# Теперь можно использовать breakpoint() внутри @tf.function
@tf.function
def buggy_fn(x):
breakpoint() # Работает только при run_functions_eagerly(True)
return x * 2
buggy_fn(tf.constant([1.0, 2.0]))
# Вернуть режим графа
tf.config.run_functions_eagerly(False)
Подводные камни
- print() внутри @tf.function: срабатывает только при трассировке графа, а не при каждом вызове. Используйте
tf.print— он встраивается в граф и выполняется при каждом forward pass. - NaN на первой эпохе: чаще всего причина — слишком большой learning rate, плохая инициализация весов или логарифм от нуля. Проверьте
tf.debugging.assert_all_finiteна выходе каждого слоя. - Граф трассируется несколько раз:
@tf.functionсоздаёт новый граф при изменении shape или dtype входа. Добавьтеinput_signature, чтобы зафиксировать форму:@tf.function(input_signature=[tf.TensorSpec(shape=[None, 20], dtype=tf.float32)]). - TensorBoard не показывает гистограммы: забыли передать
histogram_freq=1в TensorBoard callback или не дождались конца первой эпохи. - Профайлер на GPU требует CUPTI: без установленного CUDA Profiling Tools Interface профайлер TF не сможет собирать GPU-метрики — будет собирать только CPU-события.
- run_functions_eagerly в production: режим eager в 5-10 раз медленнее графового. Используйте только для отладки, не забудьте отключить.
- assert_* внутри @tf.function с optimization: TensorFlow может оптимизировать граф и убрать «бесполезные» операции. Если assertion не используется дальше по графу, убедитесь, что результат передаётся вниз по вычислениям.
- Несоответствие поведения в eager и graph режимах: Python-ветвления по значениям тензоров работают в eager, но не в graph (только при трассировке). Это приводит к трудно воспроизводимым багам при переходе к @tf.function.
Common mistakes
- Объяснять
debugging modelsтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
debugging modelsс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.