# Last modified: 2025-07-13 # # Copyright 2025 Ziyang Song, USTC. All rights reserved. # # This file has been modified from the original version. # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page # -------------------------------------------------------------------------- import logging import os import random import shutil from datetime import datetime from typing import List, Union import numpy as np import torch from omegaconf import OmegaConf from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from tqdm import tqdm from PIL import Image import torch.nn.functional as F from depthmaster import DepthMasterPipeline, DepthMasterDepthOutput from src.util import metric from src.util.data_loader import skip_first_batches from src.util.logging_util import tb_logger, eval_dic_to_text from src.util.loss import get_loss, SSIM from src.util.lr_scheduler import IterExponential from src.util.metric import MetricTracker from src.util.alignment import ( align_depth_least_square, depth2disparity, disparity2depth, ) from src.util.seeding import generate_seed_sequence from src.util.build_mlp import build_mlp_ from torchvision.transforms import Normalize from external_encoder.dinov2.dinov2 import DINOv2 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) class DepthMasterTrainerS1: def __init__( self, cfg: OmegaConf, model: DepthMasterPipeline, train_dataloader: DataLoader, device, base_ckpt_dir, out_dir_ckpt, out_dir_eval, out_dir_vis, accumulation_steps: int, val_dataloaders: List[DataLoader] = None, vis_dataloaders: List[DataLoader] = None, ): self.cfg: OmegaConf = cfg self.model: DepthMasterPipeline = model self.device = device self.seed: Union[int, None] = ( self.cfg.trainer.init_seed ) # used to generate seed sequence, set to `None` to train w/o seeding self.out_dir_ckpt = out_dir_ckpt self.out_dir_eval = out_dir_eval self.out_dir_vis = out_dir_vis self.train_loader: DataLoader = train_dataloader self.val_loaders: List[DataLoader] = val_dataloaders self.vis_loaders: List[DataLoader] = vis_dataloaders self.accumulation_steps: int = accumulation_steps # Encode empty text prompt self.model.encode_empty_text() self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) self.model.unet.enable_xformers_memory_efficient_attention() # Initialize DINOv2 encoder self.dinov2_encoder = DINOv2(model_name='vitg') dinov2_encoder_dict = self.dinov2_encoder.state_dict() pretrained_ckpt_dict = torch.load(f'checkpoints/depth_anything_v2_vitg.pth', map_location='cpu') pretrained_dict = {k.replace('pretrained.', ''): v for k, v in pretrained_ckpt_dict.items() if k.replace('pretrained.', '') in dinov2_encoder_dict} self.dinov2_encoder.load_state_dict(pretrained_dict) del self.dinov2_encoder.head self.dinov2_encoder.head = torch.nn.Identity() self.dinov2_encoder.eval() # Initialize adapter to align the feat dimension of SD and DINOv2 self.dinov2_adapter = build_mlp_(hidden_size=1280, projector_dim=1536, z_dim=1536) # Trainability self.dinov2_adapter.requires_grad_(True) self.dinov2_encoder.requires_grad_(False) self.model.vae.requires_grad_(False) self.model.text_encoder.requires_grad_(False) self.model.unet.requires_grad_(True) # Optimizer !should be defined after input layer is adapted lr = self.cfg.lr self.optimizer = Adam([ {'params': self.model.unet.parameters(), 'lr': lr}, {'params': self.dinov2_adapter.parameters(), 'lr': lr} ]) # LR scheduler lr_func = IterExponential( total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, ) self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) # Loss self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) # Eval metrics self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] self.train_metrics = MetricTracker(*["loss", "feat_align_loss"]) self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) # main metric for best checkpoint saving self.main_val_metric = cfg.validation.main_val_metric self.main_val_metric_goal = cfg.validation.main_val_metric_goal assert ( self.main_val_metric in cfg.eval.eval_metrics ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 # Settings self.max_epoch = self.cfg.max_epoch self.max_iter = self.cfg.max_iter self.gradient_accumulation_steps = accumulation_steps self.gt_depth_type = self.cfg.gt_depth_type self.gt_mask_type = self.cfg.gt_mask_type self.save_period = self.cfg.trainer.save_period self.backup_period = self.cfg.trainer.backup_period self.val_period = self.cfg.trainer.validation_period self.vis_period = self.cfg.trainer.visualization_period # Internal variables self.epoch = 1 self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training self.effective_iter = 0 # how many times optimizer.step() is called self.in_evaluation = False self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming def train(self, t_end=None): logging.info("Start training") device = self.device self.model.to(device) self.dinov2_encoder.to(device) self.dinov2_adapter.to(device) self.visualize() if self.in_evaluation: logging.info( "Last evaluation was not finished, will do evaluation before continue training." ) self.validate() self.train_metrics.reset() accumulated_step = 0 progress_bar = tqdm( range(0, self.max_iter), initial=self.effective_iter, desc="iter" ) for epoch in range(self.epoch, self.max_epoch + 1): self.epoch = epoch logging.debug(f"epoch: {self.epoch}") # Skip previous batches when resume for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): self.model.unet.train() self.dinov2_adapter.train() # >>> With gradient accumulation >>> # Get data rgb = batch["rgb_norm"].to(device) depth_gt_for_latent = batch[self.gt_depth_type].to(device) if self.gt_mask_type is not None: valid_mask_for_latent = batch[self.gt_mask_type].to(device) invalid_mask = ~valid_mask_for_latent valid_mask_down = ~torch.max_pool2d( invalid_mask.float(), 8, 8 ).bool() valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) else: raise NotImplementedError batch_size = rgb.shape[0] with torch.no_grad(): # Encode image rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] # Encode GT depth gt_depth_latent = self.encode_depth( depth_gt_for_latent ) # [B, 4, h, w] # DINOv2 feat dinov2_input_rgb = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(rgb) dinov2_input_rgb = F.interpolate(dinov2_input_rgb, scale_factor=0.875, mode='bicubic') dinov2_z = self.dinov2_encoder.forward_features(dinov2_input_rgb)['x_norm_patchtokens'] # Text embedding text_embed = self.empty_text_embed.to(device).repeat( (batch_size, 1, 1) ) # [B, 77, 1024] # Predict the noise residual rgb_latent = self.model.unet( rgb_latent, 1, text_embed ) # [B, 4, h, w] feat_16 = rgb_latent.feat_64 rgb_latent = rgb_latent.sample if self.gt_mask_type is not None: loss = self.loss( rgb_latent[valid_mask_down].float(), gt_depth_latent[valid_mask_down].float(), ).mean() else: loss = self.loss(rgb_latent.float(), gt_depth_latent.float()).mean() self.train_metrics.update("loss", loss.item()) # feat align loss b, c, h, w = feat_16.shape _, _, H, W = rgb_latent.shape # update dinov2_adapter unet_16_feat_aligned = self.dinov2_adapter(feat_16.permute(0, 2, 3, 1).reshape(batch_size, -1, c)) if torch.isnan(rgb_latent).any(): logging.warning("model_pred contains NaN.") dinov2_z = dinov2_z.reshape(b, int(H/2), int(W/2), -1).permute(0, 3, 1, 2) dinov2_z = F.interpolate(dinov2_z, size=(h, w), mode='bicubic').permute(0, 2, 3, 1).reshape(b, h*w, -1) # kl loss unet_16_feat_aligned = F.softmax(unet_16_feat_aligned, dim=-1) dinov2_z = F.softmax(dinov2_z, dim=-1) loss_feat_align = F.kl_div(unet_16_feat_aligned.log(), dinov2_z) self.train_metrics.update("feat_align_loss", loss_feat_align) loss += self.cfg.loss_feat_align.lamda * loss_feat_align loss = loss / self.gradient_accumulation_steps loss.backward() accumulated_step += 1 self.n_batch_in_epoch += 1 # Practical batch end # Perform optimization step if accumulated_step >= self.gradient_accumulation_steps: self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() accumulated_step = 0 self.effective_iter += 1 progress_bar.update(1) # Log to tensorboard accumulated_loss = self.train_metrics.result()["loss"] logs = {"loss": accumulated_loss} progress_bar.set_postfix(**logs) tb_logger.log_dic( { f"train/{k}": v for k, v in self.train_metrics.result().items() }, global_step=self.effective_iter, ) tb_logger.writer.add_scalar( "lr", self.lr_scheduler.get_last_lr()[0], global_step=self.effective_iter, ) tb_logger.writer.add_scalar( "n_batch_in_epoch", self.n_batch_in_epoch, global_step=self.effective_iter, ) self.train_metrics.reset() # Per-step callback self._train_step_callback() # End of training if self.max_iter > 0 and self.effective_iter >= self.max_iter: self.save_checkpoint( ckpt_name=self._get_backup_ckpt_name(), save_train_state=False, ) logging.info("Training ended.") return # Time's up elif t_end is not None and datetime.now() >= t_end: self.save_checkpoint(ckpt_name="latest", save_train_state=True) logging.info("Time is up, training paused.") return torch.cuda.empty_cache() # <<< Effective batch end <<< # Epoch end self.n_batch_in_epoch = 0 def encode_depth(self, depth_in): # stack depth into 3-channel stacked = self.stack_depth_images(depth_in) # encode using VAE encoder depth_latent = self.model.encode_rgb(stacked) return depth_latent @staticmethod def stack_depth_images(depth_in): if 4 == len(depth_in.shape): stacked = depth_in.repeat(1, 3, 1, 1) elif 3 == len(depth_in.shape): stacked = depth_in.unsqueeze(1) stacked = depth_in.repeat(1, 3, 1, 1) return stacked def _train_step_callback(self): """Executed after every iteration""" # Save backup (with a larger interval, without training states) if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: self.save_checkpoint( ckpt_name=self._get_backup_ckpt_name(), save_train_state=False ) _is_latest_saved = False # Validation if self.val_period > 0 and 0 == self.effective_iter % self.val_period: self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished self.save_checkpoint(ckpt_name="latest", save_train_state=True) _is_latest_saved = True self.validate() self.in_evaluation = False self.save_checkpoint(ckpt_name="latest", save_train_state=True) # Save training checkpoint (can be resumed) if ( self.save_period > 0 and 0 == self.effective_iter % self.save_period and not _is_latest_saved ): self.save_checkpoint(ckpt_name="latest", save_train_state=True) # Visualization if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: self.visualize() def validate(self): for i, val_loader in enumerate(self.val_loaders): val_dataset_name = val_loader.dataset.disp_name val_metric_dic = self.validate_single_dataset( data_loader=val_loader, metric_tracker=self.val_metrics ) logging.info( f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}" ) tb_logger.log_dic( {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, global_step=self.effective_iter, ) # save to file eval_text = eval_dic_to_text( val_metrics=val_metric_dic, dataset_name=val_dataset_name, sample_list_path=val_loader.dataset.filename_ls_path, ) _save_to = os.path.join( self.out_dir_eval, f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", ) with open(_save_to, "w+") as f: f.write(eval_text) # Update main eval metric if 0 == i: main_eval_metric = val_metric_dic[self.main_val_metric] if ( "minimize" == self.main_val_metric_goal and main_eval_metric < self.best_metric or "maximize" == self.main_val_metric_goal and main_eval_metric > self.best_metric ): self.best_metric = main_eval_metric logging.info( f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" ) # Save a checkpoint self.save_checkpoint( ckpt_name=self._get_backup_ckpt_name(), save_train_state=False ) def visualize(self): for val_loader in self.vis_loaders: vis_dataset_name = val_loader.dataset.disp_name vis_out_dir = os.path.join( self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name ) os.makedirs(vis_out_dir, exist_ok=True) _ = self.validate_single_dataset( data_loader=val_loader, metric_tracker=self.val_metrics, save_to_dir=vis_out_dir, ) @torch.no_grad() def validate_single_dataset( self, data_loader: DataLoader, metric_tracker: MetricTracker, save_to_dir: str = None, ): self.model.to(self.device) metric_tracker.reset() # Generate seed sequence for consistent evaluation val_init_seed = self.cfg.validation.init_seed val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) for i, batch in enumerate( tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), start=1, ): assert 1 == data_loader.batch_size # Read input image rgb_int = batch["rgb_int"] # [3, H, W] # GT depth depth_raw_ts = batch["depth_raw_linear"].squeeze() depth_raw = depth_raw_ts.numpy() depth_raw_ts = depth_raw_ts.to(self.device) valid_mask_ts = batch["valid_mask_raw"].squeeze() valid_mask = valid_mask_ts.numpy() valid_mask_ts = valid_mask_ts.to(self.device) # Predict depth pipe_out: DepthMasterDepthOutput = self.model( rgb_int, processing_res=self.cfg.validation.processing_res, match_input_res=self.cfg.validation.match_input_res, batch_size=1, # use batch size 1 to increase reproducibility color_map=None, show_progress_bar=False, resample_method=self.cfg.validation.resample_method, ) depth_pred: np.ndarray = pipe_out.depth_np.squeeze() if "least_square" == self.cfg.eval.alignment: depth_pred, scale, shift = align_depth_least_square( gt_arr=depth_raw, pred_arr=depth_pred, valid_mask_arr=valid_mask, return_scale_shift=True, max_resolution=self.cfg.eval.align_max_res, ) elif "least_square_disparity" == self.cfg.eval.alignment: gt_disparity = depth_raw gt_non_neg_mask = gt_disparity > 0 # LS alignment in disparity space pred_non_neg_mask = depth_pred > 0 valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask disparity_pred, scale, shift = align_depth_least_square( gt_arr=gt_disparity, pred_arr=depth_pred, valid_mask_arr=valid_nonnegative_mask, return_scale_shift=True, ) # convert to depth disparity_pred = np.clip( disparity_pred, a_min=1e-3, a_max=None ) # avoid 0 disparity depth_pred = disparity2depth(disparity_pred) depth_raw_ts = disparity2depth(depth_raw_ts) elif "least_square_sqrt_disp" == self.cfg.eval.alignment: gt_sqrt_disp = depth_raw gt_non_neg_mask = gt_sqrt_disp > 0 # LS alignment in sqrt space pred_non_neg_mask = depth_pred > 0 valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask depth_sqrt_disp_pred, scale, shift = align_depth_least_square( gt_arr=gt_sqrt_disp, pred_arr=depth_pred, valid_mask_arr=valid_mask, return_scale_shift=True, ) # convert to depth disparity_pred = depth_sqrt_disp_pred ** 2 depth_raw_ts = torch.pow(depth_raw_ts, 2) # convert to depth disparity_pred = np.clip( disparity_pred, a_min=1e-3, a_max=None ) # avoid 0 disparity depth_pred = disparity2depth(disparity_pred) depth_raw_ts = disparity2depth(depth_raw_ts) else: raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") # Clip to dataset min max depth_pred = np.clip( depth_pred, a_min=data_loader.dataset.min_depth, a_max=data_loader.dataset.max_depth, ) # clip to d > 0 for evaluation depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) # Evaluate sample_metric = [] depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) for met_func in self.metric_funcs: _metric_name = met_func.__name__ _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() sample_metric.append(_metric.__str__()) metric_tracker.update(_metric_name, _metric) # Save as 16-bit uint png if save_to_dir is not None: img_name = batch["rgb_relative_path"][0].replace("/", "_") png_save_path = os.path.join(save_to_dir, f"{img_name}.png") depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16) Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") return metric_tracker.result() def _get_next_seed(self): if 0 == len(self.global_seed_sequence): self.global_seed_sequence = generate_seed_sequence( initial_seed=self.seed, length=self.max_iter * self.gradient_accumulation_steps, ) logging.info( f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" ) return self.global_seed_sequence.pop() def save_checkpoint(self, ckpt_name, save_train_state): ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) logging.info(f"Saving checkpoint to: {ckpt_dir}") # Backup previous checkpoint temp_ckpt_dir = None if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): temp_ckpt_dir = os.path.join( os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" ) if os.path.exists(temp_ckpt_dir): shutil.rmtree(temp_ckpt_dir, ignore_errors=True) os.rename(ckpt_dir, temp_ckpt_dir) logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") # Save UNet unet_path = os.path.join(ckpt_dir, "unet") self.model.unet.save_pretrained(unet_path, safe_serialization=False) logging.info(f"UNet is saved to: {unet_path}") # Save DINOv2_Adapter adapter_path = os.path.join(ckpt_dir, "dinov2_adapter.pth") state_dict = self.dinov2_adapter.state_dict() torch.save(state_dict, adapter_path) logging.info(f"dinov2_adapter is saved to: {adapter_path}") if save_train_state: state = { "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "config": self.cfg, "effective_iter": self.effective_iter, "epoch": self.epoch, "n_batch_in_epoch": self.n_batch_in_epoch, "best_metric": self.best_metric, "in_evaluation": self.in_evaluation, "global_seed_sequence": self.global_seed_sequence, } train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") torch.save(state, train_state_path) # iteration indicator f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") f.close() logging.info(f"Trainer state is saved to: {train_state_path}") # Remove temp ckpt if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): shutil.rmtree(temp_ckpt_dir, ignore_errors=True) logging.debug("Old checkpoint backup is removed.") def load_checkpoint( self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True ): logging.info(f"Loading checkpoint from: {ckpt_path}") # Load UNet _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") self.model.unet.load_state_dict( torch.load(_model_path, map_location=self.device) ) self.model.unet.to(self.device) logging.info(f"UNet parameters are loaded from {_model_path}") # Load DINOv2_adapter _model_path = os.path.join(ckpt_path, "dinov2_adapter.pth") self.dinov2_adapter.load_state_dict( torch.load(_model_path, map_location=self.device) ) self.dinov2_adapter.to(self.device) logging.info(f"dinov2_adapter parameters are loaded from {_model_path}") # Load training states if load_trainer_state: checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) self.effective_iter = checkpoint["effective_iter"] self.epoch = checkpoint["epoch"] self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] self.in_evaluation = checkpoint["in_evaluation"] self.global_seed_sequence = checkpoint["global_seed_sequence"] self.best_metric = checkpoint["best_metric"] self.optimizer.load_state_dict(checkpoint["optimizer"]) logging.info(f"optimizer state is loaded from {ckpt_path}") if resume_lr_scheduler: self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) logging.info(f"LR scheduler state is loaded from {ckpt_path}") logging.info( f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" ) return def _get_backup_ckpt_name(self): return f"iter_{self.effective_iter:06d}"