Transformers documentation

Ulysses sequence parallelism

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v5.5.4).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Ulysses sequence parallelism

Ulysses sequence parallelism (SP) trains on very long sequences by splitting them across multiple GPUs. To compute attention correctly, an all-to-all collective swaps the sharding dimension from sequence to attention heads. Each GPU then has the full sequence and computes attention locally over a subset of heads. A second all-to-all returns to the sequence-sharded layout so the rest of the forward pass continues locally on each chunk.

                        GPU 0                       GPU 1
                   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  forward          β”‚ tokens 0..N/2 β”‚          β”‚ tokens N/2..N β”‚  ← each GPU holds half the sequence
  (seq-sharded)    β”‚  all H heads  β”‚          β”‚  all H heads  β”‚
                   β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
                           └───────── all-to-all β”€β”€β”€β”€β”€β”€β”˜
                   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  attention        β”‚ all N tokens  β”‚          β”‚ all N tokens  β”‚  ← now each GPU has the full sequence
  (head-sharded)   β”‚ heads 0..H/2  β”‚          β”‚ heads H/2..H  β”‚  ← but only half the heads
                   β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
                           └───────── all-to-all β”€β”€β”€β”€β”€β”€β”˜
                   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  forward          β”‚ tokens 0..N/2 β”‚          β”‚ tokens N/2..N β”‚  ← back to seq-sharded
  (seq-sharded)    β”‚  all H heads  β”‚          β”‚  all H heads  β”‚
                   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

This guide covers the Ulysses sequence parallelism component of ALST (Arctic Long Sequence Training). The full ALST system also includes TiledMLP and activation checkpoint offloading, which aren’t available in Transformers. See the DeepSpeed ALST tutorial for the complete system.

Configure

Sequence parallelism requires Accelerate v1.12.0 and at least 2 GPUs. Configure sequence parallelism in Accelerate’s ParallelismConfig and pass it to TrainingArguments.parallelism_config or an Accelerate config file.

parallelism_config
Accelerate config file
from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig

parallelism_config = ParallelismConfig(
    sp_backend="deepspeed",
    sp_size=4,
    dp_replicate_size=1,
    sp_handler=DeepSpeedSequenceParallelConfig(
        sp_seq_length_is_variable=True,
        sp_attn_implementation="flash_attention_2",
    ),
)

training_args = TrainingArguments(
    ...,
    deepspeed="path/to/deepspeed_config.json",
    parallelism_config=parallelism_config,
)

Run accelerate launch with a Trainer-based script.

accelerate launch --num_processes 4 train.py \
--output_dir output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1

The following fields are important for configuring sequence parallelism.

The Trainer automatically handles DataLoader sharding, position_ids generation, label shifting, and loss aggregation across SP ranks. If you’re writing a custom training loop, see the Accelerate Sequence Parallelism guide instead.

  • sp_backend must be set to "deepspeed" to use Ulysses sequence parallelism.

  • sp_size is the number of GPUs that process a single sequence in parallel. Each SP rank receives a unique data stream from the DataLoader, unlike tensor parallelism where all ranks receive identical data. The effective dp_world_size = world_size / sp_size, so with 4 GPUs and sp_size=4, dp_world_size=1 for batch size calculations. Sequences must also be padded to a multiple of sp_size. Set pad_to_multiple_of in your data collator accordingly.

    The number of attention heads must be divisible by sp_size. A model with 32 heads supports sp_size of 1, 2, 4, 8, 16, or 32.

    from transformers import DataCollatorForLanguageModeling
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        pad_to_multiple_of=sp_size,
    )
  • sp_seq_length_is_variable controls variable sequence length handling. Set it to True (recommended) for varying lengths between batches. Set it to False when all sequences pad to a fixed length specified by sp_seq_length.

  • sp_attn_implementation sets the attention backend. Supported values are "sdpa", "flash_attention_2", or "flash_attention_3". FlashAttention is recommended, especially when packing multiple samples in a batch. SDPA can attend incorrectly across sample boundaries when samples are packed. Eager attention isn’t supported because its 4D attention_mask is discarded for memory and scaling reasons.

Combining with data parallelism

Sequence parallelism and data parallelism use the same GPUs, and SP doesn’t require additional hardware. To run both, set dp_replicate_size or dp_shard_size so that dp_replicate_size Γ— dp_shard_size Γ— sp_size equals your total GPU count.

For example, with 8 GPUs and sp_size=4, set dp_replicate_size=2 (2 Γ— 1 Γ— 4 = 8).

parallelism_config = ParallelismConfig(
    sp_backend="deepspeed",
    sp_size=4,
    dp_replicate_size=2,
    sp_handler=DeepSpeedSequenceParallelConfig(
        sp_seq_length_is_variable=True,
        sp_attn_implementation="flash_attention_2",
    ),
)

Next steps

Update on GitHub