TensorFlowMiddleSystem design

Как отлаживать TensorFlow-модели — какие инструменты и техники доступны?

Для отладки TF-моделей: tf.print внутри @tf.function (не print), tf.debugging.assert_* для runtime-проверок shape/NaN, TensorBoard с histogram_freq=1 для мониторинга весов, tf.config.run_functions_eagerly(True) для использования обычного debugger, TerminateOnNaN callback для автостопа.

Инструменты отладки TensorFlow-моделей

Отладка TF-моделей делится на несколько уровней: eager execution для быстрой инспекции, TensorBoard для визуализации обучения, tf.debugging для runtime-проверок, и профайлер для производительности.

1. Eager execution и tf.print

По умолчанию TensorFlow 2.x работает в eager-режиме — тензоры можно вычислять и инспектировать как NumPy-массивы. Это главное отличие от TF 1.x.

import tensorflow as tf
import numpy as np

# Eager: операции выполняются немедленно
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
print(x.numpy())  # Доступно напрямую

# Внутри @tf.function: используйте tf.print, а не print()
@tf.function
def forward(x):
    y = x @ tf.transpose(x)
    tf.print("shape:", tf.shape(y), "min:", tf.reduce_min(y), "max:", tf.reduce_max(y))
    return y

forward(x)

2. tf.debugging — runtime assertions

@tf.function
def safe_log(x):
    tf.debugging.assert_positive(x, message="Входные значения должны быть > 0")
    tf.debugging.assert_all_finite(x, message="NaN или Inf во входе")
    return tf.math.log(x)

# Проверка shape
def model_call(inputs):
    tf.debugging.assert_rank(inputs, 2, message="Ожидается 2D тензор")
    tf.debugging.assert_shapes([(inputs, ('B', 10))],
                                message="Второе измерение должно быть 10")
    return inputs

3. TensorBoard — визуализация метрик и графа

import datetime

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,      # Гистограммы весов каждую эпоху
    write_graph=True,       # Граф модели
    write_images=True,      # Визуализация весов как изображения
    profile_batch='500,520' # Профилирование батчей 500-520
)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
    tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

X = np.random.randn(1000, 20).astype('float32')
y = np.random.randint(0, 10, size=1000)

model.fit(X, y, epochs=5, callbacks=[tensorboard_callback])
# Запустить: tensorboard --logdir logs/fit

4. Кастомный callback для отладки NaN

class NaNDetector(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        loss = logs.get('loss')
        if loss is not None and (np.isnan(loss) or np.isinf(loss)):
            print(f"\nNaN/Inf loss на батче {batch}: {loss}. Останавливаем обучение.")
            self.model.stop_training = True

# Встроенный аналог
model.fit(X, y, epochs=10, callbacks=[
    tf.keras.callbacks.TerminateOnNaN()
])

5. Профайлер TensorFlow

# Профилирование конкретного участка кода
with tf.profiler.experimental.Profile('logdir'):
    model(tf.random.normal([32, 20]), training=False)

# Через TensorBoard профайлер (более детально):
# tf.profiler.experimental.start('logdir')
# ... код ...
# tf.profiler.experimental.stop()

6. Отключение @tf.function для пошаговой отладки

# Временно отключить граф для обычного Python-дебаггера
tf.config.run_functions_eagerly(True)

# Теперь можно использовать breakpoint() внутри @tf.function
@tf.function
def buggy_fn(x):
    breakpoint()  # Работает только при run_functions_eagerly(True)
    return x * 2

buggy_fn(tf.constant([1.0, 2.0]))

# Вернуть режим графа
tf.config.run_functions_eagerly(False)

Подводные камни

  • print() внутри @tf.function: срабатывает только при трассировке графа, а не при каждом вызове. Используйте tf.print — он встраивается в граф и выполняется при каждом forward pass.
  • NaN на первой эпохе: чаще всего причина — слишком большой learning rate, плохая инициализация весов или логарифм от нуля. Проверьте tf.debugging.assert_all_finite на выходе каждого слоя.
  • Граф трассируется несколько раз: @tf.function создаёт новый граф при изменении shape или dtype входа. Добавьте input_signature, чтобы зафиксировать форму: @tf.function(input_signature=[tf.TensorSpec(shape=[None, 20], dtype=tf.float32)]).
  • TensorBoard не показывает гистограммы: забыли передать histogram_freq=1 в TensorBoard callback или не дождались конца первой эпохи.
  • Профайлер на GPU требует CUPTI: без установленного CUDA Profiling Tools Interface профайлер TF не сможет собирать GPU-метрики — будет собирать только CPU-события.
  • run_functions_eagerly в production: режим eager в 5-10 раз медленнее графового. Используйте только для отладки, не забудьте отключить.
  • assert_* внутри @tf.function с optimization: TensorFlow может оптимизировать граф и убрать «бесполезные» операции. Если assertion не используется дальше по графу, убедитесь, что результат передаётся вниз по вычислениям.
  • Несоответствие поведения в eager и graph режимах: Python-ветвления по значениям тензоров работают в eager, но не в graph (только при трассировке). Это приводит к трудно воспроизводимым багам при переходе к @tf.function.

Common mistakes

  • Объяснять debugging models только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать debugging models с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics