"""Minimal distributed training utilities.""" import os import torch import torch.distributed as dist def setup_distributed(rank: int, world_size: int, backend: str = "nccl") -> None: """Initialize distributed process group.""" if world_size <= 1: return os.environ.setdefault("MASTER_ADDR", "localhost") os.environ.setdefault("MASTER_PORT", "29500") dist.init_process_group(backend=backend, rank=rank, world_size=world_size) if torch.cuda.is_available(): torch.cuda.set_device(rank) def cleanup_distributed() -> None: """Destroy distributed process group.""" if dist.is_initialized(): dist.destroy_process_group() def is_main_process() -> bool: """Check if this is the main (rank 0) process.""" if not dist.is_initialized(): return True return dist.get_rank() == 0