| import copy |
| import json |
| import os |
| import re |
|
|
| import numpy as np |
| import pandas as pd |
| import pytorch_lightning as pl |
| import torch |
| from numpy.lib.function_base import flip |
| from pytorch_lightning import loggers as pl_loggers |
| from pytorch_lightning.callbacks import * |
| from torch import nn |
| from torch.cuda import amp |
| from torch.distributions import Categorical |
| from torch.optim.optimizer import Optimizer |
| from torch.utils.data.dataset import ConcatDataset, TensorDataset |
| from torchvision.utils import make_grid, save_image |
|
|
| from config import * |
| from dataset import * |
| from dist_utils import * |
| from lmdb_writer import * |
| from metrics import * |
| from renderer import * |
|
|
|
|
| class LitModel(pl.LightningModule): |
| def __init__(self, conf: TrainConfig): |
| super().__init__() |
| assert conf.train_mode != TrainMode.manipulate |
| if conf.seed is not None: |
| pl.seed_everything(conf.seed) |
|
|
| self.save_hyperparameters(conf.as_dict_jsonable()) |
|
|
| self.conf = conf |
|
|
| self.model = conf.make_model_conf().make_model() |
| self.ema_model = copy.deepcopy(self.model) |
| self.ema_model.requires_grad_(False) |
| self.ema_model.eval() |
|
|
| model_size = 0 |
| for param in self.model.parameters(): |
| model_size += param.data.nelement() |
| print('Model params: %.2f M' % (model_size / 1024 / 1024)) |
|
|
| self.sampler = conf.make_diffusion_conf().make_sampler() |
| self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() |
|
|
| |
| self.T_sampler = conf.make_T_sampler() |
|
|
| if conf.train_mode.use_latent_net(): |
| self.latent_sampler = conf.make_latent_diffusion_conf( |
| ).make_sampler() |
| self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( |
| ).make_sampler() |
| else: |
| self.latent_sampler = None |
| self.eval_latent_sampler = None |
|
|
| |
| self.register_buffer( |
| 'x_T', |
| torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) |
|
|
| if conf.pretrain is not None: |
| print(f'loading pretrain ... {conf.pretrain.name}') |
| state = torch.load(conf.pretrain.path, map_location='cpu') |
| print('step:', state['global_step']) |
| self.load_state_dict(state['state_dict'], strict=False) |
|
|
| if conf.latent_infer_path is not None: |
| print('loading latent stats ...') |
| state = torch.load(conf.latent_infer_path) |
| self.conds = state['conds'] |
| self.register_buffer('conds_mean', state['conds_mean'][None, :]) |
| self.register_buffer('conds_std', state['conds_std'][None, :]) |
| else: |
| self.conds_mean = None |
| self.conds_std = None |
|
|
| def normalize(self, cond): |
| cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( |
| self.device) |
| return cond |
|
|
| def denormalize(self, cond): |
| cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( |
| self.device) |
| return cond |
|
|
| def sample(self, N, device, T=None, T_latent=None): |
| if T is None: |
| sampler = self.eval_sampler |
| latent_sampler = self.latent_sampler |
| else: |
| sampler = self.conf._make_diffusion_conf(T).make_sampler() |
| latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler() |
|
|
| noise = torch.randn(N, |
| 3, |
| self.conf.img_size, |
| self.conf.img_size, |
| device=device) |
| pred_img = render_uncondition( |
| self.conf, |
| self.ema_model, |
| noise, |
| sampler=sampler, |
| latent_sampler=latent_sampler, |
| conds_mean=self.conds_mean, |
| conds_std=self.conds_std, |
| ) |
| pred_img = (pred_img + 1) / 2 |
| return pred_img |
|
|
| def render(self, noise, cond=None, T=None): |
| if T is None: |
| sampler = self.eval_sampler |
| else: |
| sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
|
| if cond is not None: |
| pred_img = render_condition(self.conf, |
| self.ema_model, |
| noise, |
| sampler=sampler, |
| cond=cond) |
| else: |
| pred_img = render_uncondition(self.conf, |
| self.ema_model, |
| noise, |
| sampler=sampler, |
| latent_sampler=None) |
| pred_img = (pred_img + 1) / 2 |
| return pred_img |
|
|
| def encode(self, x): |
| |
| assert self.conf.model_type.has_autoenc() |
| cond = self.ema_model.encoder.forward(x) |
| return cond |
|
|
| def encode_stochastic(self, x, cond, T=None): |
| if T is None: |
| sampler = self.eval_sampler |
| else: |
| sampler = self.conf._make_diffusion_conf(T).make_sampler() |
| out = sampler.ddim_reverse_sample_loop(self.ema_model, |
| x, |
| model_kwargs={'cond': cond}) |
| return out['sample'] |
|
|
| def forward(self, noise=None, x_start=None, ema_model: bool = False): |
| with amp.autocast(False): |
| if ema_model: |
| model = self.ema_model |
| else: |
| model = self.model |
| gen = self.eval_sampler.sample(model=model, |
| noise=noise, |
| x_start=x_start) |
| return gen |
|
|
| def setup(self, stage=None) -> None: |
| """ |
| make datasets & seeding each worker separately |
| """ |
| |
| |
| if self.conf.seed is not None: |
| seed = self.conf.seed * get_world_size() + self.global_rank |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| print('local seed:', seed) |
| |
|
|
| self.train_data = self.conf.make_dataset() |
| print('train data:', len(self.train_data)) |
| self.val_data = self.train_data |
| print('val data:', len(self.val_data)) |
|
|
| def _train_dataloader(self, drop_last=True): |
| """ |
| really make the dataloader |
| """ |
| |
| |
| conf = self.conf.clone() |
| conf.batch_size = self.batch_size |
|
|
| dataloader = conf.make_loader(self.train_data, |
| shuffle=True, |
| drop_last=drop_last) |
| return dataloader |
|
|
| def train_dataloader(self): |
| """ |
| return the dataloader, if diffusion mode => return image dataset |
| if latent mode => return the inferred latent dataset |
| """ |
| print('on train dataloader start ...') |
| if self.conf.train_mode.require_dataset_infer(): |
| if self.conds is None: |
| |
| |
| self.conds = self.infer_whole_dataset() |
| |
| |
| self.conds_mean.data = self.conds.float().mean(dim=0, |
| keepdim=True) |
| self.conds_std.data = self.conds.float().std(dim=0, |
| keepdim=True) |
| print('mean:', self.conds_mean.mean(), 'std:', |
| self.conds_std.mean()) |
|
|
| |
| conf = self.conf.clone() |
| conf.batch_size = self.batch_size |
| data = TensorDataset(self.conds) |
| return conf.make_loader(data, shuffle=True) |
| else: |
| return self._train_dataloader() |
|
|
| @property |
| def batch_size(self): |
| """ |
| local batch size for each worker |
| """ |
| ws = get_world_size() |
| assert self.conf.batch_size % ws == 0 |
| return self.conf.batch_size // ws |
|
|
| @property |
| def num_samples(self): |
| """ |
| (global) batch size * iterations |
| """ |
| |
| |
| return self.global_step * self.conf.batch_size_effective |
|
|
| def is_last_accum(self, batch_idx): |
| """ |
| is it the last gradient accumulation loop? |
| used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not |
| """ |
| return (batch_idx + 1) % self.conf.accum_batches == 0 |
|
|
| def infer_whole_dataset(self, |
| with_render=False, |
| T_render=None, |
| render_save_path=None): |
| """ |
| predicting the latents given images using the encoder |
| |
| Args: |
| both_flips: include both original and flipped images; no need, it's not an improvement |
| with_render: whether to also render the images corresponding to that latent |
| render_save_path: lmdb output for the rendered images |
| """ |
| data = self.conf.make_dataset() |
| if isinstance(data, CelebAlmdb) and data.crop_d2c: |
| |
| data.transform = make_transform(self.conf.img_size, |
| flip_prob=0, |
| crop_d2c=True) |
| else: |
| data.transform = make_transform(self.conf.img_size, flip_prob=0) |
|
|
| |
|
|
| loader = self.conf.make_loader( |
| data, |
| shuffle=False, |
| drop_last=False, |
| batch_size=self.conf.batch_size_eval, |
| parallel=True, |
| ) |
| model = self.ema_model |
| model.eval() |
| conds = [] |
|
|
| if with_render: |
| sampler = self.conf._make_diffusion_conf( |
| T=T_render or self.conf.T_eval).make_sampler() |
|
|
| if self.global_rank == 0: |
| writer = LMDBImageWriter(render_save_path, |
| format='webp', |
| quality=100) |
| else: |
| writer = nullcontext() |
| else: |
| writer = nullcontext() |
|
|
| with writer: |
| for batch in tqdm(loader, total=len(loader), desc='infer'): |
| with torch.no_grad(): |
| |
| |
| cond = model.encoder(batch['img'].to(self.device)) |
|
|
| |
| idx = batch['index'] |
| idx = self.all_gather(idx) |
| if idx.dim() == 2: |
| idx = idx.flatten(0, 1) |
| argsort = idx.argsort() |
|
|
| if with_render: |
| noise = torch.randn(len(cond), |
| 3, |
| self.conf.img_size, |
| self.conf.img_size, |
| device=self.device) |
| render = sampler.sample(model, noise=noise, cond=cond) |
| render = (render + 1) / 2 |
| |
| |
| render = self.all_gather(render) |
| if render.dim() == 5: |
| |
| render = render.flatten(0, 1) |
|
|
| |
|
|
| if self.global_rank == 0: |
| writer.put_images(render[argsort]) |
|
|
| |
| cond = self.all_gather(cond) |
|
|
| if cond.dim() == 3: |
| |
| cond = cond.flatten(0, 1) |
|
|
| conds.append(cond[argsort].cpu()) |
| |
| model.train() |
| |
|
|
| conds = torch.cat(conds).float() |
| return conds |
|
|
| def training_step(self, batch, batch_idx): |
| """ |
| given an input, calculate the loss function |
| no optimization at this stage. |
| """ |
| with amp.autocast(False): |
| |
| |
| if self.conf.train_mode.require_dataset_infer(): |
| |
| cond = batch[0] |
| if self.conf.latent_znormalize: |
| cond = (cond - self.conds_mean.to( |
| self.device)) / self.conds_std.to(self.device) |
| else: |
| imgs, idxs = batch['img'], batch['index'] |
| |
| x_start = imgs |
|
|
| if self.conf.train_mode == TrainMode.diffusion: |
| """ |
| main training mode!!! |
| """ |
| |
| t, weight = self.T_sampler.sample(len(x_start), x_start.device) |
| losses = self.sampler.training_losses(model=self.model, |
| x_start=x_start, |
| t=t) |
| elif self.conf.train_mode.is_latent_diffusion(): |
| """ |
| training the latent variables! |
| """ |
| |
| t, weight = self.T_sampler.sample(len(cond), cond.device) |
| latent_losses = self.latent_sampler.training_losses( |
| model=self.model.latent_net, x_start=cond, t=t) |
| |
| losses = { |
| 'latent': latent_losses['loss'], |
| 'loss': latent_losses['loss'] |
| } |
| else: |
| raise NotImplementedError() |
|
|
| loss = losses['loss'].mean() |
| |
| for key in ['loss', 'vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: |
| if key in losses: |
| losses[key] = self.all_gather(losses[key]).mean() |
|
|
| if self.global_rank == 0: |
| self.logger.experiment.add_scalar('loss', losses['loss'], |
| self.num_samples) |
| for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: |
| if key in losses: |
| self.logger.experiment.add_scalar( |
| f'loss/{key}', losses[key], self.num_samples) |
|
|
| return {'loss': loss} |
|
|
| def on_train_batch_end(self, outputs, batch, batch_idx: int, |
| dataloader_idx: int) -> None: |
| """ |
| after each training step ... |
| """ |
| if self.is_last_accum(batch_idx): |
| |
| |
| if self.conf.train_mode == TrainMode.latent_diffusion: |
| |
| ema(self.model.latent_net, self.ema_model.latent_net, |
| self.conf.ema_decay) |
| else: |
| ema(self.model, self.ema_model, self.conf.ema_decay) |
|
|
| |
| if self.conf.train_mode.require_dataset_infer(): |
| imgs = None |
| else: |
| imgs = batch['img'] |
| self.log_sample(x_start=imgs) |
| self.evaluate_scores() |
|
|
| def on_before_optimizer_step(self, optimizer: Optimizer, |
| optimizer_idx: int) -> None: |
| |
| |
| if self.conf.grad_clip > 0: |
| |
| params = [ |
| p for group in optimizer.param_groups for p in group['params'] |
| ] |
| |
| torch.nn.utils.clip_grad_norm_(params, |
| max_norm=self.conf.grad_clip) |
| |
|
|
| def log_sample(self, x_start): |
| """ |
| put images to the tensorboard |
| """ |
| def do(model, |
| postfix, |
| use_xstart, |
| save_real=False, |
| no_latent_diff=False, |
| interpolate=False): |
| model.eval() |
| with torch.no_grad(): |
| all_x_T = self.split_tensor(self.x_T) |
| batch_size = min(len(all_x_T), self.conf.batch_size_eval) |
| |
| loader = DataLoader(all_x_T, batch_size=batch_size) |
|
|
| Gen = [] |
| for x_T in loader: |
| if use_xstart: |
| _xstart = x_start[:len(x_T)] |
| else: |
| _xstart = None |
|
|
| if self.conf.train_mode.is_latent_diffusion( |
| ) and not use_xstart: |
| |
| gen = render_uncondition( |
| conf=self.conf, |
| model=model, |
| x_T=x_T, |
| sampler=self.eval_sampler, |
| latent_sampler=self.eval_latent_sampler, |
| conds_mean=self.conds_mean, |
| conds_std=self.conds_std) |
| else: |
| if not use_xstart and self.conf.model_type.has_noise_to_cond( |
| ): |
| model: BeatGANsAutoencModel |
| |
| cond = torch.randn(len(x_T), |
| self.conf.style_ch, |
| device=self.device) |
| cond = model.noise_to_cond(cond) |
| else: |
| if interpolate: |
| with amp.autocast(self.conf.fp16): |
| cond = model.encoder(_xstart) |
| i = torch.randperm(len(cond)) |
| cond = (cond + cond[i]) / 2 |
| else: |
| cond = None |
| gen = self.eval_sampler.sample(model=model, |
| noise=x_T, |
| cond=cond, |
| x_start=_xstart) |
| Gen.append(gen) |
|
|
| gen = torch.cat(Gen) |
| gen = self.all_gather(gen) |
| if gen.dim() == 5: |
| |
| gen = gen.flatten(0, 1) |
|
|
| if save_real and use_xstart: |
| |
| real = self.all_gather(_xstart) |
| if real.dim() == 5: |
| real = real.flatten(0, 1) |
|
|
| if self.global_rank == 0: |
| grid_real = (make_grid(real) + 1) / 2 |
| self.logger.experiment.add_image( |
| f'sample{postfix}/real', grid_real, |
| self.num_samples) |
|
|
| if self.global_rank == 0: |
| |
| grid = (make_grid(gen) + 1) / 2 |
| sample_dir = os.path.join(self.conf.logdir, |
| f'sample{postfix}') |
| if not os.path.exists(sample_dir): |
| os.makedirs(sample_dir) |
| path = os.path.join(sample_dir, |
| '%d.png' % self.num_samples) |
| save_image(grid, path) |
| self.logger.experiment.add_image(f'sample{postfix}', grid, |
| self.num_samples) |
| model.train() |
|
|
| if self.conf.sample_every_samples > 0 and is_time( |
| self.num_samples, self.conf.sample_every_samples, |
| self.conf.batch_size_effective): |
|
|
| if self.conf.train_mode.require_dataset_infer(): |
| do(self.model, '', use_xstart=False) |
| do(self.ema_model, '_ema', use_xstart=False) |
| else: |
| if self.conf.model_type.has_autoenc( |
| ) and self.conf.model_type.can_sample(): |
| do(self.model, '', use_xstart=False) |
| do(self.ema_model, '_ema', use_xstart=False) |
| |
| do(self.model, '_enc', use_xstart=True, save_real=True) |
| do(self.ema_model, |
| '_enc_ema', |
| use_xstart=True, |
| save_real=True) |
| elif self.conf.train_mode.use_latent_net(): |
| do(self.model, '', use_xstart=False) |
| do(self.ema_model, '_ema', use_xstart=False) |
| |
| do(self.model, '_enc', use_xstart=True, save_real=True) |
| do(self.model, |
| '_enc_nodiff', |
| use_xstart=True, |
| save_real=True, |
| no_latent_diff=True) |
| do(self.ema_model, |
| '_enc_ema', |
| use_xstart=True, |
| save_real=True) |
| else: |
| do(self.model, '', use_xstart=True, save_real=True) |
| do(self.ema_model, '_ema', use_xstart=True, save_real=True) |
|
|
| def evaluate_scores(self): |
| """ |
| evaluate FID and other scores during training (put to the tensorboard) |
| For, FID. It is a fast version with 5k images (gold standard is 50k). |
| Don't use its results in the paper! |
| """ |
| def fid(model, postfix): |
| score = evaluate_fid(self.eval_sampler, |
| model, |
| self.conf, |
| device=self.device, |
| train_data=self.train_data, |
| val_data=self.val_data, |
| latent_sampler=self.eval_latent_sampler, |
| conds_mean=self.conds_mean, |
| conds_std=self.conds_std) |
| if self.global_rank == 0: |
| self.logger.experiment.add_scalar(f'FID{postfix}', score, |
| self.num_samples) |
| if not os.path.exists(self.conf.logdir): |
| os.makedirs(self.conf.logdir) |
| with open(os.path.join(self.conf.logdir, 'eval.txt'), |
| 'a') as f: |
| metrics = { |
| f'FID{postfix}': score, |
| 'num_samples': self.num_samples, |
| } |
| f.write(json.dumps(metrics) + "\n") |
|
|
| def lpips(model, postfix): |
| if self.conf.model_type.has_autoenc( |
| ) and self.conf.train_mode.is_autoenc(): |
| |
| score = evaluate_lpips(self.eval_sampler, |
| model, |
| self.conf, |
| device=self.device, |
| val_data=self.val_data, |
| latent_sampler=self.eval_latent_sampler) |
|
|
| if self.global_rank == 0: |
| for key, val in score.items(): |
| self.logger.experiment.add_scalar( |
| f'{key}{postfix}', val, self.num_samples) |
|
|
| if self.conf.eval_every_samples > 0 and self.num_samples > 0 and is_time( |
| self.num_samples, self.conf.eval_every_samples, |
| self.conf.batch_size_effective): |
| print(f'eval fid @ {self.num_samples}') |
| lpips(self.model, '') |
| fid(self.model, '') |
|
|
| if self.conf.eval_ema_every_samples > 0 and self.num_samples > 0 and is_time( |
| self.num_samples, self.conf.eval_ema_every_samples, |
| self.conf.batch_size_effective): |
| print(f'eval fid ema @ {self.num_samples}') |
| fid(self.ema_model, '_ema') |
| |
| |
|
|
| def configure_optimizers(self): |
| out = {} |
| if self.conf.optimizer == OptimizerType.adam: |
| optim = torch.optim.Adam(self.model.parameters(), |
| lr=self.conf.lr, |
| weight_decay=self.conf.weight_decay) |
| elif self.conf.optimizer == OptimizerType.adamw: |
| optim = torch.optim.AdamW(self.model.parameters(), |
| lr=self.conf.lr, |
| weight_decay=self.conf.weight_decay) |
| else: |
| raise NotImplementedError() |
| out['optimizer'] = optim |
| if self.conf.warmup > 0: |
| sched = torch.optim.lr_scheduler.LambdaLR(optim, |
| lr_lambda=WarmupLR( |
| self.conf.warmup)) |
| out['lr_scheduler'] = { |
| 'scheduler': sched, |
| 'interval': 'step', |
| } |
| return out |
|
|
| def split_tensor(self, x): |
| """ |
| extract the tensor for a corresponding "worker" in the batch dimension |
| |
| Args: |
| x: (n, c) |
| |
| Returns: x: (n_local, c) |
| """ |
| n = len(x) |
| rank = self.global_rank |
| world_size = get_world_size() |
| |
| per_rank = n // world_size |
| return x[rank * per_rank:(rank + 1) * per_rank] |
|
|
| def test_step(self, batch, *args, **kwargs): |
| """ |
| for the "eval" mode. |
| We first select what to do according to the "conf.eval_programs". |
| test_step will only run for "one iteration" (it's a hack!). |
| |
| We just want the multi-gpu support. |
| """ |
| |
| self.setup() |
|
|
| |
| print('global step:', self.global_step) |
| """ |
| "infer" = predict the latent variables using the encoder on the whole dataset |
| """ |
| if 'infer' in self.conf.eval_programs: |
| if 'infer' in self.conf.eval_programs: |
| print('infer ...') |
| conds = self.infer_whole_dataset().float() |
| |
| save_path = f'checkpoints/{self.conf.name}/latent.pkl' |
| else: |
| raise NotImplementedError() |
|
|
| if self.global_rank == 0: |
| conds_mean = conds.mean(dim=0) |
| conds_std = conds.std(dim=0) |
| if not os.path.exists(os.path.dirname(save_path)): |
| os.makedirs(os.path.dirname(save_path)) |
| torch.save( |
| { |
| 'conds': conds, |
| 'conds_mean': conds_mean, |
| 'conds_std': conds_std, |
| }, save_path) |
| """ |
| "infer+render" = predict the latent variables using the encoder on the whole dataset |
| THIS ALSO GENERATE CORRESPONDING IMAGES |
| """ |
| |
| for each in self.conf.eval_programs: |
| if each.startswith('infer+render'): |
| m = re.match(r'infer\+render([0-9]+)', each) |
| if m is not None: |
| T = int(m[1]) |
| self.setup() |
| print(f'infer + reconstruction T{T} ...') |
| conds = self.infer_whole_dataset( |
| with_render=True, |
| T_render=T, |
| render_save_path= |
| f'latent_infer_render{T}/{self.conf.name}.lmdb', |
| ) |
| save_path = f'latent_infer_render{T}/{self.conf.name}.pkl' |
| conds_mean = conds.mean(dim=0) |
| conds_std = conds.std(dim=0) |
| if not os.path.exists(os.path.dirname(save_path)): |
| os.makedirs(os.path.dirname(save_path)) |
| torch.save( |
| { |
| 'conds': conds, |
| 'conds_mean': conds_mean, |
| 'conds_std': conds_std, |
| }, save_path) |
|
|
| |
| """ |
| "fid<T>" = unconditional generation (conf.train_mode = diffusion). |
| Note: Diff. autoenc will still receive real images in this mode. |
| "fid<T>,<T_latent>" = unconditional generation for latent models (conf.train_mode = latent_diffusion). |
| Note: Diff. autoenc will still NOT receive real images in this made. |
| but you need to make sure that the train_mode is latent_diffusion. |
| """ |
| for each in self.conf.eval_programs: |
| if each.startswith('fid'): |
| m = re.match(r'fid\(([0-9]+),([0-9]+)\)', each) |
| clip_latent_noise = False |
| if m is not None: |
| |
| T = int(m[1]) |
| T_latent = int(m[2]) |
| print(f'evaluating FID T = {T}... latent T = {T_latent}') |
| else: |
| m = re.match(r'fidclip\(([0-9]+),([0-9]+)\)', each) |
| if m is not None: |
| |
| T = int(m[1]) |
| T_latent = int(m[2]) |
| clip_latent_noise = True |
| print( |
| f'evaluating FID (clip latent noise) T = {T}... latent T = {T_latent}' |
| ) |
| else: |
| |
| _, T = each.split('fid') |
| T = int(T) |
| T_latent = None |
| print(f'evaluating FID T = {T}...') |
|
|
| self.train_dataloader() |
| sampler = self.conf._make_diffusion_conf(T=T).make_sampler() |
| if T_latent is not None: |
| latent_sampler = self.conf._make_latent_diffusion_conf( |
| T=T_latent).make_sampler() |
| else: |
| latent_sampler = None |
|
|
| conf = self.conf.clone() |
| conf.eval_num_images = 50_000 |
| score = evaluate_fid( |
| sampler, |
| self.ema_model, |
| conf, |
| device=self.device, |
| train_data=self.train_data, |
| val_data=self.val_data, |
| latent_sampler=latent_sampler, |
| conds_mean=self.conds_mean, |
| conds_std=self.conds_std, |
| remove_cache=False, |
| clip_latent_noise=clip_latent_noise, |
| ) |
| if T_latent is None: |
| self.log(f'fid_ema_T{T}', score) |
| else: |
| name = 'fid' |
| if clip_latent_noise: |
| name += '_clip' |
| name += f'_ema_T{T}_Tlatent{T_latent}' |
| self.log(name, score) |
| """ |
| "recon<T>" = reconstruction & autoencoding (without noise inversion) |
| """ |
| for each in self.conf.eval_programs: |
| if each.startswith('recon'): |
| self.model: BeatGANsAutoencModel |
| _, T = each.split('recon') |
| T = int(T) |
| print(f'evaluating reconstruction T = {T}...') |
|
|
| sampler = self.conf._make_diffusion_conf(T=T).make_sampler() |
|
|
| conf = self.conf.clone() |
| |
| conf.eval_num_images = len(self.val_data) |
| |
| score = evaluate_lpips(sampler, |
| self.ema_model, |
| conf, |
| device=self.device, |
| val_data=self.val_data, |
| latent_sampler=None) |
| for k, v in score.items(): |
| self.log(f'{k}_ema_T{T}', v) |
| """ |
| "inv<T>" = reconstruction with noise inversion |
| """ |
| for each in self.conf.eval_programs: |
| if each.startswith('inv'): |
| self.model: BeatGANsAutoencModel |
| _, T = each.split('inv') |
| T = int(T) |
| print( |
| f'evaluating reconstruction with noise inversion T = {T}...' |
| ) |
|
|
| sampler = self.conf._make_diffusion_conf(T=T).make_sampler() |
|
|
| conf = self.conf.clone() |
| |
| conf.eval_num_images = len(self.val_data) |
| |
| score = evaluate_lpips(sampler, |
| self.ema_model, |
| conf, |
| device=self.device, |
| val_data=self.val_data, |
| latent_sampler=None, |
| use_inverted_noise=True) |
| for k, v in score.items(): |
| self.log(f'{k}_inv_ema_T{T}', v) |
|
|
|
|
| def ema(source, target, decay): |
| source_dict = source.state_dict() |
| target_dict = target.state_dict() |
| for key in source_dict.keys(): |
| target_dict[key].data.copy_(target_dict[key].data * decay + |
| source_dict[key].data * (1 - decay)) |
|
|
|
|
| class WarmupLR: |
| def __init__(self, warmup) -> None: |
| self.warmup = warmup |
|
|
| def __call__(self, step): |
| return min(step, self.warmup) / self.warmup |
|
|
|
|
| def is_time(num_samples, every, step_size): |
| closest = (num_samples // every) * every |
| return num_samples - closest < step_size |
|
|
|
|
| def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): |
| print('conf:', conf.name) |
| |
| |
| model = LitModel(conf) |
|
|
| if not os.path.exists(conf.logdir): |
| os.makedirs(conf.logdir) |
| checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', |
| save_last=True, |
| save_top_k=1, |
| every_n_train_steps=conf.save_every_samples // |
| conf.batch_size_effective) |
| checkpoint_path = f'{conf.logdir}/last.ckpt' |
| print('ckpt path:', checkpoint_path) |
| if os.path.exists(checkpoint_path): |
| resume = checkpoint_path |
| print('resume!') |
| else: |
| if conf.continue_from is not None: |
| |
| resume = conf.continue_from.path |
| else: |
| resume = None |
|
|
| tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, |
| name=None, |
| version='') |
|
|
| |
|
|
| plugins = [] |
| if len(gpus) == 1 and nodes == 1: |
| accelerator = None |
| else: |
| accelerator = 'ddp' |
| from pytorch_lightning.plugins import DDPPlugin |
|
|
| |
| plugins.append(DDPPlugin(find_unused_parameters=False)) |
|
|
| trainer = pl.Trainer( |
| max_steps=conf.total_samples // conf.batch_size_effective, |
| resume_from_checkpoint=resume, |
| gpus=gpus, |
| num_nodes=nodes, |
| accelerator=accelerator, |
| precision=16 if conf.fp16 else 32, |
| callbacks=[ |
| checkpoint, |
| LearningRateMonitor(), |
| ], |
| |
| |
| replace_sampler_ddp=True, |
| logger=tb_logger, |
| accumulate_grad_batches=conf.accum_batches, |
| plugins=plugins, |
| ) |
|
|
| if mode == 'train': |
| trainer.fit(model) |
| elif mode == 'eval': |
| |
| |
| |
| dummy = DataLoader(TensorDataset(torch.tensor([0.] * conf.batch_size)), |
| batch_size=conf.batch_size) |
| eval_path = conf.eval_path or checkpoint_path |
| |
| print('loading from:', eval_path) |
| state = torch.load(eval_path, map_location='cpu') |
| print('step:', state['global_step']) |
| model.load_state_dict(state['state_dict']) |
| |
| out = trainer.test(model, dataloaders=dummy) |
| |
| out = out[0] |
| print(out) |
|
|
| if get_rank() == 0: |
| |
| for k, v in out.items(): |
| tb_logger.experiment.add_scalar( |
| k, v, state['global_step'] * conf.batch_size_effective) |
|
|
| |
| |
| |
| |
| tgt = f'evals/{conf.name}.txt' |
| dirname = os.path.dirname(tgt) |
| if not os.path.exists(dirname): |
| os.makedirs(dirname) |
| with open(tgt, 'a') as f: |
| f.write(json.dumps(out) + "\n") |
| |
| else: |
| raise NotImplementedError() |
|
|