# An official reimplemented version of Marigold training script. # Last modified: 2024-04-29 # # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import logging import os import shutil from datetime import datetime from typing import List, Union import numpy as np import torch # from diffusers import DDPMScheduler from omegaconf import OmegaConf # from torch.nn import Conv2d # from torch.nn.parameter import Parameter from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from tqdm import tqdm from PIL import Image from depthmaster.depthmaster_pipeline import DepthMasterPipeline, DepthMasterDepthOutput from src.util import metric from src.util.data_loader import skip_first_batches from src.util.logging_util import tb_logger, eval_dic_to_text from src.util.loss import get_loss from src.util.lr_scheduler import IterExponential from src.util.metric import MetricTracker from src.util.alignment import ( align_depth_least_square, depth2disparity, disparity2depth, align_depth_least_square_torch_mask, align_depth_medium_mask ) # from src.util.alignment import align_depth_least_square # from src.util.alignment import align_depth_least_square from src.util.seeding import generate_seed_sequence import torch.nn.functional as F class DepthMasterTrainerS2: def __init__( self, cfg: OmegaConf, model: DepthMasterPipeline, train_dataloader: DataLoader, device, base_ckpt_dir, out_dir_ckpt, out_dir_eval, out_dir_vis, accumulation_steps: int, val_dataloaders: List[DataLoader] = None, vis_dataloaders: List[DataLoader] = None, ): self.cfg: OmegaConf = cfg self.model: DepthMasterPipeline = model self.device = device self.seed: Union[int, None] = ( self.cfg.trainer.init_seed ) # used to generate seed sequence, set to `None` to train w/o seeding self.out_dir_ckpt = out_dir_ckpt self.out_dir_eval = out_dir_eval self.out_dir_vis = out_dir_vis self.train_loader: DataLoader = train_dataloader self.val_loaders: List[DataLoader] = val_dataloaders self.vis_loaders: List[DataLoader] = vis_dataloaders self.accumulation_steps: int = accumulation_steps # Encode empty text prompt self.model.encode_empty_text() self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) self.model.unet.enable_xformers_memory_efficient_attention() # Trainability self.model.vae.requires_grad_(False) self.model.vae.decoder.requires_grad_(False) self.model.text_encoder.requires_grad_(False) self.model.unet.requires_grad_(True) # Optimizer !should be defined after input layer is adapted lr = self.cfg.lr self.optimizer = Adam(self.model.unet.parameters(), lr=lr) # LR scheduler lr_func = IterExponential( total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, ) self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) # Loss self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) self.grad_loss = get_loss(loss_name=self.cfg.grad_loss.name, ** self.cfg.grad_loss.kwargs) # Eval metrics self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] self.train_metrics = MetricTracker(*["loss", "grad_loss"]) self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) # main metric for best checkpoint saving self.main_val_metric = cfg.validation.main_val_metric self.main_val_metric_goal = cfg.validation.main_val_metric_goal assert ( self.main_val_metric in cfg.eval.eval_metrics ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 # Settings self.max_epoch = self.cfg.max_epoch self.max_iter = self.cfg.max_iter self.gradient_accumulation_steps = accumulation_steps self.gt_depth_type = self.cfg.gt_depth_type self.gt_mask_type = self.cfg.gt_mask_type self.save_period = self.cfg.trainer.save_period self.backup_period = self.cfg.trainer.backup_period self.val_period = self.cfg.trainer.validation_period self.vis_period = self.cfg.trainer.visualization_period # Internal variables self.epoch = 1 self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training self.effective_iter = 0 # how many times optimizer.step() is called self.in_evaluation = False self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming def grad(self, x): # x.shape : n, c, h, w diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] diff_45 = x[..., :-1, 1:] - x[..., 1:, :-1] diff_135 = x[..., 1:, 1:] - x[..., :-1, :-1] # mag = diff_x**2 + diff_y**2 # # angle_ratio # angle = torch.atan(diff_y / (diff_x + 1e-10)) # result = torch.cat([mag, angle], dim=1) result = torch.cat([diff_x, diff_y, diff_45, diff_135], dim=1) return result def train(self, t_end=None): logging.info("Start training") device = self.device self.model.to(device) self.visualize() if self.in_evaluation: logging.info( "Last evaluation was not finished, will do evaluation before continue training." ) self.validate() self.train_metrics.reset() accumulated_step = 0 progress_bar = tqdm( range(0, self.max_iter), initial=self.effective_iter, desc="iter" ) for epoch in range(self.epoch, self.max_epoch + 1): self.epoch = epoch logging.debug(f"epoch: {self.epoch}") # Skip previous batches when resume for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): self.model.unet.train() # >>> With gradient accumulation >>> # Get data rgb = batch["rgb_norm"].to(device) depth_gt_for_latent = batch[self.gt_depth_type].to(device) if self.gt_mask_type is not None: valid_mask_for_latent = batch[self.gt_mask_type].to(device) else: raise NotImplementedError batch_size = rgb.shape[0] with torch.no_grad(): # Encode image rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] # Text embedding text_embed = self.empty_text_embed.to(device).repeat( (batch_size, 1, 1) ) # [B, 77, 1024] rgb_latent = self.model.unet( rgb_latent, 1, text_embed ).sample # [B, 4, h, w] depth_pred = self.model.decode_depth(rgb_latent) depth_gt_for_loss = depth_gt_for_latent aligned_pred = depth_pred if self.gt_mask_type is not None: loss = self.loss(aligned_pred[valid_mask_for_latent].float(), depth_gt_for_loss[valid_mask_for_latent].float()).mean() else: loss = self.loss(aligned_pred.float(), depth_gt_for_loss.float()).mean() self.train_metrics.update("loss", loss.item()) # grad loss depth_gt_for_loss[~valid_mask_for_latent] = 0 grad_gt = self.grad(depth_gt_for_loss) aligned_pred[~valid_mask_for_latent] = 0 grad_pred = self.grad(aligned_pred) grad_loss = self.grad_loss(grad_gt, grad_pred) self.train_metrics.update(f"grad_loss", grad_loss.item()) loss += self.cfg.grad_loss.lamda * grad_loss loss = loss / self.gradient_accumulation_steps loss.backward() accumulated_step += 1 self.n_batch_in_epoch += 1 # Practical batch end # Perform optimization step if accumulated_step >= self.gradient_accumulation_steps: self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() accumulated_step = 0 self.effective_iter += 1 progress_bar.update(1) # Log to tensorboard accumulated_loss = self.train_metrics.result()["loss"] logs = {"loss": accumulated_loss} progress_bar.set_postfix(**logs) tb_logger.log_dic( { f"train/{k}": v for k, v in self.train_metrics.result().items() }, global_step=self.effective_iter, ) tb_logger.writer.add_scalar( "lr", self.lr_scheduler.get_last_lr()[0], global_step=self.effective_iter, ) tb_logger.writer.add_scalar( "n_batch_in_epoch", self.n_batch_in_epoch, global_step=self.effective_iter, ) self.train_metrics.reset() # Per-step callback self._train_step_callback() # End of training if self.max_iter > 0 and self.effective_iter >= self.max_iter: self.save_checkpoint( ckpt_name=self._get_backup_ckpt_name(), save_train_state=False, ) logging.info("Training ended.") return # Time's up elif t_end is not None and datetime.now() >= t_end: self.save_checkpoint(ckpt_name="latest", save_train_state=True) logging.info("Time is up, training paused.") return torch.cuda.empty_cache() # <<< Effective batch end <<< # Epoch end self.n_batch_in_epoch = 0 def encode_depth(self, depth_in): # stack depth into 3-channel stacked = self.stack_depth_images(depth_in) # encode using VAE encoder depth_latent = self.model.encode_rgb(stacked) return depth_latent @staticmethod def stack_depth_images(depth_in): if 4 == len(depth_in.shape): stacked = depth_in.repeat(1, 3, 1, 1) elif 3 == len(depth_in.shape): stacked = depth_in.unsqueeze(1) stacked = depth_in.repeat(1, 3, 1, 1) return stacked def _train_step_callback(self): """Executed after every iteration""" # Save backup (with a larger interval, without training states) if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: self.save_checkpoint( ckpt_name=self._get_backup_ckpt_name(), save_train_state=False ) _is_latest_saved = False # Validation if self.val_period > 0 and 0 == self.effective_iter % self.val_period: self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished self.save_checkpoint(ckpt_name="latest", save_train_state=True) _is_latest_saved = True self.validate() self.in_evaluation = False self.save_checkpoint(ckpt_name="latest", save_train_state=True) # Save training checkpoint (can be resumed) if ( self.save_period > 0 and 0 == self.effective_iter % self.save_period and not _is_latest_saved ): self.save_checkpoint(ckpt_name="latest", save_train_state=True) # Visualization if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: self.visualize() def validate(self): for i, val_loader in enumerate(self.val_loaders): val_dataset_name = val_loader.dataset.disp_name val_metric_dic = self.validate_single_dataset( data_loader=val_loader, metric_tracker=self.val_metrics ) logging.info( f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}" ) tb_logger.log_dic( {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, global_step=self.effective_iter, ) # save to file eval_text = eval_dic_to_text( val_metrics=val_metric_dic, dataset_name=val_dataset_name, sample_list_path=val_loader.dataset.filename_ls_path, ) _save_to = os.path.join( self.out_dir_eval, f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", ) with open(_save_to, "w+") as f: f.write(eval_text) # Update main eval metric if 0 == i: main_eval_metric = val_metric_dic[self.main_val_metric] if ( "minimize" == self.main_val_metric_goal and main_eval_metric < self.best_metric or "maximize" == self.main_val_metric_goal and main_eval_metric > self.best_metric ): self.best_metric = main_eval_metric logging.info( f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" ) # Save a checkpoint self.save_checkpoint( ckpt_name=self._get_backup_ckpt_name(), save_train_state=False ) def visualize(self): for val_loader in self.vis_loaders: vis_dataset_name = val_loader.dataset.disp_name vis_out_dir = os.path.join( self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name ) os.makedirs(vis_out_dir, exist_ok=True) _ = self.validate_single_dataset( data_loader=val_loader, metric_tracker=self.val_metrics, save_to_dir=vis_out_dir, ) @torch.no_grad() def validate_single_dataset( self, data_loader: DataLoader, metric_tracker: MetricTracker, save_to_dir: str = None, ): self.model.to(self.device) metric_tracker.reset() # Generate seed sequence for consistent evaluation val_init_seed = self.cfg.validation.init_seed val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) for i, batch in enumerate( tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), start=1, ): assert 1 == data_loader.batch_size # Read input image rgb_int = batch["rgb_int"] # [3, H, W] # GT depth depth_raw_ts = batch["depth_raw_linear"].squeeze() depth_raw = depth_raw_ts.numpy() depth_raw_ts = depth_raw_ts.to(self.device) valid_mask_ts = batch["valid_mask_raw"].squeeze() valid_mask = valid_mask_ts.numpy() valid_mask_ts = valid_mask_ts.to(self.device) # Predict depth pipe_out: DepthMasterDepthOutput = self.model( rgb_int, processing_res=self.cfg.validation.processing_res, match_input_res=self.cfg.validation.match_input_res, batch_size=1, # use batch size 1 to increase reproducibility color_map=None, show_progress_bar=False, resample_method=self.cfg.validation.resample_method, ) depth_pred: np.ndarray = pipe_out.depth_np.squeeze() if "least_square" == self.cfg.eval.alignment: depth_pred, scale, shift = align_depth_least_square( gt_arr=depth_raw, pred_arr=depth_pred, valid_mask_arr=valid_mask, return_scale_shift=True, max_resolution=self.cfg.eval.align_max_res, ) elif "least_square_disparity" == self.cfg.eval.alignment: # gt_disparity = depth_raw gt_disparity = depth2disparity(depth_raw) gt_non_neg_mask = gt_disparity > 0 # LS alignment in disparity space pred_non_neg_mask = depth_pred > 0 valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask disparity_pred, scale, shift = align_depth_least_square( gt_arr=gt_disparity, pred_arr=depth_pred, valid_mask_arr=valid_nonnegative_mask, return_scale_shift=True, ) # convert to depth disparity_pred = np.clip( disparity_pred, a_min=1e-3, a_max=None ) # avoid 0 disparity depth_pred = disparity2depth(disparity_pred) depth_raw_ts = disparity2depth(depth_raw_ts) elif "least_square_sqrt_disp" == self.cfg.eval.alignment: # gt_sqrt_disp = depth_raw gt_sqrt_disp = np.sqrt(depth2disparity(depth_raw)) gt_non_neg_mask = gt_sqrt_disp > 0 # LS alignment in sqrt space pred_non_neg_mask = depth_pred > 0 valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask depth_sqrt_disp_pred, scale, shift = align_depth_least_square( gt_arr=gt_sqrt_disp, pred_arr=depth_pred, valid_mask_arr=valid_mask, return_scale_shift=True, ) # convert to depth disparity_pred = depth_sqrt_disp_pred ** 2 depth_raw_ts = torch.pow(depth_raw_ts, 2) # convert to depth disparity_pred = np.clip( disparity_pred, a_min=1e-3, a_max=None ) # avoid 0 disparity depth_pred = disparity2depth(disparity_pred) depth_raw_ts = disparity2depth(depth_raw_ts) else: raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") # Clip to dataset min max depth_pred = np.clip( depth_pred, a_min=data_loader.dataset.min_depth, a_max=data_loader.dataset.max_depth, ) # clip to d > 0 for evaluation depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) # Evaluate sample_metric = [] depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) for met_func in self.metric_funcs: _metric_name = met_func.__name__ _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() sample_metric.append(_metric.__str__()) metric_tracker.update(_metric_name, _metric) # Save as 16-bit uint png if save_to_dir is not None: img_name = batch["rgb_relative_path"][0].replace("/", "_") png_save_path = os.path.join(save_to_dir, f"{img_name}.png") depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16) Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") return metric_tracker.result() def _get_next_seed(self): if 0 == len(self.global_seed_sequence): self.global_seed_sequence = generate_seed_sequence( initial_seed=self.seed, length=self.max_iter * self.gradient_accumulation_steps, ) logging.info( f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" ) return self.global_seed_sequence.pop() def save_checkpoint(self, ckpt_name, save_train_state): ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) logging.info(f"Saving checkpoint to: {ckpt_dir}") # Backup previous checkpoint temp_ckpt_dir = None if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): temp_ckpt_dir = os.path.join( os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" ) if os.path.exists(temp_ckpt_dir): shutil.rmtree(temp_ckpt_dir, ignore_errors=True) os.rename(ckpt_dir, temp_ckpt_dir) logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") # Save UNet unet_path = os.path.join(ckpt_dir, "unet") self.model.unet.save_pretrained(unet_path, safe_serialization=False) logging.info(f"UNet is saved to: {unet_path}") if save_train_state: state = { "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "config": self.cfg, "effective_iter": self.effective_iter, "epoch": self.epoch, "n_batch_in_epoch": self.n_batch_in_epoch, "best_metric": self.best_metric, "in_evaluation": self.in_evaluation, "global_seed_sequence": self.global_seed_sequence, } train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") torch.save(state, train_state_path) # iteration indicator f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") f.close() logging.info(f"Trainer state is saved to: {train_state_path}") # Remove temp ckpt if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): shutil.rmtree(temp_ckpt_dir, ignore_errors=True) logging.debug("Old checkpoint backup is removed.") def load_checkpoint( self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True ): logging.info(f"Loading checkpoint from: {ckpt_path}") # Load UNet _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") self.model.unet.load_state_dict( torch.load(_model_path, map_location=self.device) ) self.model.unet.to(self.device) logging.info(f"UNet parameters are loaded from {_model_path}") # Load training states if load_trainer_state: checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) self.effective_iter = checkpoint["effective_iter"] self.epoch = checkpoint["epoch"] self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] self.in_evaluation = checkpoint["in_evaluation"] self.global_seed_sequence = checkpoint["global_seed_sequence"] self.best_metric = checkpoint["best_metric"] self.optimizer.load_state_dict(checkpoint["optimizer"]) logging.info(f"optimizer state is loaded from {ckpt_path}") if resume_lr_scheduler: self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) logging.info(f"LR scheduler state is loaded from {ckpt_path}") logging.info( f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" ) return def _get_backup_ckpt_name(self): return f"iter_{self.effective_iter:06d}"