PyTorchMiddleCoding
Как реализовать кастомную функцию потерь в PyTorch?
Кастомная loss — это функция или nn.Module, принимающая logits и targets и возвращающая скалярный тензор; главные требования — численная стабильность (log-space операции) и корректный reduction для .backward().
Кастомная функция потерь в PyTorch
Кастомная loss function в PyTorch реализуется двумя способами: как обычная функция Python или как подкласс nn.Module. Второй вариант предпочтителен, когда у функции потерь есть обучаемые параметры или нужно управлять состоянием.
Вариант 1: функция (без параметров)
import torch
import torch.nn.functional as F
def focal_loss(logits: torch.Tensor, targets: torch.Tensor, gamma: float = 2.0, alpha: float = 0.25) -> torch.Tensor:
"""Focal Loss для несбалансированных классов (Lin et al. 2017)."""
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
probs = torch.sigmoid(logits)
pt = torch.where(targets == 1, probs, 1 - probs)
focal_weight = alpha * (1 - pt) ** gamma
return (focal_weight * bce).mean()
# Пример использования
logits = torch.randn(32, 1) # выход модели
targets = torch.randint(0, 2, (32, 1)).float()
loss = focal_loss(logits, targets, gamma=2.0)
loss.backward() # градиенты вычисляются автоматически
Вариант 2: nn.Module (с параметрами или состоянием)
import torch
import torch.nn as nn
class WeightedMSELoss(nn.Module):
"""MSE с обучаемым вектором весов для каждого выхода."""
def __init__(self, num_outputs: int):
super().__init__()
# nn.Parameter автоматически регистрируется в model.parameters()
self.weights = nn.Parameter(torch.ones(num_outputs))
def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# preds, targets: (batch, num_outputs)
per_output_mse = ((preds - targets) ** 2).mean(dim=0) # (num_outputs,)
# softmax чтобы веса суммировались в 1 и были положительными
w = torch.softmax(self.weights, dim=0)
return (w * per_output_mse).sum()
loss_fn = WeightedMSELoss(num_outputs=3)
optimizer = torch.optim.Adam(
list(model.parameters()) + list(loss_fn.parameters()),
lr=1e-3
)
preds = torch.randn(16, 3)
targets = torch.randn(16, 3)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Числовая стабильность
Нельзя вручную считать log(p) от вероятностей — это вызывает log(0) = -inf и NaN в градиентах. Вместо этого используют log-space операции:
import torch.nn.functional as F
# Плохо:
loss = -(targets * torch.log(probs)).mean() # NaN при probs→0
# Хорошо:
loss = F.nll_loss(torch.log(probs + 1e-8), targets) # добавить epsilon
# Лучше: принять logits и использовать встроенный log_softmax:
loss = F.cross_entropy(logits, targets) # численно стабильно
Проверка корректности градиентов
from torch.autograd import gradcheck
# gradcheck требует float64 и requires_grad=True
logits = torch.randn(4, 3, dtype=torch.float64, requires_grad=True)
targets = torch.randint(0, 3, (4,))
def loss_fn_for_check(x):
return F.cross_entropy(x, targets)
assert gradcheck(loss_fn_for_check, (logits,), eps=1e-6, atol=1e-4)
Подводные камни
- reduction='none' vs 'mean': если loss возвращает тензор вместо скаляра, вызов
.backward()упадёт с ошибкой; нужен явный.mean()или.sum(). - Gradient через targets: если targets — тензор с
requires_grad=True, PyTorch будет считать градиент и по ним, что замедляет обучение; передавайте targets как.detach(). - NaN при log(0): всегда используйте встроенные функции (
F.cross_entropy,F.binary_cross_entropy_with_logits) вместо ручного логарифма. - Mixed precision (AMP): с
torch.cuda.amp.autocastоперации выполняются в float16; если loss содержит операции, нечувствительные к fp16 (exp, log), оберните вwith torch.cuda.amp.autocast(enabled=False). - Параметры loss_fn не в optimizer: при использовании
nn.Module-loss нужно добавитьloss_fn.parameters()в оптимизатор, иначе обучаемые веса не обновляются. - Несовпадение shape: бинарная классификация — targets shape
(N,), logits(N, 1); нуженlogits.squeeze(1)передF.binary_cross_entropy_with_logits. - Дисбаланс классов:
F.cross_entropyпринимает параметрweight(тензор весов классов); его стоит считать из частот в тренировочной выборке, а не задавать вручную. - gradcheck в float32: численное дифференцирование неточно в float32, gradcheck нужно запускать в float64.
Common mistakes
- Объяснять
custom loss functionтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
custom loss functionс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.