TensorFlowJuniorCoding

Что такое callbacks в Keras и как использовать ModelCheckpoint, EarlyStopping и TensorBoard?

Keras callbacks — это объекты, вызываемые в ключевые моменты обучения (конец эпохи, конец батча). ModelCheckpoint сохраняет лучшие веса, EarlyStopping останавливает при деградации метрики, TensorBoard логирует метрики и графики.

Что такое callbacks в Keras

Callback — это объект, наследующий tf.keras.callbacks.Callback, с методами-хуками: on_train_begin, on_epoch_end, on_batch_end и другими. Keras вызывает их в нужный момент обучения. Список callbacks передаётся в model.fit(callbacks=[...]). Три наиболее важных встроенных callback'а — ModelCheckpoint, EarlyStopping и TensorBoard.

ModelCheckpoint

Сохраняет веса (или всю модель) в конце каждой эпохи, опционально только когда отслеживаемая метрика улучшилась.

import tensorflow as tf

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath="checkpoints/model_epoch{epoch:02d}_val{val_accuracy:.4f}.keras",
    monitor="val_accuracy",     # отслеживаемая метрика
    save_best_only=True,        # сохранять только лучшую
    save_weights_only=False,    # сохранять всю модель (архитектура + веса)
    mode="max",                 # "max" для accuracy, "min" для loss
    verbose=1,
)

Формат .keras (Keras 3 native) предпочтительнее .h5: он поддерживает кастомные слои без дополнительной регистрации и хранит optimizer state для resuming обучения.

EarlyStopping

Прерывает обучение, если метрика перестала улучшаться в течение patience эпох. Совместно с restore_best_weights=True гарантирует, что финальная модель — наилучшая, а не последняя.

early_stop_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=5,                   # ждём 5 эпох без улучшения
    min_delta=1e-4,               # минимальное значимое улучшение
    restore_best_weights=True,    # возвращаем веса лучшей эпохи
    verbose=1,
)

TensorBoard

Логирует скалярные метрики, гистограммы весов, граф вычислений, профиль производительности pipeline и embedding-проекции.

tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir="./logs/run_001",
    histogram_freq=1,          # гистограммы весов каждую эпоху
    write_graph=True,          # визуализация графа модели
    write_images=False,        # не логируем изображения весов
    update_freq="epoch",       # частота логирования: "epoch" или int (батчи)
    profile_batch=(5, 15),     # профилируем батчи 5–15
)

Запуск TensorBoard локально:

tensorboard --logdir ./logs --port 6006
# открываем http://localhost:6006

Полный пример с тремя callbacks

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(32, 32, 3)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, activation="relu"),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax"),
])

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        "best_model.keras", monitor="val_accuracy",
        save_best_only=True, mode="max",
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=5,
        restore_best_weights=True,
    ),
    tf.keras.callbacks.TensorBoard(
        log_dir="./logs", histogram_freq=1,
    ),
]

model.fit(
    x_train, y_train,
    validation_split=0.1,
    epochs=50,         # EarlyStopping остановит раньше
    batch_size=64,
    callbacks=callbacks,
)

# Загружаем лучшую модель для финальной оценки
best_model = tf.keras.models.load_model("best_model.keras")
best_model.evaluate(x_test, y_test)

Кастомный callback

Для специфических нужд (логирование в Weights & Biases, отправка алертов при NaN loss) наследуем tf.keras.callbacks.Callback:

class NanLossAlert(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        loss = (logs or {}).get("loss")
        if loss is not None and tf.math.is_nan(loss):
            print(f"\nNaN loss detected at batch {batch}, stopping.")
            self.model.stop_training = True

Подводные камни

  • ModelCheckpoint с save_best_only=True и mode="auto" может неправильно определить направление для кастомных метрик — лучше явно указывать mode="max" или mode="min".
  • EarlyStopping без restore_best_weights=True возвращает модель последней эпохи, а не лучшей — обычно это ошибка.
  • Если validation_data не передан в fit(), метрики с префиксом val_ не вычисляются, и callback с monitor="val_loss" выбросит исключение.
  • TensorBoard с histogram_freq=1 значительно замедляет обучение на больших моделях — для prod-обучения используйте histogram_freq=0 или раз в N эпох.
  • Несколько запусков TensorBoard в одну и ту же директорию мешают друг другу — используйте уникальные log_dir (например, с timestamp через datetime.now().strftime).
  • Callbacks порядкозависимы: EarlyStopping должен стоять после ModelCheckpoint в списке, иначе чекпоинт лучшей модели может не сохраниться до остановки.
  • Формат .h5 не поддерживает сохранение optimizer state при save_weights_only=False для некоторых оптимизаторов TF 2.x — используйте .keras или SavedModel.
  • Кастомный callback с тяжёлой логикой в on_batch_end (например, API-запрос) блокирует обучение синхронно — выносите I/O в отдельный поток или используйте on_epoch_end.

Common mistakes

  • Объяснять keras callbacks только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать keras callbacks с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics