TD3B / distributed_utils.py
chq1155's picture
Upload TD3B code (inference, training, baselines)
ee6da62 verified
"""Minimal distributed training utilities."""
import os
import torch
import torch.distributed as dist
def setup_distributed(rank: int, world_size: int, backend: str = "nccl") -> None:
"""Initialize distributed process group."""
if world_size <= 1:
return
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
def cleanup_distributed() -> None:
"""Destroy distributed process group."""
if dist.is_initialized():
dist.destroy_process_group()
def is_main_process() -> bool:
"""Check if this is the main (rank 0) process."""
if not dist.is_initialized():
return True
return dist.get_rank() == 0