PyTorchSeniorExperience

Какие риски есть у PyTorch: data leakage, неверная оценка качества, memory usage, latency, cost, drift или observability?

Ключевые риски PyTorch в production: data leakage через неправильный split, memory leaks из-за удержания графов, drift без мониторинга, и скрытые проблемы воспроизводимости. Каждый риск требует явной митигации.

Риски PyTorch в реальных системах

PyTorch — мощный инструмент, но переход от прототипа к production обнажает целый класс проблем, которые не видны в notebook-разработке. Ниже — систематический разбор по категориям.

Data Leakage

Самый опасный риск: модель случайно "видит" тестовые данные во время обучения, и метрики выглядят отлично, но в production — провал.

import torch
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.preprocessing import StandardScaler
import numpy as np

# НЕПРАВИЛЬНО: scaler обучен на всём датасете, включая тест
X = np.random.randn(1000, 20)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Leakage!

# ПРАВИЛЬНО: scaler обучен только на train
train_size = 800
X_train, X_test = X[:train_size], X[train_size:]
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)  # Только transform!

# Time-series: обязательно temporal split
# Никогда не используйте random_split для временных рядов
class TemporalDataset(Dataset):
    def __init__(self, data, seq_len=30):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        return self.data[idx:idx+self.seq_len], self.data[idx+self.seq_len]

# Train/val/test split по времени, не случайно
n = len(X)
train_end = int(n * 0.7)
val_end = int(n * 0.85)

Неверная оценка качества

  • Мониторинг только accuracy: при дисбалансе классов accuracy вводит в заблуждение — логируйте F1, PR-AUC, confusion matrix.
  • Evaluation на неправильном сплите: validation loss в обучающем цикле vs. held-out test set — разные вещи.
  • Label leakage в feature engineering: включение в признаки информации о целевой переменной (например, лаговые таргеты).
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, roc_auc_score

writer = SummaryWriter(log_dir="runs/experiment_1")

def evaluate(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            preds = torch.sigmoid(logits).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(y.cpu().numpy())
    f1 = f1_score(all_labels, [p > 0.5 for p in all_preds])
    auc = roc_auc_score(all_labels, all_preds)
    return {"f1": f1, "auc": auc}

Memory Usage и утечки

PyTorch удерживает вычислительный граф в памяти, если не вызвать .detach() или loss.item().

import torch

# НЕПРАВИЛЬНО: накапливает граф в losses
losses = []
for batch in dataloader:
    loss = model(batch)
    losses.append(loss)  # Граф не освобождается!

# ПРАВИЛЬНО: отцепляем от графа
losses = []
for batch in dataloader:
    loss = model(batch)
    losses.append(loss.item())  # Только скаляр Python

# Явная очистка кэша GPU
torch.cuda.empty_cache()

# Мониторинг памяти
def log_gpu_memory(step: int):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"Step {step}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")

Latency в Production

  • Inference без torch.no_grad() строит граф и тратит вдвое больше памяти.
  • Первый вызов torch.compile компилирует граф — warmup необходим.
  • Batch size 1 для real-time serving vs. батчированный inference для throughput — разные оптимизации.
# Production inference с оптимизациями
import torch
from contextlib import contextmanager

@contextmanager
def optimized_inference(model):
    model.eval()
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        yield model

# Warmup для torch.compile
compiled_model = torch.compile(model)
warmup_input = torch.randn(1, 768).cuda()
for _ in range(3):  # Warmup runs
    _ = compiled_model(warmup_input)

Cost

GPU inference в облаке дорог. Квантизация снижает стоимость в 2–4 раза:

from torch.ao.quantization import quantize_dynamic
import torch.nn as nn

# Dynamic quantization: INT8 для линейных слоёв
quantized_model = quantize_dynamic(
    model,
    {nn.Linear},
    dtype=torch.qint8,
)
# Размер модели уменьшается ~4x, скорость на CPU растёт 2-3x

Drift и Observability

Model drift — постепенная деградация качества из-за изменения входных данных. Без мониторинга неотличим от нормы.

  • Логируйте распределение входных признаков: среднее, std, min/max, % нулей.
  • Мониторьте prediction distribution: если модель всё чаще выдаёт граничные значения — это сигнал.
  • Используйте Evidently AI или NannyML для автоматического drift detection.

Подводные камни

  • GroupKFold вместо KFold: при наличии групп (один пользователь/объект в нескольких строках) стандартный KFold даёт data leakage между фолдами — используйте sklearn.model_selection.GroupKFold.
  • model.train() vs model.eval(): Dropout и BatchNorm ведут себя по-разному — забыть переключить режим приводит к нестабильным предсказаниям в inference.
  • GPU OOM без контекстного менеджера: torch.cuda.OutOfMemoryError не освобождает память автоматически — нужен явный torch.cuda.empty_cache() в except блоке.
  • AMP и NaN градиенты: torch.autocast с float16 может давать NaN/Inf без GradScaler — всегда используйте их в паре.
  • DataLoader и fork: на macOS/Windows num_workers > 0 требует if __name__ == "__main__" guard или multiprocessing_context="spawn".
  • Нет baseline для latency: без профилирования (torch.profiler.profile) нельзя понять, что тормозит — препроцессинг, forward pass или постпроцессинг.
  • Отсутствие version pinning: PyTorch обновления ломают ONNX opset compatibility и behavior torch.compile — фиксируйте версии в requirements.txt и тестируйте при апгрейде.
  • Молчаливые ошибки в distributed training: при сбое одного worker'а остальные могут зависнуть в barrier без timeout — устанавливайте NCCL_TIMEOUT и используйте torch.distributed.init_process_group(timeout=datetime.timedelta(minutes=30)).

What hurts your answer

  • Говорить только о запуске PyTorch, но не об эксплуатации
  • Не упоминать observability, обновления, безопасность и rollback
  • Описывать риски абстрактно, без способов их снижать

What they're listening for

  • Видит production-риски PyTorch
  • Говорит про monitoring, rollout, rollback и безопасность
  • Умеет ранжировать риски по вероятности и влиянию

Related topics