TensorFlowMiddleCoding

Что такое TensorFlow Hub и как использовать предобученные модели из него?

TensorFlow Hub — репозиторий предобученных моделей. Загрузите модуль через hub.load() или hub.KerasLayer(), затем дообучите (fine-tune) верхние слои на своих данных.

Что такое TensorFlow Hub

TensorFlow Hub (tfhub.dev) — репозиторий переиспользуемых модулей машинного обучения: эмбеддинги текста, классификаторы изображений, детекторы объектов и многое другое. Модули публикуются в формате SavedModel и загружаются одной строкой кода прямо из URL.

Установка

pip install tensorflow-hub

Два способа использования

  • hub.load(url) — возвращает SavedModel-объект; подходит для инференса без Keras.
  • hub.KerasLayer(url, trainable=...) — оборачивает модуль в слой Keras; удобно для fine-tuning.

Пример: классификация изображений с MobileNetV2

import tensorflow as tf
import tensorflow_hub as hub

MODEL_URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
IMG_SIZE   = 224
NUM_CLASSES = 5  # например, цветы (roses, daisy, ...)

# Входной пайплайн
def preprocess(image_path, label):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    img = img / 255.0  # нормализация
    return img, label

# Построение модели
base_model = hub.KerasLayer(
    MODEL_URL,
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    trainable=False,  # заморозить на первом этапе
)

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(NUM_CLASSES, activation="softmax"),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
model.summary()

# --- Fine-tuning: разморозить весь base_model ---
base_model.trainable = True
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),  # маленький LR!
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
# model.fit(train_ds, epochs=5, validation_data=val_ds)

Текстовые эмбеддинги

import tensorflow_hub as hub
import numpy as np

embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
result = embed(["Hello world", "TensorFlow Hub is awesome"])
print(result.shape)  # (2, 512)

Подводные камни

  • Размер модуля: MobileNetV2 весит ~14 MB, BERT — сотни MB. При первом запуске файлы кешируются в ~/.cache/tfhub_modules — убедитесь, что хватает места на диске.
  • Несовместимость версий TF: многие старые модули написаны для TF1 и требуют hub.load с tf.compat.v1 обёртками.
  • Разморозка всех слоёв сразу при высоком LR разрушает предобученные веса (catastrophic forgetting). Используйте двухэтапное обучение: сначала обучите «голову», потом fine-tune с LR ≤ 1e-5.
  • Несоответствие нормализации: некоторые модули ожидают пиксели в [0, 1], другие — в [-1, 1] или без нормализации. Читайте документацию конкретного модуля.
  • hub.KerasLayer не сохраняется стандартным model.save("model.h5") — используйте SavedModel-формат: model.save("model_dir").
  • Переменная окружения TFHUB_CACHE_DIR не установлена: при каждом старте контейнера модели скачиваются заново.
  • Смешение trainable=True/False и BatchNormalization: BN-слои внутри hub-модуля могут обновлять running stats в нежелательный момент, если не передать training=False явно.

Common mistakes

  • Объяснять tensorflow hub только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать tensorflow hub с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics