TensorFlowMiddleCoding

Как Functional API обрабатывает модели с несколькими входами и выходами в Keras?

Functional API позволяет создавать модели с несколькими tf.keras.Input и несколькими выходными головами, компилируемыми с отдельными loss и loss_weights. Входы и выходы передаются как dict с именами, совпадающими с name= в Input/Dense слоях.

Functional API vs Sequential

Sequential API строит линейный стек слоёв. Functional API позволяет строить DAG-граф тензоров: несколько входных tf.keras.Input, общие веса между ветвями, несколько голов с разными функциями потерь. Это единственный способ реализовать multi-task learning, siamese-сети, модели с боковыми входами (текст + метаданные) или multi-label классификаторы с разными loss-весами в Keras.

Архитектура с двумя входами и двумя выходами

import tensorflow as tf
from tensorflow import keras

# --- Два входа ---
image_input = keras.Input(shape=(224, 224, 3), name='image')       # изображение
meta_input  = keras.Input(shape=(16,),          name='metadata')   # числовые метаданные

# --- Ветвь изображения ---
x = keras.layers.Conv2D(32, 3, activation='relu', padding='same')(image_input)
x = keras.layers.GlobalAveragePooling2D()(x)   # (batch, 32)
x = keras.layers.Dense(64, activation='relu')(x)

# --- Ветвь метаданных ---
m = keras.layers.Dense(32, activation='relu')(meta_input)
m = keras.layers.Dense(32, activation='relu')(m)

# --- Слияние ---
merged = keras.layers.Concatenate()([x, m])    # (batch, 96)
shared = keras.layers.Dense(64, activation='relu')(merged)

# --- Два выхода ---
class_output = keras.layers.Dense(10, activation='softmax', name='class_head')(shared)
price_output = keras.layers.Dense(1,  activation='linear',  name='price_head')(shared)

model = keras.Model(
    inputs={'image': image_input, 'metadata': meta_input},
    outputs={'class_head': class_output, 'price_head': price_output},
)
model.summary()

Компиляция с несколькими функциями потерь

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss={
        'class_head': 'sparse_categorical_crossentropy',
        'price_head': 'mse',
    },
    loss_weights={
        'class_head': 1.0,
        'price_head': 0.1,   # масштабируем, чтобы MSE не доминировала
    },
    metrics={
        'class_head': ['accuracy'],
        'price_head': ['mae'],
    },
)

Обучение с dict-входами и dict-выходами

import numpy as np

# Синтетические данные
N = 512
train_data = {
    'image':    np.random.rand(N, 224, 224, 3).astype(np.float32),
    'metadata': np.random.rand(N, 16).astype(np.float32),
}
train_labels = {
    'class_head': np.random.randint(0, 10, size=N),
    'price_head': np.random.rand(N, 1).astype(np.float32),
}

history = model.fit(
    train_data,
    train_labels,
    batch_size=32,
    epochs=5,
    validation_split=0.2,
)

# Предсказание возвращает dict с именами выходов
preds = model.predict({'image': train_data['image'][:4],
                       'metadata': train_data['metadata'][:4]})
print(preds['class_head'].shape)  # (4, 10)
print(preds['price_head'].shape)  # (4, 1)

Визуализация графа

keras.utils.plot_model(
    model,
    to_file='model_graph.png',
    show_shapes=True,
    show_layer_names=True,
    expand_nested=True,
)

Подводные камни

  • Имена входов в keras.Input(name=...) должны совпадать с ключами dict при вызове fit() и predict() — несовпадение имён порождает молчаливое игнорирование или ошибку формы.
  • loss_weights не нормируются автоматически: если MSE на задаче регрессии имеет порядок 1e4, а crossentropy — 1.0, без веса MSE «поглотит» сигнал классификации.
  • При использовании validation_split данные разбиваются по первой оси до перемешивания — если датасет не перемешан заранее, в validation попадут только последние N%.
  • Shared-слои (один слой, подключённый к нескольким ветвям) используют одни веса, но градиенты суммируются — это намеренное поведение, но может удивить при дебаге.
  • Сохранение через model.save('model.keras') работает, но tf.saved_model.save требует, чтобы все входные подписи были явными — анонимные input-слои ломают экспорт.
  • model.summary() не отображает форму тензора в точках ветвления в старых версиях Keras — используйте keras.utils.plot_model для проверки корректности графа.
  • При передаче данных как numpy-массивов (не dict), Keras сопоставляет их по порядку, а не по имени — если порядок входов неожиданный, модель молча обучается на неправильных данных.
  • Метрики каждого выхода пишутся в history как class_head_accuracy, val_price_head_mae — при логировании в TensorBoard или MLflow нужно явно указывать эти ключи.

Common mistakes

  • Объяснять functional api multi io только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать functional api multi io с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics