# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # More information about Marigold: # https://marigoldmonodepth.github.io # https://marigoldcomputervision.github.io # Efficient inference pipelines are now part of diffusers: # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage # https://huggingface.co/docs/diffusers/api/pipelines/marigold # Examples of trained models and live demos: # https://huggingface.co/prs-eth # Related projects: # https://rollingdepth.github.io/ # https://marigolddepthcompletion.github.io/ # Citation (BibTeX): # https://github.com/prs-eth/Marigold#-citation # If you find Marigold useful, we kindly ask you to cite our papers. # -------------------------------------------------------------------------- import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) import argparse import logging import os import shutil import torch from datetime import datetime, timedelta from omegaconf import OmegaConf from torch.utils.data import ConcatDataset, DataLoader from tqdm import tqdm from typing import List, Union from marigold import MarigoldNormalsPipeline from src.dataset import BaseNormalsDataset, DatasetMode, get_dataset from src.dataset.mixed_sampler import MixedBatchSampler from src.trainer import get_trainer_cls from src.util.config_util import ( find_value_in_omegaconf, recursive_load_config, ) from src.util.logging_util import ( config_logging, init_wandb, load_wandb_job_id, log_slurm_job_id, save_wandb_job_id, tb_logger, ) from src.util.slurm_util import get_local_scratch_dir, is_on_slurm if "__main__" == __name__: t_start = datetime.now() logging.info(f"Started at {t_start}") # -------------------- Arguments -------------------- parser = argparse.ArgumentParser( description="Marigold : Surface Normals Estimation : Training" ) parser.add_argument( "--config", type=str, default="config/train_marigold_normals.yaml", help="Path to config file.", ) parser.add_argument( "--resume_run", action="store", default=None, help="Path of checkpoint to be resumed. If given, will ignore --config, and checkpoint in the config.", ) parser.add_argument( "--output_dir", type=str, default=None, help="Directory to save checkpoints." ) parser.add_argument("--no_cuda", action="store_true", help="Do not use cuda.") parser.add_argument( "--exit_after", type=int, default=-1, help="Save checkpoint and exit after X minutes.", ) parser.add_argument( "--no_wandb", action="store_true", help="Run without Weights and Biases logging.", ) parser.add_argument( "--do_not_copy_data", action="store_true", help="On Slurm cluster, do not copy data to the local scratch.", ) parser.add_argument( "--base_data_dir", type=str, default=None, help="Base path to the datasets." ) parser.add_argument( "--base_ckpt_dir", type=str, default=None, help="Base path to the pretrained checkpoints.", ) parser.add_argument( "--add_datetime_prefix", action="store_true", help="Add datetime to the output folder name.", ) args = parser.parse_args() resume_run = args.resume_run output_dir = args.output_dir base_data_dir = ( args.base_data_dir if args.base_data_dir is not None else os.environ["BASE_DATA_DIR"] ) base_ckpt_dir = ( args.base_ckpt_dir if args.base_ckpt_dir is not None else os.environ["BASE_CKPT_DIR"] ) # -------------------- Initialization -------------------- # Resume previous run if resume_run is not None: logging.info(f"Resuming run: {resume_run}") out_dir_run = os.path.dirname(os.path.dirname(resume_run)) job_name = os.path.basename(out_dir_run) # Resume config file cfg = OmegaConf.load(os.path.join(out_dir_run, "config.yaml")) else: # Run from start cfg = recursive_load_config(args.config) # Full job name pure_job_name = os.path.basename(args.config).split(".")[0] # Add time prefix if args.add_datetime_prefix: job_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_job_name}" else: job_name = pure_job_name # Output dir if output_dir is not None: out_dir_run = os.path.join(output_dir, job_name) else: out_dir_run = os.path.join("./output", job_name) os.makedirs(out_dir_run, exist_ok=False) cfg_data = cfg.dataset # Other directories out_dir_ckpt = os.path.join(out_dir_run, "checkpoint") if not os.path.exists(out_dir_ckpt): os.makedirs(out_dir_ckpt) out_dir_tb = os.path.join(out_dir_run, "tensorboard") if not os.path.exists(out_dir_tb): os.makedirs(out_dir_tb) out_dir_eval = os.path.join(out_dir_run, "evaluation") if not os.path.exists(out_dir_eval): os.makedirs(out_dir_eval) out_dir_vis = os.path.join(out_dir_run, "visualization") if not os.path.exists(out_dir_vis): os.makedirs(out_dir_vis) # -------------------- Logging settings -------------------- config_logging(cfg.logging, out_dir=out_dir_run) logging.debug(f"config: {cfg}") # Initialize wandb if not args.no_wandb: if resume_run is not None: wandb_id = load_wandb_job_id(out_dir_run) wandb_cfg_dict = { "id": wandb_id, "resume": "must", **cfg.wandb, } else: wandb_cfg_dict = { "config": dict(cfg), "name": job_name, "mode": "online", **cfg.wandb, } wandb_cfg_dict.update({"dir": out_dir_run}) wandb_run = init_wandb(enable=True, **wandb_cfg_dict) save_wandb_job_id(wandb_run, out_dir_run) else: init_wandb(enable=False) # Tensorboard (should be initialized after wandb) tb_logger.set_dir(out_dir_tb) log_slurm_job_id(step=0) # -------------------- Device -------------------- cuda_avail = torch.cuda.is_available() and not args.no_cuda device = torch.device("cuda" if cuda_avail else "cpu") logging.info(f"device = {device}") # -------------------- Snapshot of code and config -------------------- if resume_run is None: _output_path = os.path.join(out_dir_run, "config.yaml") with open(_output_path, "w+") as f: OmegaConf.save(config=cfg, f=f) logging.info(f"Config saved to {_output_path}") # Copy and tar code on the first run _temp_code_dir = os.path.join(out_dir_run, "code_tar") _code_snapshot_path = os.path.join(out_dir_run, "code_snapshot.tar") os.system( f"rsync --relative -arhvz --quiet --filter=':- .gitignore' --exclude '.git' . '{_temp_code_dir}'" ) os.system(f"tar -cf {_code_snapshot_path} {_temp_code_dir}") os.system(f"rm -rf {_temp_code_dir}") logging.info(f"Code snapshot saved to: {_code_snapshot_path}") # -------------------- Copy data to local scratch (Slurm) -------------------- if is_on_slurm() and (not args.do_not_copy_data): # local scratch dir original_data_dir = base_data_dir base_data_dir = os.path.join(get_local_scratch_dir(), "Marigold_data") # copy data required_data_list = find_value_in_omegaconf("dir", cfg_data) # if cfg_train.visualize.init_latent_path is not None: # required_data_list.append(cfg_train.visualize.init_latent_path) required_data_list = list(set(required_data_list)) logging.info(f"Required_data_list: {required_data_list}") for d in tqdm(required_data_list, desc="Copy data to local scratch"): ori_dir = os.path.join(original_data_dir, d) dst_dir = os.path.join(base_data_dir, d) os.makedirs(os.path.dirname(dst_dir), exist_ok=True) if os.path.isfile(ori_dir): shutil.copyfile(ori_dir, dst_dir) elif os.path.isdir(ori_dir): shutil.copytree(ori_dir, dst_dir) logging.info(f"Data copied to: {base_data_dir}") # -------------------- Gradient accumulation steps -------------------- eff_bs = cfg.dataloader.effective_batch_size accumulation_steps = eff_bs / cfg.dataloader.max_train_batch_size assert int(accumulation_steps) == accumulation_steps accumulation_steps = int(accumulation_steps) logging.info( f"Effective batch size: {eff_bs}, accumulation steps: {accumulation_steps}" ) # -------------------- Data -------------------- loader_seed = cfg.dataloader.seed if loader_seed is None: loader_generator = None else: loader_generator = torch.Generator().manual_seed(loader_seed) # Training dataset train_dataset: Union[BaseNormalsDataset, List[BaseNormalsDataset]] = get_dataset( cfg_data.train, base_data_dir=base_data_dir, mode=DatasetMode.TRAIN, augmentation_args=cfg.augmentation, ) logging.debug("Augmentation: ", cfg.augmentation) if "mixed" == cfg_data.train.name: dataset_ls = train_dataset assert len(cfg_data.train.prob_ls) == len( dataset_ls ), "Lengths don't match: `prob_ls` and `dataset_list`" concat_dataset = ConcatDataset(dataset_ls) mixed_sampler = MixedBatchSampler( src_dataset_ls=dataset_ls, batch_size=cfg.dataloader.max_train_batch_size, drop_last=True, prob=cfg_data.train.prob_ls, shuffle=True, generator=loader_generator, ) train_loader = DataLoader( concat_dataset, batch_sampler=mixed_sampler, num_workers=cfg.dataloader.num_workers, ) else: train_loader = DataLoader( dataset=train_dataset, batch_size=cfg.dataloader.max_train_batch_size, num_workers=cfg.dataloader.num_workers, shuffle=True, generator=loader_generator, ) # Validation dataset val_loaders: List[DataLoader] = [] for _val_dict in cfg_data.val: _val_dataset = get_dataset( _val_dict, base_data_dir=base_data_dir, mode=DatasetMode.EVAL, ) _val_loader = DataLoader( dataset=_val_dataset, batch_size=1, shuffle=False, num_workers=cfg.dataloader.num_workers, ) val_loaders.append(_val_loader) # Visualization dataset vis_loaders: List[DataLoader] = [] for _vis_dict in cfg_data.vis: _vis_dataset = get_dataset( _vis_dict, base_data_dir=base_data_dir, mode=DatasetMode.EVAL, ) _vis_loader = DataLoader( dataset=_vis_dataset, batch_size=1, shuffle=False, num_workers=cfg.dataloader.num_workers, ) vis_loaders.append(_vis_loader) # -------------------- Model -------------------- _pipeline_kwargs = cfg.pipeline.kwargs if cfg.pipeline.kwargs is not None else {} model = MarigoldNormalsPipeline.from_pretrained( os.path.join(base_ckpt_dir, cfg.model.pretrained_path), **_pipeline_kwargs ) # -------------------- Trainer -------------------- # Exit time if args.exit_after > 0: t_end = t_start + timedelta(minutes=args.exit_after) logging.info(f"Will exit at {t_end}") else: t_end = None trainer_cls = get_trainer_cls(cfg.trainer.name) logging.debug(f"Trainer: {trainer_cls}") trainer = trainer_cls( cfg=cfg, model=model, train_dataloader=train_loader, device=device, out_dir_ckpt=out_dir_ckpt, out_dir_eval=out_dir_eval, out_dir_vis=out_dir_vis, accumulation_steps=accumulation_steps, val_dataloaders=val_loaders, vis_dataloaders=vis_loaders, ) # -------------------- Checkpoint -------------------- if resume_run is not None: trainer.load_checkpoint( resume_run, load_trainer_state=True, resume_lr_scheduler=True ) # -------------------- Training & Evaluation Loop -------------------- try: trainer.train(t_end=t_end) except Exception as e: logging.exception(e)