PyTorchSeniorExperience

Представьте, pipeline на PyTorch даёт хорошие результаты в dev, но плохо работает на новых данных. Как вы будете разбираться?

Проверяйте последовательно: data drift (KS-test), data leakage, соответствие параметров нормализации train/inference, режим model.eval(), корректность загрузки checkpoint и анализ ошибочных примеров по паттернам.

Диагностика: pipeline хорошо работает на dev, плохо — на новых данных

Расхождение между dev и продакшн метриками — один из самых частых и опасных сценариев. Причины делятся на несколько групп: проблемы с данными, утечки, особенности обучения, deployment. Ниже — систематический процесс разбора.

Шаг 1. Проверить сдвиг распределения (data drift)

Первый вопрос: похожи ли новые данные на обучающие? Сравните базовые статистики входных признаков.

import torch
import numpy as np
from scipy import stats

def check_drift(train_tensor, new_tensor, feature_names):
    """KS-test для каждого признака."""
    train_np = train_tensor.numpy()
    new_np = new_tensor.numpy()
    for i, name in enumerate(feature_names):
        stat, p = stats.ks_2samp(train_np[:, i], new_np[:, i])
        if p < 0.05:
            print(f'DRIFT detected: {name}, KS={stat:.3f}, p={p:.4f}')

Если p-value ниже порога для многих признаков — данные изменились, нужно дообучение или ретренинг.

Шаг 2. Проверить утечку данных (data leakage)

Если в dev метрики были завышены из-за утечки, на новых данных они упадут. Проверьте:

  • Нормализация (mean/std) вычислялась по всему датасету до сплита?
  • Временные признаки (future data) случайно попали в обучение?
  • Одни и те же примеры оказались и в train, и в val?
# Проверка пересечения train/val id
train_ids = set(train_dataset.ids)
val_ids = set(val_dataset.ids)
overlap = train_ids & val_ids
if overlap:
    print(f'Leakage! {len(overlap)} samples in both splits')

Шаг 3. Проверить preprocessing pipeline

Самая частая причина: параметры нормализации (mean/std, scaler) вычислены на train и не применяются одинаково к новым данным. Проверьте, что при инференсе используются те же самые значения.

# Сохраняем статистики при обучении
normalization_params = {
    'mean': train_mean.tolist(),
    'std': train_std.tolist()
}
import json
with open('norm_params.json', 'w') as f:
    json.dump(normalization_params, f)

# При инференсе загружаем их же
with open('norm_params.json') as f:
    params = json.load(f)
mean = torch.tensor(params['mean'])
std = torch.tensor(params['std'])
new_data_normalized = (new_data - mean) / std

Шаг 4. Проверить режим модели при инференсе

model.eval()  # обязательно!
with torch.no_grad():
    predictions = model(new_data_normalized)

BatchNorm в режиме train использует статистики текущего батча, что даёт другие результаты на малых батчах или одиночных примерах.

Шаг 5. Анализ ошибок на новых данных

Посмотрите на примеры, где модель ошибается наиболее сильно. Ищите паттерны: определённые категории, диапазоны значений, временные интервалы.

model.eval()
errors = []
with torch.no_grad():
    for x, y, meta in new_loader:
        pred = model(x.to(device))
        err = (pred.cpu() - y).abs()
        for i in range(len(err)):
            if err[i] > threshold:
                errors.append({'meta': meta[i], 'error': err[i].item()})

# Группируем по категории
from collections import Counter
category_errors = Counter(e['meta']['category'] for e in errors)
print(category_errors.most_common(10))

Шаг 6. Проверить версионирование модели и артефактов

Убедитесь, что для инференса загружается именно та версия модели и тех же самых весов, которые показали хорошие результаты на dev. Случайная инициализация части весов из-за несовпадения checkpoint — частая ошибка.

checkpoint = torch.load('model_best.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print('Loaded epoch:', checkpoint['epoch'])
print('Val loss:', checkpoint['val_loss'])

Подводные камни

  • Разные параметры нормализации в train и inference — самая частая причина деградации. Всегда сохраняйте scaler/params как артефакт рядом с моделью.
  • Забытый model.eval(): Dropout и BatchNorm в режиме train дают стохастические результаты — каждый инференс будет другим.
  • Temporal leakage: при работе с временными рядами случайный сплит вместо временного разделения даёт оптимистичный val, но плохой прод.
  • Концептуальный сдвиг (concept drift): даже при правильном preprocessing поведение пользователей или бизнес-правила меняются со временем — модель устаревает.
  • Разные версии зависимостей: разные версии torchvision transforms или других пакетов могут давать разные результаты трансформаций.
  • Half-precision ошибки: fp16/bf16 в продакшне при обучении в fp32 может добавлять погрешность, особенно для чувствительных к численной точности операций.
  • Маленькие тестовые батчи при использовании BatchNorm: с batch_size=1 running_mean/var не обновляются правильно, если model не переключена в eval.

What hurts your answer

  • Сразу обвинять PyTorch, не проверив соседние слои системы
  • Чинить симптом без минимального воспроизведения и evidence
  • Не учитывать версии, конфигурацию, окружение и recent changes

What they're listening for

  • Умеет локализовать проблему вокруг PyTorch
  • Двигается от симптома к гипотезам и проверкам
  • Отличает баг инструмента от ошибки использования или окружения

Related topics