Diffusers
Safetensors
zeyuren2002's picture
Add files using upload-large-folder tool
ecd43ed verified
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://rollingdepth.github.io/
# https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
import argparse
import logging
import os
import shutil
import torch
from datetime import datetime, timedelta
from omegaconf import OmegaConf
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
from typing import List, Union
from marigold import MarigoldNormalsPipeline
from src.dataset import BaseNormalsDataset, DatasetMode, get_dataset
from src.dataset.mixed_sampler import MixedBatchSampler
from src.trainer import get_trainer_cls
from src.util.config_util import (
find_value_in_omegaconf,
recursive_load_config,
)
from src.util.logging_util import (
config_logging,
init_wandb,
load_wandb_job_id,
log_slurm_job_id,
save_wandb_job_id,
tb_logger,
)
from src.util.slurm_util import get_local_scratch_dir, is_on_slurm
if "__main__" == __name__:
t_start = datetime.now()
logging.info(f"Started at {t_start}")
# -------------------- Arguments --------------------
parser = argparse.ArgumentParser(
description="Marigold : Surface Normals Estimation : Training"
)
parser.add_argument(
"--config",
type=str,
default="config/train_marigold_normals.yaml",
help="Path to config file.",
)
parser.add_argument(
"--resume_run",
action="store",
default=None,
help="Path of checkpoint to be resumed. If given, will ignore --config, and checkpoint in the config.",
)
parser.add_argument(
"--output_dir", type=str, default=None, help="Directory to save checkpoints."
)
parser.add_argument("--no_cuda", action="store_true", help="Do not use cuda.")
parser.add_argument(
"--exit_after",
type=int,
default=-1,
help="Save checkpoint and exit after X minutes.",
)
parser.add_argument(
"--no_wandb",
action="store_true",
help="Run without Weights and Biases logging.",
)
parser.add_argument(
"--do_not_copy_data",
action="store_true",
help="On Slurm cluster, do not copy data to the local scratch.",
)
parser.add_argument(
"--base_data_dir", type=str, default=None, help="Base path to the datasets."
)
parser.add_argument(
"--base_ckpt_dir",
type=str,
default=None,
help="Base path to the pretrained checkpoints.",
)
parser.add_argument(
"--add_datetime_prefix",
action="store_true",
help="Add datetime to the output folder name.",
)
args = parser.parse_args()
resume_run = args.resume_run
output_dir = args.output_dir
base_data_dir = (
args.base_data_dir
if args.base_data_dir is not None
else os.environ["BASE_DATA_DIR"]
)
base_ckpt_dir = (
args.base_ckpt_dir
if args.base_ckpt_dir is not None
else os.environ["BASE_CKPT_DIR"]
)
# -------------------- Initialization --------------------
# Resume previous run
if resume_run is not None:
logging.info(f"Resuming run: {resume_run}")
out_dir_run = os.path.dirname(os.path.dirname(resume_run))
job_name = os.path.basename(out_dir_run)
# Resume config file
cfg = OmegaConf.load(os.path.join(out_dir_run, "config.yaml"))
else:
# Run from start
cfg = recursive_load_config(args.config)
# Full job name
pure_job_name = os.path.basename(args.config).split(".")[0]
# Add time prefix
if args.add_datetime_prefix:
job_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_job_name}"
else:
job_name = pure_job_name
# Output dir
if output_dir is not None:
out_dir_run = os.path.join(output_dir, job_name)
else:
out_dir_run = os.path.join("./output", job_name)
os.makedirs(out_dir_run, exist_ok=False)
cfg_data = cfg.dataset
# Other directories
out_dir_ckpt = os.path.join(out_dir_run, "checkpoint")
if not os.path.exists(out_dir_ckpt):
os.makedirs(out_dir_ckpt)
out_dir_tb = os.path.join(out_dir_run, "tensorboard")
if not os.path.exists(out_dir_tb):
os.makedirs(out_dir_tb)
out_dir_eval = os.path.join(out_dir_run, "evaluation")
if not os.path.exists(out_dir_eval):
os.makedirs(out_dir_eval)
out_dir_vis = os.path.join(out_dir_run, "visualization")
if not os.path.exists(out_dir_vis):
os.makedirs(out_dir_vis)
# -------------------- Logging settings --------------------
config_logging(cfg.logging, out_dir=out_dir_run)
logging.debug(f"config: {cfg}")
# Initialize wandb
if not args.no_wandb:
if resume_run is not None:
wandb_id = load_wandb_job_id(out_dir_run)
wandb_cfg_dict = {
"id": wandb_id,
"resume": "must",
**cfg.wandb,
}
else:
wandb_cfg_dict = {
"config": dict(cfg),
"name": job_name,
"mode": "online",
**cfg.wandb,
}
wandb_cfg_dict.update({"dir": out_dir_run})
wandb_run = init_wandb(enable=True, **wandb_cfg_dict)
save_wandb_job_id(wandb_run, out_dir_run)
else:
init_wandb(enable=False)
# Tensorboard (should be initialized after wandb)
tb_logger.set_dir(out_dir_tb)
log_slurm_job_id(step=0)
# -------------------- Device --------------------
cuda_avail = torch.cuda.is_available() and not args.no_cuda
device = torch.device("cuda" if cuda_avail else "cpu")
logging.info(f"device = {device}")
# -------------------- Snapshot of code and config --------------------
if resume_run is None:
_output_path = os.path.join(out_dir_run, "config.yaml")
with open(_output_path, "w+") as f:
OmegaConf.save(config=cfg, f=f)
logging.info(f"Config saved to {_output_path}")
# Copy and tar code on the first run
_temp_code_dir = os.path.join(out_dir_run, "code_tar")
_code_snapshot_path = os.path.join(out_dir_run, "code_snapshot.tar")
os.system(
f"rsync --relative -arhvz --quiet --filter=':- .gitignore' --exclude '.git' . '{_temp_code_dir}'"
)
os.system(f"tar -cf {_code_snapshot_path} {_temp_code_dir}")
os.system(f"rm -rf {_temp_code_dir}")
logging.info(f"Code snapshot saved to: {_code_snapshot_path}")
# -------------------- Copy data to local scratch (Slurm) --------------------
if is_on_slurm() and (not args.do_not_copy_data):
# local scratch dir
original_data_dir = base_data_dir
base_data_dir = os.path.join(get_local_scratch_dir(), "Marigold_data")
# copy data
required_data_list = find_value_in_omegaconf("dir", cfg_data)
# if cfg_train.visualize.init_latent_path is not None:
# required_data_list.append(cfg_train.visualize.init_latent_path)
required_data_list = list(set(required_data_list))
logging.info(f"Required_data_list: {required_data_list}")
for d in tqdm(required_data_list, desc="Copy data to local scratch"):
ori_dir = os.path.join(original_data_dir, d)
dst_dir = os.path.join(base_data_dir, d)
os.makedirs(os.path.dirname(dst_dir), exist_ok=True)
if os.path.isfile(ori_dir):
shutil.copyfile(ori_dir, dst_dir)
elif os.path.isdir(ori_dir):
shutil.copytree(ori_dir, dst_dir)
logging.info(f"Data copied to: {base_data_dir}")
# -------------------- Gradient accumulation steps --------------------
eff_bs = cfg.dataloader.effective_batch_size
accumulation_steps = eff_bs / cfg.dataloader.max_train_batch_size
assert int(accumulation_steps) == accumulation_steps
accumulation_steps = int(accumulation_steps)
logging.info(
f"Effective batch size: {eff_bs}, accumulation steps: {accumulation_steps}"
)
# -------------------- Data --------------------
loader_seed = cfg.dataloader.seed
if loader_seed is None:
loader_generator = None
else:
loader_generator = torch.Generator().manual_seed(loader_seed)
# Training dataset
train_dataset: Union[BaseNormalsDataset, List[BaseNormalsDataset]] = get_dataset(
cfg_data.train,
base_data_dir=base_data_dir,
mode=DatasetMode.TRAIN,
augmentation_args=cfg.augmentation,
)
logging.debug("Augmentation: ", cfg.augmentation)
if "mixed" == cfg_data.train.name:
dataset_ls = train_dataset
assert len(cfg_data.train.prob_ls) == len(
dataset_ls
), "Lengths don't match: `prob_ls` and `dataset_list`"
concat_dataset = ConcatDataset(dataset_ls)
mixed_sampler = MixedBatchSampler(
src_dataset_ls=dataset_ls,
batch_size=cfg.dataloader.max_train_batch_size,
drop_last=True,
prob=cfg_data.train.prob_ls,
shuffle=True,
generator=loader_generator,
)
train_loader = DataLoader(
concat_dataset,
batch_sampler=mixed_sampler,
num_workers=cfg.dataloader.num_workers,
)
else:
train_loader = DataLoader(
dataset=train_dataset,
batch_size=cfg.dataloader.max_train_batch_size,
num_workers=cfg.dataloader.num_workers,
shuffle=True,
generator=loader_generator,
)
# Validation dataset
val_loaders: List[DataLoader] = []
for _val_dict in cfg_data.val:
_val_dataset = get_dataset(
_val_dict,
base_data_dir=base_data_dir,
mode=DatasetMode.EVAL,
)
_val_loader = DataLoader(
dataset=_val_dataset,
batch_size=1,
shuffle=False,
num_workers=cfg.dataloader.num_workers,
)
val_loaders.append(_val_loader)
# Visualization dataset
vis_loaders: List[DataLoader] = []
for _vis_dict in cfg_data.vis:
_vis_dataset = get_dataset(
_vis_dict,
base_data_dir=base_data_dir,
mode=DatasetMode.EVAL,
)
_vis_loader = DataLoader(
dataset=_vis_dataset,
batch_size=1,
shuffle=False,
num_workers=cfg.dataloader.num_workers,
)
vis_loaders.append(_vis_loader)
# -------------------- Model --------------------
_pipeline_kwargs = cfg.pipeline.kwargs if cfg.pipeline.kwargs is not None else {}
model = MarigoldNormalsPipeline.from_pretrained(
os.path.join(base_ckpt_dir, cfg.model.pretrained_path), **_pipeline_kwargs
)
# -------------------- Trainer --------------------
# Exit time
if args.exit_after > 0:
t_end = t_start + timedelta(minutes=args.exit_after)
logging.info(f"Will exit at {t_end}")
else:
t_end = None
trainer_cls = get_trainer_cls(cfg.trainer.name)
logging.debug(f"Trainer: {trainer_cls}")
trainer = trainer_cls(
cfg=cfg,
model=model,
train_dataloader=train_loader,
device=device,
out_dir_ckpt=out_dir_ckpt,
out_dir_eval=out_dir_eval,
out_dir_vis=out_dir_vis,
accumulation_steps=accumulation_steps,
val_dataloaders=val_loaders,
vis_dataloaders=vis_loaders,
)
# -------------------- Checkpoint --------------------
if resume_run is not None:
trainer.load_checkpoint(
resume_run, load_trainer_state=True, resume_lr_scheduler=True
)
# -------------------- Training & Evaluation Loop --------------------
try:
trainer.train(t_end=t_end)
except Exception as e:
logging.exception(e)