PyTorchMiddleTechnical

В чём разница между динамическим графом вычислений PyTorch и статическим графом TensorFlow?

PyTorch строит граф вычислений динамически на каждом forward pass (define-by-run), что упрощает отладку и поддержку динамических структур; TensorFlow 1.x компилировал статический граф заранее (define-and-run), давая лучшую оптимизацию но усложняя отладку. TF 2.x и torch.compile сближают подходы.

Два подхода к построению графа

Статический граф (define-and-run) — граф описывается декларативно до запуска, компилируется один раз и затем выполняется многократно. TensorFlow 1.x, Theano, MXNet использовали этот подход. Компилятор знает граф целиком и может применять глобальные оптимизации: fusion операций, constant folding, параллелизм.

Динамический граф (define-by-run) — граф строится прямо во время выполнения Python-кода forward pass. PyTorch использует этот подход с самого начала. Каждый вызов forward() создаёт новый граф; Python-отладчик видит всё как обычный код.

Как работает autograd в PyTorch

При каждом вычислении с тензорами, у которых requires_grad=True, PyTorch записывает операцию в DAG (directed acyclic graph) через объект grad_fn. После вызова loss.backward() движок обходит граф в обратном порядке и накапливает градиенты.

import torch

# Динамический граф: условие меняет топологию на каждом шаге
def dynamic_forward(x: torch.Tensor, use_skip: bool) -> torch.Tensor:
    h = torch.relu(x @ torch.randn(x.shape[-1], 64, requires_grad=True))
    if use_skip:          # топология зависит от runtime-условия
        h = h + x[:, :64]
    return h.sum()

x = torch.randn(4, 64)
loss = dynamic_forward(x, use_skip=True)
loss.backward()           # работает без предварительной компиляции
print(loss.item())

Сравнение по ключевым осям

  • Отладка: PyTorch — обычный pdb/debugpy, видны значения тензоров; TF 1.x — нужен tf.Session, значения доступны только через sess.run().
  • Динамические структуры: переменная длина последовательностей, рекурсивные сети, tree-RNN — тривиальны в PyTorch, требуют TF Dynamic RNN или padding в TF 1.x.
  • Производительность: статический граф позволяет компилятору делать kernel fusion и убирать мёртвый код. TF XLA и torch.compile (TorchDynamo) дают похожие оптимизации в современных версиях.
  • Деплой: статический граф сериализуется как протобаф и запускается без Python. PyTorch закрывает этот gap через TorchScript и ONNX.

TF 2.x и torch.compile сближают подходы

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(128, 64)

    def forward(self, x):
        return torch.relu(self.fc(x))

model = Encoder()
# torch.compile — JIT-компиляция динамического графа в оптимизированный код
compiled = torch.compile(model)   # требует PyTorch >= 2.0
out = compiled(torch.randn(16, 128))
print(out.shape)   # torch.Size([16, 64])

TensorFlow 2.x по умолчанию работает в eager mode (аналог динамического графа), а @tf.function трассирует функцию в статический граф при первом вызове.

Когда что выбирать

  • PyTorch: исследования, NLP с переменной длиной, RL, прототипирование — динамический граф упрощает эксперименты.
  • TF/JAX: production-сервинг на TPU, мобильные устройства (TFLite), экосистема TFX — статический граф даёт лучший деплой-путь.
  • В 2024–2025 граница размылась: torch.compile, ONNX Runtime, TensorRT принимают оба формата.

Подводные камни

  • Забыть освободить граф после backward: без loss.backward() с retain_graph=False (по умолчанию) граф удаляется; повторный вызов backward упадёт с ошибкой.
  • Накопление градиентов между шагами: не вызвать optimizer.zero_grad() перед backward — градиенты суммируются.
  • Python-условия внутри @torch.compile / TorchScript могут потребовать torch.jit.is_scripting() guard для корректной трассировки.
  • Перепутать TF eager mode и graph mode: некоторые операции ведут себя по-разному в @tf.function (Python side-effects выполняются один раз при трассировке).
  • torch.compile не поддерживает все Python-конструкции — некоторые графы падают обратно в eager (graph breaks), что снижает производительность.
  • Отсутствие детерминизма: динамический граф и CuDNN могут давать разные результаты при разных размерах батча, если не зафиксировать torch.backends.cudnn.deterministic = True.
  • Сериализация TorchScript требует аннотаций типов — динамический Python-код с *args/**kwargs не трассируется автоматически.
  • XLA на TPU требует статических форм — динамические размеры батча приводят к перекомпиляции на каждом шаге.

Common mistakes

  • Объяснять dynamic vs static graph только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать dynamic vs static graph с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics