Какие риски есть у PyTorch: data leakage, неверная оценка качества, memory usage, latency, cost, drift или observability?
Ключевые риски PyTorch в production: data leakage через неправильный split, memory leaks из-за удержания графов, drift без мониторинга, и скрытые проблемы воспроизводимости. Каждый риск требует явной митигации.
Риски PyTorch в реальных системах
PyTorch — мощный инструмент, но переход от прототипа к production обнажает целый класс проблем, которые не видны в notebook-разработке. Ниже — систематический разбор по категориям.
Data Leakage
Самый опасный риск: модель случайно "видит" тестовые данные во время обучения, и метрики выглядят отлично, но в production — провал.
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.preprocessing import StandardScaler
import numpy as np
# НЕПРАВИЛЬНО: scaler обучен на всём датасете, включая тест
X = np.random.randn(1000, 20)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X) # Leakage!
# ПРАВИЛЬНО: scaler обучен только на train
train_size = 800
X_train, X_test = X[:train_size], X[train_size:]
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test) # Только transform!
# Time-series: обязательно temporal split
# Никогда не используйте random_split для временных рядов
class TemporalDataset(Dataset):
def __init__(self, data, seq_len=30):
self.data = torch.tensor(data, dtype=torch.float32)
self.seq_len = seq_len
def __len__(self):
return len(self.data) - self.seq_len
def __getitem__(self, idx):
return self.data[idx:idx+self.seq_len], self.data[idx+self.seq_len]
# Train/val/test split по времени, не случайно
n = len(X)
train_end = int(n * 0.7)
val_end = int(n * 0.85)
Неверная оценка качества
- Мониторинг только accuracy: при дисбалансе классов accuracy вводит в заблуждение — логируйте F1, PR-AUC, confusion matrix.
- Evaluation на неправильном сплите: validation loss в обучающем цикле vs. held-out test set — разные вещи.
- Label leakage в feature engineering: включение в признаки информации о целевой переменной (например, лаговые таргеты).
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, roc_auc_score
writer = SummaryWriter(log_dir="runs/experiment_1")
def evaluate(model, loader, device):
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for X, y in loader:
X, y = X.to(device), y.to(device)
logits = model(X)
preds = torch.sigmoid(logits).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(y.cpu().numpy())
f1 = f1_score(all_labels, [p > 0.5 for p in all_preds])
auc = roc_auc_score(all_labels, all_preds)
return {"f1": f1, "auc": auc}
Memory Usage и утечки
PyTorch удерживает вычислительный граф в памяти, если не вызвать .detach() или loss.item().
import torch
# НЕПРАВИЛЬНО: накапливает граф в losses
losses = []
for batch in dataloader:
loss = model(batch)
losses.append(loss) # Граф не освобождается!
# ПРАВИЛЬНО: отцепляем от графа
losses = []
for batch in dataloader:
loss = model(batch)
losses.append(loss.item()) # Только скаляр Python
# Явная очистка кэша GPU
torch.cuda.empty_cache()
# Мониторинг памяти
def log_gpu_memory(step: int):
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
print(f"Step {step}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
Latency в Production
- Inference без
torch.no_grad()строит граф и тратит вдвое больше памяти. - Первый вызов
torch.compileкомпилирует граф — warmup необходим. - Batch size 1 для real-time serving vs. батчированный inference для throughput — разные оптимизации.
# Production inference с оптимизациями
import torch
from contextlib import contextmanager
@contextmanager
def optimized_inference(model):
model.eval()
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
yield model
# Warmup для torch.compile
compiled_model = torch.compile(model)
warmup_input = torch.randn(1, 768).cuda()
for _ in range(3): # Warmup runs
_ = compiled_model(warmup_input)
Cost
GPU inference в облаке дорог. Квантизация снижает стоимость в 2–4 раза:
from torch.ao.quantization import quantize_dynamic
import torch.nn as nn
# Dynamic quantization: INT8 для линейных слоёв
quantized_model = quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8,
)
# Размер модели уменьшается ~4x, скорость на CPU растёт 2-3x
Drift и Observability
Model drift — постепенная деградация качества из-за изменения входных данных. Без мониторинга неотличим от нормы.
- Логируйте распределение входных признаков: среднее, std, min/max, % нулей.
- Мониторьте prediction distribution: если модель всё чаще выдаёт граничные значения — это сигнал.
- Используйте Evidently AI или NannyML для автоматического drift detection.
Подводные камни
- GroupKFold вместо KFold: при наличии групп (один пользователь/объект в нескольких строках) стандартный KFold даёт data leakage между фолдами — используйте
sklearn.model_selection.GroupKFold. - model.train() vs model.eval(): Dropout и BatchNorm ведут себя по-разному — забыть переключить режим приводит к нестабильным предсказаниям в inference.
- GPU OOM без контекстного менеджера:
torch.cuda.OutOfMemoryErrorне освобождает память автоматически — нужен явныйtorch.cuda.empty_cache()в except блоке. - AMP и NaN градиенты:
torch.autocastс float16 может давать NaN/Inf безGradScaler— всегда используйте их в паре. - DataLoader и fork: на macOS/Windows
num_workers > 0требуетif __name__ == "__main__"guard илиmultiprocessing_context="spawn". - Нет baseline для latency: без профилирования (
torch.profiler.profile) нельзя понять, что тормозит — препроцессинг, forward pass или постпроцессинг. - Отсутствие version pinning: PyTorch обновления ломают ONNX opset compatibility и behavior
torch.compile— фиксируйте версии вrequirements.txtи тестируйте при апгрейде. - Молчаливые ошибки в distributed training: при сбое одного worker'а остальные могут зависнуть в barrier без timeout — устанавливайте
NCCL_TIMEOUTи используйтеtorch.distributed.init_process_group(timeout=datetime.timedelta(minutes=30)).
What hurts your answer
- Говорить только о запуске PyTorch, но не об эксплуатации
- Не упоминать observability, обновления, безопасность и rollback
- Описывать риски абстрактно, без способов их снижать
What they're listening for
- Видит production-риски PyTorch
- Говорит про monitoring, rollout, rollback и безопасность
- Умеет ранжировать риски по вероятности и влиянию