| |
| import os |
| import json |
| import tempfile |
| import numpy as np |
| import torch |
| import time |
| import subprocess |
| import torch.distributed as dist |
|
|
|
|
| def allreduce(x, average): |
| if mpi_size() > 1: |
| dist.all_reduce(x, dist.ReduceOp.SUM) |
| return x / mpi_size() if average else x |
|
|
|
|
| def get_cpu_stats_over_ranks(stat_dict): |
| keys = sorted(stat_dict.keys()) |
| allreduced = allreduce(torch.stack([torch.as_tensor(stat_dict[k]).detach().cuda().float() for k in keys]), average=True).cpu() |
| return {k: allreduced[i].item() for (i, k) in enumerate(keys)} |
|
|
|
|
| class Hyperparams(dict): |
| def __getattr__(self, attr): |
| try: |
| return self[attr] |
| except KeyError: |
| return None |
|
|
| def __setattr__(self, attr, value): |
| self[attr] = value |
|
|
|
|
| def logger(log_prefix): |
| 'Prints the arguments out to stdout, .txt, and .jsonl files' |
|
|
| jsonl_path = f'{log_prefix}.jsonl' |
| txt_path = f'{log_prefix}.txt' |
|
|
| def log(*args, pprint=False, **kwargs): |
| if mpi_rank() != 0: |
| return |
| t = time.ctime() |
| argdict = {'time': t} |
| if len(args) > 0: |
| argdict['message'] = ' '.join([str(x) for x in args]) |
| argdict.update(kwargs) |
|
|
| txt_str = [] |
| args_iter = sorted(argdict) if pprint else argdict |
| for k in args_iter: |
| val = argdict[k] |
| if isinstance(val, np.ndarray): |
| val = val.tolist() |
| elif isinstance(val, np.integer): |
| val = int(val) |
| elif isinstance(val, np.floating): |
| val = float(val) |
| argdict[k] = val |
| if isinstance(val, float): |
| val = f'{val:.5f}' |
| txt_str.append(f'{k}: {val}') |
| txt_str = ', '.join(txt_str) |
|
|
| if pprint: |
| json_str = json.dumps(argdict, sort_keys=True) |
| txt_str = json.dumps(argdict, sort_keys=True, indent=4) |
| else: |
| json_str = json.dumps(argdict) |
|
|
| print(txt_str, flush=True) |
|
|
| with open(txt_path, "a+") as f: |
| print(txt_str, file=f, flush=True) |
| with open(jsonl_path, "a+") as f: |
| print(json_str, file=f, flush=True) |
|
|
| return log |
|
|
|
|
| def maybe_download(path, filename=None): |
| '''If a path is a gsutil path, download it and return the local link, |
| otherwise return link''' |
| if not path.startswith('gs://'): |
| return path |
| if filename: |
| local_dest = f'/tmp/' |
| out_path = f'/tmp/{filename}' |
| if os.path.isfile(out_path): |
| return out_path |
| subprocess.check_output(['gsutil', '-m', 'cp', '-R', path, out_path]) |
| return out_path |
| else: |
| local_dest = tempfile.mkstemp()[1] |
| subprocess.check_output(['gsutil', '-m', 'cp', path, local_dest]) |
| return local_dest |
|
|
|
|
| def tile_images(images, d1=4, d2=4, border=1): |
| id1, id2, c = images[0].shape |
| out = np.ones([d1 * id1 + border * (d1 + 1), |
| d2 * id2 + border * (d2 + 1), |
| c], dtype=np.uint8) |
| out *= 255 |
| if len(images) != d1 * d2: |
| raise ValueError('Wrong num of images') |
| for imgnum, im in enumerate(images): |
| num_d1 = imgnum // d2 |
| num_d2 = imgnum % d2 |
| start_d1 = num_d1 * id1 + border * (num_d1 + 1) |
| start_d2 = num_d2 * id2 + border * (num_d2 + 1) |
| out[start_d1:start_d1 + id1, start_d2:start_d2 + id2, :] = im |
| return out |
|
|
|
|
| def mpi_size(): |
| return MPI.COMM_WORLD.Get_size() |
|
|
|
|
| def mpi_rank(): |
| return MPI.COMM_WORLD.Get_rank() |
|
|
|
|
| def num_nodes(): |
| nn = mpi_size() |
| if nn % 8 == 0: |
| return nn // 8 |
| return nn // 8 + 1 |
|
|
|
|
| def gpus_per_node(): |
| size = mpi_size() |
| if size > 1: |
| return max(size // num_nodes(), 1) |
| return 1 |
|
|
|
|
| def local_mpi_rank(): |
| return mpi_rank() % gpus_per_node() |
|
|