PyTorchMiddleCoding
Что такое transfer learning в PyTorch и как заморозить слои предобученной модели?
Transfer learning — дообучение предобученной модели на новой задаче. Замораживайте слои через requires_grad=False, передавая в оптимизатор только параметры с requires_grad=True.
Transfer Learning и заморозка слоёв в PyTorch
Концепция
Transfer learning использует веса, обученные на большом датасете (ImageNet, C4, etc.), как стартовую точку для новой задачи. Типичный сценарий: заморозить feature extractor (backbone), дообучить только classification head. Это позволяет обучаться на малых датасетах и сходиться быстрее.
Базовый пример: ResNet для классификации
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
# Загрузка предобученной модели
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Шаг 1: заморозить весь backbone
for param in model.parameters():
param.requires_grad = False
# Шаг 2: заменить и разморозить classification head
num_classes = 10
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes),
)
# Новые слои автоматически имеют requires_grad=True
# Передать в оптимизатор только trainable параметры
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3,
)
Частичная заморозка (fine-tuning последних блоков)
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Заморозить всё
for param in model.parameters():
param.requires_grad = False
# Разморозить layer4 и fc
for param in model.layer4.parameters():
param.requires_grad = True
for param in model.fc.parameters():
param.requires_grad = True
# Разные learning rate для разных частей
optimizer = torch.optim.SGD([
{"params": model.layer4.parameters(), "lr": 1e-4},
{"params": model.fc.parameters(), "lr": 1e-3},
], momentum=0.9, weight_decay=1e-4)
Проверка состояния заморозки
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
Transfer learning для NLP (HuggingFace + PyTorch)
from transformers import AutoModel
import torch.nn as nn
encoder = AutoModel.from_pretrained("bert-base-uncased")
# Заморозить первые 8 слоёв трансформера из 12
for i, layer in enumerate(encoder.encoder.layer):
if i < 8:
for param in layer.parameters():
param.requires_grad = False
classifier = nn.Linear(encoder.config.hidden_size, 3)
optimizer = torch.optim.AdamW([
{"params": encoder.encoder.layer[8:].parameters(), "lr": 2e-5},
{"params": classifier.parameters(), "lr": 1e-4},
])
Стратегии fine-tuning
- Feature extraction: весь backbone заморожен, обучается только head; подходит при малом датасете и сходстве домена.
- Fine-tuning: разморожены все или последние слои; подходит при достаточном датасете.
- Gradual unfreezing: постепенно размораживаем слои от head к backbone — стабилизирует обучение.
- Discriminative learning rates: ранние слои обучаются с очень маленьким lr (1e-5), поздние — с большим (1e-3).
Подводные камни
- BatchNorm в frozen backbone: слои с
requires_grad=Falseвсё ещё обновляютrunning_mean/running_varесли модель в режимеtrain(); нужно явно вызыватьbackbone.eval()для полной заморозки. - Создание оптимизатора до заморозки: если передать
model.parameters()в оптимизатор, а потом заморозить слои, оптимизатор всё равно хранит состояние для замороженных параметров (память тратится впустую); создавайте оптимизатор после заморозки. - Слишком высокий lr при fine-tuning предобученных слоёв разрушает выученные представления — используйте lr на порядок меньше, чем для head.
- При использовании
model.half()или смешанной точности заморозка работает корректно, ноrequires_grad=Falseпараметры всё равно кастятся — убедитесь в совместимости типов. - Некоторые архитектуры (ViT, CLIP) имеют специфичные нормализации (LayerNorm с learnable params); замораживание всего backbone может непредсказуемо повлиять на них.
- Gradual unfreezing реализуется вручную — PyTorch не имеет встроенного планировщика заморозки; забытый вызов
param.requires_grad = Trueв нужную эпоху оставит слои замороженными навсегда.
Common mistakes
- Объяснять
transfer learning freezingтолько синтаксисом без shape, dtype, состояния или режима выполнения. - Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
- Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.
What the interviewer is testing
- Может ли связать
transfer learning freezingс реальным контрактом входов и выходов. - Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
- Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.