| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from diffusers import StableDiffusionInpaintPipeline |
| import logging |
| import os |
| import pdb |
| import cv2 |
| import shutil |
| import json |
| from pycocotools import mask as coco_mask |
| from datetime import datetime |
| from typing import List, Union |
| import random |
| import safetensors |
| import numpy as np |
| import torch |
| from diffusers import DDPMScheduler |
| from omegaconf import OmegaConf |
| from torch.nn import Conv2d |
| from torch.nn.parameter import Parameter |
| from torch.optim import Adam |
| from torch.optim.lr_scheduler import LambdaLR |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from PIL import Image |
| |
|
|
| from diffusers.schedulers import PNDMScheduler |
| from torchvision.transforms.functional import pil_to_tensor |
| from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput |
| from src.util import metric |
| from src.util.data_loader import skip_first_batches |
| from src.util.logging_util import tb_logger, eval_dic_to_text |
| from src.util.loss import get_loss |
| from src.util.lr_scheduler import IterExponential |
| from src.util.metric import MetricTracker |
| from src.util.multi_res_noise import multi_res_noise_like |
| from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth |
| from src.util.seeding import generate_seed_sequence |
| from accelerate import Accelerator |
| import os |
| from torchvision.transforms import InterpolationMode, Resize, CenterCrop |
| import torchvision.transforms as transforms |
| |
|
|
| class MarigoldInpaintTrainer: |
| def __init__( |
| self, |
| cfg: OmegaConf, |
| model: MarigoldPipeline, |
| train_dataloader: DataLoader, |
| device, |
| base_ckpt_dir, |
| out_dir_ckpt, |
| out_dir_eval, |
| out_dir_vis, |
| accumulation_steps: int, |
| depth_model = None, |
| separate_list: List = None, |
| val_dataloaders: List[DataLoader] = None, |
| vis_dataloaders: List[DataLoader] = None, |
| train_dataset: Dataset = None, |
| timestep_method: str = 'unidiffuser', |
| connection: bool = False |
| ): |
| self.cfg: OmegaConf = cfg |
| self.model: MarigoldPipeline = model |
| self.depth_model = depth_model |
| self.device = device |
| self.seed: Union[int, None] = ( |
| self.cfg.trainer.init_seed |
| ) |
| self.out_dir_ckpt = out_dir_ckpt |
| self.out_dir_eval = out_dir_eval |
| self.out_dir_vis = out_dir_vis |
| self.train_loader: DataLoader = train_dataloader |
| self.val_loaders: List[DataLoader] = val_dataloaders |
| self.vis_loaders: List[DataLoader] = vis_dataloaders |
| self.accumulation_steps: int = accumulation_steps |
| self.separate_list = separate_list |
| self.timestep_method = timestep_method |
| self.train_dataset = train_dataset |
| self.connection = connection |
| |
| |
| |
| |
| |
|
|
| self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss']) |
| |
|
|
| |
| self.model.encode_empty_text() |
| self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) |
|
|
| self.model.unet.enable_xformers_memory_efficient_attention() |
|
|
| |
| self.model.text_encoder.requires_grad_(False) |
| |
|
|
| grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters()) |
|
|
| |
| lr = self.cfg.lr |
| self.optimizer = Adam(grad_part, lr=lr) |
|
|
| total_params = sum(p.numel() for p in self.model.unet.parameters()) |
| total_params_m = total_params / 1_000_000 |
| print(f"Total parameters: {total_params_m:.2f}M") |
| trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad) |
| trainable_params_m = trainable_params / 1_000_000 |
| print(f"Trainable parameters: {trainable_params_m:.2f}M") |
|
|
| |
| lr_func = IterExponential( |
| total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, |
| final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, |
| warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, |
| ) |
| self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) |
|
|
| |
| self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| self.rgb_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( |
| cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler") |
| self.depth_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( |
| cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler") |
|
|
| self.rgb_prediction_type = self.rgb_training_noise_scheduler.config.prediction_type |
| |
| |
| |
| self.depth_prediction_type = self.depth_training_noise_scheduler.config.prediction_type |
| assert ( |
| self.depth_prediction_type == self.model.depth_scheduler.config.prediction_type |
| ), "Different prediction types" |
| self.scheduler_timesteps = ( |
| self.rgb_training_noise_scheduler.config.num_train_timesteps |
| ) |
|
|
| |
| self.max_epoch = self.cfg.max_epoch |
| self.max_iter = self.cfg.max_iter |
| self.gradient_accumulation_steps = accumulation_steps |
| self.gt_depth_type = self.cfg.gt_depth_type |
| self.gt_mask_type = self.cfg.gt_mask_type |
| self.save_period = self.cfg.trainer.save_period |
| self.backup_period = self.cfg.trainer.backup_period |
| self.val_period = self.cfg.trainer.validation_period |
| self.vis_period = self.cfg.trainer.visualization_period |
|
|
| |
| self.apply_multi_res_noise = self.cfg.multi_res_noise is not None |
| if self.apply_multi_res_noise: |
| self.mr_noise_strength = self.cfg.multi_res_noise.strength |
| self.annealed_mr_noise = self.cfg.multi_res_noise.annealed |
| self.mr_noise_downscale_strategy = ( |
| self.cfg.multi_res_noise.downscale_strategy |
| ) |
|
|
| |
| self.epoch = 0 |
| self.n_batch_in_epoch = 0 |
| self.effective_iter = 0 |
| self.in_evaluation = False |
| self.global_seed_sequence: List = [] |
|
|
| def _replace_unet_conv_in(self): |
| |
| _weight = self.model.unet.conv_in.weight.clone() |
| _bias = self.model.unet.conv_in.bias.clone() |
| zero_weight = torch.zeros(_weight.shape).to(_weight.device) |
| _weight = torch.cat([_weight, zero_weight], dim=1) |
| |
| |
| |
| |
| _n_convin_out_channel = self.model.unet.conv_in.out_channels |
| _new_conv_in = Conv2d( |
| 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) |
| ) |
| _new_conv_in.weight = Parameter(_weight) |
| _new_conv_in.bias = Parameter(_bias) |
| self.model.unet.conv_in = _new_conv_in |
| logging.info("Unet conv_in layer is replaced") |
| |
| self.model.unet.config["in_channels"] = 8 |
| logging.info("Unet config is updated") |
| return |
|
|
| def parallel_train(self, t_end=None, accelerator=None): |
| logging.info("Start training") |
| self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare( |
| self.model, self.optimizer, self.train_loader, self.lr_scheduler |
| ) |
| self.depth_model = accelerator.prepare(self.depth_model) |
|
|
| self.accelerator = accelerator |
| if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')): |
| accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest')) |
| self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest')) |
|
|
| |
| |
|
|
| self.train_metrics.reset() |
| accumulated_step = 0 |
| for epoch in range(self.epoch, self.max_epoch + 1): |
| self.epoch = epoch |
| logging.debug(f"epoch: {self.epoch}") |
|
|
| |
| for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): |
| self.model.unet.train() |
|
|
| |
| if self.seed is not None: |
| local_seed = self._get_next_seed() |
| rand_num_generator = torch.Generator(device=self.model.device) |
| rand_num_generator.manual_seed(local_seed) |
| else: |
| rand_num_generator = None |
|
|
| |
|
|
| |
| rgb = batch["rgb_norm"].to(self.model.device) |
| with torch.no_grad(): |
| disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device) |
|
|
| if len(disparities.shape) == 2: |
| disparities = disparities.unsqueeze(0) |
|
|
| depth_gt_for_latent = [] |
| for disparity_map in disparities: |
| depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1 |
| depth_gt_for_latent.append(depth_map) |
| depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0) |
|
|
| batch_size = rgb.shape[0] |
|
|
| mask = self.model.mask_processor.preprocess(batch['mask'] * 255).to(self.model.device) |
|
|
| rgb_timesteps = torch.randint( |
| 0, |
| self.scheduler_timesteps, |
| (batch_size,), |
| device=self.model.device, |
| generator=rand_num_generator, |
| ).long() |
| depth_timesteps = rgb_timesteps |
|
|
| rgb_flag = 1 |
| depth_flag = 1 |
|
|
| if self.timestep_method == 'joint': |
| rgb_mask = mask |
| depth_mask = mask |
|
|
| elif self.timestep_method == 'partition': |
| rand_num = random.random() |
| if rand_num < 0.5: |
| rgb_mask = mask |
| depth_mask = mask |
| elif rand_num < 0.75: |
| rgb_flag = 0 |
| rgb_mask = torch.zeros_like(mask) |
| depth_mask = mask |
| else: |
| depth_flag = 0 |
| rgb_mask = mask |
| if random.random() < 0.5: |
| depth_mask = torch.zeros_like(mask) |
| else: |
| depth_mask = mask |
|
|
| masked_rgb = rgb * (rgb_mask < 0.5) |
| masked_depth = depth_gt_for_latent * (depth_mask.squeeze() < 0.5) |
| with torch.no_grad(): |
| |
| rgb_latent = self.model.encode_rgb(rgb) |
| mask_rgb_latent = self.model.encode_rgb(masked_rgb) |
|
|
| if depth_timesteps.sum() == 0: |
| gt_depth_latent = self.encode_depth(masked_depth) |
| else: |
| gt_depth_latent = self.encode_depth(depth_gt_for_latent) |
| mask_depth_latent = self.encode_depth(masked_depth) |
|
|
| rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) |
| depth_mask = torch.nn.functional.interpolate(depth_mask, size=gt_depth_latent.shape[-2:]) |
|
|
| |
| rgb_noise = torch.randn( |
| rgb_latent.shape, |
| device=self.model.device, |
| generator=rand_num_generator, |
| ) |
| depth_noise = torch.randn( |
| gt_depth_latent.shape, |
| device=self.model.device, |
| generator=rand_num_generator, |
| ) |
|
|
| if rgb_timesteps.sum() == 0: |
| noisy_rgb_latents = rgb_latent |
| else: |
| noisy_rgb_latents = self.rgb_training_noise_scheduler.add_noise( |
| rgb_latent, rgb_noise, rgb_timesteps |
| ) |
| if depth_timesteps.sum() == 0: |
| noisy_depth_latents = gt_depth_latent |
| else: |
| noisy_depth_latents = self.depth_training_noise_scheduler.add_noise( |
| gt_depth_latent, depth_noise, depth_timesteps |
| ) |
|
|
| noisy_latents = torch.cat( |
| [noisy_rgb_latents, rgb_mask, mask_rgb_latent, mask_depth_latent, noisy_depth_latents, depth_mask, mask_rgb_latent, mask_depth_latent], dim=1 |
| ).float() |
|
|
| |
| input_ids = self.model.tokenizer( |
| batch['text'], |
| padding="max_length", |
| max_length=self.model.tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()} |
| text_embed = self.model.text_encoder(**input_ids)[0] |
|
|
| model_pred = self.model.unet( |
| noisy_latents, rgb_timesteps, depth_timesteps, text_embed, controlnet_connection=self.connection |
| ).sample |
|
|
| if torch.isnan(model_pred).any(): |
| logging.warning("model_pred contains NaN.") |
|
|
| |
| if "sample" == self.rgb_prediction_type: |
| rgb_target = rgb_latent |
| elif "epsilon" == self.rgb_prediction_type: |
| rgb_target = rgb_latent |
| elif "v_prediction" == self.rgb_prediction_type: |
| rgb_target = self.rgb_training_noise_scheduler.get_velocity( |
| rgb_latent, rgb_noise, rgb_timesteps |
| ) |
| else: |
| raise ValueError(f"Unknown rgb prediction type {self.prediction_type}") |
|
|
| if "sample" == self.depth_prediction_type: |
| depth_target = gt_depth_latent |
| elif "epsilon" == self.depth_prediction_type: |
| depth_target = gt_depth_latent |
| elif "v_prediction" == self.depth_prediction_type: |
| depth_target = self.depth_training_noise_scheduler.get_velocity( |
| gt_depth_latent, depth_noise, depth_timesteps |
| ) |
| else: |
| raise ValueError(f"Unknown depth prediction type {self.prediction_type}") |
| |
| with accelerator.accumulate(self.model): |
|
|
| rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float()) |
| depth_loss = self.loss(model_pred[:, 4:, :, :].float(), depth_target.float()) |
|
|
| if rgb_flag == 0: |
| loss = depth_loss |
| elif depth_flag == 0: |
| loss = rgb_loss |
| else: |
| loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss |
|
|
| self.train_metrics.update("loss", loss.item()) |
| self.train_metrics.update("rgb_loss", rgb_loss.item()) |
| self.train_metrics.update("depth_loss", depth_loss.item()) |
| |
| accelerator.backward(loss) |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| |
| self.n_batch_in_epoch += 1 |
| |
| self.lr_scheduler.step(self.effective_iter) |
|
|
| if accelerator.sync_gradients: |
| accumulated_step += 1 |
|
|
| if accumulated_step >= self.gradient_accumulation_steps: |
| accumulated_step = 0 |
| self.effective_iter += 1 |
|
|
| if accelerator.is_main_process: |
| |
| if self.effective_iter == 1: |
| self._inpaint_rgbd() |
|
|
| accumulated_loss = self.train_metrics.result()["loss"] |
| rgb_loss = self.train_metrics.result()["rgb_loss"] |
| depth_loss = self.train_metrics.result()["depth_loss"] |
| tb_logger.log_dic( |
| { |
| f"train/{k}": v |
| for k, v in self.train_metrics.result().items() |
| }, |
| global_step=self.effective_iter, |
| ) |
| tb_logger.writer.add_scalar( |
| "lr", |
| self.lr_scheduler.get_last_lr()[0], |
| global_step=self.effective_iter, |
| ) |
| tb_logger.writer.add_scalar( |
| "n_batch_in_epoch", |
| self.n_batch_in_epoch, |
| global_step=self.effective_iter, |
| ) |
| logging.info( |
| f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}" |
| ) |
| accelerator.wait_for_everyone() |
|
|
| if self.save_period > 0 and 0 == self.effective_iter % self.save_period: |
| accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest')) |
| unwrapped_model = accelerator.unwrap_model(self.model) |
| if accelerator.is_main_process: |
| accelerator.save_model(unwrapped_model.unet, |
| os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False) |
| self.save_miscs('latest') |
| self._inpaint_rgbd() |
| accelerator.wait_for_everyone() |
|
|
| if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: |
| unwrapped_model = accelerator.unwrap_model(self.model) |
| if accelerator.is_main_process: |
| accelerator.save_model(unwrapped_model.unet, |
| os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()), safe_serialization=False) |
| accelerator.wait_for_everyone() |
|
|
| |
| if self.max_iter > 0 and self.effective_iter >= self.max_iter: |
| unwrapped_model = accelerator.unwrap_model(self.model) |
| if accelerator.is_main_process: |
| unwrapped_model.unet.save_pretrained( |
| os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) |
| accelerator.wait_for_everyone() |
| return |
|
|
| torch.cuda.empty_cache() |
| |
|
|
| |
| self.n_batch_in_epoch = 0 |
|
|
| def _inpaint_rgbd(self): |
| image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg', |
| '/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg', |
| '/dataset/~sa-1b/data/sa_000045/sa_457934.jpg'] |
| prompt = ['A white car is parked in front of the factory', |
| 'church with cemetery next to it', |
| 'A house with a red brick roof'] |
|
|
| imgs = [pil_to_tensor(Image.open(p)) for p in image_path] |
| depth_imgs = [self.depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs] |
|
|
| masks = [] |
| for rgb_path in image_path: |
| anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations'] |
| random.shuffle(anno) |
| object_num = random.randint(5, 10) |
| mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8) |
| for single_anno in (anno[0:object_num] if len(anno)>object_num else anno): |
| mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8) |
| masks.append(torch.from_numpy(mask)) |
|
|
| resize_transform = transforms.Compose([ |
| Resize(size=512, interpolation=InterpolationMode.NEAREST_EXACT), |
| CenterCrop(size=[512, 512])]) |
| imgs = [resize_transform(img) for img in imgs] |
| depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs] |
| masks = [resize_transform(mask.unsqueeze(0)) for mask in masks] |
| |
|
|
| for i in range(len(imgs)): |
| output_image = self.model._rgbd_inpaint(imgs[i], depth_imgs[i], masks[i], [prompt[i]], processing_res=512, mode='joint_inpaint') |
| tb_logger.writer.add_image(f'{prompt[i]}', pil_to_tensor(output_image), self.effective_iter) |
|
|
| def encode_depth(self, depth_in): |
| |
| stacked = self.stack_depth_images(depth_in) |
| |
| depth_latent = self.model.encode_rgb(stacked) |
| return depth_latent |
|
|
| @staticmethod |
| def stack_depth_images(depth_in): |
| if 4 == len(depth_in.shape): |
| stacked = depth_in.repeat(1, 3, 1, 1) |
| elif 3 == len(depth_in.shape): |
| stacked = depth_in.unsqueeze(1) |
| stacked = stacked.repeat(1, 3, 1, 1) |
| elif 2 == len(depth_in.shape): |
| stacked = depth_in.unsqueeze(0).unsqueeze(0) |
| stacked = stacked.repeat(1, 3, 1, 1) |
| return stacked |
|
|
| def visualize(self): |
| for val_loader in self.vis_loaders: |
| vis_dataset_name = val_loader.dataset.disp_name |
| vis_out_dir = os.path.join( |
| self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name |
| ) |
| os.makedirs(vis_out_dir, exist_ok=True) |
| _ = self.validate_single_dataset( |
| data_loader=val_loader, |
| metric_tracker=self.val_metrics, |
| save_to_dir=vis_out_dir, |
| ) |
|
|
| def _get_next_seed(self): |
| if 0 == len(self.global_seed_sequence): |
| self.global_seed_sequence = generate_seed_sequence( |
| initial_seed=self.seed, |
| length=self.max_iter * self.gradient_accumulation_steps, |
| ) |
| logging.info( |
| f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" |
| ) |
| return self.global_seed_sequence.pop() |
|
|
| def save_miscs(self, ckpt_name): |
| ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) |
| state = { |
| "config": self.cfg, |
| "effective_iter": self.effective_iter, |
| "epoch": self.epoch, |
| "n_batch_in_epoch": self.n_batch_in_epoch, |
| "global_seed_sequence": self.global_seed_sequence, |
| } |
| train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") |
| torch.save(state, train_state_path) |
|
|
| logging.info(f"Misc state is saved to: {train_state_path}") |
|
|
| def load_miscs(self, ckpt_path): |
| checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) |
| self.effective_iter = checkpoint["effective_iter"] |
| self.epoch = checkpoint["epoch"] |
| self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] |
| self.global_seed_sequence = checkpoint["global_seed_sequence"] |
|
|
| logging.info(f"Misc state is loaded from {ckpt_path}") |
|
|
|
|
| def save_checkpoint(self, ckpt_name, save_train_state): |
| ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) |
| logging.info(f"Saving checkpoint to: {ckpt_dir}") |
| |
| temp_ckpt_dir = None |
| if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): |
| temp_ckpt_dir = os.path.join( |
| os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" |
| ) |
| if os.path.exists(temp_ckpt_dir): |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) |
| os.rename(ckpt_dir, temp_ckpt_dir) |
| logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") |
|
|
| |
| unet_path = os.path.join(ckpt_dir, "unet") |
| self.model.unet.save_pretrained(unet_path, safe_serialization=False) |
| logging.info(f"UNet is saved to: {unet_path}") |
|
|
| if save_train_state: |
| state = { |
| "config": self.cfg, |
| "effective_iter": self.effective_iter, |
| "epoch": self.epoch, |
| "n_batch_in_epoch": self.n_batch_in_epoch, |
| "best_metric": self.best_metric, |
| "in_evaluation": self.in_evaluation, |
| "global_seed_sequence": self.global_seed_sequence, |
| } |
| train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") |
| torch.save(state, train_state_path) |
| |
| f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") |
| f.close() |
|
|
| logging.info(f"Trainer state is saved to: {train_state_path}") |
|
|
| |
| if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) |
| logging.debug("Old checkpoint backup is removed.") |
|
|
| def load_checkpoint( |
| self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True |
| ): |
| logging.info(f"Loading checkpoint from: {ckpt_path}") |
| |
| _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") |
| self.model.unet.load_state_dict( |
| torch.load(_model_path, map_location=self.device) |
| ) |
| self.model.unet.to(self.device) |
| logging.info(f"UNet parameters are loaded from {_model_path}") |
|
|
| |
| if load_trainer_state: |
| checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) |
| self.effective_iter = checkpoint["effective_iter"] |
| self.epoch = checkpoint["epoch"] |
| self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] |
| self.in_evaluation = checkpoint["in_evaluation"] |
| self.global_seed_sequence = checkpoint["global_seed_sequence"] |
|
|
| self.best_metric = checkpoint["best_metric"] |
|
|
| self.optimizer.load_state_dict(checkpoint["optimizer"]) |
| logging.info(f"optimizer state is loaded from {ckpt_path}") |
|
|
| if resume_lr_scheduler: |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
| logging.info(f"LR scheduler state is loaded from {ckpt_path}") |
|
|
| logging.info( |
| f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" |
| ) |
| return |
|
|
| def _get_backup_ckpt_name(self): |
| return f"iter_{self.effective_iter:06d}" |