Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Last modified: 2025-07-13 | |
| # | |
| # Copyright 2025 Ziyang Song, USTC. All rights reserved. | |
| # | |
| # This file has been modified from the original version. | |
| # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation | |
| # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page | |
| # -------------------------------------------------------------------------- | |
| import logging | |
| import os | |
| import random | |
| import shutil | |
| from datetime import datetime | |
| from typing import List, Union | |
| import numpy as np | |
| import torch | |
| from omegaconf import OmegaConf | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from depthmaster import DepthMasterPipeline, DepthMasterDepthOutput | |
| from src.util import metric | |
| from src.util.data_loader import skip_first_batches | |
| from src.util.logging_util import tb_logger, eval_dic_to_text | |
| from src.util.loss import get_loss, SSIM | |
| from src.util.lr_scheduler import IterExponential | |
| from src.util.metric import MetricTracker | |
| from src.util.alignment import ( | |
| align_depth_least_square, | |
| depth2disparity, | |
| disparity2depth, | |
| ) | |
| from src.util.seeding import generate_seed_sequence | |
| from src.util.build_mlp import build_mlp_ | |
| from torchvision.transforms import Normalize | |
| from external_encoder.dinov2.dinov2 import DINOv2 | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| class DepthMasterTrainerS1: | |
| def __init__( | |
| self, | |
| cfg: OmegaConf, | |
| model: DepthMasterPipeline, | |
| train_dataloader: DataLoader, | |
| device, | |
| base_ckpt_dir, | |
| out_dir_ckpt, | |
| out_dir_eval, | |
| out_dir_vis, | |
| accumulation_steps: int, | |
| val_dataloaders: List[DataLoader] = None, | |
| vis_dataloaders: List[DataLoader] = None, | |
| ): | |
| self.cfg: OmegaConf = cfg | |
| self.model: DepthMasterPipeline = model | |
| self.device = device | |
| self.seed: Union[int, None] = ( | |
| self.cfg.trainer.init_seed | |
| ) # used to generate seed sequence, set to `None` to train w/o seeding | |
| self.out_dir_ckpt = out_dir_ckpt | |
| self.out_dir_eval = out_dir_eval | |
| self.out_dir_vis = out_dir_vis | |
| self.train_loader: DataLoader = train_dataloader | |
| self.val_loaders: List[DataLoader] = val_dataloaders | |
| self.vis_loaders: List[DataLoader] = vis_dataloaders | |
| self.accumulation_steps: int = accumulation_steps | |
| # Encode empty text prompt | |
| self.model.encode_empty_text() | |
| self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) | |
| self.model.unet.enable_xformers_memory_efficient_attention() | |
| # Initialize DINOv2 encoder | |
| self.dinov2_encoder = DINOv2(model_name='vitg') | |
| dinov2_encoder_dict = self.dinov2_encoder.state_dict() | |
| pretrained_ckpt_dict = torch.load(f'checkpoints/depth_anything_v2_vitg.pth', map_location='cpu') | |
| pretrained_dict = {k.replace('pretrained.', ''): v for k, v in pretrained_ckpt_dict.items() if k.replace('pretrained.', '') in dinov2_encoder_dict} | |
| self.dinov2_encoder.load_state_dict(pretrained_dict) | |
| del self.dinov2_encoder.head | |
| self.dinov2_encoder.head = torch.nn.Identity() | |
| self.dinov2_encoder.eval() | |
| # Initialize adapter to align the feat dimension of SD and DINOv2 | |
| self.dinov2_adapter = build_mlp_(hidden_size=1280, projector_dim=1536, z_dim=1536) | |
| # Trainability | |
| self.dinov2_adapter.requires_grad_(True) | |
| self.dinov2_encoder.requires_grad_(False) | |
| self.model.vae.requires_grad_(False) | |
| self.model.text_encoder.requires_grad_(False) | |
| self.model.unet.requires_grad_(True) | |
| # Optimizer !should be defined after input layer is adapted | |
| lr = self.cfg.lr | |
| self.optimizer = Adam([ | |
| {'params': self.model.unet.parameters(), 'lr': lr}, | |
| {'params': self.dinov2_adapter.parameters(), 'lr': lr} | |
| ]) | |
| # LR scheduler | |
| lr_func = IterExponential( | |
| total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, | |
| final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, | |
| warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, | |
| ) | |
| self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) | |
| # Loss | |
| self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) | |
| # Eval metrics | |
| self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] | |
| self.train_metrics = MetricTracker(*["loss", "feat_align_loss"]) | |
| self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) | |
| # main metric for best checkpoint saving | |
| self.main_val_metric = cfg.validation.main_val_metric | |
| self.main_val_metric_goal = cfg.validation.main_val_metric_goal | |
| assert ( | |
| self.main_val_metric in cfg.eval.eval_metrics | |
| ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." | |
| self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 | |
| # Settings | |
| self.max_epoch = self.cfg.max_epoch | |
| self.max_iter = self.cfg.max_iter | |
| self.gradient_accumulation_steps = accumulation_steps | |
| self.gt_depth_type = self.cfg.gt_depth_type | |
| self.gt_mask_type = self.cfg.gt_mask_type | |
| self.save_period = self.cfg.trainer.save_period | |
| self.backup_period = self.cfg.trainer.backup_period | |
| self.val_period = self.cfg.trainer.validation_period | |
| self.vis_period = self.cfg.trainer.visualization_period | |
| # Internal variables | |
| self.epoch = 1 | |
| self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training | |
| self.effective_iter = 0 # how many times optimizer.step() is called | |
| self.in_evaluation = False | |
| self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming | |
| def train(self, t_end=None): | |
| logging.info("Start training") | |
| device = self.device | |
| self.model.to(device) | |
| self.dinov2_encoder.to(device) | |
| self.dinov2_adapter.to(device) | |
| self.visualize() | |
| if self.in_evaluation: | |
| logging.info( | |
| "Last evaluation was not finished, will do evaluation before continue training." | |
| ) | |
| self.validate() | |
| self.train_metrics.reset() | |
| accumulated_step = 0 | |
| progress_bar = tqdm( | |
| range(0, self.max_iter), | |
| initial=self.effective_iter, | |
| desc="iter" | |
| ) | |
| for epoch in range(self.epoch, self.max_epoch + 1): | |
| self.epoch = epoch | |
| logging.debug(f"epoch: {self.epoch}") | |
| # Skip previous batches when resume | |
| for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): | |
| self.model.unet.train() | |
| self.dinov2_adapter.train() | |
| # >>> With gradient accumulation >>> | |
| # Get data | |
| rgb = batch["rgb_norm"].to(device) | |
| depth_gt_for_latent = batch[self.gt_depth_type].to(device) | |
| if self.gt_mask_type is not None: | |
| valid_mask_for_latent = batch[self.gt_mask_type].to(device) | |
| invalid_mask = ~valid_mask_for_latent | |
| valid_mask_down = ~torch.max_pool2d( | |
| invalid_mask.float(), 8, 8 | |
| ).bool() | |
| valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) | |
| else: | |
| raise NotImplementedError | |
| batch_size = rgb.shape[0] | |
| with torch.no_grad(): | |
| # Encode image | |
| rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] | |
| # Encode GT depth | |
| gt_depth_latent = self.encode_depth( | |
| depth_gt_for_latent | |
| ) # [B, 4, h, w] | |
| # DINOv2 feat | |
| dinov2_input_rgb = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(rgb) | |
| dinov2_input_rgb = F.interpolate(dinov2_input_rgb, scale_factor=0.875, mode='bicubic') | |
| dinov2_z = self.dinov2_encoder.forward_features(dinov2_input_rgb)['x_norm_patchtokens'] | |
| # Text embedding | |
| text_embed = self.empty_text_embed.to(device).repeat( | |
| (batch_size, 1, 1) | |
| ) # [B, 77, 1024] | |
| # Predict the noise residual | |
| rgb_latent = self.model.unet( | |
| rgb_latent, 1, text_embed | |
| ) # [B, 4, h, w] | |
| feat_16 = rgb_latent.feat_64 | |
| rgb_latent = rgb_latent.sample | |
| if self.gt_mask_type is not None: | |
| loss = self.loss( | |
| rgb_latent[valid_mask_down].float(), | |
| gt_depth_latent[valid_mask_down].float(), | |
| ).mean() | |
| else: | |
| loss = self.loss(rgb_latent.float(), gt_depth_latent.float()).mean() | |
| self.train_metrics.update("loss", loss.item()) | |
| # feat align loss | |
| b, c, h, w = feat_16.shape | |
| _, _, H, W = rgb_latent.shape | |
| # update dinov2_adapter | |
| unet_16_feat_aligned = self.dinov2_adapter(feat_16.permute(0, 2, 3, 1).reshape(batch_size, -1, c)) | |
| if torch.isnan(rgb_latent).any(): | |
| logging.warning("model_pred contains NaN.") | |
| dinov2_z = dinov2_z.reshape(b, int(H/2), int(W/2), -1).permute(0, 3, 1, 2) | |
| dinov2_z = F.interpolate(dinov2_z, size=(h, w), mode='bicubic').permute(0, 2, 3, 1).reshape(b, h*w, -1) | |
| # kl loss | |
| unet_16_feat_aligned = F.softmax(unet_16_feat_aligned, dim=-1) | |
| dinov2_z = F.softmax(dinov2_z, dim=-1) | |
| loss_feat_align = F.kl_div(unet_16_feat_aligned.log(), dinov2_z) | |
| self.train_metrics.update("feat_align_loss", loss_feat_align) | |
| loss += self.cfg.loss_feat_align.lamda * loss_feat_align | |
| loss = loss / self.gradient_accumulation_steps | |
| loss.backward() | |
| accumulated_step += 1 | |
| self.n_batch_in_epoch += 1 | |
| # Practical batch end | |
| # Perform optimization step | |
| if accumulated_step >= self.gradient_accumulation_steps: | |
| self.optimizer.step() | |
| self.lr_scheduler.step() | |
| self.optimizer.zero_grad() | |
| accumulated_step = 0 | |
| self.effective_iter += 1 | |
| progress_bar.update(1) | |
| # Log to tensorboard | |
| accumulated_loss = self.train_metrics.result()["loss"] | |
| logs = {"loss": accumulated_loss} | |
| progress_bar.set_postfix(**logs) | |
| tb_logger.log_dic( | |
| { | |
| f"train/{k}": v | |
| for k, v in self.train_metrics.result().items() | |
| }, | |
| global_step=self.effective_iter, | |
| ) | |
| tb_logger.writer.add_scalar( | |
| "lr", | |
| self.lr_scheduler.get_last_lr()[0], | |
| global_step=self.effective_iter, | |
| ) | |
| tb_logger.writer.add_scalar( | |
| "n_batch_in_epoch", | |
| self.n_batch_in_epoch, | |
| global_step=self.effective_iter, | |
| ) | |
| self.train_metrics.reset() | |
| # Per-step callback | |
| self._train_step_callback() | |
| # End of training | |
| if self.max_iter > 0 and self.effective_iter >= self.max_iter: | |
| self.save_checkpoint( | |
| ckpt_name=self._get_backup_ckpt_name(), | |
| save_train_state=False, | |
| ) | |
| logging.info("Training ended.") | |
| return | |
| # Time's up | |
| elif t_end is not None and datetime.now() >= t_end: | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| logging.info("Time is up, training paused.") | |
| return | |
| torch.cuda.empty_cache() | |
| # <<< Effective batch end <<< | |
| # Epoch end | |
| self.n_batch_in_epoch = 0 | |
| def encode_depth(self, depth_in): | |
| # stack depth into 3-channel | |
| stacked = self.stack_depth_images(depth_in) | |
| # encode using VAE encoder | |
| depth_latent = self.model.encode_rgb(stacked) | |
| return depth_latent | |
| def stack_depth_images(depth_in): | |
| if 4 == len(depth_in.shape): | |
| stacked = depth_in.repeat(1, 3, 1, 1) | |
| elif 3 == len(depth_in.shape): | |
| stacked = depth_in.unsqueeze(1) | |
| stacked = depth_in.repeat(1, 3, 1, 1) | |
| return stacked | |
| def _train_step_callback(self): | |
| """Executed after every iteration""" | |
| # Save backup (with a larger interval, without training states) | |
| if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: | |
| self.save_checkpoint( | |
| ckpt_name=self._get_backup_ckpt_name(), save_train_state=False | |
| ) | |
| _is_latest_saved = False | |
| # Validation | |
| if self.val_period > 0 and 0 == self.effective_iter % self.val_period: | |
| self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| _is_latest_saved = True | |
| self.validate() | |
| self.in_evaluation = False | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| # Save training checkpoint (can be resumed) | |
| if ( | |
| self.save_period > 0 | |
| and 0 == self.effective_iter % self.save_period | |
| and not _is_latest_saved | |
| ): | |
| self.save_checkpoint(ckpt_name="latest", save_train_state=True) | |
| # Visualization | |
| if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: | |
| self.visualize() | |
| def validate(self): | |
| for i, val_loader in enumerate(self.val_loaders): | |
| val_dataset_name = val_loader.dataset.disp_name | |
| val_metric_dic = self.validate_single_dataset( | |
| data_loader=val_loader, metric_tracker=self.val_metrics | |
| ) | |
| logging.info( | |
| f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}" | |
| ) | |
| tb_logger.log_dic( | |
| {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, | |
| global_step=self.effective_iter, | |
| ) | |
| # save to file | |
| eval_text = eval_dic_to_text( | |
| val_metrics=val_metric_dic, | |
| dataset_name=val_dataset_name, | |
| sample_list_path=val_loader.dataset.filename_ls_path, | |
| ) | |
| _save_to = os.path.join( | |
| self.out_dir_eval, | |
| f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", | |
| ) | |
| with open(_save_to, "w+") as f: | |
| f.write(eval_text) | |
| # Update main eval metric | |
| if 0 == i: | |
| main_eval_metric = val_metric_dic[self.main_val_metric] | |
| if ( | |
| "minimize" == self.main_val_metric_goal | |
| and main_eval_metric < self.best_metric | |
| or "maximize" == self.main_val_metric_goal | |
| and main_eval_metric > self.best_metric | |
| ): | |
| self.best_metric = main_eval_metric | |
| logging.info( | |
| f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" | |
| ) | |
| # Save a checkpoint | |
| self.save_checkpoint( | |
| ckpt_name=self._get_backup_ckpt_name(), save_train_state=False | |
| ) | |
| def visualize(self): | |
| for val_loader in self.vis_loaders: | |
| vis_dataset_name = val_loader.dataset.disp_name | |
| vis_out_dir = os.path.join( | |
| self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name | |
| ) | |
| os.makedirs(vis_out_dir, exist_ok=True) | |
| _ = self.validate_single_dataset( | |
| data_loader=val_loader, | |
| metric_tracker=self.val_metrics, | |
| save_to_dir=vis_out_dir, | |
| ) | |
| def validate_single_dataset( | |
| self, | |
| data_loader: DataLoader, | |
| metric_tracker: MetricTracker, | |
| save_to_dir: str = None, | |
| ): | |
| self.model.to(self.device) | |
| metric_tracker.reset() | |
| # Generate seed sequence for consistent evaluation | |
| val_init_seed = self.cfg.validation.init_seed | |
| val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) | |
| for i, batch in enumerate( | |
| tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), | |
| start=1, | |
| ): | |
| assert 1 == data_loader.batch_size | |
| # Read input image | |
| rgb_int = batch["rgb_int"] # [3, H, W] | |
| # GT depth | |
| depth_raw_ts = batch["depth_raw_linear"].squeeze() | |
| depth_raw = depth_raw_ts.numpy() | |
| depth_raw_ts = depth_raw_ts.to(self.device) | |
| valid_mask_ts = batch["valid_mask_raw"].squeeze() | |
| valid_mask = valid_mask_ts.numpy() | |
| valid_mask_ts = valid_mask_ts.to(self.device) | |
| # Predict depth | |
| pipe_out: DepthMasterDepthOutput = self.model( | |
| rgb_int, | |
| processing_res=self.cfg.validation.processing_res, | |
| match_input_res=self.cfg.validation.match_input_res, | |
| batch_size=1, # use batch size 1 to increase reproducibility | |
| color_map=None, | |
| show_progress_bar=False, | |
| resample_method=self.cfg.validation.resample_method, | |
| ) | |
| depth_pred: np.ndarray = pipe_out.depth_np.squeeze() | |
| if "least_square" == self.cfg.eval.alignment: | |
| depth_pred, scale, shift = align_depth_least_square( | |
| gt_arr=depth_raw, | |
| pred_arr=depth_pred, | |
| valid_mask_arr=valid_mask, | |
| return_scale_shift=True, | |
| max_resolution=self.cfg.eval.align_max_res, | |
| ) | |
| elif "least_square_disparity" == self.cfg.eval.alignment: | |
| gt_disparity = depth_raw | |
| gt_non_neg_mask = gt_disparity > 0 | |
| # LS alignment in disparity space | |
| pred_non_neg_mask = depth_pred > 0 | |
| valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask | |
| disparity_pred, scale, shift = align_depth_least_square( | |
| gt_arr=gt_disparity, | |
| pred_arr=depth_pred, | |
| valid_mask_arr=valid_nonnegative_mask, | |
| return_scale_shift=True, | |
| ) | |
| # convert to depth | |
| disparity_pred = np.clip( | |
| disparity_pred, a_min=1e-3, a_max=None | |
| ) # avoid 0 disparity | |
| depth_pred = disparity2depth(disparity_pred) | |
| depth_raw_ts = disparity2depth(depth_raw_ts) | |
| elif "least_square_sqrt_disp" == self.cfg.eval.alignment: | |
| gt_sqrt_disp = depth_raw | |
| gt_non_neg_mask = gt_sqrt_disp > 0 | |
| # LS alignment in sqrt space | |
| pred_non_neg_mask = depth_pred > 0 | |
| valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask | |
| depth_sqrt_disp_pred, scale, shift = align_depth_least_square( | |
| gt_arr=gt_sqrt_disp, | |
| pred_arr=depth_pred, | |
| valid_mask_arr=valid_mask, | |
| return_scale_shift=True, | |
| ) | |
| # convert to depth | |
| disparity_pred = depth_sqrt_disp_pred ** 2 | |
| depth_raw_ts = torch.pow(depth_raw_ts, 2) | |
| # convert to depth | |
| disparity_pred = np.clip( | |
| disparity_pred, a_min=1e-3, a_max=None | |
| ) # avoid 0 disparity | |
| depth_pred = disparity2depth(disparity_pred) | |
| depth_raw_ts = disparity2depth(depth_raw_ts) | |
| else: | |
| raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") | |
| # Clip to dataset min max | |
| depth_pred = np.clip( | |
| depth_pred, | |
| a_min=data_loader.dataset.min_depth, | |
| a_max=data_loader.dataset.max_depth, | |
| ) | |
| # clip to d > 0 for evaluation | |
| depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) | |
| # Evaluate | |
| sample_metric = [] | |
| depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) | |
| for met_func in self.metric_funcs: | |
| _metric_name = met_func.__name__ | |
| _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() | |
| sample_metric.append(_metric.__str__()) | |
| metric_tracker.update(_metric_name, _metric) | |
| # Save as 16-bit uint png | |
| if save_to_dir is not None: | |
| img_name = batch["rgb_relative_path"][0].replace("/", "_") | |
| png_save_path = os.path.join(save_to_dir, f"{img_name}.png") | |
| depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16) | |
| Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") | |
| return metric_tracker.result() | |
| def _get_next_seed(self): | |
| if 0 == len(self.global_seed_sequence): | |
| self.global_seed_sequence = generate_seed_sequence( | |
| initial_seed=self.seed, | |
| length=self.max_iter * self.gradient_accumulation_steps, | |
| ) | |
| logging.info( | |
| f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" | |
| ) | |
| return self.global_seed_sequence.pop() | |
| def save_checkpoint(self, ckpt_name, save_train_state): | |
| ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) | |
| logging.info(f"Saving checkpoint to: {ckpt_dir}") | |
| # Backup previous checkpoint | |
| temp_ckpt_dir = None | |
| if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): | |
| temp_ckpt_dir = os.path.join( | |
| os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" | |
| ) | |
| if os.path.exists(temp_ckpt_dir): | |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) | |
| os.rename(ckpt_dir, temp_ckpt_dir) | |
| logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") | |
| # Save UNet | |
| unet_path = os.path.join(ckpt_dir, "unet") | |
| self.model.unet.save_pretrained(unet_path, safe_serialization=False) | |
| logging.info(f"UNet is saved to: {unet_path}") | |
| # Save DINOv2_Adapter | |
| adapter_path = os.path.join(ckpt_dir, "dinov2_adapter.pth") | |
| state_dict = self.dinov2_adapter.state_dict() | |
| torch.save(state_dict, adapter_path) | |
| logging.info(f"dinov2_adapter is saved to: {adapter_path}") | |
| if save_train_state: | |
| state = { | |
| "optimizer": self.optimizer.state_dict(), | |
| "lr_scheduler": self.lr_scheduler.state_dict(), | |
| "config": self.cfg, | |
| "effective_iter": self.effective_iter, | |
| "epoch": self.epoch, | |
| "n_batch_in_epoch": self.n_batch_in_epoch, | |
| "best_metric": self.best_metric, | |
| "in_evaluation": self.in_evaluation, | |
| "global_seed_sequence": self.global_seed_sequence, | |
| } | |
| train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") | |
| torch.save(state, train_state_path) | |
| # iteration indicator | |
| f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") | |
| f.close() | |
| logging.info(f"Trainer state is saved to: {train_state_path}") | |
| # Remove temp ckpt | |
| if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): | |
| shutil.rmtree(temp_ckpt_dir, ignore_errors=True) | |
| logging.debug("Old checkpoint backup is removed.") | |
| def load_checkpoint( | |
| self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True | |
| ): | |
| logging.info(f"Loading checkpoint from: {ckpt_path}") | |
| # Load UNet | |
| _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") | |
| self.model.unet.load_state_dict( | |
| torch.load(_model_path, map_location=self.device) | |
| ) | |
| self.model.unet.to(self.device) | |
| logging.info(f"UNet parameters are loaded from {_model_path}") | |
| # Load DINOv2_adapter | |
| _model_path = os.path.join(ckpt_path, "dinov2_adapter.pth") | |
| self.dinov2_adapter.load_state_dict( | |
| torch.load(_model_path, map_location=self.device) | |
| ) | |
| self.dinov2_adapter.to(self.device) | |
| logging.info(f"dinov2_adapter parameters are loaded from {_model_path}") | |
| # Load training states | |
| if load_trainer_state: | |
| checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) | |
| self.effective_iter = checkpoint["effective_iter"] | |
| self.epoch = checkpoint["epoch"] | |
| self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] | |
| self.in_evaluation = checkpoint["in_evaluation"] | |
| self.global_seed_sequence = checkpoint["global_seed_sequence"] | |
| self.best_metric = checkpoint["best_metric"] | |
| self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
| logging.info(f"optimizer state is loaded from {ckpt_path}") | |
| if resume_lr_scheduler: | |
| self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
| logging.info(f"LR scheduler state is loaded from {ckpt_path}") | |
| logging.info( | |
| f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" | |
| ) | |
| return | |
| def _get_backup_ckpt_name(self): | |
| return f"iter_{self.effective_iter:06d}" |