| import os |
| import torch |
| import datetime |
| import numpy as np |
| import torch.distributed as dist |
|
|
| def setup_distributed(): |
| """Initialize distributed training""" |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ['WORLD_SIZE']) |
| gpu = int(os.environ['LOCAL_RANK']) |
| elif 'SLURM_PROCID' in os.environ: |
| rank = int(os.environ['SLURM_PROCID']) |
| gpu = rank % torch.cuda.device_count() |
| world_size = int(os.environ['SLURM_NTASKS']) |
| else: |
| print('Not using distributed mode') |
| return False, 0, 1, 0 |
|
|
| torch.cuda.set_device(gpu) |
| dist.init_process_group( |
| backend='nccl', |
| init_method='env://', |
| world_size=world_size, |
| rank=rank, |
| timeout=datetime.timedelta(minutes=30) |
| ) |
| dist.barrier() |
| return True, rank, world_size, gpu |
|
|
|
|
| def cleanup_distributed(): |
| """Cleanup distributed training""" |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| def set_seed(seed, rank=0): |
| """Set random seed for reproducibility""" |
| seed = seed + rank |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|