Diffusers
Safetensors
EvalMDE / DepthMaster /src /trainer /trainer_s2.py
zeyuren2002's picture
Add files using upload-large-folder tool
4b7b610 verified
# An official reimplemented version of Marigold training script.
# Last modified: 2024-04-29
#
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import logging
import os
import shutil
from datetime import datetime
from typing import List, Union
import numpy as np
import torch
# from diffusers import DDPMScheduler
from omegaconf import OmegaConf
# from torch.nn import Conv2d
# from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
from depthmaster.depthmaster_pipeline import DepthMasterPipeline, DepthMasterDepthOutput
from src.util import metric
from src.util.data_loader import skip_first_batches
from src.util.logging_util import tb_logger, eval_dic_to_text
from src.util.loss import get_loss
from src.util.lr_scheduler import IterExponential
from src.util.metric import MetricTracker
from src.util.alignment import (
align_depth_least_square,
depth2disparity,
disparity2depth,
align_depth_least_square_torch_mask,
align_depth_medium_mask
)
# from src.util.alignment import align_depth_least_square
# from src.util.alignment import align_depth_least_square
from src.util.seeding import generate_seed_sequence
import torch.nn.functional as F
class DepthMasterTrainerS2:
def __init__(
self,
cfg: OmegaConf,
model: DepthMasterPipeline,
train_dataloader: DataLoader,
device,
base_ckpt_dir,
out_dir_ckpt,
out_dir_eval,
out_dir_vis,
accumulation_steps: int,
val_dataloaders: List[DataLoader] = None,
vis_dataloaders: List[DataLoader] = None,
):
self.cfg: OmegaConf = cfg
self.model: DepthMasterPipeline = model
self.device = device
self.seed: Union[int, None] = (
self.cfg.trainer.init_seed
) # used to generate seed sequence, set to `None` to train w/o seeding
self.out_dir_ckpt = out_dir_ckpt
self.out_dir_eval = out_dir_eval
self.out_dir_vis = out_dir_vis
self.train_loader: DataLoader = train_dataloader
self.val_loaders: List[DataLoader] = val_dataloaders
self.vis_loaders: List[DataLoader] = vis_dataloaders
self.accumulation_steps: int = accumulation_steps
# Encode empty text prompt
self.model.encode_empty_text()
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
self.model.unet.enable_xformers_memory_efficient_attention()
# Trainability
self.model.vae.requires_grad_(False)
self.model.vae.decoder.requires_grad_(False)
self.model.text_encoder.requires_grad_(False)
self.model.unet.requires_grad_(True)
# Optimizer !should be defined after input layer is adapted
lr = self.cfg.lr
self.optimizer = Adam(self.model.unet.parameters(), lr=lr)
# LR scheduler
lr_func = IterExponential(
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
)
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
# Loss
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
self.grad_loss = get_loss(loss_name=self.cfg.grad_loss.name, ** self.cfg.grad_loss.kwargs)
# Eval metrics
self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
self.train_metrics = MetricTracker(*["loss", "grad_loss"])
self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
# main metric for best checkpoint saving
self.main_val_metric = cfg.validation.main_val_metric
self.main_val_metric_goal = cfg.validation.main_val_metric_goal
assert (
self.main_val_metric in cfg.eval.eval_metrics
), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
# Settings
self.max_epoch = self.cfg.max_epoch
self.max_iter = self.cfg.max_iter
self.gradient_accumulation_steps = accumulation_steps
self.gt_depth_type = self.cfg.gt_depth_type
self.gt_mask_type = self.cfg.gt_mask_type
self.save_period = self.cfg.trainer.save_period
self.backup_period = self.cfg.trainer.backup_period
self.val_period = self.cfg.trainer.validation_period
self.vis_period = self.cfg.trainer.visualization_period
# Internal variables
self.epoch = 1
self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
self.effective_iter = 0 # how many times optimizer.step() is called
self.in_evaluation = False
self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
def grad(self, x):
# x.shape : n, c, h, w
diff_x = x[..., 1:, 1:] - x[..., 1:, :-1]
diff_y = x[..., 1:, 1:] - x[..., :-1, 1:]
diff_45 = x[..., :-1, 1:] - x[..., 1:, :-1]
diff_135 = x[..., 1:, 1:] - x[..., :-1, :-1]
# mag = diff_x**2 + diff_y**2
# # angle_ratio
# angle = torch.atan(diff_y / (diff_x + 1e-10))
# result = torch.cat([mag, angle], dim=1)
result = torch.cat([diff_x, diff_y, diff_45, diff_135], dim=1)
return result
def train(self, t_end=None):
logging.info("Start training")
device = self.device
self.model.to(device)
self.visualize()
if self.in_evaluation:
logging.info(
"Last evaluation was not finished, will do evaluation before continue training."
)
self.validate()
self.train_metrics.reset()
accumulated_step = 0
progress_bar = tqdm(
range(0, self.max_iter),
initial=self.effective_iter,
desc="iter"
)
for epoch in range(self.epoch, self.max_epoch + 1):
self.epoch = epoch
logging.debug(f"epoch: {self.epoch}")
# Skip previous batches when resume
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
self.model.unet.train()
# >>> With gradient accumulation >>>
# Get data
rgb = batch["rgb_norm"].to(device)
depth_gt_for_latent = batch[self.gt_depth_type].to(device)
if self.gt_mask_type is not None:
valid_mask_for_latent = batch[self.gt_mask_type].to(device)
else:
raise NotImplementedError
batch_size = rgb.shape[0]
with torch.no_grad():
# Encode image
rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
# Text embedding
text_embed = self.empty_text_embed.to(device).repeat(
(batch_size, 1, 1)
) # [B, 77, 1024]
rgb_latent = self.model.unet(
rgb_latent, 1, text_embed
).sample # [B, 4, h, w]
depth_pred = self.model.decode_depth(rgb_latent)
depth_gt_for_loss = depth_gt_for_latent
aligned_pred = depth_pred
if self.gt_mask_type is not None:
loss = self.loss(aligned_pred[valid_mask_for_latent].float(), depth_gt_for_loss[valid_mask_for_latent].float()).mean()
else:
loss = self.loss(aligned_pred.float(), depth_gt_for_loss.float()).mean()
self.train_metrics.update("loss", loss.item())
# grad loss
depth_gt_for_loss[~valid_mask_for_latent] = 0
grad_gt = self.grad(depth_gt_for_loss)
aligned_pred[~valid_mask_for_latent] = 0
grad_pred = self.grad(aligned_pred)
grad_loss = self.grad_loss(grad_gt, grad_pred)
self.train_metrics.update(f"grad_loss", grad_loss.item())
loss += self.cfg.grad_loss.lamda * grad_loss
loss = loss / self.gradient_accumulation_steps
loss.backward()
accumulated_step += 1
self.n_batch_in_epoch += 1
# Practical batch end
# Perform optimization step
if accumulated_step >= self.gradient_accumulation_steps:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
accumulated_step = 0
self.effective_iter += 1
progress_bar.update(1)
# Log to tensorboard
accumulated_loss = self.train_metrics.result()["loss"]
logs = {"loss": accumulated_loss}
progress_bar.set_postfix(**logs)
tb_logger.log_dic(
{
f"train/{k}": v
for k, v in self.train_metrics.result().items()
},
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"lr",
self.lr_scheduler.get_last_lr()[0],
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"n_batch_in_epoch",
self.n_batch_in_epoch,
global_step=self.effective_iter,
)
self.train_metrics.reset()
# Per-step callback
self._train_step_callback()
# End of training
if self.max_iter > 0 and self.effective_iter >= self.max_iter:
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(),
save_train_state=False,
)
logging.info("Training ended.")
return
# Time's up
elif t_end is not None and datetime.now() >= t_end:
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
logging.info("Time is up, training paused.")
return
torch.cuda.empty_cache()
# <<< Effective batch end <<<
# Epoch end
self.n_batch_in_epoch = 0
def encode_depth(self, depth_in):
# stack depth into 3-channel
stacked = self.stack_depth_images(depth_in)
# encode using VAE encoder
depth_latent = self.model.encode_rgb(stacked)
return depth_latent
@staticmethod
def stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = depth_in.repeat(1, 3, 1, 1)
return stacked
def _train_step_callback(self):
"""Executed after every iteration"""
# Save backup (with a larger interval, without training states)
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
)
_is_latest_saved = False
# Validation
if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
_is_latest_saved = True
self.validate()
self.in_evaluation = False
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
# Save training checkpoint (can be resumed)
if (
self.save_period > 0
and 0 == self.effective_iter % self.save_period
and not _is_latest_saved
):
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
# Visualization
if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
self.visualize()
def validate(self):
for i, val_loader in enumerate(self.val_loaders):
val_dataset_name = val_loader.dataset.disp_name
val_metric_dic = self.validate_single_dataset(
data_loader=val_loader, metric_tracker=self.val_metrics
)
logging.info(
f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}"
)
tb_logger.log_dic(
{f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
global_step=self.effective_iter,
)
# save to file
eval_text = eval_dic_to_text(
val_metrics=val_metric_dic,
dataset_name=val_dataset_name,
sample_list_path=val_loader.dataset.filename_ls_path,
)
_save_to = os.path.join(
self.out_dir_eval,
f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
)
with open(_save_to, "w+") as f:
f.write(eval_text)
# Update main eval metric
if 0 == i:
main_eval_metric = val_metric_dic[self.main_val_metric]
if (
"minimize" == self.main_val_metric_goal
and main_eval_metric < self.best_metric
or "maximize" == self.main_val_metric_goal
and main_eval_metric > self.best_metric
):
self.best_metric = main_eval_metric
logging.info(
f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
)
# Save a checkpoint
self.save_checkpoint(
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
)
def visualize(self):
for val_loader in self.vis_loaders:
vis_dataset_name = val_loader.dataset.disp_name
vis_out_dir = os.path.join(
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
)
os.makedirs(vis_out_dir, exist_ok=True)
_ = self.validate_single_dataset(
data_loader=val_loader,
metric_tracker=self.val_metrics,
save_to_dir=vis_out_dir,
)
@torch.no_grad()
def validate_single_dataset(
self,
data_loader: DataLoader,
metric_tracker: MetricTracker,
save_to_dir: str = None,
):
self.model.to(self.device)
metric_tracker.reset()
# Generate seed sequence for consistent evaluation
val_init_seed = self.cfg.validation.init_seed
val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
for i, batch in enumerate(
tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
start=1,
):
assert 1 == data_loader.batch_size
# Read input image
rgb_int = batch["rgb_int"] # [3, H, W]
# GT depth
depth_raw_ts = batch["depth_raw_linear"].squeeze()
depth_raw = depth_raw_ts.numpy()
depth_raw_ts = depth_raw_ts.to(self.device)
valid_mask_ts = batch["valid_mask_raw"].squeeze()
valid_mask = valid_mask_ts.numpy()
valid_mask_ts = valid_mask_ts.to(self.device)
# Predict depth
pipe_out: DepthMasterDepthOutput = self.model(
rgb_int,
processing_res=self.cfg.validation.processing_res,
match_input_res=self.cfg.validation.match_input_res,
batch_size=1, # use batch size 1 to increase reproducibility
color_map=None,
show_progress_bar=False,
resample_method=self.cfg.validation.resample_method,
)
depth_pred: np.ndarray = pipe_out.depth_np.squeeze()
if "least_square" == self.cfg.eval.alignment:
depth_pred, scale, shift = align_depth_least_square(
gt_arr=depth_raw,
pred_arr=depth_pred,
valid_mask_arr=valid_mask,
return_scale_shift=True,
max_resolution=self.cfg.eval.align_max_res,
)
elif "least_square_disparity" == self.cfg.eval.alignment:
# gt_disparity = depth_raw
gt_disparity = depth2disparity(depth_raw)
gt_non_neg_mask = gt_disparity > 0
# LS alignment in disparity space
pred_non_neg_mask = depth_pred > 0
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
disparity_pred, scale, shift = align_depth_least_square(
gt_arr=gt_disparity,
pred_arr=depth_pred,
valid_mask_arr=valid_nonnegative_mask,
return_scale_shift=True,
)
# convert to depth
disparity_pred = np.clip(
disparity_pred, a_min=1e-3, a_max=None
) # avoid 0 disparity
depth_pred = disparity2depth(disparity_pred)
depth_raw_ts = disparity2depth(depth_raw_ts)
elif "least_square_sqrt_disp" == self.cfg.eval.alignment:
# gt_sqrt_disp = depth_raw
gt_sqrt_disp = np.sqrt(depth2disparity(depth_raw))
gt_non_neg_mask = gt_sqrt_disp > 0
# LS alignment in sqrt space
pred_non_neg_mask = depth_pred > 0
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
depth_sqrt_disp_pred, scale, shift = align_depth_least_square(
gt_arr=gt_sqrt_disp,
pred_arr=depth_pred,
valid_mask_arr=valid_mask,
return_scale_shift=True,
)
# convert to depth
disparity_pred = depth_sqrt_disp_pred ** 2
depth_raw_ts = torch.pow(depth_raw_ts, 2)
# convert to depth
disparity_pred = np.clip(
disparity_pred, a_min=1e-3, a_max=None
) # avoid 0 disparity
depth_pred = disparity2depth(disparity_pred)
depth_raw_ts = disparity2depth(depth_raw_ts)
else:
raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
# Clip to dataset min max
depth_pred = np.clip(
depth_pred,
a_min=data_loader.dataset.min_depth,
a_max=data_loader.dataset.max_depth,
)
# clip to d > 0 for evaluation
depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
# Evaluate
sample_metric = []
depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
for met_func in self.metric_funcs:
_metric_name = met_func.__name__
_metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
sample_metric.append(_metric.__str__())
metric_tracker.update(_metric_name, _metric)
# Save as 16-bit uint png
if save_to_dir is not None:
img_name = batch["rgb_relative_path"][0].replace("/", "_")
png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16)
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
return metric_tracker.result()
def _get_next_seed(self):
if 0 == len(self.global_seed_sequence):
self.global_seed_sequence = generate_seed_sequence(
initial_seed=self.seed,
length=self.max_iter * self.gradient_accumulation_steps,
)
logging.info(
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
)
return self.global_seed_sequence.pop()
def save_checkpoint(self, ckpt_name, save_train_state):
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
logging.info(f"Saving checkpoint to: {ckpt_dir}")
# Backup previous checkpoint
temp_ckpt_dir = None
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
temp_ckpt_dir = os.path.join(
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
)
if os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
os.rename(ckpt_dir, temp_ckpt_dir)
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
# Save UNet
unet_path = os.path.join(ckpt_dir, "unet")
self.model.unet.save_pretrained(unet_path, safe_serialization=False)
logging.info(f"UNet is saved to: {unet_path}")
if save_train_state:
state = {
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"config": self.cfg,
"effective_iter": self.effective_iter,
"epoch": self.epoch,
"n_batch_in_epoch": self.n_batch_in_epoch,
"best_metric": self.best_metric,
"in_evaluation": self.in_evaluation,
"global_seed_sequence": self.global_seed_sequence,
}
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
torch.save(state, train_state_path)
# iteration indicator
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
f.close()
logging.info(f"Trainer state is saved to: {train_state_path}")
# Remove temp ckpt
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
logging.debug("Old checkpoint backup is removed.")
def load_checkpoint(
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
):
logging.info(f"Loading checkpoint from: {ckpt_path}")
# Load UNet
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
self.model.unet.load_state_dict(
torch.load(_model_path, map_location=self.device)
)
self.model.unet.to(self.device)
logging.info(f"UNet parameters are loaded from {_model_path}")
# Load training states
if load_trainer_state:
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
self.effective_iter = checkpoint["effective_iter"]
self.epoch = checkpoint["epoch"]
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
self.in_evaluation = checkpoint["in_evaluation"]
self.global_seed_sequence = checkpoint["global_seed_sequence"]
self.best_metric = checkpoint["best_metric"]
self.optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(f"optimizer state is loaded from {ckpt_path}")
if resume_lr_scheduler:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
logging.info(f"LR scheduler state is loaded from {ckpt_path}")
logging.info(
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
)
return
def _get_backup_ckpt_name(self):
return f"iter_{self.effective_iter:06d}"