PyTorchMiddleCoding

Что такое transfer learning в PyTorch и как заморозить слои предобученной модели?

Transfer learning — дообучение предобученной модели на новой задаче. Замораживайте слои через requires_grad=False, передавая в оптимизатор только параметры с requires_grad=True.

Transfer Learning и заморозка слоёв в PyTorch

Концепция

Transfer learning использует веса, обученные на большом датасете (ImageNet, C4, etc.), как стартовую точку для новой задачи. Типичный сценарий: заморозить feature extractor (backbone), дообучить только classification head. Это позволяет обучаться на малых датасетах и сходиться быстрее.

Базовый пример: ResNet для классификации

import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

# Загрузка предобученной модели
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Шаг 1: заморозить весь backbone
for param in model.parameters():
    param.requires_grad = False

# Шаг 2: заменить и разморозить classification head
num_classes = 10
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes),
)
# Новые слои автоматически имеют requires_grad=True

# Передать в оптимизатор только trainable параметры
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
)

Частичная заморозка (fine-tuning последних блоков)

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Заморозить всё
for param in model.parameters():
    param.requires_grad = False

# Разморозить layer4 и fc
for param in model.layer4.parameters():
    param.requires_grad = True
for param in model.fc.parameters():
    param.requires_grad = True

# Разные learning rate для разных частей
optimizer = torch.optim.SGD([
    {"params": model.layer4.parameters(), "lr": 1e-4},
    {"params": model.fc.parameters(), "lr": 1e-3},
], momentum=0.9, weight_decay=1e-4)

Проверка состояния заморозки

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")

Transfer learning для NLP (HuggingFace + PyTorch)

from transformers import AutoModel
import torch.nn as nn

encoder = AutoModel.from_pretrained("bert-base-uncased")

# Заморозить первые 8 слоёв трансформера из 12
for i, layer in enumerate(encoder.encoder.layer):
    if i < 8:
        for param in layer.parameters():
            param.requires_grad = False

classifier = nn.Linear(encoder.config.hidden_size, 3)

optimizer = torch.optim.AdamW([
    {"params": encoder.encoder.layer[8:].parameters(), "lr": 2e-5},
    {"params": classifier.parameters(), "lr": 1e-4},
])

Стратегии fine-tuning

  • Feature extraction: весь backbone заморожен, обучается только head; подходит при малом датасете и сходстве домена.
  • Fine-tuning: разморожены все или последние слои; подходит при достаточном датасете.
  • Gradual unfreezing: постепенно размораживаем слои от head к backbone — стабилизирует обучение.
  • Discriminative learning rates: ранние слои обучаются с очень маленьким lr (1e-5), поздние — с большим (1e-3).

Подводные камни

  • BatchNorm в frozen backbone: слои с requires_grad=False всё ещё обновляют running_mean/running_var если модель в режиме train(); нужно явно вызывать backbone.eval() для полной заморозки.
  • Создание оптимизатора до заморозки: если передать model.parameters() в оптимизатор, а потом заморозить слои, оптимизатор всё равно хранит состояние для замороженных параметров (память тратится впустую); создавайте оптимизатор после заморозки.
  • Слишком высокий lr при fine-tuning предобученных слоёв разрушает выученные представления — используйте lr на порядок меньше, чем для head.
  • При использовании model.half() или смешанной точности заморозка работает корректно, но requires_grad=False параметры всё равно кастятся — убедитесь в совместимости типов.
  • Некоторые архитектуры (ViT, CLIP) имеют специфичные нормализации (LayerNorm с learnable params); замораживание всего backbone может непредсказуемо повлиять на них.
  • Gradual unfreezing реализуется вручную — PyTorch не имеет встроенного планировщика заморозки; забытый вызов param.requires_grad = True в нужную эпоху оставит слои замороженными навсегда.

Common mistakes

  • Объяснять transfer learning freezing только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать transfer learning freezing с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics