| import datetime |
| import functools |
| import math |
| import os |
| import random |
| import subprocess |
| import sys |
| import threading |
| import time |
| from collections import defaultdict, deque |
| from typing import Iterator, List, Tuple |
|
|
| import numpy as np |
| import pytz |
| import torch |
| import torch.distributed as tdist |
| import torch.nn.functional as F |
|
|
| import infinity.utils.dist as dist |
|
|
| os_system = functools.partial(subprocess.call, shell=True) |
| def echo(info): |
| os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') |
| def os_system_get_stdout(cmd): |
| return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') |
| def os_system_get_stdout_stderr(cmd): |
| cnt = 0 |
| while True: |
| try: |
| sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30) |
| except subprocess.TimeoutExpired: |
| cnt += 1 |
| print(f'[fetch free_port file] timeout cnt={cnt}') |
| else: |
| return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') |
|
|
|
|
| def is_pow2n(x): |
| return x > 0 and (x & (x - 1) == 0) |
|
|
|
|
| def time_str(fmt='[%m-%d %H:%M:%S]'): |
| return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) |
|
|
|
|
| class DistLogger(object): |
| def __init__(self, lg): |
| self._lg = lg |
| |
| @staticmethod |
| def do_nothing(*args, **kwargs): |
| pass |
| |
| def __getattr__(self, attr: str): |
| return getattr(self._lg, attr) if self._lg is not None else DistLogger.do_nothing |
|
|
| class TensorboardLogger(object): |
| def __init__(self, log_dir, filename_suffix): |
| try: import tensorflow_io as tfio |
| except: pass |
| from torch.utils.tensorboard import SummaryWriter |
| self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix) |
| self.step = 0 |
| |
| def set_step(self, step=None): |
| if step is not None: |
| self.step = step |
| else: |
| self.step += 1 |
| |
| def loggable(self): |
| return self.step == 0 or (self.step + 1) % 500 == 0 |
| |
| def update(self, head='scalar', step=None, **kwargs): |
| if step is None: |
| step = self.step |
| if not self.loggable(): return |
| for k, v in kwargs.items(): |
| if v is None: continue |
| if hasattr(v, 'item'): v = v.item() |
| self.writer.add_scalar(f'{head}/{k}', v, step) |
| |
| def log_tensor_as_distri(self, tag, tensor1d, step=None): |
| if step is None: |
| step = self.step |
| if not self.loggable(): return |
| try: |
| self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step) |
| except Exception as e: |
| print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}') |
| |
| def log_image(self, tag, img_chw, step=None): |
| if step is None: |
| step = self.step |
| if not self.loggable(): return |
| self.writer.add_image(tag, img_chw, step, dataformats='CHW') |
| |
| def flush(self): |
| self.writer.flush() |
| |
| def close(self): |
| self.writer.close() |
|
|
|
|
| class Low_GPU_usage(object): |
| def __init__(self, files, sleep_secs, verbose): |
| pass |
|
|
| def early_stop(self): |
| pass |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| pass |
|
|
| class TouchingDaemonDontForgetToStartMe(threading.Thread): |
| def __init__(self, files: List[str], sleep_secs: int, verbose=False): |
| super().__init__(daemon=True) |
| self.files = tuple(files) |
| self.sleep_secs = sleep_secs |
| self.is_finished = False |
| self.verbose = verbose |
| |
| f_back = sys._getframe().f_back |
| file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] |
| self.print_prefix = f' ({file_desc}, line{f_back.f_lineno:-4d}) @daemon@ ' |
| |
| def finishing(self): |
| self.is_finished = True |
| |
| def run(self) -> None: |
| kw = {} |
| if tdist.is_initialized(): kw['clean'] = True |
| |
| stt = time.time() |
| if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] start touching {self.files} per {self.sleep_secs}s ...', **kw) |
| while not self.is_finished: |
| for f in self.files: |
| if os.path.exists(f): |
| try: |
| os.utime(f) |
| fp = open(f, 'a') |
| fp.close() |
| except: pass |
| time.sleep(self.sleep_secs) |
| |
| if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] finish touching after {time.time()-stt:.1f} secs {self.files} per {self.sleep_secs}s. ', **kw) |
|
|
|
|
| class SmoothedValue(object): |
| """Track a series of values and provide access to smoothed values over a |
| window or the global series average. |
| """ |
| |
| def __init__(self, window_size=30, fmt=None): |
| if fmt is None: |
| fmt = "{median:.4f} ({global_avg:.4f})" |
| self.deque = deque(maxlen=window_size) |
| self.total = 0.0 |
| self.count = 0 |
| self.fmt = fmt |
| |
| def update(self, value, n=1): |
| self.deque.append(value) |
| self.count += n |
| self.total += value * n |
| |
| def synchronize_between_processes(self): |
| """ |
| Warning: does not synchronize the deque! |
| """ |
| t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') |
| tdist.barrier() |
| tdist.all_reduce(t) |
| t = t.tolist() |
| self.count = int(t[0]) |
| self.total = t[1] |
| |
| @property |
| def median(self): |
| return np.median(self.deque) if len(self.deque) else 0 |
| |
| @property |
| def avg(self): |
| return sum(self.deque) / (len(self.deque) or 1) |
| |
| @property |
| def global_avg(self): |
| return self.total / (self.count or 1) |
| |
| @property |
| def max(self): |
| return max(self.deque) if len(self.deque) else 0 |
| |
| @property |
| def value(self): |
| return self.deque[-1] if len(self.deque) else 0 |
| |
| def time_preds(self, counts) -> Tuple[float, str, str]: |
| remain_secs = counts * self.median |
| return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs)) |
| |
| def __str__(self): |
| return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) |
|
|
|
|
| class MetricLogger(object): |
| def __init__(self): |
| self.meters = defaultdict(SmoothedValue) |
| self.iter_end_t = time.time() |
| self.log_iters = set() |
| self.log_every_iter = False |
| |
| def update(self, **kwargs): |
| |
| for k, v in kwargs.items(): |
| if v is None: continue |
| if hasattr(v, 'item'): v = v.item() |
| |
| self.meters[k].update(v) |
| |
| def __getattr__(self, attr): |
| if attr in self.meters: |
| return self.meters[attr] |
| if attr in self.__dict__: |
| return self.__dict__[attr] |
| raise AttributeError("'{}' object has no attribute '{}'".format( |
| type(self).__name__, attr)) |
| |
| def __str__(self): |
| loss_str = [] |
| for name, meter in self.meters.items(): |
| if len(meter.deque): |
| loss_str.append( |
| "{}: {}".format(name, str(meter)) |
| ) |
| return ' '.join(loss_str) |
| |
| def synchronize_between_processes(self): |
| for meter in self.meters.values(): |
| meter.synchronize_between_processes() |
| |
| def add_meter(self, name, meter): |
| self.meters[name] = meter |
| |
| def log_every(self, start_it, max_iters, itrt, log_freq, log_every_iter=False, header=''): |
| start_it = start_it % max_iters |
| self.log_iters = set(range(start_it, max_iters, log_freq)) |
| self.log_iters.add(start_it) |
| self.log_iters.add(max_iters-1) |
| self.log_iters.add(max_iters) |
| self.log_every_iter = log_every_iter |
| self.iter_end_t = time.time() |
| self.iter_time = SmoothedValue(fmt='{value:.4f}') |
| self.data_time = SmoothedValue(fmt='{value:.3f}') |
| header_fmt = header + ': [{0:' + str(len(str(max_iters))) + 'd}/{1}]' |
| |
| start_time = time.time() |
| if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): |
| for it in range(start_it, max_iters): |
| obj = next(itrt) |
| if it < start_it: continue |
| self.data_time.update(time.time() - self.iter_end_t) |
| yield it, obj |
| self.iter_time.update(time.time() - self.iter_end_t) |
| if self.log_every_iter or it in self.log_iters: |
| eta_seconds = self.iter_time.avg * (max_iters - it) |
| print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True) |
| |
| self.iter_end_t = time.time() |
| else: |
| if isinstance(itrt, int): itrt = range(itrt) |
| for it, obj in enumerate(itrt): |
| if it < start_it: |
| self.iter_end_t = time.time() |
| continue |
| self.data_time.update(time.time() - self.iter_end_t) |
| yield it, obj |
| self.iter_time.update(time.time() - self.iter_end_t) |
| if self.log_every_iter or it in self.log_iters: |
| eta_seconds = self.iter_time.avg * (max_iters - it) |
| print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True) |
| self.iter_end_t = time.time() |
| cost = time.time() - start_time |
| cost_str = str(datetime.timedelta(seconds=int(cost))) |
| print(f'{header} Cost of this ep: {cost_str} ({cost / (max_iters-start_it):.3f} s / it)', flush=True) |
|
|
|
|
| class NullDDP(torch.nn.Module): |
| def __init__(self, module, *args, **kwargs): |
| super(NullDDP, self).__init__() |
| self.module = module |
| self.require_backward_grad_sync = False |
| |
| def forward(self, *args, **kwargs): |
| return self.module(*args, **kwargs) |
|
|
|
|
| def build_2d_sincos_position_embedding(h, w, embed_dim, temperature=10000., sc=0, verbose=True): |
| |
| |
| grid_w = torch.arange(w, dtype=torch.float32) |
| grid_h = torch.arange(h, dtype=torch.float32) |
| grid_w, grid_h = torch.meshgrid([grid_w, grid_h], indexing='ij') |
| if sc == 0: |
| scale = 1 |
| elif sc == 1: |
| scale = math.pi * 2 / w |
| else: |
| scale = 1 / w |
| grid_w = scale * grid_w.reshape(h*w, 1) |
| grid_h = scale * grid_h.reshape(h*w, 1) |
| |
| assert embed_dim % 4 == 0, f'Embed dimension ({embed_dim}) must be divisible by 4 for 2D sin-cos position embedding!' |
| pos_dim = embed_dim // 4 |
| omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim |
| omega = (-math.log(temperature) * omega).exp() |
| |
| out_w = grid_w * omega.view(1, pos_dim) |
| out_h = grid_h * omega.view(1, pos_dim) |
| pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] |
| if verbose: print(f'[build_2d_sincos_position_embedding @ {hw} x {hw}] scale_type={sc}, temperature={temperature:g}, shape={pos_emb.shape}') |
| return pos_emb |
|
|
|
|
| if __name__ == '__main__': |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| cmap_div = sns.color_palette('icefire', as_cmap=True) |
| |
| scs = [0, 1, 2] |
| temps = [20, 50, 100, 1000] |
| reso = 3.0 |
| RR, CC = len(scs), len(temps) |
| plt.figure(figsize=(CC * reso, RR * reso)) |
| for row, sc in enumerate(scs): |
| for col, temp in enumerate(temps): |
| name = f'sc={sc}, T={temp}' |
| hw, C = 16, 512 |
| N = hw*hw |
| pe = build_2d_sincos_position_embedding(hw, C, temperature=temp, sc=sc, verbose=False)[0] |
| |
| hw2 = 16 |
| N2 = hw2*hw2 |
| pe2 = build_2d_sincos_position_embedding(hw2, C, temperature=temp, sc=sc, verbose=False)[0] |
| |
| bchw, bchw2 = F.normalize(pe.view(hw, hw, C).permute(2, 0, 1).unsqueeze(0), dim=1), F.normalize(pe2.view(hw2, hw2, C).permute(2, 0, 1).unsqueeze(0), dim=1) |
| dis = [ |
| f'{F.mse_loss(bchw, F.interpolate(bchw2, size=bchw.shape[-2], mode=inter)).item():.3f}' |
| for inter in ('bilinear', 'bicubic', 'nearest') |
| ] |
| dis += [ |
| f'{F.mse_loss(F.interpolate(bchw, size=bchw2.shape[-2], mode=inter), bchw2).item():.3f}' |
| for inter in ('area', 'nearest') |
| ] |
| print(f'[{name:^20s}] dis: {dis}') |
| """ |
| [ sc=0, T=20 ] dis: ['0.010', '0.011', '0.011', '0.009', '0.010'] |
| [ sc=0, T=100 ] dis: ['0.007', '0.007', '0.007', '0.006', '0.007'] |
| [ sc=0, T=1000 ] dis: ['0.005', '0.005', '0.005', '0.004', '0.005'] |
| [ sc=0, T=10000 ] dis: ['0.004', '0.004', '0.004', '0.003', '0.004'] |
| [ sc=1, T=20 ] dis: ['0.007', '0.008', '0.008', '0.007', '0.008'] |
| [ sc=1, T=100 ] dis: ['0.005', '0.005', '0.005', '0.005', '0.005'] |
| [ sc=1, T=1000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003'] |
| [ sc=1, T=10000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003'] |
| [ sc=2, T=20 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] |
| [ sc=2, T=100 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] |
| [ sc=2, T=1000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] |
| [ sc=2, T=10000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000'] |
| Process finished with exit code 0 |
| """ |
| |
| pe = torch.from_numpy(cmap_div(pe.T.numpy())[:, :, :3]) |
| tar_h, tar_w = 1024, 1024 |
| pe = pe.repeat_interleave(tar_w//pe.shape[0], dim=0).repeat_interleave(tar_h//pe.shape[1], dim=1) |
| plt.subplot(RR, CC, 1+row*CC+col) |
| plt.title(name) |
| plt.xlabel('hxw'), plt.ylabel('C') |
| plt.xticks([]), plt.yticks([]) |
| plt.imshow(pe.mul(255).round().clamp(0, 255).byte().numpy()) |
| plt.tight_layout(h_pad=0.02) |
| plt.show() |
|
|
|
|
| def check_randomness(args): |
| U = 16384 |
| t = torch.zeros(dist.get_world_size(), 4, dtype=torch.float32, device=args.device) |
| t0 = torch.zeros(1, dtype=torch.float32, device=args.device).random_(U) |
| t[dist.get_rank(), 0] = float(random.randrange(U)) |
| t[dist.get_rank(), 1] = float(np.random.randint(U)) |
| t[dist.get_rank(), 2] = float(torch.randint(0, U, (1,))[0]) |
| t[dist.get_rank(), 3] = float(t0[0]) |
| dist.allreduce(t) |
| for rk in range(1, dist.get_world_size()): |
| assert torch.allclose(t[rk - 1], t[rk]), f't={t}' |
| del t0, t, U |
|
|