Что такое callbacks в Keras и как использовать ModelCheckpoint, EarlyStopping и TensorBoard?
Keras callbacks — это объекты, вызываемые в ключевые моменты обучения (конец эпохи, конец батча). ModelCheckpoint сохраняет лучшие веса, EarlyStopping останавливает при деградации метрики, TensorBoard логирует метрики и графики.
Что такое callbacks в Keras
Callback — это объект, наследующий tf.keras.callbacks.Callback, с методами-хуками: on_train_begin, on_epoch_end, on_batch_end и другими. Keras вызывает их в нужный момент обучения. Список callbacks передаётся в model.fit(callbacks=[...]). Три наиболее важных встроенных callback'а — ModelCheckpoint, EarlyStopping и TensorBoard.
ModelCheckpoint
Сохраняет веса (или всю модель) в конце каждой эпохи, опционально только когда отслеживаемая метрика улучшилась.
import tensorflow as tf
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
filepath="checkpoints/model_epoch{epoch:02d}_val{val_accuracy:.4f}.keras",
monitor="val_accuracy", # отслеживаемая метрика
save_best_only=True, # сохранять только лучшую
save_weights_only=False, # сохранять всю модель (архитектура + веса)
mode="max", # "max" для accuracy, "min" для loss
verbose=1,
)
Формат .keras (Keras 3 native) предпочтительнее .h5: он поддерживает кастомные слои без дополнительной регистрации и хранит optimizer state для resuming обучения.
EarlyStopping
Прерывает обучение, если метрика перестала улучшаться в течение patience эпох. Совместно с restore_best_weights=True гарантирует, что финальная модель — наилучшая, а не последняя.
early_stop_cb = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=5, # ждём 5 эпох без улучшения
min_delta=1e-4, # минимальное значимое улучшение
restore_best_weights=True, # возвращаем веса лучшей эпохи
verbose=1,
)
TensorBoard
Логирует скалярные метрики, гистограммы весов, граф вычислений, профиль производительности pipeline и embedding-проекции.
tensorboard_cb = tf.keras.callbacks.TensorBoard(
log_dir="./logs/run_001",
histogram_freq=1, # гистограммы весов каждую эпоху
write_graph=True, # визуализация графа модели
write_images=False, # не логируем изображения весов
update_freq="epoch", # частота логирования: "epoch" или int (батчи)
profile_batch=(5, 15), # профилируем батчи 5–15
)
Запуск TensorBoard локально:
tensorboard --logdir ./logs --port 6006
# открываем http://localhost:6006
Полный пример с тремя callbacks
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation="relu"),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
"best_model.keras", monitor="val_accuracy",
save_best_only=True, mode="max",
),
tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=5,
restore_best_weights=True,
),
tf.keras.callbacks.TensorBoard(
log_dir="./logs", histogram_freq=1,
),
]
model.fit(
x_train, y_train,
validation_split=0.1,
epochs=50, # EarlyStopping остановит раньше
batch_size=64,
callbacks=callbacks,
)
# Загружаем лучшую модель для финальной оценки
best_model = tf.keras.models.load_model("best_model.keras")
best_model.evaluate(x_test, y_test)
Кастомный callback
Для специфических нужд (логирование в Weights & Biases, отправка алертов при NaN loss) наследуем tf.keras.callbacks.Callback:
class NanLossAlert(tf.keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
loss = (logs or {}).get("loss")
if loss is not None and tf.math.is_nan(loss):
print(f"\nNaN loss detected at batch {batch}, stopping.")
self.model.stop_training = True
Подводные камни
ModelCheckpointсsave_best_only=Trueиmode="auto"может неправильно определить направление для кастомных метрик — лучше явно указыватьmode="max"илиmode="min".EarlyStoppingбезrestore_best_weights=Trueвозвращает модель последней эпохи, а не лучшей — обычно это ошибка.- Если
validation_dataне передан вfit(), метрики с префиксомval_не вычисляются, и callback сmonitor="val_loss"выбросит исключение. TensorBoardсhistogram_freq=1значительно замедляет обучение на больших моделях — для prod-обучения используйтеhistogram_freq=0или раз в N эпох.- Несколько запусков TensorBoard в одну и ту же директорию мешают друг другу — используйте уникальные
log_dir(например, с timestamp черезdatetime.now().strftime). - Callbacks порядкозависимы:
EarlyStoppingдолжен стоять послеModelCheckpointв списке, иначе чекпоинт лучшей модели может не сохраниться до остановки. - Формат
.h5не поддерживает сохранение optimizer state приsave_weights_only=Falseдля некоторых оптимизаторов TF 2.x — используйте.kerasилиSavedModel. - Кастомный callback с тяжёлой логикой в
on_batch_end(например, API-запрос) блокирует обучение синхронно — выносите I/O в отдельный поток или используйтеon_epoch_end.
Common mistakes
- Объяснять
keras callbacksтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
keras callbacksс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.