| import copy |
| |
| import functools |
| import json |
| import os |
| from pathlib import Path |
| from pdb import set_trace as st |
| from einops import rearrange |
| import webdataset as wds |
|
|
| import traceback |
| import blobfile as bf |
| import imageio |
| import numpy as np |
| |
| import torch as th |
| import torch.distributed as dist |
| import torchvision |
| from PIL import Image |
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
| from torch.optim import AdamW |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
|
|
| from guided_diffusion import dist_util, logger |
| from guided_diffusion.fp16_util import MixedPrecisionTrainer |
| from guided_diffusion.nn import update_ema |
| from guided_diffusion.resample import LossAwareSampler, UniformSampler |
| from guided_diffusion.train_util import (calc_average_loss, |
| find_ema_checkpoint, |
| find_resume_checkpoint, |
| get_blob_logdir, log_rec3d_loss_dict, |
| parse_resume_step_from_filename) |
|
|
| from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics |
|
|
| from .train_util import TrainLoop3DRec |
|
|
|
|
| class TrainLoop3DRecNV(TrainLoop3DRec): |
| |
| def __init__(self, |
| *, |
| rec_model, |
| loss_class, |
| data, |
| eval_data, |
| batch_size, |
| microbatch, |
| lr, |
| ema_rate, |
| log_interval, |
| eval_interval, |
| save_interval, |
| resume_checkpoint, |
| use_fp16=False, |
| fp16_scale_growth=0.001, |
| weight_decay=0, |
| lr_anneal_steps=0, |
| iterations=10001, |
| load_submodule_name='', |
| ignore_resume_opt=False, |
| model_name='rec', |
| use_amp=False, |
| **kwargs): |
| super().__init__(rec_model=rec_model, |
| loss_class=loss_class, |
| data=data, |
| eval_data=eval_data, |
| batch_size=batch_size, |
| microbatch=microbatch, |
| lr=lr, |
| ema_rate=ema_rate, |
| log_interval=log_interval, |
| eval_interval=eval_interval, |
| save_interval=save_interval, |
| resume_checkpoint=resume_checkpoint, |
| use_fp16=use_fp16, |
| fp16_scale_growth=fp16_scale_growth, |
| weight_decay=weight_decay, |
| lr_anneal_steps=lr_anneal_steps, |
| iterations=iterations, |
| load_submodule_name=load_submodule_name, |
| ignore_resume_opt=ignore_resume_opt, |
| model_name=model_name, |
| use_amp=use_amp, |
| **kwargs) |
| self.rec_cano = True |
|
|
| def forward_backward(self, batch, *args, **kwargs): |
| |
|
|
| self.mp_trainer_rec.zero_grad() |
| batch_size = batch['img_to_encoder'].shape[0] |
|
|
| for i in range(0, batch_size, self.microbatch): |
|
|
| |
| micro = { |
| k: v[i:i + self.microbatch].to(dist_util.dev()) |
| for k, v in batch.items() |
| } |
|
|
| |
|
|
| |
| with th.autocast(device_type='cuda', |
| dtype=th.float16, |
| enabled=self.mp_trainer_rec.use_amp): |
|
|
| target_nvs = {} |
| target_cano = {} |
|
|
| latent = self.rec_model(img=micro['img_to_encoder'], |
| behaviour='enc_dec_wo_triplane') |
|
|
| pred = self.rec_model( |
| latent=latent, |
| c=micro['nv_c'], |
| behaviour='triplane_dec') |
|
|
| for k, v in micro.items(): |
| if k[:2] == 'nv': |
| orig_key = k.replace('nv_', '') |
| target_nvs[orig_key] = v |
| target_cano[orig_key] = micro[orig_key] |
|
|
| with self.rec_model.no_sync(): |
| loss, loss_dict, fg_mask = self.loss_class( |
| pred, |
| target_nvs, |
| step=self.step + self.resume_step, |
| test_mode=False, |
| return_fg_mask=True, |
| conf_sigma_l1=None, |
| conf_sigma_percl=None) |
| log_rec3d_loss_dict(loss_dict) |
|
|
| if self.rec_cano: |
|
|
| pred_cano = self.rec_model(latent=latent, |
| c=micro['c'], |
| behaviour='triplane_dec') |
|
|
| with self.rec_model.no_sync(): |
|
|
| fg_mask = target_cano['depth_mask'].unsqueeze( |
| 1).repeat_interleave(3, 1).float() |
|
|
| loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( |
| pred_cano['image_raw'], |
| target_cano['img'], |
| fg_mask, |
| step=self.step + self.resume_step, |
| test_mode=False, |
| ) |
|
|
| loss = loss + loss_cano |
|
|
| |
| log_rec3d_loss_dict({ |
| f'cano_{k}': v |
| for k, v in loss_cano_dict.items() |
| |
| }) |
|
|
| self.mp_trainer_rec.backward(loss) |
|
|
| if dist_util.get_rank() == 0 and self.step % 500 == 0: |
| if self.rec_cano: |
| self.log_img(micro, pred, pred_cano) |
| else: |
| self.log_img(micro, pred, None) |
|
|
| @th.inference_mode() |
| def log_img(self, micro, pred, pred_cano): |
| |
|
|
| def norm_depth(pred_depth): |
| |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
| pred_depth.min()) |
| return -(pred_depth * 2 - 1) |
|
|
| pred_img = pred['image_raw'] |
| gt_img = micro['img'] |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| gt_depth = micro['depth'] |
| if gt_depth.ndim == 3: |
| gt_depth = gt_depth.unsqueeze(1) |
| gt_depth = norm_depth(gt_depth) |
| |
| |
| |
| fg_mask = pred['image_mask'] * 2 - 1 |
| input_fg_mask = pred_cano['image_mask'] * 2 - 1 |
| if 'image_depth' in pred: |
| pred_depth = norm_depth(pred['image_depth']) |
| pred_nv_depth = norm_depth(pred_cano['image_depth']) |
| else: |
| pred_depth = th.zeros_like(gt_depth) |
| pred_nv_depth = th.zeros_like(gt_depth) |
|
|
| if 'image_sr' in pred: |
| if pred['image_sr'].shape[-1] == 512: |
| pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], |
| dim=-1) |
| gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], |
| dim=-1) |
| pred_depth = self.pool_512(pred_depth) |
| gt_depth = self.pool_512(gt_depth) |
|
|
| elif pred['image_sr'].shape[-1] == 256: |
| pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], |
| dim=-1) |
| gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], |
| dim=-1) |
| pred_depth = self.pool_256(pred_depth) |
| gt_depth = self.pool_256(gt_depth) |
|
|
| else: |
| pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], |
| dim=-1) |
| gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], |
| dim=-1) |
| gt_depth = self.pool_128(gt_depth) |
| pred_depth = self.pool_128(pred_depth) |
| else: |
| gt_img = self.pool_64(gt_img) |
| gt_depth = self.pool_64(gt_depth) |
|
|
| pred_vis = th.cat([ |
| pred_img, |
| pred_depth.repeat_interleave(3, dim=1), |
| fg_mask.repeat_interleave(3, dim=1), |
| ], |
| dim=-1) |
|
|
| pred_vis_nv = th.cat([ |
| pred_cano['image_raw'], |
| pred_nv_depth.repeat_interleave(3, dim=1), |
| input_fg_mask.repeat_interleave(3, dim=1), |
| ], |
| dim=-1) |
|
|
| pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) |
|
|
| gt_vis = th.cat([ |
| gt_img, |
| gt_depth.repeat_interleave(3, dim=1), |
| th.zeros_like(gt_img) |
| ], |
| dim=-1) |
|
|
| if 'conf_sigma' in pred: |
| gt_vis = th.cat([gt_vis, fg_mask], dim=-1) |
|
|
| |
| vis = th.cat([gt_vis, pred_vis], dim=-2) |
| |
| |
| vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // |
| 64) |
| torchvision.utils.save_image( |
| vis_tensor, |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', |
| value_range=(-1, 1), |
| normalize=True) |
| |
| |
|
|
| |
| |
|
|
| logger.log('log vis to: ', |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
|
| |
| |
| |
| |
|
|
|
|
| |
|
|
|
|
| class TrainLoop3DRecNVPatch(TrainLoop3DRecNV): |
| |
| def __init__(self, |
| *, |
| rec_model, |
| loss_class, |
| data, |
| eval_data, |
| batch_size, |
| microbatch, |
| lr, |
| ema_rate, |
| log_interval, |
| eval_interval, |
| save_interval, |
| resume_checkpoint, |
| use_fp16=False, |
| fp16_scale_growth=0.001, |
| weight_decay=0, |
| lr_anneal_steps=0, |
| iterations=10001, |
| load_submodule_name='', |
| ignore_resume_opt=False, |
| model_name='rec', |
| use_amp=False, |
| **kwargs): |
| super().__init__(rec_model=rec_model, |
| loss_class=loss_class, |
| data=data, |
| eval_data=eval_data, |
| batch_size=batch_size, |
| microbatch=microbatch, |
| lr=lr, |
| ema_rate=ema_rate, |
| log_interval=log_interval, |
| eval_interval=eval_interval, |
| save_interval=save_interval, |
| resume_checkpoint=resume_checkpoint, |
| use_fp16=use_fp16, |
| fp16_scale_growth=fp16_scale_growth, |
| weight_decay=weight_decay, |
| lr_anneal_steps=lr_anneal_steps, |
| iterations=iterations, |
| load_submodule_name=load_submodule_name, |
| ignore_resume_opt=ignore_resume_opt, |
| model_name=model_name, |
| use_amp=use_amp, |
| **kwargs) |
| |
| self.eg3d_model = self.rec_model.module.decoder.triplane_decoder |
| |
| self.rec_cano = True |
|
|
| def forward_backward(self, batch, *args, **kwargs): |
| |
|
|
| self.mp_trainer_rec.zero_grad() |
| batch_size = batch['img_to_encoder'].shape[0] |
|
|
| for i in range(0, batch_size, self.microbatch): |
|
|
| micro = { |
| k: v[i:i + self.microbatch].to(dist_util.dev()) |
| for k, v in batch.items() |
| } |
|
|
| |
| target = { |
| **self.eg3d_model( |
| c=micro['nv_c'], |
| ws=None, |
| planes=None, |
| sample_ray_only=True, |
| fg_bbox=micro['nv_bbox']), |
| } |
|
|
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ |
| 'patch_rendering_resolution'] |
| cropped_target = { |
| k: |
| th.empty_like(v) |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] |
| if k not in [ |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', |
| 'nv_img_sr', 'c' |
| ] else v |
| for k, v in micro.items() |
| } |
|
|
| |
| for j in range(micro['img'].shape[0]): |
| top, left, height, width = target['ray_bboxes'][ |
| j] |
| |
| for key in ('img', 'depth_mask', 'depth'): |
| |
| |
| |
| |
|
|
| cropped_target[f'{key}'][ |
| j:j + 1] = torchvision.transforms.functional.crop( |
| micro[f'nv_{key}'][j:j + 1], top, left, height, |
| width) |
|
|
| |
|
|
| |
| with th.autocast(device_type='cuda', |
| dtype=th.float16, |
| enabled=self.mp_trainer_rec.use_amp): |
|
|
| |
| |
|
|
| latent = self.rec_model(img=micro['img_to_encoder'], |
| behaviour='enc_dec_wo_triplane') |
|
|
| pred_nv = self.rec_model( |
| latent=latent, |
| c=micro['nv_c'], |
| behaviour='triplane_dec', |
| ray_origins=target['ray_origins'], |
| ray_directions=target['ray_directions'], |
| ) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| with self.rec_model.no_sync(): |
| loss, loss_dict, _ = self.loss_class(pred_nv, |
| cropped_target, |
| step=self.step + |
| self.resume_step, |
| test_mode=False, |
| return_fg_mask=True, |
| conf_sigma_l1=None, |
| conf_sigma_percl=None) |
| log_rec3d_loss_dict(loss_dict) |
|
|
| if self.rec_cano: |
|
|
| cano_target = { |
| **self.eg3d_model( |
| c=micro['c'], |
| ws=None, |
| planes=None, |
| sample_ray_only=True, |
| fg_bbox=micro['bbox']), |
| } |
|
|
| cano_cropped_target = { |
| k: th.empty_like(v) |
| for k, v in cropped_target.items() |
| } |
|
|
| for j in range(micro['img'].shape[0]): |
| top, left, height, width = cano_target['ray_bboxes'][ |
| j] |
| |
| for key in ('img', 'depth_mask', |
| 'depth'): |
| |
| cano_cropped_target[key][ |
| j:j + |
| 1] = torchvision.transforms.functional.crop( |
| micro[key][j:j + 1], top, left, height, |
| width) |
|
|
| |
|
|
| pred_cano = self.rec_model( |
| latent=latent, |
| c=micro['c'], |
| behaviour='triplane_dec', |
| ray_origins=cano_target['ray_origins'], |
| ray_directions=cano_target['ray_directions'], |
| ) |
|
|
| with self.rec_model.no_sync(): |
|
|
| fg_mask = cano_cropped_target['depth_mask'].unsqueeze( |
| 1).repeat_interleave(3, 1).float() |
|
|
| loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( |
| pred_cano['image_raw'], |
| cano_cropped_target['img'], |
| fg_mask, |
| step=self.step + self.resume_step, |
| test_mode=False, |
| ) |
|
|
| loss = loss + loss_cano |
|
|
| |
| log_rec3d_loss_dict({ |
| f'cano_{k}': v |
| for k, v in loss_cano_dict.items() |
| |
| }) |
|
|
| self.mp_trainer_rec.backward(loss) |
|
|
| if dist_util.get_rank() == 0 and self.step % 500 == 0: |
| self.log_patch_img(cropped_target, pred_nv, pred_cano) |
|
|
| @th.inference_mode() |
| def log_patch_img(self, micro, pred, pred_cano): |
| |
|
|
| def norm_depth(pred_depth): |
| |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
| pred_depth.min()) |
| return -(pred_depth * 2 - 1) |
|
|
| pred_img = pred['image_raw'] |
| gt_img = micro['img'] |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| gt_depth = micro['depth'] |
| if gt_depth.ndim == 3: |
| gt_depth = gt_depth.unsqueeze(1) |
| gt_depth = norm_depth(gt_depth) |
| |
| |
| |
| fg_mask = pred['image_mask'] * 2 - 1 |
| input_fg_mask = pred_cano['image_mask'] * 2 - 1 |
| if 'image_depth' in pred: |
| pred_depth = norm_depth(pred['image_depth']) |
| pred_cano_depth = norm_depth(pred_cano['image_depth']) |
| else: |
| pred_depth = th.zeros_like(gt_depth) |
| pred_cano_depth = th.zeros_like(gt_depth) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| pred_vis = th.cat([ |
| pred_img, |
| pred_depth.repeat_interleave(3, dim=1), |
| fg_mask.repeat_interleave(3, dim=1), |
| ], |
| dim=-1) |
|
|
| pred_vis_nv = th.cat([ |
| pred_cano['image_raw'], |
| pred_cano_depth.repeat_interleave(3, dim=1), |
| input_fg_mask.repeat_interleave(3, dim=1), |
| ], |
| dim=-1) |
|
|
| pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) |
|
|
| gt_vis = th.cat([ |
| gt_img, |
| gt_depth.repeat_interleave(3, dim=1), |
| th.zeros_like(gt_img) |
| ], |
| dim=-1) |
|
|
| |
| |
|
|
| |
| |
| vis = th.cat([gt_vis, pred_vis], dim=-2) |
| |
| |
| vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // |
| 64) |
| torchvision.utils.save_image( |
| vis_tensor, |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', |
| value_range=(-1, 1), |
| normalize=True) |
|
|
| logger.log('log vis to: ', |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
|
| |
| |
| |
| |
|
|
|
|
| class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch): |
|
|
| def __init__(self, |
| *, |
| rec_model, |
| loss_class, |
| data, |
| eval_data, |
| batch_size, |
| microbatch, |
| lr, |
| ema_rate, |
| log_interval, |
| eval_interval, |
| save_interval, |
| resume_checkpoint, |
| use_fp16=False, |
| fp16_scale_growth=0.001, |
| weight_decay=0, |
| lr_anneal_steps=0, |
| iterations=10001, |
| load_submodule_name='', |
| ignore_resume_opt=False, |
| model_name='rec', |
| use_amp=False, |
| **kwargs): |
| super().__init__(rec_model=rec_model, |
| loss_class=loss_class, |
| data=data, |
| eval_data=eval_data, |
| batch_size=batch_size, |
| microbatch=microbatch, |
| lr=lr, |
| ema_rate=ema_rate, |
| log_interval=log_interval, |
| eval_interval=eval_interval, |
| save_interval=save_interval, |
| resume_checkpoint=resume_checkpoint, |
| use_fp16=use_fp16, |
| fp16_scale_growth=fp16_scale_growth, |
| weight_decay=weight_decay, |
| lr_anneal_steps=lr_anneal_steps, |
| iterations=iterations, |
| load_submodule_name=load_submodule_name, |
| ignore_resume_opt=ignore_resume_opt, |
| model_name=model_name, |
| use_amp=use_amp, |
| **kwargs) |
|
|
| def forward_backward(self, batch, *args, **kwargs): |
| |
|
|
| self.mp_trainer_rec.zero_grad() |
| batch_size = batch['img_to_encoder'].shape[0] |
|
|
| batch.pop('caption') |
| batch.pop('ins') |
| |
|
|
| for i in range(0, batch_size, self.microbatch): |
|
|
| micro = { |
| k: |
| v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( |
| v, th.Tensor) else v[i:i + self.microbatch] |
| for k, v in batch.items() |
| } |
|
|
| |
| target = { |
| **self.eg3d_model( |
| c=micro['nv_c'], |
| ws=None, |
| planes=None, |
| sample_ray_only=True, |
| fg_bbox=micro['nv_bbox']), |
| } |
|
|
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ |
| 'patch_rendering_resolution'] |
| cropped_target = { |
| k: |
| th.empty_like(v) |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] |
| if k not in [ |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', |
| 'nv_img_sr', 'c', 'caption', 'nv_caption' |
| ] else v |
| for k, v in micro.items() |
| } |
|
|
| |
| for j in range(micro['img'].shape[0]): |
| top, left, height, width = target['ray_bboxes'][ |
| j] |
| |
| for key in ('img', 'depth_mask', 'depth'): |
| |
| |
| |
| |
|
|
| cropped_target[f'{key}'][ |
| j:j + 1] = torchvision.transforms.functional.crop( |
| micro[f'nv_{key}'][j:j + 1], top, left, height, |
| width) |
|
|
| |
| cano_target = { |
| **self.eg3d_model( |
| c=micro['c'], |
| ws=None, |
| planes=None, |
| sample_ray_only=True, |
| fg_bbox=micro['bbox']), |
| } |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| latent = self.rec_model(img=micro['img_to_encoder'], |
| behaviour='enc_dec_wo_triplane') |
|
|
| |
| with th.autocast(device_type='cuda', |
| dtype=th.float16, |
| enabled=self.mp_trainer_rec.use_amp): |
|
|
| |
| |
| instance_mv_num = batch_size // 4 |
| |
| |
| c = th.cat([ |
| micro['nv_c'].roll(instance_mv_num * i, dims=0) |
| for i in range(1, 4) |
| ] |
| |
| ) |
|
|
| ray_origins = th.cat( |
| [ |
| target['ray_origins'].roll(instance_mv_num * i, dims=0) |
| for i in range(1, 4) |
| ] |
| |
| , |
| 0) |
|
|
| ray_directions = th.cat([ |
| target['ray_directions'].roll(instance_mv_num * i, dims=0) |
| for i in range(1, 4) |
| ] |
| |
| ) |
|
|
| pred_nv_cano = self.rec_model( |
| |
| latent={ |
| 'latent_after_vit': |
| |
| latent['latent_after_vit'].repeat(3, 1, 1, 1) |
| }, |
| c=c, |
| behaviour='triplane_dec', |
| |
| |
| ray_origins=ray_origins, |
| ray_directions=ray_directions, |
| ) |
|
|
| pred_nv_cano.update( |
| latent |
| ) |
| |
| |
| |
| |
| gt = { |
| k: |
| th.cat( |
| [ |
| v.roll(instance_mv_num * i, dims=0) |
| for i in range(1, 4) |
| ] |
| |
| , |
| 0) |
| for k, v in cropped_target.items() |
| } |
|
|
| with self.rec_model.no_sync(): |
| loss, loss_dict, _ = self.loss_class( |
| pred_nv_cano, |
| gt, |
| step=self.step + self.resume_step, |
| test_mode=False, |
| return_fg_mask=True, |
| conf_sigma_l1=None, |
| conf_sigma_percl=None) |
| log_rec3d_loss_dict(loss_dict) |
|
|
| self.mp_trainer_rec.backward(loss) |
|
|
| |
| |
| |
|
|
| if dist_util.get_rank() == 0 and self.step % 500 == 0: |
| micro_bs = micro['img_to_encoder'].shape[0] |
| self.log_patch_img( |
| cropped_target, |
| { |
| k: pred_nv_cano[k][-micro_bs:] |
| for k in ['image_raw', 'image_depth', 'image_mask'] |
| }, |
| { |
| k: pred_nv_cano[k][:micro_bs] |
| for k in ['image_raw', 'image_depth', 'image_mask'] |
| }, |
| ) |
|
|
| def eval_loop(self): |
| return super().eval_loop() |
|
|
| @th.inference_mode() |
| |
| def eval_novelview_loop_old(self, camera=None): |
| |
|
|
| all_loss_dict = [] |
| novel_view_micro = {} |
|
|
| |
|
|
| export_mesh = True |
| if export_mesh: |
| Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True, |
| exist_ok=True) |
|
|
| |
|
|
| batch = {} |
| |
| |
| |
| |
| |
|
|
| for eval_idx, render_reference in enumerate(tqdm(self.eval_data)): |
|
|
| if eval_idx > 500: |
| break |
|
|
| video_out = imageio.get_writer( |
| f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4', |
| mode='I', |
| fps=25, |
| codec='libx264') |
|
|
| with open( |
| f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt', |
| 'w') as f: |
| f.write(render_reference['caption']) |
|
|
| for key in ['ins', 'bbox', 'caption']: |
| if key in render_reference: |
| render_reference.pop(key) |
|
|
| real_flag = False |
| mv_flag = False |
| if render_reference['c'].shape[:2] == (1, 40): |
| real_flag = True |
| |
| |
| render_reference = [{ |
| k: v[0][idx:idx + 1] |
| for k, v in render_reference.items() |
| } for idx in range(40)] |
|
|
| elif render_reference['c'].shape[0] == 8: |
| mv_flag = True |
|
|
| render_reference = { |
| k: v[:4] |
| for k, v in render_reference.items() |
| } |
|
|
| |
| torchvision.utils.save_image( |
| render_reference[0:4]['img'], |
| logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx), |
| padding=0, |
| normalize=True, |
| value_range=(-1, 1), |
| ) |
| |
| |
| |
| |
| |
| |
|
|
| else: |
| |
| st() |
| render_reference = [{ |
| k: v[idx:idx + 1] |
| for k, v in render_reference.items() |
| } for idx in range(40)] |
|
|
| |
| render_reference[0]['img_to_encoder'] = render_reference[14][ |
| 'img_to_encoder'] |
| render_reference[0]['img'] = render_reference[14][ |
| 'img'] |
|
|
| |
| torchvision.utils.save_image( |
| render_reference[0]['img'], |
| logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx), |
| padding=0, |
| normalize=True, |
| value_range=(-1, 1)) |
|
|
| |
| for i, batch in enumerate(render_reference): |
| |
| |
| micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
| st() |
| if i == 0: |
| if mv_flag: |
| novel_view_micro = None |
| else: |
| novel_view_micro = { |
| k: |
| v[0:1].to(dist_util.dev()).repeat_interleave( |
| |
| micro['img'].shape[0], |
| 0) if isinstance(v, th.Tensor) else v[0:1] |
| for k, v in batch.items() |
| } |
|
|
| else: |
| if i == 1: |
|
|
| |
| if export_mesh: |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| mesh_size = 384 |
| |
| |
| |
| mesh_thres = 10 |
| import mcubes |
| import trimesh |
| dump_path = f'{logger.get_dir()}/mesh/' |
|
|
| os.makedirs(dump_path, exist_ok=True) |
|
|
| grid_out = self.rec_model( |
| latent=pred, |
| grid_size=mesh_size, |
| behaviour='triplane_decode_grid', |
| ) |
|
|
| vtx, faces = mcubes.marching_cubes( |
| grid_out['sigma'].squeeze(0).squeeze( |
| -1).cpu().numpy(), mesh_thres) |
| vtx = vtx / (mesh_size - 1) * 2 - 1 |
|
|
| |
| |
| |
|
|
| |
| mesh = trimesh.Trimesh( |
| vertices=vtx, |
| faces=faces, |
| ) |
|
|
| mesh_dump_path = os.path.join( |
| dump_path, f'{eval_idx}.ply') |
| mesh.export(mesh_dump_path, 'ply') |
|
|
| print(f"Mesh dumped to {dump_path}") |
| del grid_out, mesh |
| th.cuda.empty_cache() |
| |
| |
|
|
| |
| novel_view_micro = { |
| k: |
| v[0:1].to(dist_util.dev()).repeat_interleave( |
| micro['img'].shape[0], 0) |
| for k, v in novel_view_micro.items() |
| } |
|
|
| pred = self.rec_model(img=novel_view_micro['img_to_encoder'], |
| c=micro['c']) |
| |
| |
| |
| |
| |
| |
|
|
| |
| if not real_flag: |
| _, loss_dict = self.loss_class(pred, micro, test_mode=True) |
| all_loss_dict.append(loss_dict) |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| pred_depth = pred['image_depth'] |
| pred_depth = (pred_depth - pred_depth.min()) / ( |
| pred_depth.max() - pred_depth.min()) |
| if 'image_sr' in pred: |
|
|
| if pred['image_sr'].shape[-1] == 512: |
|
|
| pred_vis = th.cat([ |
| micro['img_sr'], |
| self.pool_512(pred['image_raw']), pred['image_sr'], |
| self.pool_512(pred_depth).repeat_interleave(3, |
| dim=1) |
| ], |
| dim=-1) |
|
|
| elif pred['image_sr'].shape[-1] == 256: |
|
|
| pred_vis = th.cat([ |
| micro['img_sr'], |
| self.pool_256(pred['image_raw']), pred['image_sr'], |
| self.pool_256(pred_depth).repeat_interleave(3, |
| dim=1) |
| ], |
| dim=-1) |
|
|
| else: |
| pred_vis = th.cat([ |
| micro['img_sr'], |
| self.pool_128(pred['image_raw']), |
| self.pool_128(pred['image_sr']), |
| self.pool_128(pred_depth).repeat_interleave(3, |
| dim=1) |
| ], |
| dim=-1) |
|
|
| else: |
| |
| |
| |
| |
| |
|
|
| pooled_depth = self.pool_128(pred_depth).repeat_interleave( |
| 3, dim=1) |
| pred_vis = th.cat( |
| [ |
| |
| self.pool_128(novel_view_micro['img'] |
| ), |
| self.pool_128(pred['image_raw']), |
| pooled_depth, |
| ], |
| dim=-1) |
|
|
| vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
| vis = vis * 127.5 + 127.5 |
| vis = vis.clip(0, 255).astype(np.uint8) |
|
|
| if export_mesh: |
| |
| torchvision.utils.save_image( |
| pred['image_raw'], |
| logger.get_dir() + |
| '/FID_Cals/{}_{}.png'.format(eval_idx, i), |
| padding=0, |
| normalize=True, |
| value_range=(-1, 1)) |
|
|
| torchvision.utils.save_image( |
| pooled_depth, |
| logger.get_dir() + |
| '/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i), |
| padding=0, |
| normalize=True, |
| value_range=(0, 1)) |
|
|
| |
|
|
| for j in range(vis.shape[0]): |
| video_out.append_data(vis[j]) |
|
|
| video_out.close() |
|
|
| |
| if not real_flag or mv_flag: |
| val_scores_for_logging = calc_average_loss(all_loss_dict) |
| with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), |
| 'a') as f: |
| json.dump({'step': self.step, **val_scores_for_logging}, f) |
|
|
| |
| for k, v in val_scores_for_logging.items(): |
| self.writer.add_scalar(f'Eval/NovelView/{k}', v, |
| self.step + self.resume_step) |
|
|
| del video_out |
| |
| |
|
|
| th.cuda.empty_cache() |
|
|
| @th.inference_mode() |
| |
| def eval_novelview_loop(self, camera=None, save_latent=False): |
| |
| if save_latent: |
| latent_dir = Path(f'{logger.get_dir()}/latent_dir') |
| latent_dir.mkdir(exist_ok=True, parents=True) |
|
|
| |
| |
| |
|
|
| |
| |
| eval_batch_size = 40 |
|
|
| for eval_idx, micro in enumerate(tqdm(self.eval_data)): |
|
|
| |
| |
|
|
| latent = self.rec_model( |
| img=micro['img_to_encoder'][:4], |
| behaviour='encoder_vae') |
| |
| if micro['img'].shape[0] == 40: |
| assert eval_batch_size == 40 |
|
|
| if save_latent: |
| |
|
|
| latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' |
| Path(latent_save_dir).mkdir(parents=True, exist_ok=True) |
|
|
| np.save(f'{latent_save_dir}/latent.npy', |
| latent[self.latent_name][0].cpu().numpy()) |
| assert all([ |
| micro['ins'][0] == micro['ins'][i] |
| for i in range(micro['c'].shape[0]) |
| ]) |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| if eval_idx < 50: |
| |
| self.render_video_given_triplane( |
| latent[self.latent_name], |
| self.rec_model, |
| name_prefix=f'{self.step + self.resume_step}_{eval_idx}', |
| save_img=False, |
| render_reference={'c': camera}, |
| save_mesh=True) |
|
|
|
|
| class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward): |
|
|
| def __init__(self, |
| *, |
| rec_model, |
| loss_class, |
| data, |
| eval_data, |
| batch_size, |
| microbatch, |
| lr, |
| ema_rate, |
| log_interval, |
| eval_interval, |
| save_interval, |
| resume_checkpoint, |
| use_fp16=False, |
| fp16_scale_growth=0.001, |
| weight_decay=0, |
| lr_anneal_steps=0, |
| iterations=10001, |
| load_submodule_name='', |
| ignore_resume_opt=False, |
| model_name='rec', |
| use_amp=False, |
| **kwargs): |
| super().__init__(rec_model=rec_model, |
| loss_class=loss_class, |
| data=data, |
| eval_data=eval_data, |
| batch_size=batch_size, |
| microbatch=microbatch, |
| lr=lr, |
| ema_rate=ema_rate, |
| log_interval=log_interval, |
| eval_interval=eval_interval, |
| save_interval=save_interval, |
| resume_checkpoint=resume_checkpoint, |
| use_fp16=use_fp16, |
| fp16_scale_growth=fp16_scale_growth, |
| weight_decay=weight_decay, |
| lr_anneal_steps=lr_anneal_steps, |
| iterations=iterations, |
| load_submodule_name=load_submodule_name, |
| ignore_resume_opt=ignore_resume_opt, |
| model_name=model_name, |
| use_amp=use_amp, |
| **kwargs) |
|
|
| def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): |
| |
|
|
| self.mp_trainer_rec.zero_grad() |
| batch_size = batch['img_to_encoder'].shape[0] |
|
|
| batch.pop('caption') |
| batch.pop('ins') |
| if '__key__' in batch.keys(): |
| batch.pop('__key__') |
|
|
| for i in range(0, batch_size, self.microbatch): |
|
|
| micro = { |
| k: |
| v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( |
| v, th.Tensor) else v[i:i + self.microbatch] |
| for k, v in batch.items() |
| } |
|
|
| |
| |
| nv_c = th.cat([micro['nv_c'], micro['c']]) |
| |
| target = { |
| **self.eg3d_model( |
| c=nv_c, |
| ws=None, |
| planes=None, |
| sample_ray_only=True, |
| fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), |
| } |
|
|
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ |
| 'patch_rendering_resolution'] |
| cropped_target = { |
| k: |
| th.empty_like(v).repeat_interleave(2, 0) |
| |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] |
| if k not in [ |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', |
| 'nv_img_sr', 'c', 'caption', 'nv_caption' |
| ] else v |
| for k, v in micro.items() |
| } |
|
|
| |
| for j in range(2 * self.microbatch): |
| top, left, height, width = target['ray_bboxes'][ |
| j] |
| |
| for key in ('img', 'depth_mask', 'depth'): |
|
|
| if j < self.microbatch: |
| cropped_target[f'{key}'][ |
| j:j + 1] = torchvision.transforms.functional.crop( |
| micro[f'nv_{key}'][j:j + 1], top, left, height, |
| width) |
| else: |
| cropped_target[f'{key}'][ |
| j:j + 1] = torchvision.transforms.functional.crop( |
| micro[f'{key}'][j - self.microbatch:j - |
| self.microbatch + 1], top, |
| left, height, width) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| latent = self.rec_model(img=micro['img_to_encoder'], |
| behaviour='enc_dec_wo_triplane') |
|
|
| |
| with th.autocast(device_type='cuda', |
| dtype=th.float16, |
| enabled=self.mp_trainer_rec.use_amp): |
|
|
| |
| |
| |
| |
| |
| |
| ray_origins = target['ray_origins'] |
| ray_directions = target['ray_directions'] |
|
|
| pred_nv_cano = self.rec_model( |
| |
| latent={ |
| 'latent_after_vit': |
| latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) |
| |
| }, |
| c=nv_c, |
| behaviour='triplane_dec', |
| ray_origins=ray_origins, |
| ray_directions=ray_directions, |
| ) |
|
|
| pred_nv_cano.update( |
| latent |
| ) |
| gt = cropped_target |
|
|
| with self.rec_model.no_sync(): |
| loss, loss_dict, _ = self.loss_class( |
| pred_nv_cano, |
| gt, |
| step=self.step + self.resume_step, |
| test_mode=False, |
| return_fg_mask=True, |
| behaviour=behaviour, |
| conf_sigma_l1=None, |
| conf_sigma_percl=None) |
| log_rec3d_loss_dict(loss_dict) |
|
|
| self.mp_trainer_rec.backward(loss) |
|
|
| |
| |
| |
| |
| |
|
|
| if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0: |
| try: |
| torchvision.utils.save_image( |
| th.cat( |
| [cropped_target['img'], pred_nv_cano['image_raw'] |
| ], ), |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', |
| normalize=True) |
|
|
| logger.log( |
| 'log vis to: ', |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
| except Exception as e: |
| logger.log(e) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
| class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss( |
| TrainLoop3DRecNVPatchSingleForwardMV): |
|
|
| def __init__(self, |
| *, |
| rec_model, |
| loss_class, |
| data, |
| eval_data, |
| batch_size, |
| microbatch, |
| lr, |
| ema_rate, |
| log_interval, |
| eval_interval, |
| save_interval, |
| resume_checkpoint, |
| use_fp16=False, |
| fp16_scale_growth=0.001, |
| weight_decay=0, |
| lr_anneal_steps=0, |
| iterations=10001, |
| load_submodule_name='', |
| ignore_resume_opt=False, |
| model_name='rec', |
| use_amp=False, |
| **kwargs): |
| super().__init__(rec_model=rec_model, |
| loss_class=loss_class, |
| data=data, |
| eval_data=eval_data, |
| batch_size=batch_size, |
| microbatch=microbatch, |
| lr=lr, |
| ema_rate=ema_rate, |
| log_interval=log_interval, |
| eval_interval=eval_interval, |
| save_interval=save_interval, |
| resume_checkpoint=resume_checkpoint, |
| use_fp16=use_fp16, |
| fp16_scale_growth=fp16_scale_growth, |
| weight_decay=weight_decay, |
| lr_anneal_steps=lr_anneal_steps, |
| iterations=iterations, |
| load_submodule_name=load_submodule_name, |
| ignore_resume_opt=ignore_resume_opt, |
| model_name=model_name, |
| use_amp=use_amp, |
| **kwargs) |
|
|
| |
| disc_params = self.loss_class.get_trainable_parameters() |
|
|
| self.mp_trainer_disc = MixedPrecisionTrainer( |
| model=self.loss_class.discriminator, |
| use_fp16=self.use_fp16, |
| fp16_scale_growth=fp16_scale_growth, |
| model_name='disc', |
| use_amp=use_amp, |
| model_params=disc_params) |
|
|
| |
| self.opt_disc = AdamW( |
| self.mp_trainer_disc.master_params, |
| lr=self.lr, |
| betas=(0, 0.999), |
| eps=1e-8) |
|
|
| |
| if self.use_ddp: |
| self.ddp_disc = DDP( |
| self.loss_class.discriminator, |
| device_ids=[dist_util.dev()], |
| output_device=dist_util.dev(), |
| broadcast_buffers=False, |
| bucket_cap_mb=128, |
| find_unused_parameters=False, |
| ) |
| else: |
| self.ddp_disc = self.loss_class.discriminator |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def save(self, mp_trainer=None, model_name='rec'): |
| if mp_trainer is None: |
| mp_trainer = self.mp_trainer_rec |
|
|
| def save_checkpoint(rate, params): |
| state_dict = mp_trainer.master_params_to_state_dict(params) |
| if dist_util.get_rank() == 0: |
| logger.log(f"saving model {model_name} {rate}...") |
| if not rate: |
| filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" |
| else: |
| filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" |
| with bf.BlobFile(bf.join(get_blob_logdir(), filename), |
| "wb") as f: |
| th.save(state_dict, f) |
|
|
| save_checkpoint(0, mp_trainer.master_params) |
|
|
| dist.barrier() |
|
|
| def run_step(self, batch, step='g_step'): |
| |
|
|
| if step == 'g_step': |
| self.forward_backward(batch, behaviour='g_step') |
| took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) |
|
|
| if took_step_g_rec: |
| self._update_ema() |
|
|
| elif step == 'd_step': |
| self.forward_backward(batch, behaviour='d_step') |
| _ = self.mp_trainer_disc.optimize(self.opt_disc) |
|
|
| self._anneal_lr() |
| self.log_step() |
|
|
| def run_loop(self, batch=None): |
| while (not self.lr_anneal_steps |
| or self.step + self.resume_step < self.lr_anneal_steps): |
|
|
| batch = next(self.data) |
| self.run_step(batch, 'g_step') |
|
|
| batch = next(self.data) |
| self.run_step(batch, 'd_step') |
|
|
| if self.step % 1000 == 0: |
| dist_util.synchronize() |
| if self.step % 10000 == 0: |
| th.cuda.empty_cache() |
|
|
| if self.step % self.log_interval == 0 and dist_util.get_rank( |
| ) == 0: |
| out = logger.dumpkvs() |
| |
| for k, v in out.items(): |
| self.writer.add_scalar(f'Loss/{k}', v, |
| self.step + self.resume_step) |
|
|
| if self.step % self.eval_interval == 0 and self.step != 0: |
| if dist_util.get_rank() == 0: |
| try: |
| self.eval_loop() |
| except Exception as e: |
| logger.log(e) |
| dist_util.synchronize() |
|
|
| |
| if self.step % self.save_interval == 0: |
| self.save() |
| self.save(self.mp_trainer_disc, |
| self.mp_trainer_disc.model_name) |
| dist_util.synchronize() |
| |
| if os.environ.get("DIFFUSION_TRAINING_TEST", |
| "") and self.step > 0: |
| return |
|
|
| self.step += 1 |
|
|
| if self.step > self.iterations: |
| logger.log('reached maximum iterations, exiting') |
|
|
| |
| if (self.step - |
| 1) % self.save_interval != 0 and self.step != 1: |
| self.save() |
|
|
| exit() |
|
|
| |
| |
| if (self.step - 1) % self.save_interval != 0: |
| self.save() |
| self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) |
|
|