File size: 848 Bytes
ee6da62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | """Minimal distributed training utilities."""
import os
import torch
import torch.distributed as dist
def setup_distributed(rank: int, world_size: int, backend: str = "nccl") -> None:
"""Initialize distributed process group."""
if world_size <= 1:
return
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
def cleanup_distributed() -> None:
"""Destroy distributed process group."""
if dist.is_initialized():
dist.destroy_process_group()
def is_main_process() -> bool:
"""Check if this is the main (rank 0) process."""
if not dist.is_initialized():
return True
return dist.get_rank() == 0
|