scikit-learnJuniorCoding
Что такое интерфейс fit(), transform() и fit_transform() в scikit-learn и почему такой дизайн API важен?
fit() запоминает статистики по train-данным, transform() применяет их к любым данным, fit_transform() — удобный shortcut только для train. Pipeline гарантирует, что препроцессор никогда не видит тестовые данные при кросс-валидации.
Контракт API: fit, transform, fit_transform
В scikit-learn все трансформеры реализуют три метода с чётким разделением ответственности:
fit(X, y=None)— вычисляет и запоминает статистики из обучающих данных (среднее, дисперсию, словарь категорий). Возвращаетselfдля chaining. Не меняет X.transform(X)— применяет уже вычисленные параметры к новым данным. Можно вызывать многократно: на train, val, test, production-данных.fit_transform(X, y=None)— эквивалентfit(X).transform(X), но часто реализован эффективнее. Используется только на train-выборке.
Пример: StandardScaler
import numpy as np
from sklearn.preprocessing import StandardScaler
X_train = np.array([[1.0, 200.0], [2.0, 400.0], [3.0, 600.0]])
X_test = np.array([[4.0, 800.0], [0.5, 100.0]])
scaler = StandardScaler()
# fit вычисляет mean и std ТОЛЬКО по train
scaler.fit(X_train)
print(scaler.mean_) # [2. 400.]
print(scaler.scale_) # [0.81649658 163.2993...]
# transform применяет (X - mean) / std
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test) # используем те же mean/std!
print(X_train_scaled)
# [[-1.22 -1.22]
# [ 0. 0. ]
# [ 1.22 1.22]]
Почему нельзя fit на тесте
# НЕВЕРНО: scaler видит тестовые значения
scaler_bad = StandardScaler()
X_test_wrong = scaler_bad.fit_transform(X_test) # mean теперь другой!
# Модель получит данные в другом масштабе — предсказания будут неверными
# ВЕРНО: fit только на train, transform на всём остальном
scaler_good = StandardScaler()
X_train_scaled = scaler_good.fit_transform(X_train)
X_test_scaled = scaler_good.transform(X_test)
Pipeline автоматически соблюдает контракт
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
import pandas as pd
X, y = make_classification(n_samples=500, n_features=4, random_state=0)
df = pd.DataFrame(X, columns=["age", "income", "score", "flag"])
df["country"] = ["RU", "US", "DE"][0:1] * 167 + ["US"] * 166
num_cols = ["age", "income", "score", "flag"]
cat_cols = ["country"]
preprocessor = ColumnTransformer([
("num", StandardScaler(), num_cols),
("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), cat_cols),
])
pipeline = Pipeline([
("prep", preprocessor),
("clf", LogisticRegression(max_iter=300, random_state=0)),
])
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
# Pipeline.fit вызывает prep.fit_transform(X_train) + clf.fit(...)
# Pipeline.predict вызывает prep.transform(X_test) + clf.predict(...)
pipeline.fit(X_train, y_train)
print("Test accuracy:", pipeline.score(X_test, y_test))
Зачем такой дизайн API
Единый интерфейс fit / transform позволяет:
- Компоновать любые трансформеры в
PipelineиColumnTransformerбез изменения кода. - Безопасно использовать
cross_validate— каждый fold независимо fit-ит препроцессор, исключая leakage. - Сериализовать весь pipeline одним
joblib.dump(pipeline, "model.pkl")и деплоить без повторного fit-а. - Реализовывать кастомные трансформеры через
BaseEstimator + TransformerMixin, получаяfit_transformбесплатно.
Подводные камни
- Вызов transform до fit — бросает
sklearn.exceptions.NotFittedError. Проверяйте черезcheck_is_fitted(estimator). - fit_transform на тесте — самая частая ошибка новичков: статистики пересчитываются, масштаб не совпадает с train.
- Мутация входного X — некоторые трансформеры (например,
SimpleImputerсcopy=False) изменяют массив на месте. По умолчаниюcopy=True. - ColumnTransformer отбрасывает неперечисленные столбцы — используйте
remainder="passthrough", если нужно сохранить остальные признаки. - sparse_output=True у OneHotEncoder — по умолчанию возвращает разреженную матрицу; StandardScaler её не принимает. Используйте
sparse_output=FalseилиMaxAbsScaler. - set_output API (sklearn 1.2+) —
pipeline.set_output(transform="pandas")возвращает DataFrame с именами столбцов; без него именованный доступ к признакам теряется.
Common mistakes
- Объяснять
fit transform apiтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
fit transform apiс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.