TensorFlowMiddleCoding
Как Functional API обрабатывает модели с несколькими входами и выходами в Keras?
Functional API позволяет создавать модели с несколькими tf.keras.Input и несколькими выходными головами, компилируемыми с отдельными loss и loss_weights. Входы и выходы передаются как dict с именами, совпадающими с name= в Input/Dense слоях.
Functional API vs Sequential
Sequential API строит линейный стек слоёв. Functional API позволяет строить DAG-граф тензоров: несколько входных tf.keras.Input, общие веса между ветвями, несколько голов с разными функциями потерь. Это единственный способ реализовать multi-task learning, siamese-сети, модели с боковыми входами (текст + метаданные) или multi-label классификаторы с разными loss-весами в Keras.
Архитектура с двумя входами и двумя выходами
import tensorflow as tf
from tensorflow import keras
# --- Два входа ---
image_input = keras.Input(shape=(224, 224, 3), name='image') # изображение
meta_input = keras.Input(shape=(16,), name='metadata') # числовые метаданные
# --- Ветвь изображения ---
x = keras.layers.Conv2D(32, 3, activation='relu', padding='same')(image_input)
x = keras.layers.GlobalAveragePooling2D()(x) # (batch, 32)
x = keras.layers.Dense(64, activation='relu')(x)
# --- Ветвь метаданных ---
m = keras.layers.Dense(32, activation='relu')(meta_input)
m = keras.layers.Dense(32, activation='relu')(m)
# --- Слияние ---
merged = keras.layers.Concatenate()([x, m]) # (batch, 96)
shared = keras.layers.Dense(64, activation='relu')(merged)
# --- Два выхода ---
class_output = keras.layers.Dense(10, activation='softmax', name='class_head')(shared)
price_output = keras.layers.Dense(1, activation='linear', name='price_head')(shared)
model = keras.Model(
inputs={'image': image_input, 'metadata': meta_input},
outputs={'class_head': class_output, 'price_head': price_output},
)
model.summary()
Компиляция с несколькими функциями потерь
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss={
'class_head': 'sparse_categorical_crossentropy',
'price_head': 'mse',
},
loss_weights={
'class_head': 1.0,
'price_head': 0.1, # масштабируем, чтобы MSE не доминировала
},
metrics={
'class_head': ['accuracy'],
'price_head': ['mae'],
},
)
Обучение с dict-входами и dict-выходами
import numpy as np
# Синтетические данные
N = 512
train_data = {
'image': np.random.rand(N, 224, 224, 3).astype(np.float32),
'metadata': np.random.rand(N, 16).astype(np.float32),
}
train_labels = {
'class_head': np.random.randint(0, 10, size=N),
'price_head': np.random.rand(N, 1).astype(np.float32),
}
history = model.fit(
train_data,
train_labels,
batch_size=32,
epochs=5,
validation_split=0.2,
)
# Предсказание возвращает dict с именами выходов
preds = model.predict({'image': train_data['image'][:4],
'metadata': train_data['metadata'][:4]})
print(preds['class_head'].shape) # (4, 10)
print(preds['price_head'].shape) # (4, 1)
Визуализация графа
keras.utils.plot_model(
model,
to_file='model_graph.png',
show_shapes=True,
show_layer_names=True,
expand_nested=True,
)
Подводные камни
- Имена входов в
keras.Input(name=...)должны совпадать с ключами dict при вызовеfit()иpredict()— несовпадение имён порождает молчаливое игнорирование или ошибку формы. loss_weightsне нормируются автоматически: если MSE на задаче регрессии имеет порядок 1e4, а crossentropy — 1.0, без веса MSE «поглотит» сигнал классификации.- При использовании
validation_splitданные разбиваются по первой оси до перемешивания — если датасет не перемешан заранее, в validation попадут только последние N%. - Shared-слои (один слой, подключённый к нескольким ветвям) используют одни веса, но градиенты суммируются — это намеренное поведение, но может удивить при дебаге.
- Сохранение через
model.save('model.keras')работает, ноtf.saved_model.saveтребует, чтобы все входные подписи были явными — анонимные input-слои ломают экспорт. model.summary()не отображает форму тензора в точках ветвления в старых версиях Keras — используйтеkeras.utils.plot_modelдля проверки корректности графа.- При передаче данных как numpy-массивов (не dict), Keras сопоставляет их по порядку, а не по имени — если порядок входов неожиданный, модель молча обучается на неправильных данных.
- Метрики каждого выхода пишутся в history как
class_head_accuracy,val_price_head_mae— при логировании в TensorBoard или MLflow нужно явно указывать эти ключи.
Common mistakes
- Объяснять
functional api multi ioтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
functional api multi ioс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.