scikit-learnJuniorCoding

Что такое интерфейс fit(), transform() и fit_transform() в scikit-learn и почему такой дизайн API важен?

fit() запоминает статистики по train-данным, transform() применяет их к любым данным, fit_transform() — удобный shortcut только для train. Pipeline гарантирует, что препроцессор никогда не видит тестовые данные при кросс-валидации.

Контракт API: fit, transform, fit_transform

В scikit-learn все трансформеры реализуют три метода с чётким разделением ответственности:

  • fit(X, y=None) — вычисляет и запоминает статистики из обучающих данных (среднее, дисперсию, словарь категорий). Возвращает self для chaining. Не меняет X.
  • transform(X) — применяет уже вычисленные параметры к новым данным. Можно вызывать многократно: на train, val, test, production-данных.
  • fit_transform(X, y=None) — эквивалент fit(X).transform(X), но часто реализован эффективнее. Используется только на train-выборке.

Пример: StandardScaler

import numpy as np
from sklearn.preprocessing import StandardScaler

X_train = np.array([[1.0, 200.0], [2.0, 400.0], [3.0, 600.0]])
X_test  = np.array([[4.0, 800.0], [0.5, 100.0]])

scaler = StandardScaler()

# fit вычисляет mean и std ТОЛЬКО по train
scaler.fit(X_train)
print(scaler.mean_)   # [2.  400.]
print(scaler.scale_)  # [0.81649658 163.2993...]

# transform применяет (X - mean) / std
X_train_scaled = scaler.transform(X_train)
X_test_scaled  = scaler.transform(X_test)   # используем те же mean/std!

print(X_train_scaled)
# [[-1.22  -1.22]
#  [ 0.     0.  ]
#  [ 1.22   1.22]]

Почему нельзя fit на тесте

# НЕВЕРНО: scaler видит тестовые значения
scaler_bad = StandardScaler()
X_test_wrong = scaler_bad.fit_transform(X_test)  # mean теперь другой!
# Модель получит данные в другом масштабе — предсказания будут неверными

# ВЕРНО: fit только на train, transform на всём остальном
scaler_good = StandardScaler()
X_train_scaled = scaler_good.fit_transform(X_train)
X_test_scaled  = scaler_good.transform(X_test)

Pipeline автоматически соблюдает контракт

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
import pandas as pd

X, y = make_classification(n_samples=500, n_features=4, random_state=0)
df = pd.DataFrame(X, columns=["age", "income", "score", "flag"])
df["country"] = ["RU", "US", "DE"][0:1] * 167 + ["US"] * 166

num_cols = ["age", "income", "score", "flag"]
cat_cols = ["country"]

preprocessor = ColumnTransformer([
    ("num", StandardScaler(), num_cols),
    ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), cat_cols),
])

pipeline = Pipeline([
    ("prep", preprocessor),
    ("clf",  LogisticRegression(max_iter=300, random_state=0)),
])

X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)

# Pipeline.fit вызывает prep.fit_transform(X_train) + clf.fit(...)
# Pipeline.predict вызывает prep.transform(X_test)  + clf.predict(...)
pipeline.fit(X_train, y_train)
print("Test accuracy:", pipeline.score(X_test, y_test))

Зачем такой дизайн API

Единый интерфейс fit / transform позволяет:

  • Компоновать любые трансформеры в Pipeline и ColumnTransformer без изменения кода.
  • Безопасно использовать cross_validate — каждый fold независимо fit-ит препроцессор, исключая leakage.
  • Сериализовать весь pipeline одним joblib.dump(pipeline, "model.pkl") и деплоить без повторного fit-а.
  • Реализовывать кастомные трансформеры через BaseEstimator + TransformerMixin, получая fit_transform бесплатно.

Подводные камни

  • Вызов transform до fit — бросает sklearn.exceptions.NotFittedError. Проверяйте через check_is_fitted(estimator).
  • fit_transform на тесте — самая частая ошибка новичков: статистики пересчитываются, масштаб не совпадает с train.
  • Мутация входного X — некоторые трансформеры (например, SimpleImputer с copy=False) изменяют массив на месте. По умолчанию copy=True.
  • ColumnTransformer отбрасывает неперечисленные столбцы — используйте remainder="passthrough", если нужно сохранить остальные признаки.
  • sparse_output=True у OneHotEncoder — по умолчанию возвращает разреженную матрицу; StandardScaler её не принимает. Используйте sparse_output=False или MaxAbsScaler.
  • set_output API (sklearn 1.2+)pipeline.set_output(transform="pandas") возвращает DataFrame с именами столбцов; без него именованный доступ к признакам теряется.

Common mistakes

  • Объяснять fit transform api только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать fit transform api с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics