Transformers documentation

Gradient checkpointing

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v5.5.4).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Gradient checkpointing

The forward pass typically caches all intermediate activations for the backward pass to reuse. However, activations scale with batch size and sequence length. Gradient checkpointing only saves certain activations and discards the rest. This forces the backward pass to recompute some of the activations on-the-fly as they’re needed.

Normal training:
  Forward:   [L1]→[L2]→[L3]→[L4]   (save ALL activations)
  Backward:  ←uses cached activations everywhere

Gradient checkpointing:
  Forward:   [L1]→[L2]→[L3]→[L4]   (save only at checkpoints, discard the rest)
  Backward:  ←reaches L2, recomputes L2→L3 from scratch, uses it, discards it

Training will be ~20% slower because some activations need to be recomputed, but it reduces activation memory.

Set gradient_checkpointing=True to enable.

Use with gradient accumulation to further reduce memory usage.

from transformers import TrainingArguments

args = TrainingArguments(
    ...,
    gradient_checkpointing=True,
)

Next steps

  • Read the GPU memory usage doc to understand what is driving memory usage on the GPU during training.
  • See the Mixed precision training guide to learn how to use lower precision data types to further reduce memory and speed up training.
  • See the Kernels guide to learn how to speed up training with custom fused kernels.
Update on GitHub