PyTorchMiddleCoding

Как реализовать кастомную функцию потерь в PyTorch?

Кастомная loss — это функция или nn.Module, принимающая logits и targets и возвращающая скалярный тензор; главные требования — численная стабильность (log-space операции) и корректный reduction для .backward().

Кастомная функция потерь в PyTorch

Кастомная loss function в PyTorch реализуется двумя способами: как обычная функция Python или как подкласс nn.Module. Второй вариант предпочтителен, когда у функции потерь есть обучаемые параметры или нужно управлять состоянием.

Вариант 1: функция (без параметров)

import torch
import torch.nn.functional as F

def focal_loss(logits: torch.Tensor, targets: torch.Tensor, gamma: float = 2.0, alpha: float = 0.25) -> torch.Tensor:
    """Focal Loss для несбалансированных классов (Lin et al. 2017)."""
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    probs = torch.sigmoid(logits)
    pt = torch.where(targets == 1, probs, 1 - probs)
    focal_weight = alpha * (1 - pt) ** gamma
    return (focal_weight * bce).mean()

# Пример использования
logits = torch.randn(32, 1)      # выход модели
targets = torch.randint(0, 2, (32, 1)).float()
loss = focal_loss(logits, targets, gamma=2.0)
loss.backward()  # градиенты вычисляются автоматически

Вариант 2: nn.Module (с параметрами или состоянием)

import torch
import torch.nn as nn

class WeightedMSELoss(nn.Module):
    """MSE с обучаемым вектором весов для каждого выхода."""

    def __init__(self, num_outputs: int):
        super().__init__()
        # nn.Parameter автоматически регистрируется в model.parameters()
        self.weights = nn.Parameter(torch.ones(num_outputs))

    def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # preds, targets: (batch, num_outputs)
        per_output_mse = ((preds - targets) ** 2).mean(dim=0)  # (num_outputs,)
        # softmax чтобы веса суммировались в 1 и были положительными
        w = torch.softmax(self.weights, dim=0)
        return (w * per_output_mse).sum()

loss_fn = WeightedMSELoss(num_outputs=3)
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(loss_fn.parameters()),
    lr=1e-3
)

preds = torch.randn(16, 3)
targets = torch.randn(16, 3)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()

Числовая стабильность

Нельзя вручную считать log(p) от вероятностей — это вызывает log(0) = -inf и NaN в градиентах. Вместо этого используют log-space операции:

import torch.nn.functional as F

# Плохо:
loss = -(targets * torch.log(probs)).mean()        # NaN при probs→0

# Хорошо:
loss = F.nll_loss(torch.log(probs + 1e-8), targets)  # добавить epsilon
# Лучше: принять logits и использовать встроенный log_softmax:
loss = F.cross_entropy(logits, targets)              # численно стабильно

Проверка корректности градиентов

from torch.autograd import gradcheck

# gradcheck требует float64 и requires_grad=True
logits = torch.randn(4, 3, dtype=torch.float64, requires_grad=True)
targets = torch.randint(0, 3, (4,))

def loss_fn_for_check(x):
    return F.cross_entropy(x, targets)

assert gradcheck(loss_fn_for_check, (logits,), eps=1e-6, atol=1e-4)

Подводные камни

  • reduction='none' vs 'mean': если loss возвращает тензор вместо скаляра, вызов .backward() упадёт с ошибкой; нужен явный .mean() или .sum().
  • Gradient через targets: если targets — тензор с requires_grad=True, PyTorch будет считать градиент и по ним, что замедляет обучение; передавайте targets как .detach().
  • NaN при log(0): всегда используйте встроенные функции (F.cross_entropy, F.binary_cross_entropy_with_logits) вместо ручного логарифма.
  • Mixed precision (AMP): с torch.cuda.amp.autocast операции выполняются в float16; если loss содержит операции, нечувствительные к fp16 (exp, log), оберните в with torch.cuda.amp.autocast(enabled=False).
  • Параметры loss_fn не в optimizer: при использовании nn.Module-loss нужно добавить loss_fn.parameters() в оптимизатор, иначе обучаемые веса не обновляются.
  • Несовпадение shape: бинарная классификация — targets shape (N,), logits (N, 1); нужен logits.squeeze(1) перед F.binary_cross_entropy_with_logits.
  • Дисбаланс классов: F.cross_entropy принимает параметр weight (тензор весов классов); его стоит считать из частот в тренировочной выборке, а не задавать вручную.
  • gradcheck в float32: численное дифференцирование неточно в float32, gradcheck нужно запускать в float64.

Common mistakes

  • Объяснять custom loss function только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать custom loss function с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics