PyTorchMiddleTechnical
Что такое quantization и как она влияет на inference latency/accuracy?
Quantization снижает точность весов/активаций (fp32→int8/int4), уменьшая память и ускоряя inference за счёт целочисленных инструкций; accuracy обычно теряется менее чем на 1% при PTQ и ещё меньше при QAT.
Quantization в PyTorch
Квантизация заменяет числа с плавающей точкой (fp32/fp16) на целочисленные представления (int8, int4, int2). Это уменьшает размер модели в 2–8 раз, снижает потребление памяти и ускоряет вычисления за счёт оптимизированных SIMD/VNNI/Tensor Core инструкций.
Три подхода
- PTQ (Post-Training Quantization) — квантизация уже обученной модели; быстро, но accuracy может пострадать на выбросах.
- QAT (Quantization-Aware Training) — fake quantization во время обучения; точнее, но требует дообучения.
- Dynamic Quantization — веса квантизируются статически, активации — на лету; минимальные усилия.
Dynamic Quantization (быстрый старт)
import torch
from torch.quantization import quantize_dynamic
model = MyLSTMModel()
model.load_state_dict(torch.load("model.pt"))
model.eval()
quantized = quantize_dynamic(
model,
qconfig_spec={torch.nn.Linear, torch.nn.LSTM},
dtype=torch.qint8,
)
# Сравнение размера
import os
torch.save(model.state_dict(), "/tmp/fp32.pt")
torch.save(quantized.state_dict(), "/tmp/int8.pt")
print(f"FP32: {os.path.getsize('/tmp/fp32.pt') / 1e6:.1f} MB")
print(f"INT8: {os.path.getsize('/tmp/int8.pt') / 1e6:.1f} MB")
Static PTQ с калибровкой
from torch.quantization import prepare, convert, get_default_qconfig
model.eval()
model.qconfig = get_default_qconfig("fbgemm") # x86; "qnnpack" для ARM
prepare(model, inplace=True) # вставляет наблюдателей
# Калибровка на репрезентативном датасете
with torch.no_grad():
for batch in calib_loader:
model(batch)
convert(model, inplace=True) # заменяет операции на int8
Влияние на latency и accuracy
- INT8 linear на CPU (x86 с VNNI): ускорение 2–4× по сравнению с fp32.
- INT4 weight-only на GPU (bitsandbytes): экономия VRAM 4×, latency улучшается при memory-bound inference.
- PTQ accuracy drop: обычно <0.5% top-1 для классификаторов; трансформеры с выбросами активаций (LLaMA) могут терять 2–5% без специальных методов (AWQ, GPTQ).
- QAT восстанавливает практически всю accuracy, но требует 10–20% эпох дообучения.
Современные форматы для LLM
# bitsandbytes — 4-bit NF4
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto",
)
Подводные камни
- PTQ на моделях с выбросами активаций (BERT, LLaMA) даёт сильную деградацию без специальной обработки: нужны SmoothQuant или AWQ.
fbgemmbackend работает только на x86 с AVX2+; на ARM и мобильных устройствах нуженqnnpack.- Квантизированные модели нельзя напрямую экспортировать в ONNX стандартными средствами — требуется
torch.onnx.exportсopset_version>=13. - GPU INT8 (через
torch.ao.quantization) поддерживается хуже, чем CPU; на CUDA эффективнее использовать TensorRT или bitsandbytes. - Калибровочный датасет должен покрывать весь диапазон входов; на плохой калибровке наблюдатель установит неверные масштабы и точность упадёт.
- QAT несовместим с
torch.compileиз-за fake-quantization операторов — нужно компилировать уже сконвертированную модель. - Не все слои квантизируются эффективно: первый и последний слои часто оставляют в fp32 для сохранения точности.
Common mistakes
- Отвечать определением без production-сценария.
- Не называть runtime boundary, security boundary или failure mode.
- Игнорировать версию API, observability и тестовую проверку.
What the interviewer is testing
- Объясняет механизм своими словами и без выдуманных API.
- Называет реальные риски, диагностику и критерий корректности.
- Связывает ответ с текущей документацией и миграционными ограничениями.