| """Minimal distributed training utilities.""" | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| def setup_distributed(rank: int, world_size: int, backend: str = "nccl") -> None: | |
| """Initialize distributed process group.""" | |
| if world_size <= 1: | |
| return | |
| os.environ.setdefault("MASTER_ADDR", "localhost") | |
| os.environ.setdefault("MASTER_PORT", "29500") | |
| dist.init_process_group(backend=backend, rank=rank, world_size=world_size) | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(rank) | |
| def cleanup_distributed() -> None: | |
| """Destroy distributed process group.""" | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def is_main_process() -> bool: | |
| """Check if this is the main (rank 0) process.""" | |
| if not dist.is_initialized(): | |
| return True | |
| return dist.get_rank() == 0 | |