Diffusers
Safetensors
EvalMDE / DepthMaster /src /trainer /trainer_s1.py
zeyuren2002's picture
Add files using upload-large-folder tool
4b7b610 verified
# Last modified: 2025-07-13
#
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
#
# This file has been modified from the original version.
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
# --------------------------------------------------------------------------
import logging
import os
import random
import shutil
from datetime import datetime
from typing import List, Union
import numpy as np
import torch
from omegaconf import OmegaConf
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from depthmaster import DepthMasterPipeline, DepthMasterDepthOutput
from src.util import metric
from src.util.data_loader import skip_first_batches
from src.util.logging_util import tb_logger, eval_dic_to_text
from src.util.loss import get_loss, SSIM
from src.util.lr_scheduler import IterExponential
from src.util.metric import MetricTracker
from src.util.alignment import (
align_depth_least_square,
depth2disparity,
disparity2depth,
)
from src.util.seeding import generate_seed_sequence
from src.util.build_mlp import build_mlp_
from torchvision.transforms import Normalize
from external_encoder.dinov2.dinov2 import DINOv2
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
class DepthMasterTrainerS1:
def __init__(
self,
cfg: OmegaConf,
model: DepthMasterPipeline,
train_dataloader: DataLoader,
device,
base_ckpt_dir,
out_dir_ckpt,
out_dir_eval,
out_dir_vis,
accumulation_steps: int,
val_dataloaders: List[DataLoader] = None,
vis_dataloaders: List[DataLoader] = None,
):
self.cfg: OmegaConf = cfg
self.model: DepthMasterPipeline = model
self.device = device
self.seed: Union[int, None] = (
self.cfg.trainer.init_seed
) # used to generate seed sequence, set to `None` to train w/o seeding
self.out_dir_ckpt = out_dir_ckpt
self.out_dir_eval = out_dir_eval
self.out_dir_vis = out_dir_vis
self.train_loader: DataLoader = train_dataloader
self.val_loaders: List[DataLoader] = val_dataloaders
self.vis_loaders: List[DataLoader] = vis_dataloaders
self.accumulation_steps: int = accumulation_steps
# Encode empty text prompt
self.model.encode_empty_text()
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
self.model.unet.enable_xformers_memory_efficient_attention()
# Initialize DINOv2 encoder
self.dinov2_encoder = DINOv2(model_name='vitg')
dinov2_encoder_dict = self.dinov2_encoder.state_dict()
pretrained_ckpt_dict = torch.load(f'checkpoints/depth_anything_v2_vitg.pth', map_location='cpu')
pretrained_dict = {k.replace('pretrained.', ''): v for k, v in pretrained_ckpt_dict.items() if k.replace('pretrained.', '') in dinov2_encoder_dict}
self.dinov2_encoder.load_state_dict(pretrained_dict)
del self.dinov2_encoder.head
self.dinov2_encoder.head = torch.nn.Identity()
self.dinov2_encoder.eval()
# Initialize adapter to align the feat dimension of SD and DINOv2
self.dinov2_adapter = build_mlp_(hidden_size=1280, projector_dim=1536, z_dim=1536)
# Trainability
self.dinov2_adapter.requires_grad_(True)
self.dinov2_encoder.requires_grad_(False)
self.model.vae.requires_grad_(False)
self.model.text_encoder.requires_grad_(False)
self.model.unet.requires_grad_(True)
# Optimizer !should be defined after input layer is adapted
lr = self.cfg.lr
self.optimizer = Adam([
{'params': self.model.unet.parameters(), 'lr': lr},
{'params': self.dinov2_adapter.parameters(), 'lr': lr}
])
# LR scheduler
lr_func = IterExponential(
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
)
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
# Loss
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
# Eval metrics
self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
self.train_metrics = MetricTracker(*["loss", "feat_align_loss"])
self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
# main metric for best checkpoint saving
self.main_val_metric = cfg.validation.main_val_metric
self.main_val_metric_goal = cfg.validation.main_val_metric_goal
assert (
self.main_val_metric in cfg.eval.eval_metrics
), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
# Settings
self.max_epoch = self.cfg.max_epoch
self.max_iter = self.cfg.max_iter
self.gradient_accumulation_steps = accumulation_steps
self.gt_depth_type = self.cfg.gt_depth_type
self.gt_mask_type = self.cfg.gt_mask_type
self.save_period = self.cfg.trainer.save_period
self.backup_period = self.cfg.trainer.backup_period
self.val_period = self.cfg.trainer.validation_period
self.vis_period = self.cfg.trainer.visualization_period
# Internal variables
self.epoch = 1
self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
self.effective_iter = 0 # how many times optimizer.step() is called
self.in_evaluation = False
self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
def train(self, t_end=None):
logging.info("Start training")
device = self.device
self.model.to(device)
self.dinov2_encoder.to(device)
self.dinov2_adapter.to(device)
self.visualize()
if self.in_evaluation:
logging.info(
"Last evaluation was not finished, will do evaluation before continue training."
)
self.validate()
self.train_metrics.reset()
accumulated_step = 0
progress_bar = tqdm(
range(0, self.max_iter),
initial=self.effective_iter,
desc="iter"
)
for epoch in range(self.epoch, self.max_epoch + 1):
self.epoch = epoch
logging.debug(f"epoch: {self.epoch}")
# Skip previous batches when resume
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
self.model.unet.train()
self.dinov2_adapter.train()
# >>> With gradient accumulation >>>
# Get data
rgb = batch["rgb_norm"].to(device)
depth_gt_for_latent = batch[self.gt_depth_type].to(device)
if self.gt_mask_type is not None:
valid_mask_for_latent = batch[self.gt_mask_type].to(device)
invalid_mask = ~valid_mask_for_latent
valid_mask_down = ~torch.max_pool2d(
invalid_mask.float(), 8, 8
).bool()
valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1))
else:
raise NotImplementedError
batch_size = rgb.shape[0]
with torch.no_grad():
# Encode image
rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
# Encode GT depth
gt_depth_latent = self.encode_depth(
depth_gt_for_latent
) # [B, 4, h, w]
# DINOv2 feat
dinov2_input_rgb = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(rgb)
dinov2_input_rgb = F.interpolate(dinov2_input_rgb, scale_factor=0.875, mode='bicubic')
dinov2_z = self.dinov2_encoder.forward_features(dinov2_input_rgb)['x_norm_patchtokens']
# Text embedding
text_embed = self.empty_text_embed.to(device).repeat(
(batch_size, 1, 1)
) # [B, 77, 1024]
# Predict the noise residual
rgb_latent = self.model.unet(
rgb_latent, 1, text_embed
) # [B, 4, h, w]
feat_16 = rgb_latent.feat_64
rgb_latent = rgb_latent.sample
if self.gt_mask_type is not None:
loss = self.loss(
rgb_latent[valid_mask_down].float(),
gt_depth_latent[valid_mask_down].float(),
).mean()
else:
loss = self.loss(rgb_latent.float(), gt_depth_latent.float()).mean()
self.train_metrics.update("loss", loss.item())
# feat align loss
b, c, h, w = feat_16.shape
_, _, H, W = rgb_latent.shape
# update dinov2_adapter
unet_16_feat_aligned = self.dinov2_adapter(feat_16.permute(0, 2, 3, 1).reshape(batch_size, -1, c))
if torch.isnan(rgb_latent).any():
logging.warning("model_pred contains NaN.")
dinov2_z = dinov2_z.reshape(b, int(H/2), int(W/2), -1).permute(0, 3, 1, 2)
dinov2_z = F.interpolate(dinov2_z, size=(h, w), mode='bicubic').permute(0, 2, 3, 1).reshape(b, h*w, -1)
# kl loss
unet_16_feat_aligned = F.softmax(unet_16_feat_aligned, dim=-1)
dinov2_z = F.softmax(dinov2_z, dim=-1)
loss_feat_align = F.kl_div(unet_16_feat_aligned.log(), dinov2_z)
self.train_metrics.update("feat_align_loss", loss_feat_align)
loss += self.cfg.loss_feat_align.lamda * loss_feat_align
loss = loss / self.gradient_accumulation_steps
loss.backward()
accumulated_step += 1
self.n_batch_in_epoch += 1
# Practical batch end
# Perform optimization step
if accumulated_step >= self.gradient_accumulation_steps:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
accumulated_step = 0
self.effective_iter += 1
progress_bar.update(1)
# Log to tensorboard
accumulated_loss = self.train_metrics.result()["loss"]
logs = {"loss": accumulated_loss}
progress_bar.set_postfix(**logs)
tb_logger.log_dic(
{
f"train/{k}": v
for k, v in self.train_metrics.result().items()
},
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"lr",
self.lr_scheduler.get_last_lr()[0],
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"n_batch_in_epoch",
self.n_batch_in_epoch,
global_step=self.effective_iter,
)
self.train_metrics.reset()
# Per-step callback
self._train_step_callback()
# End of training
if self.max_iter > 0 and self.effective_iter >= self.max_iter:
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(),
save_train_state=False,
)
logging.info("Training ended.")
return
# Time's up
elif t_end is not None and datetime.now() >= t_end:
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
logging.info("Time is up, training paused.")
return
torch.cuda.empty_cache()
# <<< Effective batch end <<<
# Epoch end
self.n_batch_in_epoch = 0
def encode_depth(self, depth_in):
# stack depth into 3-channel
stacked = self.stack_depth_images(depth_in)
# encode using VAE encoder
depth_latent = self.model.encode_rgb(stacked)
return depth_latent
@staticmethod
def stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = depth_in.repeat(1, 3, 1, 1)
return stacked
def _train_step_callback(self):
"""Executed after every iteration"""
# Save backup (with a larger interval, without training states)
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
)
_is_latest_saved = False
# Validation
if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
_is_latest_saved = True
self.validate()
self.in_evaluation = False
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
# Save training checkpoint (can be resumed)
if (
self.save_period > 0
and 0 == self.effective_iter % self.save_period
and not _is_latest_saved
):
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
# Visualization
if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
self.visualize()
def validate(self):
for i, val_loader in enumerate(self.val_loaders):
val_dataset_name = val_loader.dataset.disp_name
val_metric_dic = self.validate_single_dataset(
data_loader=val_loader, metric_tracker=self.val_metrics
)
logging.info(
f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}"
)
tb_logger.log_dic(
{f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
global_step=self.effective_iter,
)
# save to file
eval_text = eval_dic_to_text(
val_metrics=val_metric_dic,
dataset_name=val_dataset_name,
sample_list_path=val_loader.dataset.filename_ls_path,
)
_save_to = os.path.join(
self.out_dir_eval,
f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
)
with open(_save_to, "w+") as f:
f.write(eval_text)
# Update main eval metric
if 0 == i:
main_eval_metric = val_metric_dic[self.main_val_metric]
if (
"minimize" == self.main_val_metric_goal
and main_eval_metric < self.best_metric
or "maximize" == self.main_val_metric_goal
and main_eval_metric > self.best_metric
):
self.best_metric = main_eval_metric
logging.info(
f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
)
# Save a checkpoint
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
)
def visualize(self):
for val_loader in self.vis_loaders:
vis_dataset_name = val_loader.dataset.disp_name
vis_out_dir = os.path.join(
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
)
os.makedirs(vis_out_dir, exist_ok=True)
_ = self.validate_single_dataset(
data_loader=val_loader,
metric_tracker=self.val_metrics,
save_to_dir=vis_out_dir,
)
@torch.no_grad()
def validate_single_dataset(
self,
data_loader: DataLoader,
metric_tracker: MetricTracker,
save_to_dir: str = None,
):
self.model.to(self.device)
metric_tracker.reset()
# Generate seed sequence for consistent evaluation
val_init_seed = self.cfg.validation.init_seed
val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
for i, batch in enumerate(
tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
start=1,
):
assert 1 == data_loader.batch_size
# Read input image
rgb_int = batch["rgb_int"] # [3, H, W]
# GT depth
depth_raw_ts = batch["depth_raw_linear"].squeeze()
depth_raw = depth_raw_ts.numpy()
depth_raw_ts = depth_raw_ts.to(self.device)
valid_mask_ts = batch["valid_mask_raw"].squeeze()
valid_mask = valid_mask_ts.numpy()
valid_mask_ts = valid_mask_ts.to(self.device)
# Predict depth
pipe_out: DepthMasterDepthOutput = self.model(
rgb_int,
processing_res=self.cfg.validation.processing_res,
match_input_res=self.cfg.validation.match_input_res,
batch_size=1, # use batch size 1 to increase reproducibility
color_map=None,
show_progress_bar=False,
resample_method=self.cfg.validation.resample_method,
)
depth_pred: np.ndarray = pipe_out.depth_np.squeeze()
if "least_square" == self.cfg.eval.alignment:
depth_pred, scale, shift = align_depth_least_square(
gt_arr=depth_raw,
pred_arr=depth_pred,
valid_mask_arr=valid_mask,
return_scale_shift=True,
max_resolution=self.cfg.eval.align_max_res,
)
elif "least_square_disparity" == self.cfg.eval.alignment:
gt_disparity = depth_raw
gt_non_neg_mask = gt_disparity > 0
# LS alignment in disparity space
pred_non_neg_mask = depth_pred > 0
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
disparity_pred, scale, shift = align_depth_least_square(
gt_arr=gt_disparity,
pred_arr=depth_pred,
valid_mask_arr=valid_nonnegative_mask,
return_scale_shift=True,
)
# convert to depth
disparity_pred = np.clip(
disparity_pred, a_min=1e-3, a_max=None
) # avoid 0 disparity
depth_pred = disparity2depth(disparity_pred)
depth_raw_ts = disparity2depth(depth_raw_ts)
elif "least_square_sqrt_disp" == self.cfg.eval.alignment:
gt_sqrt_disp = depth_raw
gt_non_neg_mask = gt_sqrt_disp > 0
# LS alignment in sqrt space
pred_non_neg_mask = depth_pred > 0
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
depth_sqrt_disp_pred, scale, shift = align_depth_least_square(
gt_arr=gt_sqrt_disp,
pred_arr=depth_pred,
valid_mask_arr=valid_mask,
return_scale_shift=True,
)
# convert to depth
disparity_pred = depth_sqrt_disp_pred ** 2
depth_raw_ts = torch.pow(depth_raw_ts, 2)
# convert to depth
disparity_pred = np.clip(
disparity_pred, a_min=1e-3, a_max=None
) # avoid 0 disparity
depth_pred = disparity2depth(disparity_pred)
depth_raw_ts = disparity2depth(depth_raw_ts)
else:
raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
# Clip to dataset min max
depth_pred = np.clip(
depth_pred,
a_min=data_loader.dataset.min_depth,
a_max=data_loader.dataset.max_depth,
)
# clip to d > 0 for evaluation
depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
# Evaluate
sample_metric = []
depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
for met_func in self.metric_funcs:
_metric_name = met_func.__name__
_metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
sample_metric.append(_metric.__str__())
metric_tracker.update(_metric_name, _metric)
# Save as 16-bit uint png
if save_to_dir is not None:
img_name = batch["rgb_relative_path"][0].replace("/", "_")
png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16)
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
return metric_tracker.result()
def _get_next_seed(self):
if 0 == len(self.global_seed_sequence):
self.global_seed_sequence = generate_seed_sequence(
initial_seed=self.seed,
length=self.max_iter * self.gradient_accumulation_steps,
)
logging.info(
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
)
return self.global_seed_sequence.pop()
def save_checkpoint(self, ckpt_name, save_train_state):
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
logging.info(f"Saving checkpoint to: {ckpt_dir}")
# Backup previous checkpoint
temp_ckpt_dir = None
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
temp_ckpt_dir = os.path.join(
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
)
if os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
os.rename(ckpt_dir, temp_ckpt_dir)
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
# Save UNet
unet_path = os.path.join(ckpt_dir, "unet")
self.model.unet.save_pretrained(unet_path, safe_serialization=False)
logging.info(f"UNet is saved to: {unet_path}")
# Save DINOv2_Adapter
adapter_path = os.path.join(ckpt_dir, "dinov2_adapter.pth")
state_dict = self.dinov2_adapter.state_dict()
torch.save(state_dict, adapter_path)
logging.info(f"dinov2_adapter is saved to: {adapter_path}")
if save_train_state:
state = {
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"config": self.cfg,
"effective_iter": self.effective_iter,
"epoch": self.epoch,
"n_batch_in_epoch": self.n_batch_in_epoch,
"best_metric": self.best_metric,
"in_evaluation": self.in_evaluation,
"global_seed_sequence": self.global_seed_sequence,
}
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
torch.save(state, train_state_path)
# iteration indicator
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
f.close()
logging.info(f"Trainer state is saved to: {train_state_path}")
# Remove temp ckpt
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
logging.debug("Old checkpoint backup is removed.")
def load_checkpoint(
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
):
logging.info(f"Loading checkpoint from: {ckpt_path}")
# Load UNet
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
self.model.unet.load_state_dict(
torch.load(_model_path, map_location=self.device)
)
self.model.unet.to(self.device)
logging.info(f"UNet parameters are loaded from {_model_path}")
# Load DINOv2_adapter
_model_path = os.path.join(ckpt_path, "dinov2_adapter.pth")
self.dinov2_adapter.load_state_dict(
torch.load(_model_path, map_location=self.device)
)
self.dinov2_adapter.to(self.device)
logging.info(f"dinov2_adapter parameters are loaded from {_model_path}")
# Load training states
if load_trainer_state:
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
self.effective_iter = checkpoint["effective_iter"]
self.epoch = checkpoint["epoch"]
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
self.in_evaluation = checkpoint["in_evaluation"]
self.global_seed_sequence = checkpoint["global_seed_sequence"]
self.best_metric = checkpoint["best_metric"]
self.optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(f"optimizer state is loaded from {ckpt_path}")
if resume_lr_scheduler:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
logging.info(f"LR scheduler state is loaded from {ckpt_path}")
logging.info(
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
)
return
def _get_backup_ckpt_name(self):
return f"iter_{self.effective_iter:06d}"