Transformers documentation
Gradient checkpointing
You are viewing main version, which requires installation from source. If you'd like
regular pip install, checkout the latest stable version (v5.5.4).
Gradient checkpointing
The forward pass typically caches all intermediate activations for the backward pass to reuse. However, activations scale with batch size and sequence length. Gradient checkpointing only saves certain activations and discards the rest. This forces the backward pass to recompute some of the activations on-the-fly as they’re needed.
Normal training: Forward: [L1]→[L2]→[L3]→[L4] (save ALL activations) Backward: ←uses cached activations everywhere Gradient checkpointing: Forward: [L1]→[L2]→[L3]→[L4] (save only at checkpoints, discard the rest) Backward: ←reaches L2, recomputes L2→L3 from scratch, uses it, discards it
Training will be ~20% slower because some activations need to be recomputed, but it reduces activation memory.
Set gradient_checkpointing=True to enable.
Use with gradient accumulation to further reduce memory usage.
from transformers import TrainingArguments
args = TrainingArguments(
...,
gradient_checkpointing=True,
)Next steps
- Read the GPU memory usage doc to understand what is driving memory usage on the GPU during training.
- See the Mixed precision training guide to learn how to use lower precision data types to further reduce memory and speed up training.
- See the Kernels guide to learn how to speed up training with custom fused kernels.