scikit-learnJuniorTechnical
В чём разница между predict() и predict_proba() и когда важен последний?
predict() возвращает метку класса (жёсткое решение с порогом 0.5), predict_proba() — вероятности принадлежности к каждому классу. predict_proba нужен для ROC AUC, настройки порога, ранжирования и ансамблирования.
predict() и predict_proba(): в чём разница
predict(X) возвращает конкретную метку класса для каждого объекта — это финальное решение классификатора. predict_proba(X) возвращает массив вероятностей принадлежности к каждому классу: строка соответствует объекту, столбец — классу из clf.classes_. Сумма вероятностей по строке всегда равна 1.0.
Пример с LogisticRegression
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
scaler = StandardScaler()
X_train_s = scaler.fit_transform(X_train)
X_test_s = scaler.transform(X_test)
clf = LogisticRegression(max_iter=1000).fit(X_train_s, y_train)
# Жёсткие предсказания
labels = clf.predict(X_test_s)
print(labels[:5]) # [1 0 0 1 1]
print(clf.classes_) # [0 1]
# Вероятности: столбец 0 — P(класс=0), столбец 1 — P(класс=1)
proba = clf.predict_proba(X_test_s)
print(proba[:5])
# [[0.02 0.98]
# [0.97 0.03]
# ...]
# Вероятность позитивного класса
pos_proba = proba[:, 1]
Когда нужен predict_proba()
- ROC AUC: метрика
roc_auc_scoreтребует вероятностей, а не меток. - Настройка порога: по умолчанию порог 0.5, но при дисбалансе классов или разной стоимости ошибок оптимальный порог может быть 0.3 или 0.7.
- Ранжирование: в информационном поиске или рекомендательных системах нужен рейтинг объектов по вероятности.
- Калибровка: анализ calibration curve требует вероятностей.
- Ансамблирование: soft voting усредняет вероятности моделей.
Настройка порога решения
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt
pos_proba = clf.predict_proba(X_test_s)[:, 1]
# ROC AUC работает только с вероятностями
print("ROC AUC:", roc_auc_score(y_test, pos_proba))
# Настройка порога для максимизации F1
thresholds = np.linspace(0.1, 0.9, 81)
f1_scores = [f1_score(y_test, pos_proba >= t) for t in thresholds]
best_threshold = thresholds[np.argmax(f1_scores)]
print(f"Лучший порог: {best_threshold:.2f}, F1: {max(f1_scores):.3f}")
# Применение нестандартного порога
custom_labels = (pos_proba >= best_threshold).astype(int)
Калибровка вероятностей
from sklearn.calibration import CalibratedClassifierCV
from sklearn.svm import SVC
# SVC не имеет predict_proba по умолчанию (только decision_function)
# CalibratedClassifierCV добавляет откалиброванные вероятности
svc = SVC(kernel='rbf')
calibrated = CalibratedClassifierCV(svc, method='isotonic', cv=5)
calibrated.fit(X_train_s, y_train)
print(calibrated.predict_proba(X_test_s)[:3])
Модели без predict_proba
SVC(probability=False)— нетpredict_proba, естьdecision_function.SGDClassifierс некоторыми loss — толькоdecision_function.- Для получения вероятностей используйте
CalibratedClassifierCV.
Подводные камни
- predict_proba доступен не у всех классификаторов:
SVC(probability=False)вызовет AttributeError; включениеprobability=Trueсущественно замедляет обучение SVC. - Вероятности RandomForest и деревьев не откалиброваны: они часто смещены к 0 и 1, calibration curve показывает отклонение от диагонали.
- predict() использует порог 0.5 по умолчанию: это неоптимально при дисбалансе классов; всегда проверяйте порог явно.
- Многоклассовый случай:
predict_probaвозвращает матрицу (n_samples, n_classes); для бинарной метрики типа ROC AUC нужно выбрать столбец нужного класса. - roc_auc_score и predict(): передача меток вместо вероятностей в roc_auc_score даст значение 0.5 или ошибку — частая ошибка новичков.
- VotingClassifier:
voting='hard'использует predict(),voting='soft'требует predict_proba у всех базовых моделей. - Pipeline и predict_proba: метод проксируется к финальному шагу — если он не поддерживает predict_proba, pipeline тоже не поддерживает.
Common mistakes
- Объяснять
predict vs predict probaтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
predict vs predict probaс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.