| import torch |
| import numpy as np |
| |
| import socket |
| import argparse |
| import os |
| import json |
| import subprocess |
| from hps import Hyperparams, parse_args_and_update_hparams, add_vae_arguments |
| from utils import (logger, |
| local_mpi_rank, |
| mpi_size, |
| maybe_download, |
| mpi_rank) |
| from data import mkdir_p |
| from contextlib import contextmanager |
| import torch.distributed as dist |
| |
| from vae import VAE |
| from torch.nn.parallel.distributed import DistributedDataParallel |
|
|
|
|
| def update_ema(vae, ema_vae, ema_rate): |
| for p1, p2 in zip(vae.parameters(), ema_vae.parameters()): |
| p2.data.mul_(ema_rate) |
| p2.data.add_(p1.data * (1 - ema_rate)) |
|
|
|
|
| def save_model(path, vae, ema_vae, optimizer, H): |
| torch.save(vae.state_dict(), f'{path}-model.th') |
| torch.save(ema_vae.state_dict(), f'{path}-model-ema.th') |
| torch.save(optimizer.state_dict(), f'{path}-opt.th') |
| from_log = os.path.join(H.save_dir, 'log.jsonl') |
| to_log = f'{os.path.dirname(path)}/{os.path.basename(path)}-log.jsonl' |
| subprocess.check_output(['cp', from_log, to_log]) |
|
|
|
|
| def accumulate_stats(stats, frequency): |
| z = {} |
| for k in stats[-1]: |
| if k in ['distortion_nans', 'rate_nans', 'skipped_updates', 'gcskip']: |
| z[k] = np.sum([a[k] for a in stats[-frequency:]]) |
| elif k == 'grad_norm': |
| vals = [a[k] for a in stats[-frequency:]] |
| finites = np.array(vals)[np.isfinite(vals)] |
| if len(finites) == 0: |
| z[k] = 0.0 |
| else: |
| z[k] = np.max(finites) |
| elif k == 'elbo': |
| vals = [a[k] for a in stats[-frequency:]] |
| finites = np.array(vals)[np.isfinite(vals)] |
| z['elbo'] = np.mean(vals) |
| z['elbo_filtered'] = np.mean(finites) |
| elif k == 'iter_time': |
| z[k] = stats[-1][k] if len(stats) < frequency else np.mean([a[k] for a in stats[-frequency:]]) |
| else: |
| z[k] = np.mean([a[k] for a in stats[-frequency:]]) |
| return z |
|
|
|
|
| def linear_warmup(warmup_iters): |
| def f(iteration): |
| return 1.0 if iteration > warmup_iters else iteration / warmup_iters |
| return f |
|
|
|
|
| def setup_mpi(H): |
| H.mpi_size = mpi_size() |
| H.local_rank = local_mpi_rank() |
| H.rank = mpi_rank() |
| os.environ["RANK"] = str(H.rank) |
| os.environ["WORLD_SIZE"] = str(H.mpi_size) |
| os.environ["MASTER_PORT"] = str(H.port) |
| |
| os.environ["MASTER_ADDR"] = MPI.COMM_WORLD.bcast(socket.gethostname(), root=0) |
| torch.cuda.set_device(H.local_rank) |
| dist.init_process_group(backend='nccl', init_method=f"env://") |
|
|
|
|
| def distributed_maybe_download(path, local_rank, mpi_size): |
| if not path.startswith('gs://'): |
| return path |
| filename = path[5:].replace('/', '-') |
| with first_rank_first(local_rank, mpi_size): |
| fp = maybe_download(path, filename) |
| return fp |
|
|
|
|
| @contextmanager |
| def first_rank_first(local_rank, mpi_size): |
| if mpi_size > 1 and local_rank > 0: |
| dist.barrier() |
|
|
| try: |
| yield |
| finally: |
| if mpi_size > 1 and local_rank == 0: |
| dist.barrier() |
|
|
|
|
| def setup_save_dirs(H): |
| H.save_dir = os.path.join(H.save_dir, H.desc) |
| mkdir_p(H.save_dir) |
| H.logdir = os.path.join(H.save_dir, 'log') |
|
|
|
|
| def set_up_hyperparams(s=None): |
| H = Hyperparams() |
| parser = argparse.ArgumentParser() |
| parser = add_vae_arguments(parser) |
| parse_args_and_update_hparams(H, parser, s=s) |
| setup_mpi(H) |
| setup_save_dirs(H) |
| logprint = logger(H.logdir) |
| for i, k in enumerate(sorted(H)): |
| logprint(type='hparam', key=k, value=H[k]) |
| np.random.seed(H.seed) |
| torch.manual_seed(H.seed) |
| torch.cuda.manual_seed(H.seed) |
| logprint('training model', H.desc, 'on', H.dataset) |
| return H, logprint |
|
|
|
|
| def restore_params(model, path, local_rank, mpi_size, map_ddp=True, map_cpu=False): |
| state_dict = torch.load(distributed_maybe_download(path, local_rank, mpi_size), map_location='cpu' if map_cpu else None) |
| if map_ddp: |
| new_state_dict = {} |
| l = len('module.') |
| for k in state_dict: |
| if k.startswith('module.'): |
| new_state_dict[k[l:]] = state_dict[k] |
| else: |
| new_state_dict[k] = state_dict[k] |
| state_dict = new_state_dict |
| model.load_state_dict(state_dict) |
|
|
|
|
| def restore_log(path, local_rank, mpi_size): |
| loaded = [json.loads(l) for l in open(distributed_maybe_download(path, local_rank, mpi_size))] |
| try: |
| cur_eval_loss = min([z['elbo'] for z in loaded if 'type' in z and z['type'] == 'eval_loss']) |
| except ValueError: |
| cur_eval_loss = float('inf') |
| starting_epoch = max([z['epoch'] for z in loaded if 'type' in z and z['type'] == 'train_loss']) |
| iterate = max([z['step'] for z in loaded if 'type' in z and z['type'] == 'train_loss']) |
| return cur_eval_loss, iterate, starting_epoch |
|
|
|
|
| def load_vaes(H, logprint): |
| vae = VAE(H) |
| if H.restore_path: |
| logprint(f'Restoring vae from {H.restore_path}') |
| restore_params(vae, H.restore_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size) |
|
|
| ema_vae = VAE(H) |
| if H.restore_ema_path: |
| logprint(f'Restoring ema vae from {H.restore_ema_path}') |
| restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size) |
| else: |
| ema_vae.load_state_dict(vae.state_dict()) |
| ema_vae.requires_grad_(False) |
|
|
| vae = vae.cuda(H.local_rank) |
| ema_vae = ema_vae.cuda(H.local_rank) |
|
|
| vae = DistributedDataParallel(vae, device_ids=[H.local_rank], output_device=H.local_rank) |
|
|
| if len(list(vae.named_parameters())) != len(list(vae.parameters())): |
| raise ValueError('Some params are not named. Please name all params.') |
| total_params = 0 |
| for name, p in vae.named_parameters(): |
| total_params += np.prod(p.shape) |
| logprint(total_params=total_params, readable=f'{total_params:,}') |
| return vae, ema_vae |
|
|
|
|
| def load_opt(H, vae, logprint): |
| optimizer = AdamW(vae.parameters(), weight_decay=H.wd, lr=H.lr, betas=(H.adam_beta1, H.adam_beta2)) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup(H.warmup_iters)) |
|
|
| if H.restore_optimizer_path: |
| optimizer.load_state_dict( |
| torch.load(distributed_maybe_download(H.restore_optimizer_path, H.local_rank, H.mpi_size), map_location='cpu')) |
| if H.restore_log_path: |
| cur_eval_loss, iterate, starting_epoch = restore_log(H.restore_log_path, H.local_rank, H.mpi_size) |
| else: |
| cur_eval_loss, iterate, starting_epoch = float('inf'), 0, 0 |
| logprint('starting at epoch', starting_epoch, 'iterate', iterate, 'eval loss', cur_eval_loss) |
| return optimizer, scheduler, cur_eval_loss, iterate, starting_epoch |
|
|