Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # An official reimplemented version of Marigold training script. | |
| # Last modified: 2024-04-29 | |
| # | |
| # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
| # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. | |
| # More information about the method can be found at https://marigoldmonodepth.github.io | |
| # -------------------------------------------------------------------------- | |
| import logging | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from typing import List, Union | |
| import numpy as np | |
| import torch | |
| # from diffusers import DDPMScheduler | |
| from omegaconf import OmegaConf | |
| # from torch.nn import Conv2d | |
| # from torch.nn.parameter import Parameter | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from depthmaster.depthmaster_pipeline import DepthMasterPipeline, DepthMasterDepthOutput | |
| from src.util import metric | |
| from src.util.data_loader import skip_first_batches | |
| from src.util.logging_util import tb_logger, eval_dic_to_text | |
| from src.util.loss import get_loss | |
| from src.util.lr_scheduler import IterExponential | |
| from src.util.metric import MetricTracker | |
| from src.util.alignment import ( | |
| align_depth_least_square, | |
| depth2disparity, | |
| disparity2depth, | |
| align_depth_least_square_torch_mask, | |
| align_depth_medium_mask | |
| ) | |
| # from src.util.alignment import align_depth_least_square | |
| # from src.util.alignment import align_depth_least_square | |
| from src.util.seeding import generate_seed_sequence | |
| import torch.nn.functional as F | |
| class DepthMasterTrainerS2: | |
| def __init__( | |
| self, | |
| cfg: OmegaConf, | |
| model: DepthMasterPipeline, | |
| train_dataloader: DataLoader, | |
| device, | |
| base_ckpt_dir, | |
| out_dir_ckpt, | |
| out_dir_eval, | |
| out_dir_vis, | |
| accumulation_steps: int, | |
| val_dataloaders: List[DataLoader] = None, | |
| vis_dataloaders: List[DataLoader] = None, | |
| ): | |
| self.cfg: OmegaConf = cfg | |
| self.model: DepthMasterPipeline = model | |
| self.device = device | |
| self.seed: Union[int, None] = ( | |
| self.cfg.trainer.init_seed | |
| ) # used to generate seed sequence, set to `None` to train w/o seeding | |
| self.out_dir_ckpt = out_dir_ckpt | |
| self.out_dir_eval = out_dir_eval | |
| self.out_dir_vis = out_dir_vis | |
| self.train_loader: DataLoader = train_dataloader | |
| self.val_loaders: List[DataLoader] = val_dataloaders | |
| self.vis_loaders: List[DataLoader] = vis_dataloaders | |
| self.accumulation_steps: int = accumulation_steps | |
| # Encode empty text prompt | |
| self.model.encode_empty_text() | |
| self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) | |
| self.model.unet.enable_xformers_memory_efficient_attention() | |
| # Trainability | |
| self.model.vae.requires_grad_(False) | |
| self.model.vae.decoder.requires_grad_(False) | |
| self.model.text_encoder.requires_grad_(False) | |
| self.model.unet.requires_grad_(True) | |
| # Optimizer !should be defined after input layer is adapted | |
| lr = self.cfg.lr | |
| self.optimizer = Adam(self.model.unet.parameters(), lr=lr) | |
| # LR scheduler | |
| lr_func = IterExponential( | |
| total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, | |
| final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, | |
| warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, | |
| ) | |
| self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) | |
| # Loss | |
| self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) | |
| self.grad_loss = get_loss(loss_name=self.cfg.grad_loss.name, ** self.cfg.grad_loss.kwargs) | |
| # Eval metrics | |
| self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] | |
| self.train_metrics = MetricTracker(*["loss", "grad_loss"]) | |
| self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) | |
| # main metric for best checkpoint saving | |
| self.main_val_metric = cfg.validation.main_val_metric | |
| self.main_val_metric_goal = cfg.validation.main_val_metric_goal | |
| assert ( | |
| self.main_val_metric in cfg.eval.eval_metrics | |
| ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." | |
| self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 | |
| # Settings | |
| self.max_epoch = self.cfg.max_epoch | |
| self.max_iter = self.cfg.max_iter | |
| self.gradient_accumulation_steps = accumulation_steps | |
| self.gt_depth_type = self.cfg.gt_depth_type | |
| self.gt_mask_type = self.cfg.gt_mask_type | |
| self.save_period = self.cfg.trainer.save_period | |
| self.backup_period = self.cfg.trainer.backup_period | |
| self.val_period = self.cfg.trainer.validation_period | |
| self.vis_period = self.cfg.trainer.visualization_period | |
| # Internal variables | |
| self.epoch = 1 | |
| self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training | |
| self.effective_iter = 0 # how many times optimizer.step() is called | |
| self.in_evaluation = False | |
| self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming | |
| def grad(self, x): | |
| # x.shape : n, c, h, w | |
| diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] | |
| diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] | |
| diff_45 = x[..., :-1, 1:] - x[..., 1:, :-1] | |
| diff_135 = x[..., 1:, 1:] - x[..., :-1, :-1] | |
| # mag = diff_x**2 + diff_y**2 | |
| # # angle_ratio | |
| # angle = torch.atan(diff_y / (diff_x + 1e-10)) | |
| # result = torch.cat([mag, angle], dim=1) | |
| result = torch.cat([diff_x, diff_y, diff_45, diff_135], dim=1) | |
| return result | |
| def train(self, t_end=None): | |
| logging.info("Start training") | |
| device = self.device | |
| self.model.to(device) | |
| self.visualize() | |
| if self.in_evaluation: | |
| logging.info( | |
| "Last evaluation was not finished, will do evaluation before continue training." | |
| ) | |
| self.validate() | |
| self.train_metrics.reset() | |
| accumulated_step = 0 | |
| progress_bar = tqdm( | |
| range(0, self.max_iter), | |
| initial=self.effective_iter, | |
| desc="iter" | |
| ) | |
| for epoch in range(self.epoch, self.max_epoch + 1): | |
| self.epoch = epoch | |
| logging.debug(f"epoch: {self.epoch}") | |
| # Skip previous batches when resume | |
| for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): | |
| self.model.unet.train() | |
| # >>> With gradient accumulation >>> | |
| # Get data | |
| rgb = batch["rgb_norm"].to(device) | |
| depth_gt_for_latent = batch[self.gt_depth_type].to(device) | |
| if self.gt_mask_type is not None: | |
| valid_mask_for_latent = batch[self.gt_mask_type].to(device) | |
| else: | |
| raise NotImplementedError | |
| batch_size = rgb.shape[0] | |
| with torch.no_grad(): | |
| # Encode image | |
| rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] | |
| # Text embedding | |
| text_embed = self.empty_text_embed.to(device).repeat( | |
| (batch_size, 1, 1) | |
| ) # [B, 77, 1024] | |
| rgb_latent = self.model.unet( | |
| rgb_latent, 1, text_embed | |
| ).sample # [B, 4, h, w] | |
| depth_pred = self.model.decode_depth(rgb_latent) | |
| depth_gt_for_loss = depth_gt_for_latent | |
| aligned_pred = depth_pred | |
| if self.gt_mask_type is not None: | |
| loss = self.loss(aligned_pred[valid_mask_for_latent].float(), depth_gt_for_loss[valid_mask_for_latent].float()).mean() | |
| else: | |
| loss = self.loss(aligned_pred.float(), depth_gt_for_loss.float()).mean() | |
| self.train_metrics.update("loss", loss.item()) | |
| # grad loss | |
| depth_gt_for_loss[~valid_mask_for_latent] = 0 | |
| grad_gt = self.grad(depth_gt_for_loss) | |
| aligned_pred[~valid_mask_for_latent] = 0 | |
| grad_pred = self.grad(aligned_pred) | |
| grad_loss = self.grad_loss(grad_gt, grad_pred) | |
| self.train_metrics.update(f"grad_loss", grad_loss.item()) | |
| loss += self.cfg.grad_loss.lamda * grad_loss | |
| loss = loss / self.gradient_accumulation_steps | |
| loss.backward() | |
| accumulated_step += 1 | |
| self.n_batch_in_epoch += 1 | |
| # Practical batch end | |
| # Perform optimization step | |
| if accumulated_step >= self.gradient_accumulation_steps: | |
| self.optimizer.step() | |
| self.lr_scheduler.step() | |
| self.optimizer.zero_grad() | |
| accumulated_step = 0 | |
| self.effective_iter += 1 | |
| progress_bar.update(1) | |
| # Log to tensorboard | |
| accumulated_loss = self.train_metrics.result()["loss"] | |
| logs = {"loss": accumulated_loss} | |
| progress_bar.set_postfix(**logs) | |
| tb_logger.log_dic( | |
| { | |
| f"train/{k}": v | |
| for k, v in self.train_metrics.result().items() | |
| }, | |
| global_step=self.effective_iter, | |
| ) | |
| tb_logger.writer.add_scalar( | |
| "lr", | |
| self.lr_scheduler.get_last_lr()[0], | |
| global_step=self.effective_iter, | |
| ) | |
| tb_logger.writer.add_scalar( | |
| "n_batch_in_epoch", | |
| self.n_batch_in_epoch, | |
| global_step=self.effective_iter, | |
| ) | |
| self.train_metrics.reset() | |
| # Per-step callback | |
| self._train_step_callback() | |
| # End of training | |
| if self.max_iter > 0 and self.effective_iter >= self.max_iter: | |
| self.save_checkpoint( | |
| ckpt_name=self._get_backup_ckpt_name(), | |
| save_train_state=False, | |
| ) | |
| logging.info("Training ended.") | |
| return | |
| # Time's up | |
| elif t_end is not None and datetime.now() >= t_end: | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| logging.info("Time is up, training paused.") | |
| return | |
| torch.cuda.empty_cache() | |
| # <<< Effective batch end <<< | |
| # Epoch end | |
| self.n_batch_in_epoch = 0 | |
| def encode_depth(self, depth_in): | |
| # stack depth into 3-channel | |
| stacked = self.stack_depth_images(depth_in) | |
| # encode using VAE encoder | |
| depth_latent = self.model.encode_rgb(stacked) | |
| return depth_latent | |
| def stack_depth_images(depth_in): | |
| if 4 == len(depth_in.shape): | |
| stacked = depth_in.repeat(1, 3, 1, 1) | |
| elif 3 == len(depth_in.shape): | |
| stacked = depth_in.unsqueeze(1) | |
| stacked = depth_in.repeat(1, 3, 1, 1) | |
| return stacked | |
| def _train_step_callback(self): | |
| """Executed after every iteration""" | |
| # Save backup (with a larger interval, without training states) | |
| if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: | |
| self.save_checkpoint( | |
| ckpt_name=self._get_backup_ckpt_name(), save_train_state=False | |
| ) | |
| _is_latest_saved = False | |
| # Validation | |
| if self.val_period > 0 and 0 == self.effective_iter % self.val_period: | |
| self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| _is_latest_saved = True | |
| self.validate() | |
| self.in_evaluation = False | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| # Save training checkpoint (can be resumed) | |
| if ( | |
| self.save_period > 0 | |
| and 0 == self.effective_iter % self.save_period | |
| and not _is_latest_saved | |
| ): | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| # Visualization | |
| if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: | |
| self.visualize() | |
| def validate(self): | |
| for i, val_loader in enumerate(self.val_loaders): | |
| val_dataset_name = val_loader.dataset.disp_name | |
| val_metric_dic = self.validate_single_dataset( | |
| data_loader=val_loader, metric_tracker=self.val_metrics | |
| ) | |
| logging.info( | |
| f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}" | |
| ) | |
| tb_logger.log_dic( | |
| {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, | |
| global_step=self.effective_iter, | |
| ) | |
| # save to file | |
| eval_text = eval_dic_to_text( | |
| val_metrics=val_metric_dic, | |
| dataset_name=val_dataset_name, | |
| sample_list_path=val_loader.dataset.filename_ls_path, | |
| ) | |
| _save_to = os.path.join( | |
| self.out_dir_eval, | |
| f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", | |
| ) | |
| with open(_save_to, "w+") as f: | |
| f.write(eval_text) | |
| # Update main eval metric | |
| if 0 == i: | |
| main_eval_metric = val_metric_dic[self.main_val_metric] | |
| if ( | |
| "minimize" == self.main_val_metric_goal | |
| and main_eval_metric < self.best_metric | |
| or "maximize" == self.main_val_metric_goal | |
| and main_eval_metric > self.best_metric | |
| ): | |
| self.best_metric = main_eval_metric | |
| logging.info( | |
| f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" | |
| ) | |
| # Save a checkpoint | |
| self.save_checkpoint( | |
| ckpt_name=self._get_backup_ckpt_name(), save_train_state=False | |
| ) | |
| def visualize(self): | |
| for val_loader in self.vis_loaders: | |
| vis_dataset_name = val_loader.dataset.disp_name | |
| vis_out_dir = os.path.join( | |
| self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name | |
| ) | |
| os.makedirs(vis_out_dir, exist_ok=True) | |
| _ = self.validate_single_dataset( | |
| data_loader=val_loader, | |
| metric_tracker=self.val_metrics, | |
| save_to_dir=vis_out_dir, | |
| ) | |
| def validate_single_dataset( | |
| self, | |
| data_loader: DataLoader, | |
| metric_tracker: MetricTracker, | |
| save_to_dir: str = None, | |
| ): | |
| self.model.to(self.device) | |
| metric_tracker.reset() | |
| # Generate seed sequence for consistent evaluation | |
| val_init_seed = self.cfg.validation.init_seed | |
| val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) | |
| for i, batch in enumerate( | |
| tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), | |
| start=1, | |
| ): | |
| assert 1 == data_loader.batch_size | |
| # Read input image | |
| rgb_int = batch["rgb_int"] # [3, H, W] | |
| # GT depth | |
| depth_raw_ts = batch["depth_raw_linear"].squeeze() | |
| depth_raw = depth_raw_ts.numpy() | |
| depth_raw_ts = depth_raw_ts.to(self.device) | |
| valid_mask_ts = batch["valid_mask_raw"].squeeze() | |
| valid_mask = valid_mask_ts.numpy() | |
| valid_mask_ts = valid_mask_ts.to(self.device) | |
| # Predict depth | |
| pipe_out: DepthMasterDepthOutput = self.model( | |
| rgb_int, | |
| processing_res=self.cfg.validation.processing_res, | |
| match_input_res=self.cfg.validation.match_input_res, | |
| batch_size=1, # use batch size 1 to increase reproducibility | |
| color_map=None, | |
| show_progress_bar=False, | |
| resample_method=self.cfg.validation.resample_method, | |
| ) | |
| depth_pred: np.ndarray = pipe_out.depth_np.squeeze() | |
| if "least_square" == self.cfg.eval.alignment: | |
| depth_pred, scale, shift = align_depth_least_square( | |
| gt_arr=depth_raw, | |
| pred_arr=depth_pred, | |
| valid_mask_arr=valid_mask, | |
| return_scale_shift=True, | |
| max_resolution=self.cfg.eval.align_max_res, | |
| ) | |
| elif "least_square_disparity" == self.cfg.eval.alignment: | |
| # gt_disparity = depth_raw | |
| gt_disparity = depth2disparity(depth_raw) | |
| gt_non_neg_mask = gt_disparity > 0 | |
| # LS alignment in disparity space | |
| pred_non_neg_mask = depth_pred > 0 | |
| valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask | |
| disparity_pred, scale, shift = align_depth_least_square( | |
| gt_arr=gt_disparity, | |
| pred_arr=depth_pred, | |
| valid_mask_arr=valid_nonnegative_mask, | |
| return_scale_shift=True, | |
| ) | |
| # convert to depth | |
| disparity_pred = np.clip( | |
| disparity_pred, a_min=1e-3, a_max=None | |
| ) # avoid 0 disparity | |
| depth_pred = disparity2depth(disparity_pred) | |
| depth_raw_ts = disparity2depth(depth_raw_ts) | |
| elif "least_square_sqrt_disp" == self.cfg.eval.alignment: | |
| # gt_sqrt_disp = depth_raw | |
| gt_sqrt_disp = np.sqrt(depth2disparity(depth_raw)) | |
| gt_non_neg_mask = gt_sqrt_disp > 0 | |
| # LS alignment in sqrt space | |
| pred_non_neg_mask = depth_pred > 0 | |
| valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask | |
| depth_sqrt_disp_pred, scale, shift = align_depth_least_square( | |
| gt_arr=gt_sqrt_disp, | |
| pred_arr=depth_pred, | |
| valid_mask_arr=valid_mask, | |
| return_scale_shift=True, | |
| ) | |
| # convert to depth | |
| disparity_pred = depth_sqrt_disp_pred ** 2 | |
| depth_raw_ts = torch.pow(depth_raw_ts, 2) | |
| # convert to depth | |
| disparity_pred = np.clip( | |
| disparity_pred, a_min=1e-3, a_max=None | |
| ) # avoid 0 disparity | |
| depth_pred = disparity2depth(disparity_pred) | |
| depth_raw_ts = disparity2depth(depth_raw_ts) | |
| else: | |
| raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") | |
| # Clip to dataset min max | |
| depth_pred = np.clip( | |
| depth_pred, | |
| a_min=data_loader.dataset.min_depth, | |
| a_max=data_loader.dataset.max_depth, | |
| ) | |
| # clip to d > 0 for evaluation | |
| depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) | |
| # Evaluate | |
| sample_metric = [] | |
| depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) | |
| for met_func in self.metric_funcs: | |
| _metric_name = met_func.__name__ | |
| _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() | |
| sample_metric.append(_metric.__str__()) | |
| metric_tracker.update(_metric_name, _metric) | |
| # Save as 16-bit uint png | |
| if save_to_dir is not None: | |
| img_name = batch["rgb_relative_path"][0].replace("/", "_") | |
| png_save_path = os.path.join(save_to_dir, f"{img_name}.png") | |
| depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16) | |
| Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") | |
| return metric_tracker.result() | |
| def _get_next_seed(self): | |
| if 0 == len(self.global_seed_sequence): | |
| self.global_seed_sequence = generate_seed_sequence( | |
| initial_seed=self.seed, | |
| length=self.max_iter * self.gradient_accumulation_steps, | |
| ) | |
| logging.info( | |
| f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" | |
| ) | |
| return self.global_seed_sequence.pop() | |
| def save_checkpoint(self, ckpt_name, save_train_state): | |
| ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) | |
| logging.info(f"Saving checkpoint to: {ckpt_dir}") | |
| # Backup previous checkpoint | |
| temp_ckpt_dir = None | |
| if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): | |
| temp_ckpt_dir = os.path.join( | |
| os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" | |
| ) | |
| if os.path.exists(temp_ckpt_dir): | |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) | |
| os.rename(ckpt_dir, temp_ckpt_dir) | |
| logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") | |
| # Save UNet | |
| unet_path = os.path.join(ckpt_dir, "unet") | |
| self.model.unet.save_pretrained(unet_path, safe_serialization=False) | |
| logging.info(f"UNet is saved to: {unet_path}") | |
| if save_train_state: | |
| state = { | |
| "optimizer": self.optimizer.state_dict(), | |
| "lr_scheduler": self.lr_scheduler.state_dict(), | |
| "config": self.cfg, | |
| "effective_iter": self.effective_iter, | |
| "epoch": self.epoch, | |
| "n_batch_in_epoch": self.n_batch_in_epoch, | |
| "best_metric": self.best_metric, | |
| "in_evaluation": self.in_evaluation, | |
| "global_seed_sequence": self.global_seed_sequence, | |
| } | |
| train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") | |
| torch.save(state, train_state_path) | |
| # iteration indicator | |
| f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") | |
| f.close() | |
| logging.info(f"Trainer state is saved to: {train_state_path}") | |
| # Remove temp ckpt | |
| if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): | |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) | |
| logging.debug("Old checkpoint backup is removed.") | |
| def load_checkpoint( | |
| self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True | |
| ): | |
| logging.info(f"Loading checkpoint from: {ckpt_path}") | |
| # Load UNet | |
| _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") | |
| self.model.unet.load_state_dict( | |
| torch.load(_model_path, map_location=self.device) | |
| ) | |
| self.model.unet.to(self.device) | |
| logging.info(f"UNet parameters are loaded from {_model_path}") | |
| # Load training states | |
| if load_trainer_state: | |
| checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) | |
| self.effective_iter = checkpoint["effective_iter"] | |
| self.epoch = checkpoint["epoch"] | |
| self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] | |
| self.in_evaluation = checkpoint["in_evaluation"] | |
| self.global_seed_sequence = checkpoint["global_seed_sequence"] | |
| self.best_metric = checkpoint["best_metric"] | |
| self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
| logging.info(f"optimizer state is loaded from {ckpt_path}") | |
| if resume_lr_scheduler: | |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
| logging.info(f"LR scheduler state is loaded from {ckpt_path}") | |
| logging.info( | |
| f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" | |
| ) | |
| return | |
| def _get_backup_ckpt_name(self): | |
| return f"iter_{self.effective_iter:06d}" | |