Transformers documentation
Tensor parallelism
Tensor parallelism
Tensor parallelism (TP) splits weight matrices column-wise or row-wise across GPUs. Each GPU holds a shard, computes a partial result, and synchronizes with an all-reduce to produce the full output.
TP relies on frequent cross-GPU communication. It works best on hardware with fast intra-node links such as NVLink.
βββββββββββββββββββββββββββββββ
β X (replicated) β
ββββββ¬βββββββββββ¬ββββββββββ¬ββββ
β β β
ββββββΌββββ ββββββΌββββ βββββΌβββββ
β βββ Wβ β β βββ Wβ β β βββ Wβ β
β X@Wβ β β X@Wβ β β X@Wβ β
ββββββ¬ββββ ββββββ¬ββββ βββββ¬βββββ
ββββββββββββΌββββββββββ
Yβ+Yβ+Yβ
ββββββββββββββββββββββββββββββ
β Y (full) β
ββββββββββββββββββββββββββββββTransformers supports TP for architectures whose config defines base_model_tp_plan. Check that field first to see whether a model supports native TP.
from transformers import AutoConfig
config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")
print(config.base_model_tp_plan is not None)
print(config.base_model_tp_plan)If a model supports TP, set tp_plan="auto" in from_pretrained(). Transformers initializes the device mesh and shards the supported layers for you.
Donβt use
device_mapwithtp_plan. The two conflict at the weight-loading level.device_mapplaces whole modules on specific GPUs, whiletp_planshards those same parameters across all GPUs.
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_plan="auto",
)Trainer detects tp_plan, reads tp_size from the model, and creates a ParallelismConfig automatically.
Launch training on one node with 4 GPUs.
torchrun --nproc-per-node 4 train_tp.py
ParallelismConfig
Pass ParallelismConfig explicitly when combining TP with other parallelism techniques like FSDP.
import torch
from accelerate import ParallelismConfig
from transformers import AutoModelForCausalLM, TrainingArguments
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_plan="auto",
)
parallelism_config = ParallelismConfig(tp_size=4)
args = TrainingArguments(
...,
parallelism_config=parallelism_config,
)Next steps
- Read the Tensor Parallelism chapter from The Ultra-Scale Playbook for more details about how it works.
- Read the tensor parallelism inference guide to learn more about partitioning strategies, manual TP plans, and implementation details.