import datetime import functools import os import sys from typing import List from typing import Union import pytz import torch import torch.distributed as tdist import torch.multiprocessing as mp from einops import repeat import math __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' __rank_str_zfill = '0' __initialized = False def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding def initialized(): return __initialized def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): global __device if not torch.cuda.is_available(): print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) return elif 'RANK' not in os.environ: torch.cuda.set_device(gpu_id_if_not_distibuted) __device = torch.empty(1).cuda().device print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) return # then 'RANK' must exist global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() local_rank = global_rank % num_gpus torch.cuda.set_device(local_rank) # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 """ if mp.get_start_method(allow_none=True) is None: method = 'fork' if fork else 'spawn' print(f'[dist initialize] mp method={method}') mp.set_start_method(method) """ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill __local_rank = local_rank __rank, __world_size = tdist.get_rank(), tdist.get_world_size() __rank_str_zfill = str(__rank).zfill(len(str(__world_size))) __device = torch.device(local_rank) __initialized = True assert tdist.is_initialized(), 'torch.distributed is not initialized!' print(f'[lrk={get_local_rank()}, rk={get_rank()}]') def get_rank(): return __rank def get_rank_given_group(group: tdist.ProcessGroup): return tdist.get_rank(group=group) def get_rank_str_zfill(): return __rank_str_zfill def get_local_rank(): return __local_rank def get_world_size(): return __world_size def get_device(): return __device def set_gpu_id(gpu_id: int): if gpu_id is None: return global __device if isinstance(gpu_id, (str, int)): torch.cuda.set_device(int(gpu_id)) __device = torch.empty(1).cuda().device else: raise NotImplementedError def is_master(): return __rank == 0 def is_local_master(): return __local_rank == 0 def is_visualizer(): return __rank == 0 # return __rank == max(__world_size - 8, 0) def parallelize(net, syncbn=False): if syncbn: net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) net = net.cuda() net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) return net def new_group(ranks: List[int]): if __initialized: return tdist.new_group(ranks=ranks) return None def new_local_machine_group(): if __initialized: cur_subgroup, subgroups = tdist.new_subgroups() return cur_subgroup return None def barrier(): if __initialized: tdist.barrier() def allreduce(t: torch.Tensor, async_op=False): if __initialized: if not t.is_cuda: cu = t.detach().cuda() ret = tdist.all_reduce(cu, async_op=async_op) t.copy_(cu.cpu()) else: ret = tdist.all_reduce(t, async_op=async_op) return ret return None def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: if __initialized: if not t.is_cuda: t = t.cuda() ls = [torch.empty_like(t) for _ in range(__world_size)] tdist.all_gather(ls, t) else: ls = [t] if cat: ls = torch.cat(ls, dim=0) return ls def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: if __initialized: if not t.is_cuda: t = t.cuda() t_size = torch.tensor(t.size(), device=t.device) ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] tdist.all_gather(ls_size, t_size) max_B = max(size[0].item() for size in ls_size) pad = max_B - t_size[0].item() if pad: pad_size = (pad, *t.size()[1:]) t = torch.cat((t, t.new_empty(pad_size)), dim=0) ls_padded = [torch.empty_like(t) for _ in range(__world_size)] tdist.all_gather(ls_padded, t) ls = [] for t, size in zip(ls_padded, ls_size): ls.append(t[:size[0].item()]) else: ls = [t] if cat: ls = torch.cat(ls, dim=0) return ls def broadcast(t: torch.Tensor, src_rank) -> None: if __initialized: if not t.is_cuda: cu = t.detach().cuda() tdist.broadcast(cu, src=src_rank) t.copy_(cu.cpu()) else: tdist.broadcast(t, src=src_rank) def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: if not initialized(): return torch.tensor([val]) if fmt is None else [fmt % val] ts = torch.zeros(__world_size) ts[__rank] = val allreduce(ts) if fmt is None: return ts return [fmt % v for v in ts.cpu().numpy().tolist()] def master_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): force = kwargs.pop('force', False) if force or is_master(): ret = func(*args, **kwargs) else: ret = None barrier() return ret return wrapper def local_master_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): force = kwargs.pop('force', False) if force or is_local_master(): ret = func(*args, **kwargs) else: ret = None barrier() return ret return wrapper def for_visualize(func): @functools.wraps(func) def wrapper(*args, **kwargs): if is_visualizer(): # with torch.no_grad(): ret = func(*args, **kwargs) else: ret = None return ret return wrapper def finalize(): if __initialized: tdist.destroy_process_group() def init_distributed_mode(local_out_path, fork=False, only_sync_master=False, timeout_minutes=30): try: __initialize(fork=fork, timeout_minutes=timeout_minutes) barrier() except RuntimeError as e: print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) raise e if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) _change_builtin_print(is_local_master()) if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) def _change_builtin_print(is_master): import builtins as __builtin__ builtin_print = __builtin__.print if type(builtin_print) != type(open): return def prt(*args, **kwargs): force = kwargs.pop('force', False) clean = kwargs.pop('clean', False) deeper = kwargs.pop('deeper', False) if is_master or force: if not clean: f_back = sys._getframe().f_back if deeper and f_back.f_back is not None: f_back = f_back.f_back file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) else: builtin_print(*args, **kwargs) __builtin__.print = prt class BackupStreamToFile(object): def __init__(self, local_output_dir, for_stdout=True): self.for_stdout = for_stdout self.terminal_stream = sys.stdout if for_stdout else sys.stderr fname = os.path.join(local_output_dir, 'b1_stdout.txt' if for_stdout else 'b2_stderr.txt') existing = os.path.exists(fname) self.file_stream = open(fname, 'a') if existing: time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') self.file_stream.flush() os.system(f'ln -s {fname} /opt/tiger/run_trial/ >/dev/null 2>&1') self.enabled = True def write(self, message): self.terminal_stream.write(message) self.file_stream.write(message) def flush(self): self.terminal_stream.flush() self.file_stream.flush() def isatty(self): return True def close(self): if not self.enabled: return self.enabled = False self.file_stream.flush() self.file_stream.close() if self.for_stdout: sys.stdout = self.terminal_stream sys.stdout.flush() else: sys.stderr = self.terminal_stream sys.stderr.flush() def __del__(self): self.close()