scikit-learnJuniorTechnical

В чём разница между predict() и predict_proba() и когда важен последний?

predict() возвращает метку класса (жёсткое решение с порогом 0.5), predict_proba() — вероятности принадлежности к каждому классу. predict_proba нужен для ROC AUC, настройки порога, ранжирования и ансамблирования.

predict() и predict_proba(): в чём разница

predict(X) возвращает конкретную метку класса для каждого объекта — это финальное решение классификатора. predict_proba(X) возвращает массив вероятностей принадлежности к каждому классу: строка соответствует объекту, столбец — классу из clf.classes_. Сумма вероятностей по строке всегда равна 1.0.

Пример с LogisticRegression

from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

scaler = StandardScaler()
X_train_s = scaler.fit_transform(X_train)
X_test_s = scaler.transform(X_test)

clf = LogisticRegression(max_iter=1000).fit(X_train_s, y_train)

# Жёсткие предсказания
labels = clf.predict(X_test_s)
print(labels[:5])          # [1 0 0 1 1]
print(clf.classes_)        # [0 1]

# Вероятности: столбец 0 — P(класс=0), столбец 1 — P(класс=1)
proba = clf.predict_proba(X_test_s)
print(proba[:5])
# [[0.02 0.98]
#  [0.97 0.03]
#  ...]

# Вероятность позитивного класса
pos_proba = proba[:, 1]

Когда нужен predict_proba()

  • ROC AUC: метрика roc_auc_score требует вероятностей, а не меток.
  • Настройка порога: по умолчанию порог 0.5, но при дисбалансе классов или разной стоимости ошибок оптимальный порог может быть 0.3 или 0.7.
  • Ранжирование: в информационном поиске или рекомендательных системах нужен рейтинг объектов по вероятности.
  • Калибровка: анализ calibration curve требует вероятностей.
  • Ансамблирование: soft voting усредняет вероятности моделей.

Настройка порога решения

from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt

pos_proba = clf.predict_proba(X_test_s)[:, 1]

# ROC AUC работает только с вероятностями
print("ROC AUC:", roc_auc_score(y_test, pos_proba))

# Настройка порога для максимизации F1
thresholds = np.linspace(0.1, 0.9, 81)
f1_scores = [f1_score(y_test, pos_proba >= t) for t in thresholds]
best_threshold = thresholds[np.argmax(f1_scores)]
print(f"Лучший порог: {best_threshold:.2f}, F1: {max(f1_scores):.3f}")

# Применение нестандартного порога
custom_labels = (pos_proba >= best_threshold).astype(int)

Калибровка вероятностей

from sklearn.calibration import CalibratedClassifierCV
from sklearn.svm import SVC

# SVC не имеет predict_proba по умолчанию (только decision_function)
# CalibratedClassifierCV добавляет откалиброванные вероятности
svc = SVC(kernel='rbf')
calibrated = CalibratedClassifierCV(svc, method='isotonic', cv=5)
calibrated.fit(X_train_s, y_train)
print(calibrated.predict_proba(X_test_s)[:3])

Модели без predict_proba

  • SVC(probability=False) — нет predict_proba, есть decision_function.
  • SGDClassifier с некоторыми loss — только decision_function.
  • Для получения вероятностей используйте CalibratedClassifierCV.

Подводные камни

  • predict_proba доступен не у всех классификаторов: SVC(probability=False) вызовет AttributeError; включение probability=True существенно замедляет обучение SVC.
  • Вероятности RandomForest и деревьев не откалиброваны: они часто смещены к 0 и 1, calibration curve показывает отклонение от диагонали.
  • predict() использует порог 0.5 по умолчанию: это неоптимально при дисбалансе классов; всегда проверяйте порог явно.
  • Многоклассовый случай: predict_proba возвращает матрицу (n_samples, n_classes); для бинарной метрики типа ROC AUC нужно выбрать столбец нужного класса.
  • roc_auc_score и predict(): передача меток вместо вероятностей в roc_auc_score даст значение 0.5 или ошибку — частая ошибка новичков.
  • VotingClassifier: voting='hard' использует predict(), voting='soft' требует predict_proba у всех базовых моделей.
  • Pipeline и predict_proba: метод проксируется к финальному шагу — если он не поддерживает predict_proba, pipeline тоже не поддерживает.

Common mistakes

  • Объяснять predict vs predict proba только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать predict vs predict proba с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics