TensorFlowMiddleCoding

Что такое gradient clipping в TensorFlow и как его применять?

Gradient clipping ограничивает норму или значение градиентов перед шагом оптимизатора, предотвращая exploding gradients в RNN и трансформерах. В Keras передаётся через clipnorm/clipvalue в оптимизатор.

Зачем нужен gradient clipping

В глубоких сетях, особенно рекуррентных (LSTM, GRU) и трансформерах, градиенты могут экспоненциально расти при backpropagation through time — это exploding gradients. Модель начинает делать огромные шаги, loss становится NaN. Gradient clipping ограничивает величину градиентов до безопасного порога, не позволяя оптимизатору «улетать».

Два вида clipping

  • clip by value (clipvalue) — каждый элемент градиента обрезается до диапазона [-v, v]. Прост, но меняет направление вектора градиента.
  • clip by norm (clipnorm) — если L2-норма вектора градиентов превышает порог, весь вектор масштабируется так, чтобы его норма равнялась порогу. Направление сохраняется.

Вариант 1: через аргументы оптимизатора (рекомендуемый)

import tensorflow as tf

optimizer = tf.keras.optimizers.Adam(
    learning_rate=1e-3,
    clipnorm=1.0,          # clip by global norm
    # clipvalue=0.5,       # альтернатива: clip by value
)

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(input_dim=10000, output_dim=64),
    tf.keras.layers.LSTM(128, return_sequences=True),
    tf.keras.layers.LSTM(64),
    tf.keras.layers.Dense(1, activation="sigmoid"),
])

model.compile(
    optimizer=optimizer,
    loss="binary_crossentropy",
    metrics=["accuracy"],
)

model.fit(train_ds, validation_data=val_ds, epochs=10)

Параметр clipnorm=1.0 говорит Adam применять tf.clip_by_global_norm к списку градиентов перед каждым шагом. Это наиболее распространённый подход в production.

Вариант 2: вручную через GradientTape (для кастомных циклов)

import tensorflow as tf

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()

@tf.function
def train_step(x_batch, y_batch):
    with tf.GradientTape() as tape:
        logits = model(x_batch, training=True)
        loss = loss_fn(y_batch, logits)

    gradients = tape.gradient(loss, model.trainable_variables)

    # Clip by global norm
    gradients, global_norm = tf.clip_by_global_norm(gradients, clip_norm=1.0)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss, global_norm

for x_batch, y_batch in train_ds:
    loss, norm = train_step(x_batch, y_batch)
    tf.summary.scalar("train/loss", loss, step=optimizer.iterations)
    tf.summary.scalar("train/global_norm", norm, step=optimizer.iterations)

tf.clip_by_global_norm также возвращает исходную глобальную норму до клиппинга — это ценная диагностическая метрика. Если норма постоянно намного выше порога, стоит снизить learning rate или пересмотреть инициализацию весов.

Когда применять

  • RNN/LSTM/GRU с длинными последовательностями (BPTT через 50+ шагов).
  • Трансформеры при обучении с нуля — особенно до warmup фазы scheduler.
  • Когда loss становится NaN или inf в первых эпохах.
  • Fine-tuning предобученных моделей с большим learning rate.

Подводные камни

  • Слишком маленький clipnorm (например, 0.01) замедляет сходимость не хуже, чем очень малый learning rate — модель «идёт» но очень медленно.
  • clipvalue меняет направление градиентного вектора (обрезает компоненты независимо), что может нарушить баланс обновлений в многозадачных архитектурах.
  • При использовании tf.keras.mixed_precision градиенты масштабируются loss scaler'ом перед клиппингом — нужно клиппить после unscale, иначе реальная норма будет на порядки больше.
  • Параметр global_clipnorm в оптимизаторе TF2.x применяет глобальную норму по всем переменным; старый clipnorm применяет норму к каждому градиенту по отдельности — семантика разная, легко перепутать.
  • Мониторинг gradient norm обязателен: если норма постоянно ниже порога, clipping не работает — проблема может быть в vanishing gradients, а не exploding.
  • В распределённом обучении (tf.distribute.MirroredStrategy) градиенты суммируются по репликам перед клиппингом — порог нужно подбирать с учётом числа GPU.
  • Забытый @tf.function на train_step с ручным клиппингом делает каждый шаг на 5–10x медленнее из-за Python overhead.
  • Gradient clipping не решает проблему плохой инициализации весов — если начальные веса дают огромный loss, клиппинг лишь замедляет катастрофу, а не устраняет причину.

Common mistakes

  • Объяснять gradient clipping только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать gradient clipping с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics