Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- DepthMaster/ckpt/eval/.gitattributes +35 -0
- DepthMaster/ckpt/eval/README.md +97 -0
- DepthMaster/ckpt/eval/model_index.json +28 -0
- DepthMaster/ckpt/eval/text_encoder/config.json +25 -0
- DepthMaster/ckpt/eval/tokenizer/merges.txt +0 -0
- DepthMaster/ckpt/eval/tokenizer/special_tokens_map.json +24 -0
- DepthMaster/ckpt/eval/tokenizer/tokenizer_config.json +34 -0
- DepthMaster/ckpt/eval/tokenizer/vocab.json +0 -0
- DepthMaster/ckpt/eval/unet/config.json +73 -0
- DepthMaster/ckpt/eval/vae/config.json +30 -0
- DepthMaster/data_split/kitti/eigen_train_files_with_gt.txt +0 -0
- DepthMaster/depthmaster/modules/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
- DepthMaster/depthmaster/modules/__pycache__/unet_2d_condition_s2.cpython-310.pyc +0 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/attention.py +83 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/block.py +252 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/drop_path.py +35 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/layer_scale.py +28 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/mlp.py +41 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/patch_embed.py +89 -0
- DepthMaster/external_encoder/dinov2/dinov2_layers/swiglu_ffn.py +63 -0
- DepthMaster/external_encoder/dinov2/util/transform.py +160 -0
- DepthMaster/in_the_wild_example/input/06.jpg +0 -0
- DepthMaster/scripts/eval_diode.sh +13 -0
- DepthMaster/scripts/eval_eth3d.sh +13 -0
- DepthMaster/scripts/eval_hypersim.sh +13 -0
- DepthMaster/scripts/eval_kitti.sh +13 -0
- DepthMaster/scripts/eval_nyu.sh +13 -0
- DepthMaster/scripts/eval_scannet.sh +13 -0
- DepthMaster/scripts/infer.sh +10 -0
- DepthMaster/scripts/train_s1.sh +9 -0
- DepthMaster/scripts/train_s2.sh +9 -0
- DepthMaster/src/dataset/__init__.py +71 -0
- DepthMaster/src/dataset/base_depth_dataset.py +303 -0
- DepthMaster/src/dataset/diode_dataset.py +94 -0
- DepthMaster/src/dataset/eth3d_dataset.py +68 -0
- DepthMaster/src/dataset/hypersim_dataset.py +48 -0
- DepthMaster/src/dataset/kitti_dataset.py +127 -0
- DepthMaster/src/dataset/mixed_sampler.py +151 -0
- DepthMaster/src/dataset/nyu_dataset.py +64 -0
- DepthMaster/src/dataset/scannet_dataset.py +47 -0
- DepthMaster/src/dataset/vkitti_dataset.py +100 -0
- DepthMaster/src/trainer/__init__.py +15 -0
- DepthMaster/src/trainer/trainer_s1.py +671 -0
- DepthMaster/src/trainer/trainer_s2.py +630 -0
- DepthMaster/src/util/alignment.py +180 -0
- DepthMaster/src/util/boundary_metrics.py +332 -0
- DepthMaster/src/util/build_mlp.py +10 -0
- DepthMaster/src/util/config_util.py +70 -0
- DepthMaster/src/util/data_loader.py +111 -0
- DepthMaster/src/util/depth_transform.py +124 -0
DepthMaster/ckpt/eval/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
DepthMaster/ckpt/eval/README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
base_model:
|
| 6 |
+
- stabilityai/stable-diffusion-2
|
| 7 |
+
pipeline_tag: depth-estimation
|
| 8 |
+
---
|
| 9 |
+
<!-- # DepthMaster: Taming Diffusion Models for Monocular Depth Estimation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
This repository represents the official implementation of the paper titled "DepthMaster: Taming Diffusion Models for Monocular Depth Estimation". -->
|
| 13 |
+
|
| 14 |
+
<!-- [](https://marigoldmonodepth.github.io)
|
| 15 |
+
[](https://arxiv.org/abs/2312.02145) -->
|
| 16 |
+
|
| 17 |
+
<!-- [](https://www.apache.org/licenses/LICENSE-2.0) -->
|
| 18 |
+
|
| 19 |
+
<h1 align="center"><strong>DepthMaster: Taming Diffusion Models for Monocular Depth Estimation</strong></h1>
|
| 20 |
+
<p align="center">
|
| 21 |
+
<a href="https://indu1ge.github.io/ziyangsong">Ziyang Song*</a>,
|
| 22 |
+
<a href="https://orcid.org/0009-0001-6677-0572">Zerong Wang*</a>,
|
| 23 |
+
<a href="https://orcid.org/0000-0001-7817-0665">Bo Li</a>,
|
| 24 |
+
<a href="https://orcid.org/0009-0007-1175-5918">Hao Zhang</a>,
|
| 25 |
+
<a href="https://ruijiezhu94.github.io/ruijiezhu/">Ruijie Zhu</a>,
|
| 26 |
+
<a href="https://orcid.org/0009-0004-3280-8490">Li Liu</a>,
|
| 27 |
+
<a href="https://pengtaojiang.github.io/">Peng-Tao Jiang†</a>,
|
| 28 |
+
<a href="http://staff.ustc.edu.cn/~tzzhang/">Tianzhu Zhang†</a>,
|
| 29 |
+
<br>
|
| 30 |
+
*Equal Contribution, †Corresponding Author
|
| 31 |
+
<br>
|
| 32 |
+
University of Science and Technology of China, vivo Mobile Communication Co., Ltd.
|
| 33 |
+
<br>
|
| 34 |
+
<b>Arxiv 2025</b>
|
| 35 |
+
</p>
|
| 36 |
+
<!-- [Ziyang Song*](https://indu1ge.github.io/ziyangsong),
|
| 37 |
+
[Zerong Wang*](),
|
| 38 |
+
[Bo Li](https://orcid.org/0000-0001-7817-0665),
|
| 39 |
+
[Hao Zhang](https://orcid.org/0009-0007-1175-5918),
|
| 40 |
+
[Ruijie Zhu](https://ruijiezhu94.github.io/ruijiezhu/),
|
| 41 |
+
[Li Liu](https://orcid.org/0009-0004-3280-8490)
|
| 42 |
+
[Tianzhu Zhang](http://staff.ustc.edu.cn/~tzzhang/)
|
| 43 |
+
[Peng-Tao Jiang](https://pengtaojiang.github.io/) -->
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
<div align="center">
|
| 48 |
+
<a href='https://arxiv.org/abs/2501.02576'>
|
| 49 |
+
<img src='https://img.shields.io/badge/Paper-arXiv-red'>
|
| 50 |
+
</a>
|
| 51 |
+
<a href='https://indu1ge.github.io/DepthMaster_page/'>
|
| 52 |
+
<img src='https://img.shields.io/badge/Project-Page-Green'>
|
| 53 |
+
</a>
|
| 54 |
+
<a href='https://github.com/indu1ge/DepthMaster'>
|
| 55 |
+
<img src='https://img.shields.io/badge/GitHub-Repository-blue?logo=github'>
|
| 56 |
+
</a>
|
| 57 |
+
<a href='https://www.apache.org/licenses/LICENSE-2.0'>
|
| 58 |
+
<img src='https://img.shields.io/badge/License-Apache--2.0-929292'>
|
| 59 |
+
</a>
|
| 60 |
+
</div>
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
<!-- We present Marigold, a diffusion model, and associated fine-tuning protocol for monocular depth estimation. Its core principle is to leverage the rich visual knowledge stored in modern generative image models. Our model, derived from Stable Diffusion and fine-tuned with synthetic data, can zero-shot transfer to unseen data, offering state-of-the-art monocular depth estimation results. -->
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+

|
| 68 |
+
|
| 69 |
+
<!-- >We present DepthMaster, a tamed single-step diffusion model designed to enhance the generalization and detail preservation abilities of depth estimation models. Through feature alignment, we effectively prevent the overfitting to texture details. By adaptively enhance -->
|
| 70 |
+
>We present DepthMaster, a tamed single-step diffusion model that customizes generative features in diffusion models to suit the discriminative depth estimation task. We introduce a Feature Alignment module to mitigate overfitting to texture and a Fourier Enhancement module to refine fine-grained details. DepthMaster exhibits state-of-the-art zero-shot performance and superior detail preservation ability, surpassing
|
| 71 |
+
other diffusion-based methods across various datasets.
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
## 🎓 Citation
|
| 75 |
+
|
| 76 |
+
Please cite our paper:
|
| 77 |
+
|
| 78 |
+
```bibtex
|
| 79 |
+
@article{song2025depthmaster,
|
| 80 |
+
title={DepthMaster: Taming Diffusion Models for Monocular Depth Estimation},
|
| 81 |
+
author={Song, Ziyang and Wang, Zerong and Li, Bo and Zhang, Hao and Zhu, Ruijie and Liu, Li and Jiang, Peng-Tao and Zhang, Tianzhu},
|
| 82 |
+
journal={arXiv preprint arXiv:2501.02576},
|
| 83 |
+
year={2025}
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## Acknowledgements
|
| 88 |
+
|
| 89 |
+
The code is based on [Marigold](https://github.com/prs-eth/Marigold).
|
| 90 |
+
|
| 91 |
+
## 🎫 License
|
| 92 |
+
|
| 93 |
+
This work is licensed under the Apache License, Version 2.0 (as defined in the [LICENSE](LICENSE.txt)).
|
| 94 |
+
|
| 95 |
+
By downloading and using the code and model you agree to the terms in the [LICENSsE](LICENSE.txt).
|
| 96 |
+
|
| 97 |
+
[](https://www.apache.org/licenses/LICENSE-2.0)
|
DepthMaster/ckpt/eval/model_index.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name":"MarigoldPipeline",
|
| 3 |
+
"_diffusers_version":"0.24.0",
|
| 4 |
+
"scale_invariant": true,
|
| 5 |
+
"shift_invariant": true,
|
| 6 |
+
"default_denoising_steps": 10,
|
| 7 |
+
"default_processing_resolution": 768,
|
| 8 |
+
"unet":[
|
| 9 |
+
"diffusers",
|
| 10 |
+
"UNet2DConditionModel"
|
| 11 |
+
],
|
| 12 |
+
"vae":[
|
| 13 |
+
"diffusers",
|
| 14 |
+
"AutoencoderKL"
|
| 15 |
+
],
|
| 16 |
+
"scheduler":[
|
| 17 |
+
"diffusers",
|
| 18 |
+
"DDIMScheduler"
|
| 19 |
+
],
|
| 20 |
+
"text_encoder":[
|
| 21 |
+
"transformers",
|
| 22 |
+
"CLIPTextModel"
|
| 23 |
+
],
|
| 24 |
+
"tokenizer":[
|
| 25 |
+
"transformers",
|
| 26 |
+
"CLIPTokenizer"
|
| 27 |
+
]
|
| 28 |
+
}
|
DepthMaster/ckpt/eval/text_encoder/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "hf-models/stable-diffusion-v2-768x768/text_encoder",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPTextModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_factor": 1.0,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 4096,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 77,
|
| 17 |
+
"model_type": "clip_text_model",
|
| 18 |
+
"num_attention_heads": 16,
|
| 19 |
+
"num_hidden_layers": 23,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"projection_dim": 512,
|
| 22 |
+
"torch_dtype": "float32",
|
| 23 |
+
"transformers_version": "4.25.0.dev0",
|
| 24 |
+
"vocab_size": 49408
|
| 25 |
+
}
|
DepthMaster/ckpt/eval/tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
DepthMaster/ckpt/eval/tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|startoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "!",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<|endoftext|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": true,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
DepthMaster/ckpt/eval/tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"__type": "AddedToken",
|
| 5 |
+
"content": "<|startoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false
|
| 10 |
+
},
|
| 11 |
+
"do_lower_case": true,
|
| 12 |
+
"eos_token": {
|
| 13 |
+
"__type": "AddedToken",
|
| 14 |
+
"content": "<|endoftext|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": true,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false
|
| 19 |
+
},
|
| 20 |
+
"errors": "replace",
|
| 21 |
+
"model_max_length": 77,
|
| 22 |
+
"name_or_path": "hf-models/stable-diffusion-v2-768x768/tokenizer",
|
| 23 |
+
"pad_token": "<|endoftext|>",
|
| 24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
| 25 |
+
"tokenizer_class": "CLIPTokenizer",
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"__type": "AddedToken",
|
| 28 |
+
"content": "<|endoftext|>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false
|
| 33 |
+
}
|
| 34 |
+
}
|
DepthMaster/ckpt/eval/tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
DepthMaster/ckpt/eval/unet/config.json
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "UNet2DConditionModel",
|
| 3 |
+
"_diffusers_version": "0.31.0",
|
| 4 |
+
"_name_or_path": "/data/vjuicefs_ai_camera_jgroup/11169299/Marigold_rgb2d/log/depth_preprocess/rgb2disp_bs4_sqrt_disp_cos1e-3_0.85/checkpoint/iter_014000/unet",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"addition_embed_type": null,
|
| 7 |
+
"addition_embed_type_num_heads": 64,
|
| 8 |
+
"addition_time_embed_dim": null,
|
| 9 |
+
"attention_head_dim": [
|
| 10 |
+
5,
|
| 11 |
+
10,
|
| 12 |
+
20,
|
| 13 |
+
20
|
| 14 |
+
],
|
| 15 |
+
"attention_type": "default",
|
| 16 |
+
"block_out_channels": [
|
| 17 |
+
320,
|
| 18 |
+
640,
|
| 19 |
+
1280,
|
| 20 |
+
1280
|
| 21 |
+
],
|
| 22 |
+
"center_input_sample": false,
|
| 23 |
+
"class_embed_type": null,
|
| 24 |
+
"class_embeddings_concat": false,
|
| 25 |
+
"conv_in_kernel": 3,
|
| 26 |
+
"conv_out_kernel": 3,
|
| 27 |
+
"cross_attention_dim": 1024,
|
| 28 |
+
"cross_attention_norm": null,
|
| 29 |
+
"down_block_types": [
|
| 30 |
+
"CrossAttnDownBlock2D",
|
| 31 |
+
"CrossAttnDownBlock2D",
|
| 32 |
+
"CrossAttnDownBlock2D",
|
| 33 |
+
"DownBlock2D"
|
| 34 |
+
],
|
| 35 |
+
"downsample_padding": 1,
|
| 36 |
+
"dropout": 0.0,
|
| 37 |
+
"dual_cross_attention": false,
|
| 38 |
+
"encoder_hid_dim": null,
|
| 39 |
+
"encoder_hid_dim_type": null,
|
| 40 |
+
"flip_sin_to_cos": true,
|
| 41 |
+
"freq_shift": 0,
|
| 42 |
+
"in_channels": 4,
|
| 43 |
+
"layers_per_block": 2,
|
| 44 |
+
"mid_block_only_cross_attention": null,
|
| 45 |
+
"mid_block_scale_factor": 1,
|
| 46 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 47 |
+
"norm_eps": 1e-05,
|
| 48 |
+
"norm_num_groups": 32,
|
| 49 |
+
"num_attention_heads": null,
|
| 50 |
+
"num_class_embeds": null,
|
| 51 |
+
"only_cross_attention": false,
|
| 52 |
+
"out_channels": 4,
|
| 53 |
+
"projection_class_embeddings_input_dim": null,
|
| 54 |
+
"resnet_out_scale_factor": 1.0,
|
| 55 |
+
"resnet_skip_time_act": false,
|
| 56 |
+
"resnet_time_scale_shift": "default",
|
| 57 |
+
"reverse_transformer_layers_per_block": null,
|
| 58 |
+
"sample_size": 96,
|
| 59 |
+
"time_cond_proj_dim": null,
|
| 60 |
+
"time_embedding_act_fn": null,
|
| 61 |
+
"time_embedding_dim": null,
|
| 62 |
+
"time_embedding_type": "positional",
|
| 63 |
+
"timestep_post_act": null,
|
| 64 |
+
"transformer_layers_per_block": 1,
|
| 65 |
+
"up_block_types": [
|
| 66 |
+
"UpBlock2D",
|
| 67 |
+
"CrossAttnUpBlock2D",
|
| 68 |
+
"CrossAttnUpBlock2D",
|
| 69 |
+
"CrossAttnUpBlock2D"
|
| 70 |
+
],
|
| 71 |
+
"upcast_attention": false,
|
| 72 |
+
"use_linear_projection": true
|
| 73 |
+
}
|
DepthMaster/ckpt/eval/vae/config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.8.0",
|
| 4 |
+
"_name_or_path": "hf-models/stable-diffusion-v2-768x768/vae",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 4,
|
| 20 |
+
"layers_per_block": 2,
|
| 21 |
+
"norm_num_groups": 32,
|
| 22 |
+
"out_channels": 3,
|
| 23 |
+
"sample_size": 768,
|
| 24 |
+
"up_block_types": [
|
| 25 |
+
"UpDecoderBlock2D",
|
| 26 |
+
"UpDecoderBlock2D",
|
| 27 |
+
"UpDecoderBlock2D",
|
| 28 |
+
"UpDecoderBlock2D"
|
| 29 |
+
]
|
| 30 |
+
}
|
DepthMaster/data_split/kitti/eigen_train_files_with_gt.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
DepthMaster/depthmaster/modules/__pycache__/unet_2d_blocks.cpython-310.pyc
ADDED
|
Binary file (67.3 kB). View file
|
|
|
DepthMaster/depthmaster/modules/__pycache__/unet_2d_condition_s2.cpython-310.pyc
ADDED
|
Binary file (40.9 kB). View file
|
|
|
DepthMaster/external_encoder/dinov2/dinov2_layers/attention.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
logger.warning("xFormers not available")
|
| 26 |
+
XFORMERS_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int,
|
| 33 |
+
num_heads: int = 8,
|
| 34 |
+
qkv_bias: bool = False,
|
| 35 |
+
proj_bias: bool = True,
|
| 36 |
+
attn_drop: float = 0.0,
|
| 37 |
+
proj_drop: float = 0.0,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
head_dim = dim // num_heads
|
| 42 |
+
self.scale = head_dim**-0.5
|
| 43 |
+
|
| 44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 50 |
+
B, N, C = x.shape
|
| 51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 52 |
+
|
| 53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
|
| 56 |
+
attn = attn.softmax(dim=-1)
|
| 57 |
+
attn = self.attn_drop(attn)
|
| 58 |
+
|
| 59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MemEffAttention(Attention):
|
| 66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 67 |
+
if not XFORMERS_AVAILABLE:
|
| 68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 69 |
+
return super().forward(x)
|
| 70 |
+
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 73 |
+
|
| 74 |
+
q, k, v = unbind(qkv, 2)
|
| 75 |
+
|
| 76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 77 |
+
x = x.reshape([B, N, C])
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
DepthMaster/external_encoder/dinov2/dinov2_layers/block.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
|
| 17 |
+
from .attention import Attention, MemEffAttention
|
| 18 |
+
from .drop_path import DropPath
|
| 19 |
+
from .layer_scale import LayerScale
|
| 20 |
+
from .mlp import Mlp
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from xformers.ops import fmha
|
| 28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
| 29 |
+
|
| 30 |
+
XFORMERS_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("xFormers not available")
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Block(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int,
|
| 41 |
+
mlp_ratio: float = 4.0,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
ffn_bias: bool = True,
|
| 45 |
+
drop: float = 0.0,
|
| 46 |
+
attn_drop: float = 0.0,
|
| 47 |
+
init_values=None,
|
| 48 |
+
drop_path: float = 0.0,
|
| 49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 53 |
+
) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 56 |
+
self.norm1 = norm_layer(dim)
|
| 57 |
+
self.attn = attn_class(
|
| 58 |
+
dim,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
qkv_bias=qkv_bias,
|
| 61 |
+
proj_bias=proj_bias,
|
| 62 |
+
attn_drop=attn_drop,
|
| 63 |
+
proj_drop=drop,
|
| 64 |
+
)
|
| 65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 67 |
+
|
| 68 |
+
self.norm2 = norm_layer(dim)
|
| 69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 70 |
+
self.mlp = ffn_layer(
|
| 71 |
+
in_features=dim,
|
| 72 |
+
hidden_features=mlp_hidden_dim,
|
| 73 |
+
act_layer=act_layer,
|
| 74 |
+
drop=drop,
|
| 75 |
+
bias=ffn_bias,
|
| 76 |
+
)
|
| 77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 79 |
+
|
| 80 |
+
self.sample_drop_ratio = drop_path
|
| 81 |
+
|
| 82 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 85 |
+
|
| 86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 88 |
+
|
| 89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 91 |
+
x = drop_add_residual_stochastic_depth(
|
| 92 |
+
x,
|
| 93 |
+
residual_func=attn_residual_func,
|
| 94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 95 |
+
)
|
| 96 |
+
x = drop_add_residual_stochastic_depth(
|
| 97 |
+
x,
|
| 98 |
+
residual_func=ffn_residual_func,
|
| 99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 100 |
+
)
|
| 101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 104 |
+
else:
|
| 105 |
+
x = x + attn_residual_func(x)
|
| 106 |
+
x = x + ffn_residual_func(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def drop_add_residual_stochastic_depth(
|
| 111 |
+
x: Tensor,
|
| 112 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 113 |
+
sample_drop_ratio: float = 0.0,
|
| 114 |
+
) -> Tensor:
|
| 115 |
+
# 1) extract subset using permutation
|
| 116 |
+
b, n, d = x.shape
|
| 117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 119 |
+
x_subset = x[brange]
|
| 120 |
+
|
| 121 |
+
# 2) apply residual_func to get residual
|
| 122 |
+
residual = residual_func(x_subset)
|
| 123 |
+
|
| 124 |
+
x_flat = x.flatten(1)
|
| 125 |
+
residual = residual.flatten(1)
|
| 126 |
+
|
| 127 |
+
residual_scale_factor = b / sample_subset_size
|
| 128 |
+
|
| 129 |
+
# 3) add the residual
|
| 130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 131 |
+
return x_plus_residual.view_as(x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 135 |
+
b, n, d = x.shape
|
| 136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 138 |
+
residual_scale_factor = b / sample_subset_size
|
| 139 |
+
return brange, residual_scale_factor
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 143 |
+
if scaling_vector is None:
|
| 144 |
+
x_flat = x.flatten(1)
|
| 145 |
+
residual = residual.flatten(1)
|
| 146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 147 |
+
else:
|
| 148 |
+
x_plus_residual = scaled_index_add(
|
| 149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 150 |
+
)
|
| 151 |
+
return x_plus_residual
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 158 |
+
"""
|
| 159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 160 |
+
"""
|
| 161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 163 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 164 |
+
seqlens = []
|
| 165 |
+
for b, x in zip(batch_sizes, x_list):
|
| 166 |
+
for _ in range(b):
|
| 167 |
+
seqlens.append(x.shape[1])
|
| 168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 169 |
+
attn_bias._batch_sizes = batch_sizes
|
| 170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 171 |
+
|
| 172 |
+
if branges is not None:
|
| 173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 174 |
+
else:
|
| 175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 177 |
+
|
| 178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def drop_add_residual_stochastic_depth_list(
|
| 182 |
+
x_list: List[Tensor],
|
| 183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 184 |
+
sample_drop_ratio: float = 0.0,
|
| 185 |
+
scaling_vector=None,
|
| 186 |
+
) -> Tensor:
|
| 187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 189 |
+
branges = [s[0] for s in branges_scales]
|
| 190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 191 |
+
|
| 192 |
+
# 2) get attention bias and index+concat the tensors
|
| 193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 194 |
+
|
| 195 |
+
# 3) apply residual_func to get residual, and split the result
|
| 196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 197 |
+
|
| 198 |
+
outputs = []
|
| 199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NestedTensorBlock(Block):
|
| 205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 206 |
+
"""
|
| 207 |
+
x_list contains a list of tensors to nest together and run
|
| 208 |
+
"""
|
| 209 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 210 |
+
|
| 211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 212 |
+
|
| 213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 215 |
+
|
| 216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 217 |
+
return self.mlp(self.norm2(x))
|
| 218 |
+
|
| 219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 220 |
+
x_list,
|
| 221 |
+
residual_func=attn_residual_func,
|
| 222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 224 |
+
)
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=ffn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
return x_list
|
| 232 |
+
else:
|
| 233 |
+
|
| 234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 236 |
+
|
| 237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 239 |
+
|
| 240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 242 |
+
x = x + ffn_residual_func(x)
|
| 243 |
+
return attn_bias.split(x)
|
| 244 |
+
|
| 245 |
+
def forward(self, x_or_x_list):
|
| 246 |
+
if isinstance(x_or_x_list, Tensor):
|
| 247 |
+
return super().forward(x_or_x_list)
|
| 248 |
+
elif isinstance(x_or_x_list, list):
|
| 249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
| 250 |
+
return self.forward_nested(x_or_x_list)
|
| 251 |
+
else:
|
| 252 |
+
raise AssertionError
|
DepthMaster/external_encoder/dinov2/dinov2_layers/drop_path.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 16 |
+
if drop_prob == 0.0 or not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 21 |
+
if keep_prob > 0.0:
|
| 22 |
+
random_tensor.div_(keep_prob)
|
| 23 |
+
output = x * random_tensor
|
| 24 |
+
return output
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DropPath(nn.Module):
|
| 28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, drop_prob=None):
|
| 31 |
+
super(DropPath, self).__init__()
|
| 32 |
+
self.drop_prob = drop_prob
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return drop_path(x, self.drop_prob, self.training)
|
DepthMaster/external_encoder/dinov2/dinov2_layers/layer_scale.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 8 |
+
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LayerScale(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 21 |
+
inplace: bool = False,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.inplace = inplace
|
| 25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
DepthMaster/external_encoder/dinov2/dinov2_layers/mlp.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_features: int,
|
| 21 |
+
hidden_features: Optional[int] = None,
|
| 22 |
+
out_features: Optional[int] = None,
|
| 23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 24 |
+
drop: float = 0.0,
|
| 25 |
+
bias: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
DepthMaster/external_encoder/dinov2/dinov2_layers/patch_embed.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PatchEmbed(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size: Image size.
|
| 32 |
+
patch_size: Patch token size.
|
| 33 |
+
in_chans: Number of input image channels.
|
| 34 |
+
embed_dim: Number of linear projection output channels.
|
| 35 |
+
norm_layer: Normalization layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 42 |
+
in_chans: int = 3,
|
| 43 |
+
embed_dim: int = 768,
|
| 44 |
+
norm_layer: Optional[Callable] = None,
|
| 45 |
+
flatten_embedding: bool = True,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
image_HW = make_2tuple(img_size)
|
| 50 |
+
patch_HW = make_2tuple(patch_size)
|
| 51 |
+
patch_grid_size = (
|
| 52 |
+
image_HW[0] // patch_HW[0],
|
| 53 |
+
image_HW[1] // patch_HW[1],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.img_size = image_HW
|
| 57 |
+
self.patch_size = patch_HW
|
| 58 |
+
self.patches_resolution = patch_grid_size
|
| 59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 60 |
+
|
| 61 |
+
self.in_chans = in_chans
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
|
| 64 |
+
self.flatten_embedding = flatten_embedding
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 70 |
+
_, _, H, W = x.shape
|
| 71 |
+
patch_H, patch_W = self.patch_size
|
| 72 |
+
|
| 73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 75 |
+
|
| 76 |
+
x = self.proj(x) # B C H W
|
| 77 |
+
H, W = x.size(2), x.size(3)
|
| 78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 79 |
+
x = self.norm(x)
|
| 80 |
+
if not self.flatten_embedding:
|
| 81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def flops(self) -> float:
|
| 85 |
+
Ho, Wo = self.patches_resolution
|
| 86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 87 |
+
if self.norm is not None:
|
| 88 |
+
flops += Ho * Wo * self.embed_dim
|
| 89 |
+
return flops
|
DepthMaster/external_encoder/dinov2/dinov2_layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SwiGLUFFN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
hidden_features: Optional[int] = None,
|
| 18 |
+
out_features: Optional[int] = None,
|
| 19 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
+
drop: float = 0.0,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x12 = self.w12(x)
|
| 31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
+
hidden = F.silu(x1) * x2
|
| 33 |
+
return self.w3(hidden)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from xformers.ops import SwiGLU
|
| 38 |
+
|
| 39 |
+
XFORMERS_AVAILABLE = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
SwiGLU = SwiGLUFFN
|
| 42 |
+
XFORMERS_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
in_features: int,
|
| 49 |
+
hidden_features: Optional[int] = None,
|
| 50 |
+
out_features: Optional[int] = None,
|
| 51 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
bias: bool = True,
|
| 54 |
+
) -> None:
|
| 55 |
+
out_features = out_features or in_features
|
| 56 |
+
hidden_features = hidden_features or in_features
|
| 57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 58 |
+
super().__init__(
|
| 59 |
+
in_features=in_features,
|
| 60 |
+
hidden_features=hidden_features,
|
| 61 |
+
out_features=out_features,
|
| 62 |
+
bias=bias,
|
| 63 |
+
)
|
DepthMaster/external_encoder/dinov2/util/transform.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Resize(object):
|
| 7 |
+
"""Resize sample to given size (width, height).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
width,
|
| 13 |
+
height,
|
| 14 |
+
resize_target=True,
|
| 15 |
+
keep_aspect_ratio=False,
|
| 16 |
+
ensure_multiple_of=1,
|
| 17 |
+
resize_method="lower_bound",
|
| 18 |
+
image_interpolation_method=cv2.INTER_AREA,
|
| 19 |
+
):
|
| 20 |
+
"""Init.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
width (int): desired output width
|
| 24 |
+
height (int): desired output height
|
| 25 |
+
resize_target (bool, optional):
|
| 26 |
+
True: Resize the full sample (image, mask, target).
|
| 27 |
+
False: Resize image only.
|
| 28 |
+
Defaults to True.
|
| 29 |
+
keep_aspect_ratio (bool, optional):
|
| 30 |
+
True: Keep the aspect ratio of the input sample.
|
| 31 |
+
Output sample might not have the given width and height, and
|
| 32 |
+
resize behaviour depends on the parameter 'resize_method'.
|
| 33 |
+
Defaults to False.
|
| 34 |
+
ensure_multiple_of (int, optional):
|
| 35 |
+
Output width and height is constrained to be multiple of this parameter.
|
| 36 |
+
Defaults to 1.
|
| 37 |
+
resize_method (str, optional):
|
| 38 |
+
"lower_bound": Output will be at least as large as the given size.
|
| 39 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
| 40 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
| 41 |
+
Defaults to "lower_bound".
|
| 42 |
+
"""
|
| 43 |
+
self.__width = width
|
| 44 |
+
self.__height = height
|
| 45 |
+
|
| 46 |
+
self.__resize_target = resize_target
|
| 47 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
| 48 |
+
self.__multiple_of = ensure_multiple_of
|
| 49 |
+
self.__resize_method = resize_method
|
| 50 |
+
self.__image_interpolation_method = image_interpolation_method
|
| 51 |
+
|
| 52 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
| 53 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 54 |
+
|
| 55 |
+
if max_val is not None and y > max_val:
|
| 56 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 57 |
+
|
| 58 |
+
if y < min_val:
|
| 59 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
| 60 |
+
|
| 61 |
+
return y
|
| 62 |
+
|
| 63 |
+
def get_size(self, width, height):
|
| 64 |
+
# determine new height and width
|
| 65 |
+
scale_height = self.__height / height
|
| 66 |
+
scale_width = self.__width / width
|
| 67 |
+
|
| 68 |
+
if self.__keep_aspect_ratio:
|
| 69 |
+
if self.__resize_method == "lower_bound":
|
| 70 |
+
# scale such that output size is lower bound
|
| 71 |
+
if scale_width > scale_height:
|
| 72 |
+
# fit width
|
| 73 |
+
scale_height = scale_width
|
| 74 |
+
else:
|
| 75 |
+
# fit height
|
| 76 |
+
scale_width = scale_height
|
| 77 |
+
elif self.__resize_method == "upper_bound":
|
| 78 |
+
# scale such that output size is upper bound
|
| 79 |
+
if scale_width < scale_height:
|
| 80 |
+
# fit width
|
| 81 |
+
scale_height = scale_width
|
| 82 |
+
else:
|
| 83 |
+
# fit height
|
| 84 |
+
scale_width = scale_height
|
| 85 |
+
elif self.__resize_method == "minimal":
|
| 86 |
+
# scale as least as possbile
|
| 87 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
| 88 |
+
# fit width
|
| 89 |
+
scale_height = scale_width
|
| 90 |
+
else:
|
| 91 |
+
# fit height
|
| 92 |
+
scale_width = scale_height
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 95 |
+
|
| 96 |
+
if self.__resize_method == "lower_bound":
|
| 97 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
| 98 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
| 99 |
+
elif self.__resize_method == "upper_bound":
|
| 100 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
| 101 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
| 102 |
+
elif self.__resize_method == "minimal":
|
| 103 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
| 104 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
| 107 |
+
|
| 108 |
+
return (new_width, new_height)
|
| 109 |
+
|
| 110 |
+
def __call__(self, sample):
|
| 111 |
+
width, height = self.get_size(sample["image"].shape[-1], sample["image"].shape[-2])
|
| 112 |
+
|
| 113 |
+
# resize sample
|
| 114 |
+
# sample["image"] = cv2.resize(sample["image"], dsize=(width, height), interpolation=cv2.INTER_NEAREST)
|
| 115 |
+
sample["image"] = F.interpolate(sample["image"], size=(height, width), mode='bilinear', align_corners=False)
|
| 116 |
+
|
| 117 |
+
if self.__resize_target:
|
| 118 |
+
if "depth" in sample:
|
| 119 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
| 120 |
+
|
| 121 |
+
if "mask" in sample:
|
| 122 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
| 123 |
+
|
| 124 |
+
return sample, (height, width)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class NormalizeImage(object):
|
| 128 |
+
"""Normlize image by given mean and std.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, mean, std):
|
| 132 |
+
self.__mean = mean
|
| 133 |
+
self.__std = std
|
| 134 |
+
|
| 135 |
+
def __call__(self, sample):
|
| 136 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
| 137 |
+
|
| 138 |
+
return sample
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class PrepareForNet(object):
|
| 142 |
+
"""Prepare sample for usage as network input.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self):
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
def __call__(self, sample):
|
| 149 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
| 150 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
| 151 |
+
|
| 152 |
+
if "depth" in sample:
|
| 153 |
+
depth = sample["depth"].astype(np.float32)
|
| 154 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
| 155 |
+
|
| 156 |
+
if "mask" in sample:
|
| 157 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
| 158 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
| 159 |
+
|
| 160 |
+
return sample
|
DepthMaster/in_the_wild_example/input/06.jpg
ADDED
|
DepthMaster/scripts/eval_diode.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python evaluate.py \
|
| 7 |
+
--base_data_dir path/to/basedata \
|
| 8 |
+
--dataset_config config/dataset/data_diode_all.yaml \
|
| 9 |
+
--alignment least_square_sqrt_disp \
|
| 10 |
+
--output_dir output/diode/final \
|
| 11 |
+
--checkpoint ckpt/eval \
|
| 12 |
+
--processing_res 640 \
|
| 13 |
+
--seed 1234 \
|
DepthMaster/scripts/eval_eth3d.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python evaluate.py \
|
| 7 |
+
--base_data_dir path/to/basedata \
|
| 8 |
+
--dataset_config config/dataset/data_eth3d.yaml \
|
| 9 |
+
--alignment least_square_sqrt_disp \
|
| 10 |
+
--output_dir output/eth3d/final \
|
| 11 |
+
--checkpoint ckpt/eval \
|
| 12 |
+
--processing_res 756 \
|
| 13 |
+
--seed 1234 \
|
DepthMaster/scripts/eval_hypersim.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python evaluate.py \
|
| 7 |
+
--base_data_dir path/to/basedata \
|
| 8 |
+
--dataset_config config/dataset/data_hypersim_test.yaml \
|
| 9 |
+
--alignment least_square_sqrt_disp \
|
| 10 |
+
--output_dir output/hypersim/final \
|
| 11 |
+
--checkpoint ckpt/eval \
|
| 12 |
+
--processing_res 0 \
|
| 13 |
+
--seed 1234 \
|
DepthMaster/scripts/eval_kitti.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python evaluate.py \
|
| 7 |
+
--base_data_dir path/to/basedata \
|
| 8 |
+
--dataset_config config/dataset/data_kitti_eigen_test.yaml \
|
| 9 |
+
--alignment least_square_sqrt_disp \
|
| 10 |
+
--output_dir output/kitti/final \
|
| 11 |
+
--checkpoint ckpt/eval \
|
| 12 |
+
--processing_res 0 \
|
| 13 |
+
--seed 1234 \
|
DepthMaster/scripts/eval_nyu.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python evaluate.py \
|
| 7 |
+
--base_data_dir path/to/basedata \
|
| 8 |
+
--dataset_config config/dataset/data_nyu_test.yaml \
|
| 9 |
+
--alignment least_square_sqrt_disp \
|
| 10 |
+
--output_dir output/nyu/final1 \
|
| 11 |
+
--checkpoint ckpt/eval \
|
| 12 |
+
--processing_res 0 \
|
| 13 |
+
--seed 1234 \
|
DepthMaster/scripts/eval_scannet.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python evaluate.py \
|
| 7 |
+
--base_data_dir path/to/basedata \
|
| 8 |
+
--dataset_config config/dataset/data_scannet_val.yaml \
|
| 9 |
+
--alignment least_square_sqrt_disp \
|
| 10 |
+
--output_dir output/scannet/final \
|
| 11 |
+
--checkpoint ckpt/eval \
|
| 12 |
+
--processing_res 0 \
|
| 13 |
+
--seed 1234 \
|
DepthMaster/scripts/infer.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=5
|
| 6 |
+
python run.py \
|
| 7 |
+
--checkpoint ckpt/eval \
|
| 8 |
+
--processing_res 768 \
|
| 9 |
+
--input_rgb_dir in_the_wild_example/input \
|
| 10 |
+
--output_dir in_the_wild_example/output/final \
|
DepthMaster/scripts/train_s1.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BASE_DATA_DIR="path/to/basedata"
|
| 2 |
+
BASE_CKPT_DIR="path/to/sd2_ckpt"
|
| 3 |
+
|
| 4 |
+
export CUDA_VISIBLE_DEVICES=3
|
| 5 |
+
python train_s1.py --config config/train_s1.yaml \
|
| 6 |
+
--base_data_dir $BASE_DATA_DIR \
|
| 7 |
+
--base_ckpt_dir $BASE_CKPT_DIR \
|
| 8 |
+
--output_dir log/stage1_bs8 \
|
| 9 |
+
--no_wandb \
|
DepthMaster/scripts/train_s2.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BASE_DATA_DIR="/zhdd/dataset"
|
| 2 |
+
BASE_CKPT_DIR="ori_ckpt"
|
| 3 |
+
|
| 4 |
+
export CUDA_VISIBLE_DEVICES=2
|
| 5 |
+
python train_s2.py --config config/train_s2.yaml \
|
| 6 |
+
--base_data_dir $BASE_DATA_DIR \
|
| 7 |
+
--base_ckpt_dir $BASE_CKPT_DIR \
|
| 8 |
+
--output_dir log/stage2 \
|
| 9 |
+
--no_wandb \
|
DepthMaster/src/dataset/__init__.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
from .base_depth_dataset import BaseDepthDataset, get_pred_name, DatasetMode # noqa: F401
|
| 28 |
+
from .diode_dataset import DIODEDataset
|
| 29 |
+
from .eth3d_dataset import ETH3DDataset
|
| 30 |
+
from .hypersim_dataset import HypersimDataset
|
| 31 |
+
from .kitti_dataset import KITTIDataset
|
| 32 |
+
from .nyu_dataset import NYUDataset
|
| 33 |
+
from .scannet_dataset import ScanNetDataset
|
| 34 |
+
from .vkitti_dataset import VirtualKITTIDataset
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
dataset_name_class_dict = {
|
| 38 |
+
"hypersim": HypersimDataset,
|
| 39 |
+
"vkitti": VirtualKITTIDataset,
|
| 40 |
+
"nyu_v2": NYUDataset,
|
| 41 |
+
"kitti": KITTIDataset,
|
| 42 |
+
"eth3d": ETH3DDataset,
|
| 43 |
+
"diode": DIODEDataset,
|
| 44 |
+
"scannet": ScanNetDataset,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_dataset(
|
| 49 |
+
cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs
|
| 50 |
+
) -> BaseDepthDataset:
|
| 51 |
+
if "mixed" == cfg_data_split.name:
|
| 52 |
+
assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets."
|
| 53 |
+
dataset_ls = [
|
| 54 |
+
get_dataset(_cfg, base_data_dir, mode, **kwargs)
|
| 55 |
+
for _cfg in cfg_data_split.dataset_list
|
| 56 |
+
]
|
| 57 |
+
return dataset_ls
|
| 58 |
+
elif cfg_data_split.name in dataset_name_class_dict.keys():
|
| 59 |
+
dataset_class = dataset_name_class_dict[cfg_data_split.name]
|
| 60 |
+
dataset = dataset_class(
|
| 61 |
+
mode=mode,
|
| 62 |
+
filename_ls_path=cfg_data_split.filenames,
|
| 63 |
+
dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir),
|
| 64 |
+
# dataset_tom_dir=os.path.join(base_data_dir, cfg_data_split.tom_dir),
|
| 65 |
+
**cfg_data_split,
|
| 66 |
+
**kwargs,
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError
|
| 70 |
+
|
| 71 |
+
return dataset
|
DepthMaster/src/dataset/base_depth_dataset.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import io
|
| 26 |
+
import os
|
| 27 |
+
import random
|
| 28 |
+
import tarfile
|
| 29 |
+
from enum import Enum
|
| 30 |
+
from typing import Union
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
import torch
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from torch.utils.data import Dataset
|
| 36 |
+
from torchvision.transforms import InterpolationMode, Resize, RandomResizedCrop
|
| 37 |
+
|
| 38 |
+
from src.util.depth_transform import DepthNormalizerBase
|
| 39 |
+
from src.util.alignment import depth2disparity
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DatasetMode(Enum):
|
| 43 |
+
RGB_ONLY = "rgb_only"
|
| 44 |
+
EVAL = "evaluate"
|
| 45 |
+
TRAIN = "train"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DepthFileNameMode(Enum):
|
| 49 |
+
"""Prediction file naming modes"""
|
| 50 |
+
|
| 51 |
+
id = 1 # id.png
|
| 52 |
+
rgb_id = 2 # rgb_id.png
|
| 53 |
+
i_d_rgb = 3 # i_d_1_rgb.png
|
| 54 |
+
rgb_i_d = 4
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def read_image_from_tar(tar_obj, img_rel_path):
|
| 58 |
+
image = tar_obj.extractfile("./" + img_rel_path)
|
| 59 |
+
image = image.read()
|
| 60 |
+
image = Image.open(io.BytesIO(image))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class BaseDepthDataset(Dataset):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
mode: DatasetMode,
|
| 67 |
+
filename_ls_path: str,
|
| 68 |
+
dataset_dir: str,
|
| 69 |
+
disp_name: str,
|
| 70 |
+
min_depth: float,
|
| 71 |
+
max_depth: float,
|
| 72 |
+
has_filled_depth: bool,
|
| 73 |
+
has_egde_mask: bool,
|
| 74 |
+
name_mode: DepthFileNameMode,
|
| 75 |
+
depth_transform: Union[DepthNormalizerBase, None] = None,
|
| 76 |
+
augmentation_args: dict = None,
|
| 77 |
+
resize_to_hw=None,
|
| 78 |
+
move_invalid_to_far_plane: bool = True,
|
| 79 |
+
rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1],
|
| 80 |
+
**kwargs,
|
| 81 |
+
) -> None:
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.mode = mode
|
| 84 |
+
# dataset info
|
| 85 |
+
self.filename_ls_path = filename_ls_path
|
| 86 |
+
self.dataset_dir = dataset_dir
|
| 87 |
+
assert os.path.exists(
|
| 88 |
+
self.dataset_dir
|
| 89 |
+
), f"Dataset does not exist at: {self.dataset_dir}"
|
| 90 |
+
self.disp_name = disp_name
|
| 91 |
+
self.has_filled_depth = has_filled_depth
|
| 92 |
+
self.has_egde_mask = has_egde_mask
|
| 93 |
+
self.name_mode: DepthFileNameMode = name_mode
|
| 94 |
+
self.min_depth = min_depth
|
| 95 |
+
self.max_depth = max_depth
|
| 96 |
+
|
| 97 |
+
# training arguments
|
| 98 |
+
self.depth_transform: DepthNormalizerBase = depth_transform
|
| 99 |
+
self.augm_args = augmentation_args
|
| 100 |
+
self.resize_to_hw = resize_to_hw
|
| 101 |
+
self.rgb_transform = rgb_transform
|
| 102 |
+
self.move_invalid_to_far_plane = move_invalid_to_far_plane
|
| 103 |
+
|
| 104 |
+
# Load filenames
|
| 105 |
+
with open(self.filename_ls_path, "r") as f:
|
| 106 |
+
self.filenames = [
|
| 107 |
+
s.split() for s in f.readlines()
|
| 108 |
+
] # [['rgb.png', 'depth.tif'], [], ...]
|
| 109 |
+
|
| 110 |
+
# Tar dataset
|
| 111 |
+
self.tar_obj = None
|
| 112 |
+
self.is_tar = (
|
| 113 |
+
True
|
| 114 |
+
if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir)
|
| 115 |
+
else False
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return len(self.filenames)
|
| 120 |
+
|
| 121 |
+
def __getitem__(self, index):
|
| 122 |
+
rasters, other = self._get_data_item(index)
|
| 123 |
+
if DatasetMode.TRAIN == self.mode:
|
| 124 |
+
rasters = self._training_preprocess(rasters)
|
| 125 |
+
# merge
|
| 126 |
+
outputs = rasters
|
| 127 |
+
outputs.update(other)
|
| 128 |
+
return outputs
|
| 129 |
+
|
| 130 |
+
def _get_data_item(self, index):
|
| 131 |
+
rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index)
|
| 132 |
+
|
| 133 |
+
rasters = {}
|
| 134 |
+
|
| 135 |
+
# RGB data
|
| 136 |
+
rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
|
| 137 |
+
|
| 138 |
+
# Depth data
|
| 139 |
+
if DatasetMode.RGB_ONLY != self.mode:
|
| 140 |
+
# load data
|
| 141 |
+
depth_data = self._load_depth_data(
|
| 142 |
+
depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path
|
| 143 |
+
)
|
| 144 |
+
rasters.update(depth_data)
|
| 145 |
+
# valid mask
|
| 146 |
+
rasters["valid_mask_raw"] = self._get_valid_mask(
|
| 147 |
+
rasters["depth_raw_linear"]
|
| 148 |
+
).clone()
|
| 149 |
+
rasters["valid_mask_filled"] = self._get_valid_mask(
|
| 150 |
+
rasters["depth_filled_linear"]
|
| 151 |
+
).clone()
|
| 152 |
+
|
| 153 |
+
if DatasetMode.TRAIN == self.mode:
|
| 154 |
+
# depth2disp
|
| 155 |
+
rasters["depth_raw_linear"] = depth2disparity(rasters["depth_raw_linear"]).clone()
|
| 156 |
+
if self.has_filled_depth:
|
| 157 |
+
rasters["depth_filled_linear"] = depth2disparity(rasters["depth_filled_linear"]).clone()
|
| 158 |
+
|
| 159 |
+
# sqrt(x)
|
| 160 |
+
rasters["depth_raw_linear"] = torch.sqrt(rasters["depth_raw_linear"]).clone()
|
| 161 |
+
if self.has_filled_depth:
|
| 162 |
+
rasters["depth_filled_linear"] = torch.sqrt(rasters["depth_filled_linear"]).clone()
|
| 163 |
+
|
| 164 |
+
other = {"index": index, "rgb_relative_path": rgb_rel_path}
|
| 165 |
+
|
| 166 |
+
return rasters, other
|
| 167 |
+
|
| 168 |
+
def _load_rgb_data(self, rgb_rel_path):
|
| 169 |
+
# Read RGB data
|
| 170 |
+
rgb = self._read_rgb_file(rgb_rel_path)
|
| 171 |
+
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
| 172 |
+
|
| 173 |
+
outputs = {
|
| 174 |
+
"rgb_int": torch.from_numpy(rgb).int(),
|
| 175 |
+
"rgb_norm": torch.from_numpy(rgb_norm).float(),
|
| 176 |
+
}
|
| 177 |
+
return outputs
|
| 178 |
+
|
| 179 |
+
def _load_depth_data(self, depth_rel_path, filled_rel_path):
|
| 180 |
+
# Read depth data
|
| 181 |
+
outputs = {}
|
| 182 |
+
depth_raw = self._read_depth_file(depth_rel_path).squeeze()
|
| 183 |
+
depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W]
|
| 184 |
+
outputs["depth_raw_linear"] = depth_raw_linear.clone()
|
| 185 |
+
|
| 186 |
+
if self.has_filled_depth:
|
| 187 |
+
depth_filled = self._read_depth_file(filled_rel_path).squeeze()
|
| 188 |
+
depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0)
|
| 189 |
+
outputs["depth_filled_linear"] = depth_filled_linear
|
| 190 |
+
else:
|
| 191 |
+
outputs["depth_filled_linear"] = depth_raw_linear.clone()
|
| 192 |
+
|
| 193 |
+
return outputs
|
| 194 |
+
|
| 195 |
+
def _get_data_path(self, index):
|
| 196 |
+
filename_line = self.filenames[index]
|
| 197 |
+
|
| 198 |
+
# Get data path
|
| 199 |
+
rgb_rel_path = filename_line[0]
|
| 200 |
+
|
| 201 |
+
depth_rel_path, filled_rel_path = None, None
|
| 202 |
+
if DatasetMode.RGB_ONLY != self.mode:
|
| 203 |
+
depth_rel_path = filename_line[1]
|
| 204 |
+
if self.has_filled_depth:
|
| 205 |
+
filled_rel_path = filename_line[2]
|
| 206 |
+
return rgb_rel_path, depth_rel_path, filled_rel_path
|
| 207 |
+
|
| 208 |
+
def _read_image(self, img_rel_path) -> np.ndarray:
|
| 209 |
+
if self.is_tar:
|
| 210 |
+
if self.tar_obj is None:
|
| 211 |
+
self.tar_obj = tarfile.open(self.dataset_dir)
|
| 212 |
+
image_to_read = self.tar_obj.extractfile("./" + img_rel_path)
|
| 213 |
+
image_to_read = image_to_read.read()
|
| 214 |
+
image_to_read = io.BytesIO(image_to_read)
|
| 215 |
+
else:
|
| 216 |
+
image_to_read = os.path.join(self.dataset_dir, img_rel_path)
|
| 217 |
+
image = Image.open(image_to_read) # [H, W, rgb]
|
| 218 |
+
image = np.asarray(image)
|
| 219 |
+
return image
|
| 220 |
+
|
| 221 |
+
def _read_rgb_file(self, rel_path) -> np.ndarray:
|
| 222 |
+
rgb = self._read_image(rel_path)
|
| 223 |
+
rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W]
|
| 224 |
+
return rgb
|
| 225 |
+
|
| 226 |
+
def _read_depth_file(self, rel_path):
|
| 227 |
+
depth_in = self._read_image(rel_path)
|
| 228 |
+
# Replace code below to decode depth according to dataset definition
|
| 229 |
+
depth_decoded = depth_in
|
| 230 |
+
|
| 231 |
+
return depth_decoded
|
| 232 |
+
|
| 233 |
+
def _get_valid_mask(self, depth: torch.Tensor):
|
| 234 |
+
valid_mask = torch.logical_and(
|
| 235 |
+
(depth > self.min_depth), (depth < self.max_depth)
|
| 236 |
+
).bool()
|
| 237 |
+
return valid_mask
|
| 238 |
+
|
| 239 |
+
def _training_preprocess(self, rasters):
|
| 240 |
+
# Augmentation
|
| 241 |
+
if self.augm_args is not None:
|
| 242 |
+
rasters = self._augment_data(rasters)
|
| 243 |
+
|
| 244 |
+
# Normalization
|
| 245 |
+
rasters["depth_raw_norm"] = self.depth_transform(
|
| 246 |
+
rasters["depth_raw_linear"], rasters["valid_mask_raw"]
|
| 247 |
+
).clone()
|
| 248 |
+
rasters["depth_filled_norm"] = self.depth_transform(
|
| 249 |
+
rasters["depth_filled_linear"], rasters["valid_mask_filled"]
|
| 250 |
+
).clone()
|
| 251 |
+
|
| 252 |
+
# Set invalid pixel to far plane
|
| 253 |
+
if self.move_invalid_to_far_plane:
|
| 254 |
+
if self.depth_transform.far_plane_at_max:
|
| 255 |
+
rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
|
| 256 |
+
self.depth_transform.norm_max
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
|
| 260 |
+
self.depth_transform.norm_min
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Resize
|
| 264 |
+
if self.resize_to_hw is not None:
|
| 265 |
+
resize_transform = Resize(
|
| 266 |
+
size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT
|
| 267 |
+
)
|
| 268 |
+
rasters = {k: resize_transform(v) for k, v in rasters.items()}
|
| 269 |
+
|
| 270 |
+
# # randomresizedcrop
|
| 271 |
+
# resizedcrop = RandomResizedCrop(size=self.resize_to_hw, scale=(0.9, 1), ratio=())
|
| 272 |
+
|
| 273 |
+
return rasters
|
| 274 |
+
|
| 275 |
+
def _augment_data(self, rasters_dict):
|
| 276 |
+
# lr flipping
|
| 277 |
+
lr_flip_p = self.augm_args.lr_flip_p
|
| 278 |
+
if random.random() < lr_flip_p:
|
| 279 |
+
rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()}
|
| 280 |
+
|
| 281 |
+
return rasters_dict
|
| 282 |
+
|
| 283 |
+
def __del__(self):
|
| 284 |
+
if hasattr(self, "tar_obj") and self.tar_obj is not None:
|
| 285 |
+
self.tar_obj.close()
|
| 286 |
+
self.tar_obj = None
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def get_pred_name(rgb_basename, name_mode, suffix=".png"):
|
| 290 |
+
if DepthFileNameMode.rgb_id == name_mode:
|
| 291 |
+
pred_basename = "pred_" + rgb_basename.split("_")[1]
|
| 292 |
+
elif DepthFileNameMode.i_d_rgb == name_mode:
|
| 293 |
+
pred_basename = rgb_basename.replace("_rgb.", "_pred.")
|
| 294 |
+
elif DepthFileNameMode.id == name_mode:
|
| 295 |
+
pred_basename = "pred_" + rgb_basename
|
| 296 |
+
elif DepthFileNameMode.rgb_i_d == name_mode:
|
| 297 |
+
pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:])
|
| 298 |
+
else:
|
| 299 |
+
raise NotImplementedError
|
| 300 |
+
# change suffix
|
| 301 |
+
pred_basename = os.path.splitext(pred_basename)[0] + suffix
|
| 302 |
+
|
| 303 |
+
return pred_basename
|
DepthMaster/src/dataset/diode_dataset.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import tarfile
|
| 27 |
+
from io import BytesIO
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
|
| 32 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode, DatasetMode
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DIODEDataset(BaseDepthDataset):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
**kwargs,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__(
|
| 41 |
+
# DIODE data parameter
|
| 42 |
+
min_depth=0.6,
|
| 43 |
+
max_depth=350,
|
| 44 |
+
has_filled_depth=False,
|
| 45 |
+
has_egde_mask=False,
|
| 46 |
+
name_mode=DepthFileNameMode.id,
|
| 47 |
+
**kwargs,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def _read_npy_file(self, rel_path):
|
| 51 |
+
if self.is_tar:
|
| 52 |
+
if self.tar_obj is None:
|
| 53 |
+
self.tar_obj = tarfile.open(self.dataset_dir)
|
| 54 |
+
fileobj = self.tar_obj.extractfile("./" + rel_path)
|
| 55 |
+
npy_path_or_content = BytesIO(fileobj.read())
|
| 56 |
+
else:
|
| 57 |
+
npy_path_or_content = os.path.join(self.dataset_dir, rel_path)
|
| 58 |
+
data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :]
|
| 59 |
+
return data
|
| 60 |
+
|
| 61 |
+
def _read_depth_file(self, rel_path):
|
| 62 |
+
depth = self._read_npy_file(rel_path)
|
| 63 |
+
return depth
|
| 64 |
+
|
| 65 |
+
def _get_data_path(self, index):
|
| 66 |
+
return self.filenames[index]
|
| 67 |
+
|
| 68 |
+
def _get_data_item(self, index):
|
| 69 |
+
# Special: depth mask is read from data
|
| 70 |
+
|
| 71 |
+
rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index)
|
| 72 |
+
|
| 73 |
+
rasters = {}
|
| 74 |
+
|
| 75 |
+
# RGB data
|
| 76 |
+
rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
|
| 77 |
+
|
| 78 |
+
# Depth data
|
| 79 |
+
if DatasetMode.RGB_ONLY != self.mode:
|
| 80 |
+
# load data
|
| 81 |
+
depth_data = self._load_depth_data(
|
| 82 |
+
depth_rel_path=depth_rel_path, filled_rel_path=None
|
| 83 |
+
)
|
| 84 |
+
rasters.update(depth_data)
|
| 85 |
+
|
| 86 |
+
# valid mask
|
| 87 |
+
mask = self._read_npy_file(mask_rel_path).astype(bool)
|
| 88 |
+
mask = torch.from_numpy(mask).bool()
|
| 89 |
+
rasters["valid_mask_raw"] = mask.clone()
|
| 90 |
+
rasters["valid_mask_filled"] = mask.clone()
|
| 91 |
+
|
| 92 |
+
other = {"index": index, "rgb_relative_path": rgb_rel_path}
|
| 93 |
+
|
| 94 |
+
return rasters, other
|
DepthMaster/src/dataset/eth3d_dataset.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import tarfile
|
| 27 |
+
import os
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ETH3DDataset(BaseDepthDataset):
|
| 34 |
+
HEIGHT, WIDTH = 4032, 6048
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
**kwargs,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__(
|
| 41 |
+
# ETH3D data parameter
|
| 42 |
+
min_depth=1e-5,
|
| 43 |
+
max_depth=torch.inf,
|
| 44 |
+
has_filled_depth=False,
|
| 45 |
+
has_egde_mask=False,
|
| 46 |
+
name_mode=DepthFileNameMode.id,
|
| 47 |
+
**kwargs,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def _read_depth_file(self, rel_path):
|
| 51 |
+
# Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats
|
| 52 |
+
if self.is_tar:
|
| 53 |
+
if self.tar_obj is None:
|
| 54 |
+
self.tar_obj = tarfile.open(self.dataset_dir)
|
| 55 |
+
binary_data = self.tar_obj.extractfile("./" + rel_path)
|
| 56 |
+
binary_data = binary_data.read()
|
| 57 |
+
|
| 58 |
+
else:
|
| 59 |
+
depth_path = os.path.join(self.dataset_dir, rel_path)
|
| 60 |
+
with open(depth_path, "rb") as file:
|
| 61 |
+
binary_data = file.read()
|
| 62 |
+
# Convert the binary data to a numpy array of 32-bit floats
|
| 63 |
+
depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy()
|
| 64 |
+
|
| 65 |
+
depth_decoded[depth_decoded == torch.inf] = 0.0
|
| 66 |
+
|
| 67 |
+
depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH))
|
| 68 |
+
return depth_decoded
|
DepthMaster/src/dataset/hypersim_dataset.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HypersimDataset(BaseDepthDataset):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
**kwargs,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(
|
| 35 |
+
# Hypersim data parameter
|
| 36 |
+
min_depth=1e-5,
|
| 37 |
+
max_depth=65.0,
|
| 38 |
+
has_filled_depth=False,
|
| 39 |
+
has_egde_mask=False,
|
| 40 |
+
name_mode=DepthFileNameMode.rgb_i_d,
|
| 41 |
+
**kwargs,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def _read_depth_file(self, rel_path):
|
| 45 |
+
depth_in = self._read_image(rel_path)
|
| 46 |
+
# Decode Hypersim depth
|
| 47 |
+
depth_decoded = depth_in / 1000.0
|
| 48 |
+
return depth_decoded
|
DepthMaster/src/dataset/kitti_dataset.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class KITTIDataset(BaseDepthDataset):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
kitti_bm_crop, # Crop to KITTI benchmark size
|
| 34 |
+
valid_mask_crop, # Evaluation mask. [None, garg or eigen]
|
| 35 |
+
**kwargs,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__(
|
| 38 |
+
# KITTI data parameter
|
| 39 |
+
min_depth=1e-5,
|
| 40 |
+
max_depth=80,
|
| 41 |
+
has_filled_depth=False,
|
| 42 |
+
has_egde_mask=False,
|
| 43 |
+
name_mode=DepthFileNameMode.id,
|
| 44 |
+
**kwargs,
|
| 45 |
+
)
|
| 46 |
+
self.kitti_bm_crop = kitti_bm_crop
|
| 47 |
+
self.valid_mask_crop = valid_mask_crop
|
| 48 |
+
assert self.valid_mask_crop in [
|
| 49 |
+
None,
|
| 50 |
+
"garg", # set evaluation mask according to Garg ECCV16
|
| 51 |
+
"eigen", # set evaluation mask according to Eigen NIPS14
|
| 52 |
+
], f"Unknown crop type: {self.valid_mask_crop}"
|
| 53 |
+
|
| 54 |
+
# Filter out empty depth
|
| 55 |
+
self.filenames = [f for f in self.filenames if "None" != f[1]]
|
| 56 |
+
|
| 57 |
+
def _read_depth_file(self, rel_path):
|
| 58 |
+
depth_in = self._read_image(rel_path)
|
| 59 |
+
# Decode KITTI depth
|
| 60 |
+
depth_decoded = depth_in / 256.0
|
| 61 |
+
return depth_decoded
|
| 62 |
+
|
| 63 |
+
def _load_rgb_data(self, rgb_rel_path):
|
| 64 |
+
rgb_data = super()._load_rgb_data(rgb_rel_path)
|
| 65 |
+
if self.kitti_bm_crop:
|
| 66 |
+
rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()}
|
| 67 |
+
return rgb_data
|
| 68 |
+
|
| 69 |
+
def _load_depth_data(self, depth_rel_path, filled_rel_path):
|
| 70 |
+
depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path)
|
| 71 |
+
if self.kitti_bm_crop:
|
| 72 |
+
depth_data = {
|
| 73 |
+
k: self.kitti_benchmark_crop(v) for k, v in depth_data.items()
|
| 74 |
+
}
|
| 75 |
+
return depth_data
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def kitti_benchmark_crop(input_img):
|
| 79 |
+
"""
|
| 80 |
+
Crop images to KITTI benchmark size
|
| 81 |
+
Args:
|
| 82 |
+
`input_img` (torch.Tensor): Input image to be cropped.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
torch.Tensor:Cropped image.
|
| 86 |
+
"""
|
| 87 |
+
KB_CROP_HEIGHT = 352
|
| 88 |
+
KB_CROP_WIDTH = 1216
|
| 89 |
+
|
| 90 |
+
height, width = input_img.shape[-2:]
|
| 91 |
+
top_margin = int(height - KB_CROP_HEIGHT)
|
| 92 |
+
left_margin = int((width - KB_CROP_WIDTH) / 2)
|
| 93 |
+
if 2 == len(input_img.shape):
|
| 94 |
+
out = input_img[
|
| 95 |
+
top_margin : top_margin + KB_CROP_HEIGHT,
|
| 96 |
+
left_margin : left_margin + KB_CROP_WIDTH,
|
| 97 |
+
]
|
| 98 |
+
elif 3 == len(input_img.shape):
|
| 99 |
+
out = input_img[
|
| 100 |
+
:,
|
| 101 |
+
top_margin : top_margin + KB_CROP_HEIGHT,
|
| 102 |
+
left_margin : left_margin + KB_CROP_WIDTH,
|
| 103 |
+
]
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
def _get_valid_mask(self, depth: torch.Tensor):
|
| 107 |
+
# reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py
|
| 108 |
+
valid_mask = super()._get_valid_mask(depth) # [1, H, W]
|
| 109 |
+
|
| 110 |
+
if self.valid_mask_crop is not None:
|
| 111 |
+
eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
|
| 112 |
+
gt_height, gt_width = eval_mask.shape
|
| 113 |
+
|
| 114 |
+
if "garg" == self.valid_mask_crop:
|
| 115 |
+
eval_mask[
|
| 116 |
+
int(0.40810811 * gt_height) : int(0.99189189 * gt_height),
|
| 117 |
+
int(0.03594771 * gt_width) : int(0.96405229 * gt_width),
|
| 118 |
+
] = 1
|
| 119 |
+
elif "eigen" == self.valid_mask_crop:
|
| 120 |
+
eval_mask[
|
| 121 |
+
int(0.3324324 * gt_height) : int(0.91351351 * gt_height),
|
| 122 |
+
int(0.0359477 * gt_width) : int(0.96405229 * gt_width),
|
| 123 |
+
] = 1
|
| 124 |
+
|
| 125 |
+
eval_mask.reshape(valid_mask.shape)
|
| 126 |
+
valid_mask = torch.logical_and(valid_mask, eval_mask)
|
| 127 |
+
return valid_mask
|
DepthMaster/src/dataset/mixed_sampler.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch.utils.data import (
|
| 27 |
+
BatchSampler,
|
| 28 |
+
RandomSampler,
|
| 29 |
+
SequentialSampler,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MixedBatchSampler(BatchSampler):
|
| 34 |
+
"""Sample one batch from a selected dataset with given probability.
|
| 35 |
+
Compatible with datasets at different resolution
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None
|
| 40 |
+
):
|
| 41 |
+
self.base_sampler = None
|
| 42 |
+
self.batch_size = batch_size
|
| 43 |
+
self.shuffle = shuffle
|
| 44 |
+
self.drop_last = drop_last
|
| 45 |
+
self.generator = generator
|
| 46 |
+
|
| 47 |
+
self.src_dataset_ls = src_dataset_ls
|
| 48 |
+
self.n_dataset = len(self.src_dataset_ls)
|
| 49 |
+
|
| 50 |
+
# Dataset length
|
| 51 |
+
self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
|
| 52 |
+
self.cum_dataset_length = [
|
| 53 |
+
sum(self.dataset_length[:i]) for i in range(self.n_dataset)
|
| 54 |
+
] # cumulative dataset length
|
| 55 |
+
|
| 56 |
+
# BatchSamplers for each source dataset
|
| 57 |
+
if self.shuffle:
|
| 58 |
+
self.src_batch_samplers = [
|
| 59 |
+
BatchSampler(
|
| 60 |
+
sampler=RandomSampler(
|
| 61 |
+
ds, replacement=False, generator=self.generator
|
| 62 |
+
),
|
| 63 |
+
batch_size=self.batch_size,
|
| 64 |
+
drop_last=self.drop_last,
|
| 65 |
+
)
|
| 66 |
+
for ds in self.src_dataset_ls
|
| 67 |
+
]
|
| 68 |
+
else:
|
| 69 |
+
self.src_batch_samplers = [
|
| 70 |
+
BatchSampler(
|
| 71 |
+
sampler=SequentialSampler(ds),
|
| 72 |
+
batch_size=self.batch_size,
|
| 73 |
+
drop_last=self.drop_last,
|
| 74 |
+
)
|
| 75 |
+
for ds in self.src_dataset_ls
|
| 76 |
+
]
|
| 77 |
+
self.raw_batches = [
|
| 78 |
+
list(bs) for bs in self.src_batch_samplers
|
| 79 |
+
] # index in original dataset
|
| 80 |
+
self.n_batches = [len(b) for b in self.raw_batches]
|
| 81 |
+
self.n_total_batch = sum(self.n_batches)
|
| 82 |
+
|
| 83 |
+
# sampling probability
|
| 84 |
+
if prob is None:
|
| 85 |
+
# if not given, decide by dataset length
|
| 86 |
+
self.prob = torch.tensor(self.n_batches) / self.n_total_batch
|
| 87 |
+
else:
|
| 88 |
+
self.prob = torch.as_tensor(prob)
|
| 89 |
+
|
| 90 |
+
def __iter__(self):
|
| 91 |
+
"""_summary_
|
| 92 |
+
|
| 93 |
+
Yields:
|
| 94 |
+
list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls
|
| 95 |
+
"""
|
| 96 |
+
for _ in range(self.n_total_batch):
|
| 97 |
+
idx_ds = torch.multinomial(
|
| 98 |
+
self.prob, 1, replacement=True, generator=self.generator
|
| 99 |
+
).item()
|
| 100 |
+
# if batch list is empty, generate new list
|
| 101 |
+
if 0 == len(self.raw_batches[idx_ds]):
|
| 102 |
+
self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
|
| 103 |
+
# get a batch from list
|
| 104 |
+
batch_raw = self.raw_batches[idx_ds].pop()
|
| 105 |
+
# shift by cumulative dataset length
|
| 106 |
+
shift = self.cum_dataset_length[idx_ds]
|
| 107 |
+
batch = [n + shift for n in batch_raw]
|
| 108 |
+
|
| 109 |
+
yield batch
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return self.n_total_batch
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Unit test
|
| 116 |
+
if "__main__" == __name__:
|
| 117 |
+
from torch.utils.data import ConcatDataset, DataLoader, Dataset
|
| 118 |
+
|
| 119 |
+
class SimpleDataset(Dataset):
|
| 120 |
+
def __init__(self, start, len) -> None:
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.start = start
|
| 123 |
+
self.len = len
|
| 124 |
+
|
| 125 |
+
def __len__(self):
|
| 126 |
+
return self.len
|
| 127 |
+
|
| 128 |
+
def __getitem__(self, index):
|
| 129 |
+
return self.start + index
|
| 130 |
+
|
| 131 |
+
dataset_1 = SimpleDataset(0, 10)
|
| 132 |
+
dataset_2 = SimpleDataset(200, 20)
|
| 133 |
+
dataset_3 = SimpleDataset(1000, 50)
|
| 134 |
+
|
| 135 |
+
concat_dataset = ConcatDataset(
|
| 136 |
+
[dataset_1, dataset_2, dataset_3]
|
| 137 |
+
) # will directly concatenate
|
| 138 |
+
|
| 139 |
+
mixed_sampler = MixedBatchSampler(
|
| 140 |
+
src_dataset_ls=[dataset_1, dataset_2, dataset_3],
|
| 141 |
+
batch_size=4,
|
| 142 |
+
drop_last=True,
|
| 143 |
+
shuffle=False,
|
| 144 |
+
prob=[0.6, 0.3, 0.1],
|
| 145 |
+
generator=torch.Generator().manual_seed(0),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler)
|
| 149 |
+
|
| 150 |
+
for d in loader:
|
| 151 |
+
print(d)
|
DepthMaster/src/dataset/nyu_dataset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NYUDataset(BaseDepthDataset):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
eigen_valid_mask: bool,
|
| 34 |
+
**kwargs,
|
| 35 |
+
) -> None:
|
| 36 |
+
super().__init__(
|
| 37 |
+
# NYUv2 dataset parameter
|
| 38 |
+
min_depth=1e-3,
|
| 39 |
+
max_depth=10.0,
|
| 40 |
+
has_filled_depth=True,
|
| 41 |
+
has_egde_mask=False,
|
| 42 |
+
name_mode=DepthFileNameMode.rgb_id,
|
| 43 |
+
**kwargs,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.eigen_valid_mask = eigen_valid_mask
|
| 47 |
+
|
| 48 |
+
def _read_depth_file(self, rel_path):
|
| 49 |
+
depth_in = self._read_image(rel_path)
|
| 50 |
+
# Decode NYU depth
|
| 51 |
+
depth_decoded = depth_in / 1000.0
|
| 52 |
+
return depth_decoded
|
| 53 |
+
|
| 54 |
+
def _get_valid_mask(self, depth: torch.Tensor):
|
| 55 |
+
valid_mask = super()._get_valid_mask(depth)
|
| 56 |
+
|
| 57 |
+
# Eigen crop for evaluation
|
| 58 |
+
if self.eigen_valid_mask:
|
| 59 |
+
eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
|
| 60 |
+
eval_mask[45:471, 41:601] = 1
|
| 61 |
+
eval_mask.reshape(valid_mask.shape)
|
| 62 |
+
valid_mask = torch.logical_and(valid_mask, eval_mask)
|
| 63 |
+
|
| 64 |
+
return valid_mask
|
DepthMaster/src/dataset/scannet_dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ScanNetDataset(BaseDepthDataset):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
**kwargs,
|
| 32 |
+
) -> None:
|
| 33 |
+
super().__init__(
|
| 34 |
+
# ScanNet data parameter
|
| 35 |
+
min_depth=1e-3,
|
| 36 |
+
max_depth=10,
|
| 37 |
+
has_filled_depth=False,
|
| 38 |
+
has_egde_mask=False,
|
| 39 |
+
name_mode=DepthFileNameMode.id,
|
| 40 |
+
**kwargs,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _read_depth_file(self, rel_path):
|
| 44 |
+
depth_in = self._read_image(rel_path)
|
| 45 |
+
# Decode ScanNet depth
|
| 46 |
+
depth_decoded = depth_in / 1000.0
|
| 47 |
+
return depth_decoded
|
DepthMaster/src/dataset/vkitti_dataset.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
|
| 27 |
+
from .kitti_dataset import KITTIDataset
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class VirtualKITTIDataset(BaseDepthDataset):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
kitti_bm_crop, # Crop to KITTI benchmark size
|
| 34 |
+
valid_mask_crop, # Evaluation mask. [None, garg or eigen]
|
| 35 |
+
**kwargs,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__(
|
| 38 |
+
# virtual KITTI data parameter
|
| 39 |
+
min_depth=1e-5,
|
| 40 |
+
max_depth=80, # 655.35
|
| 41 |
+
has_filled_depth=False,
|
| 42 |
+
has_egde_mask=False,
|
| 43 |
+
name_mode=DepthFileNameMode.id,
|
| 44 |
+
**kwargs,
|
| 45 |
+
)
|
| 46 |
+
self.kitti_bm_crop = kitti_bm_crop
|
| 47 |
+
self.valid_mask_crop = valid_mask_crop
|
| 48 |
+
assert self.valid_mask_crop in [
|
| 49 |
+
None,
|
| 50 |
+
"garg", # set evaluation mask according to Garg ECCV16
|
| 51 |
+
"eigen", # set evaluation mask according to Eigen NIPS14
|
| 52 |
+
], f"Unknown crop type: {self.valid_mask_crop}"
|
| 53 |
+
|
| 54 |
+
# Filter out empty depth
|
| 55 |
+
self.filenames = [f for f in self.filenames if "None" != f[1]]
|
| 56 |
+
|
| 57 |
+
def _read_depth_file(self, rel_path):
|
| 58 |
+
depth_in = self._read_image(rel_path)
|
| 59 |
+
# Decode vKITTI depth
|
| 60 |
+
depth_decoded = depth_in / 100.0
|
| 61 |
+
return depth_decoded
|
| 62 |
+
|
| 63 |
+
def _load_rgb_data(self, rgb_rel_path):
|
| 64 |
+
rgb_data = super()._load_rgb_data(rgb_rel_path)
|
| 65 |
+
if self.kitti_bm_crop:
|
| 66 |
+
rgb_data = {
|
| 67 |
+
k: KITTIDataset.kitti_benchmark_crop(v) for k, v in rgb_data.items()
|
| 68 |
+
}
|
| 69 |
+
return rgb_data
|
| 70 |
+
|
| 71 |
+
def _load_depth_data(self, depth_rel_path, filled_rel_path):
|
| 72 |
+
depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path)
|
| 73 |
+
if self.kitti_bm_crop:
|
| 74 |
+
depth_data = {
|
| 75 |
+
k: KITTIDataset.kitti_benchmark_crop(v) for k, v in depth_data.items()
|
| 76 |
+
}
|
| 77 |
+
return depth_data
|
| 78 |
+
|
| 79 |
+
def _get_valid_mask(self, depth: torch.Tensor):
|
| 80 |
+
# reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py
|
| 81 |
+
valid_mask = super()._get_valid_mask(depth) # [1, H, W]
|
| 82 |
+
|
| 83 |
+
if self.valid_mask_crop is not None:
|
| 84 |
+
eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
|
| 85 |
+
gt_height, gt_width = eval_mask.shape
|
| 86 |
+
|
| 87 |
+
if "garg" == self.valid_mask_crop:
|
| 88 |
+
eval_mask[
|
| 89 |
+
int(0.40810811 * gt_height) : int(0.99189189 * gt_height),
|
| 90 |
+
int(0.03594771 * gt_width) : int(0.96405229 * gt_width),
|
| 91 |
+
] = 1
|
| 92 |
+
elif "eigen" == self.valid_mask_crop:
|
| 93 |
+
eval_mask[
|
| 94 |
+
int(0.3324324 * gt_height) : int(0.91351351 * gt_height),
|
| 95 |
+
int(0.0359477 * gt_width) : int(0.96405229 * gt_width),
|
| 96 |
+
] = 1
|
| 97 |
+
|
| 98 |
+
eval_mask.reshape(valid_mask.shape)
|
| 99 |
+
valid_mask = torch.logical_and(valid_mask, eval_mask)
|
| 100 |
+
return valid_mask
|
DepthMaster/src/trainer/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Author: Bingxin Ke
|
| 2 |
+
# Last modified: 2024-05-17
|
| 3 |
+
|
| 4 |
+
from .trainer_s1 import DepthMasterTrainerS1
|
| 5 |
+
from .trainer_s2 import DepthMasterTrainerS2
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
trainer_cls_name_dict = {
|
| 9 |
+
"DepthMasterTrainerS1": DepthMasterTrainerS1,
|
| 10 |
+
"DepthMasterTrainerS2": DepthMasterTrainerS2,
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_trainer_cls(trainer_name):
|
| 15 |
+
return trainer_cls_name_dict[trainer_name]
|
DepthMaster/src/trainer/trainer_s1.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-07-13
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
import logging
|
| 25 |
+
import os
|
| 26 |
+
import random
|
| 27 |
+
import shutil
|
| 28 |
+
from datetime import datetime
|
| 29 |
+
from typing import List, Union
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from omegaconf import OmegaConf
|
| 34 |
+
from torch.optim import Adam
|
| 35 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 36 |
+
from torch.utils.data import DataLoader
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
from PIL import Image
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
|
| 41 |
+
from depthmaster import DepthMasterPipeline, DepthMasterDepthOutput
|
| 42 |
+
from src.util import metric
|
| 43 |
+
from src.util.data_loader import skip_first_batches
|
| 44 |
+
from src.util.logging_util import tb_logger, eval_dic_to_text
|
| 45 |
+
from src.util.loss import get_loss, SSIM
|
| 46 |
+
from src.util.lr_scheduler import IterExponential
|
| 47 |
+
from src.util.metric import MetricTracker
|
| 48 |
+
from src.util.alignment import (
|
| 49 |
+
align_depth_least_square,
|
| 50 |
+
depth2disparity,
|
| 51 |
+
disparity2depth,
|
| 52 |
+
)
|
| 53 |
+
from src.util.seeding import generate_seed_sequence
|
| 54 |
+
from src.util.build_mlp import build_mlp_
|
| 55 |
+
from torchvision.transforms import Normalize
|
| 56 |
+
from external_encoder.dinov2.dinov2 import DINOv2
|
| 57 |
+
|
| 58 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 59 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 60 |
+
|
| 61 |
+
class DepthMasterTrainerS1:
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
cfg: OmegaConf,
|
| 65 |
+
model: DepthMasterPipeline,
|
| 66 |
+
train_dataloader: DataLoader,
|
| 67 |
+
device,
|
| 68 |
+
base_ckpt_dir,
|
| 69 |
+
out_dir_ckpt,
|
| 70 |
+
out_dir_eval,
|
| 71 |
+
out_dir_vis,
|
| 72 |
+
accumulation_steps: int,
|
| 73 |
+
val_dataloaders: List[DataLoader] = None,
|
| 74 |
+
vis_dataloaders: List[DataLoader] = None,
|
| 75 |
+
):
|
| 76 |
+
self.cfg: OmegaConf = cfg
|
| 77 |
+
self.model: DepthMasterPipeline = model
|
| 78 |
+
self.device = device
|
| 79 |
+
self.seed: Union[int, None] = (
|
| 80 |
+
self.cfg.trainer.init_seed
|
| 81 |
+
) # used to generate seed sequence, set to `None` to train w/o seeding
|
| 82 |
+
self.out_dir_ckpt = out_dir_ckpt
|
| 83 |
+
self.out_dir_eval = out_dir_eval
|
| 84 |
+
self.out_dir_vis = out_dir_vis
|
| 85 |
+
self.train_loader: DataLoader = train_dataloader
|
| 86 |
+
self.val_loaders: List[DataLoader] = val_dataloaders
|
| 87 |
+
self.vis_loaders: List[DataLoader] = vis_dataloaders
|
| 88 |
+
self.accumulation_steps: int = accumulation_steps
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Encode empty text prompt
|
| 92 |
+
self.model.encode_empty_text()
|
| 93 |
+
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
|
| 94 |
+
self.model.unet.enable_xformers_memory_efficient_attention()
|
| 95 |
+
|
| 96 |
+
# Initialize DINOv2 encoder
|
| 97 |
+
self.dinov2_encoder = DINOv2(model_name='vitg')
|
| 98 |
+
dinov2_encoder_dict = self.dinov2_encoder.state_dict()
|
| 99 |
+
pretrained_ckpt_dict = torch.load(f'checkpoints/depth_anything_v2_vitg.pth', map_location='cpu')
|
| 100 |
+
pretrained_dict = {k.replace('pretrained.', ''): v for k, v in pretrained_ckpt_dict.items() if k.replace('pretrained.', '') in dinov2_encoder_dict}
|
| 101 |
+
self.dinov2_encoder.load_state_dict(pretrained_dict)
|
| 102 |
+
del self.dinov2_encoder.head
|
| 103 |
+
self.dinov2_encoder.head = torch.nn.Identity()
|
| 104 |
+
self.dinov2_encoder.eval()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Initialize adapter to align the feat dimension of SD and DINOv2
|
| 108 |
+
self.dinov2_adapter = build_mlp_(hidden_size=1280, projector_dim=1536, z_dim=1536)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Trainability
|
| 112 |
+
self.dinov2_adapter.requires_grad_(True)
|
| 113 |
+
self.dinov2_encoder.requires_grad_(False)
|
| 114 |
+
self.model.vae.requires_grad_(False)
|
| 115 |
+
self.model.text_encoder.requires_grad_(False)
|
| 116 |
+
self.model.unet.requires_grad_(True)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Optimizer !should be defined after input layer is adapted
|
| 120 |
+
lr = self.cfg.lr
|
| 121 |
+
self.optimizer = Adam([
|
| 122 |
+
{'params': self.model.unet.parameters(), 'lr': lr},
|
| 123 |
+
{'params': self.dinov2_adapter.parameters(), 'lr': lr}
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
# LR scheduler
|
| 127 |
+
lr_func = IterExponential(
|
| 128 |
+
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
|
| 129 |
+
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
|
| 130 |
+
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
|
| 131 |
+
)
|
| 132 |
+
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
|
| 133 |
+
|
| 134 |
+
# Loss
|
| 135 |
+
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
|
| 136 |
+
|
| 137 |
+
# Eval metrics
|
| 138 |
+
self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
|
| 139 |
+
self.train_metrics = MetricTracker(*["loss", "feat_align_loss"])
|
| 140 |
+
self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
|
| 141 |
+
# main metric for best checkpoint saving
|
| 142 |
+
self.main_val_metric = cfg.validation.main_val_metric
|
| 143 |
+
self.main_val_metric_goal = cfg.validation.main_val_metric_goal
|
| 144 |
+
assert (
|
| 145 |
+
self.main_val_metric in cfg.eval.eval_metrics
|
| 146 |
+
), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
|
| 147 |
+
self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
|
| 148 |
+
|
| 149 |
+
# Settings
|
| 150 |
+
self.max_epoch = self.cfg.max_epoch
|
| 151 |
+
self.max_iter = self.cfg.max_iter
|
| 152 |
+
self.gradient_accumulation_steps = accumulation_steps
|
| 153 |
+
self.gt_depth_type = self.cfg.gt_depth_type
|
| 154 |
+
self.gt_mask_type = self.cfg.gt_mask_type
|
| 155 |
+
self.save_period = self.cfg.trainer.save_period
|
| 156 |
+
self.backup_period = self.cfg.trainer.backup_period
|
| 157 |
+
self.val_period = self.cfg.trainer.validation_period
|
| 158 |
+
self.vis_period = self.cfg.trainer.visualization_period
|
| 159 |
+
|
| 160 |
+
# Internal variables
|
| 161 |
+
self.epoch = 1
|
| 162 |
+
self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
|
| 163 |
+
self.effective_iter = 0 # how many times optimizer.step() is called
|
| 164 |
+
self.in_evaluation = False
|
| 165 |
+
self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def train(self, t_end=None):
|
| 169 |
+
logging.info("Start training")
|
| 170 |
+
|
| 171 |
+
device = self.device
|
| 172 |
+
self.model.to(device)
|
| 173 |
+
self.dinov2_encoder.to(device)
|
| 174 |
+
self.dinov2_adapter.to(device)
|
| 175 |
+
self.visualize()
|
| 176 |
+
|
| 177 |
+
if self.in_evaluation:
|
| 178 |
+
logging.info(
|
| 179 |
+
"Last evaluation was not finished, will do evaluation before continue training."
|
| 180 |
+
)
|
| 181 |
+
self.validate()
|
| 182 |
+
|
| 183 |
+
self.train_metrics.reset()
|
| 184 |
+
accumulated_step = 0
|
| 185 |
+
|
| 186 |
+
progress_bar = tqdm(
|
| 187 |
+
range(0, self.max_iter),
|
| 188 |
+
initial=self.effective_iter,
|
| 189 |
+
desc="iter"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
for epoch in range(self.epoch, self.max_epoch + 1):
|
| 193 |
+
self.epoch = epoch
|
| 194 |
+
logging.debug(f"epoch: {self.epoch}")
|
| 195 |
+
|
| 196 |
+
# Skip previous batches when resume
|
| 197 |
+
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
|
| 198 |
+
self.model.unet.train()
|
| 199 |
+
self.dinov2_adapter.train()
|
| 200 |
+
|
| 201 |
+
# >>> With gradient accumulation >>>
|
| 202 |
+
|
| 203 |
+
# Get data
|
| 204 |
+
rgb = batch["rgb_norm"].to(device)
|
| 205 |
+
depth_gt_for_latent = batch[self.gt_depth_type].to(device)
|
| 206 |
+
|
| 207 |
+
if self.gt_mask_type is not None:
|
| 208 |
+
valid_mask_for_latent = batch[self.gt_mask_type].to(device)
|
| 209 |
+
invalid_mask = ~valid_mask_for_latent
|
| 210 |
+
valid_mask_down = ~torch.max_pool2d(
|
| 211 |
+
invalid_mask.float(), 8, 8
|
| 212 |
+
).bool()
|
| 213 |
+
valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1))
|
| 214 |
+
else:
|
| 215 |
+
raise NotImplementedError
|
| 216 |
+
|
| 217 |
+
batch_size = rgb.shape[0]
|
| 218 |
+
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
# Encode image
|
| 221 |
+
rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
|
| 222 |
+
# Encode GT depth
|
| 223 |
+
gt_depth_latent = self.encode_depth(
|
| 224 |
+
depth_gt_for_latent
|
| 225 |
+
) # [B, 4, h, w]
|
| 226 |
+
# DINOv2 feat
|
| 227 |
+
dinov2_input_rgb = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(rgb)
|
| 228 |
+
dinov2_input_rgb = F.interpolate(dinov2_input_rgb, scale_factor=0.875, mode='bicubic')
|
| 229 |
+
dinov2_z = self.dinov2_encoder.forward_features(dinov2_input_rgb)['x_norm_patchtokens']
|
| 230 |
+
|
| 231 |
+
# Text embedding
|
| 232 |
+
text_embed = self.empty_text_embed.to(device).repeat(
|
| 233 |
+
(batch_size, 1, 1)
|
| 234 |
+
) # [B, 77, 1024]
|
| 235 |
+
|
| 236 |
+
# Predict the noise residual
|
| 237 |
+
rgb_latent = self.model.unet(
|
| 238 |
+
rgb_latent, 1, text_embed
|
| 239 |
+
) # [B, 4, h, w]
|
| 240 |
+
|
| 241 |
+
feat_16 = rgb_latent.feat_64
|
| 242 |
+
rgb_latent = rgb_latent.sample
|
| 243 |
+
|
| 244 |
+
if self.gt_mask_type is not None:
|
| 245 |
+
loss = self.loss(
|
| 246 |
+
rgb_latent[valid_mask_down].float(),
|
| 247 |
+
gt_depth_latent[valid_mask_down].float(),
|
| 248 |
+
).mean()
|
| 249 |
+
else:
|
| 250 |
+
loss = self.loss(rgb_latent.float(), gt_depth_latent.float()).mean()
|
| 251 |
+
|
| 252 |
+
self.train_metrics.update("loss", loss.item())
|
| 253 |
+
|
| 254 |
+
# feat align loss
|
| 255 |
+
b, c, h, w = feat_16.shape
|
| 256 |
+
_, _, H, W = rgb_latent.shape
|
| 257 |
+
# update dinov2_adapter
|
| 258 |
+
unet_16_feat_aligned = self.dinov2_adapter(feat_16.permute(0, 2, 3, 1).reshape(batch_size, -1, c))
|
| 259 |
+
if torch.isnan(rgb_latent).any():
|
| 260 |
+
logging.warning("model_pred contains NaN.")
|
| 261 |
+
|
| 262 |
+
dinov2_z = dinov2_z.reshape(b, int(H/2), int(W/2), -1).permute(0, 3, 1, 2)
|
| 263 |
+
dinov2_z = F.interpolate(dinov2_z, size=(h, w), mode='bicubic').permute(0, 2, 3, 1).reshape(b, h*w, -1)
|
| 264 |
+
|
| 265 |
+
# kl loss
|
| 266 |
+
unet_16_feat_aligned = F.softmax(unet_16_feat_aligned, dim=-1)
|
| 267 |
+
dinov2_z = F.softmax(dinov2_z, dim=-1)
|
| 268 |
+
|
| 269 |
+
loss_feat_align = F.kl_div(unet_16_feat_aligned.log(), dinov2_z)
|
| 270 |
+
|
| 271 |
+
self.train_metrics.update("feat_align_loss", loss_feat_align)
|
| 272 |
+
|
| 273 |
+
loss += self.cfg.loss_feat_align.lamda * loss_feat_align
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
loss = loss / self.gradient_accumulation_steps
|
| 277 |
+
loss.backward()
|
| 278 |
+
accumulated_step += 1
|
| 279 |
+
|
| 280 |
+
self.n_batch_in_epoch += 1
|
| 281 |
+
# Practical batch end
|
| 282 |
+
|
| 283 |
+
# Perform optimization step
|
| 284 |
+
if accumulated_step >= self.gradient_accumulation_steps:
|
| 285 |
+
self.optimizer.step()
|
| 286 |
+
self.lr_scheduler.step()
|
| 287 |
+
self.optimizer.zero_grad()
|
| 288 |
+
accumulated_step = 0
|
| 289 |
+
|
| 290 |
+
self.effective_iter += 1
|
| 291 |
+
progress_bar.update(1)
|
| 292 |
+
|
| 293 |
+
# Log to tensorboard
|
| 294 |
+
accumulated_loss = self.train_metrics.result()["loss"]
|
| 295 |
+
logs = {"loss": accumulated_loss}
|
| 296 |
+
progress_bar.set_postfix(**logs)
|
| 297 |
+
tb_logger.log_dic(
|
| 298 |
+
{
|
| 299 |
+
f"train/{k}": v
|
| 300 |
+
for k, v in self.train_metrics.result().items()
|
| 301 |
+
},
|
| 302 |
+
global_step=self.effective_iter,
|
| 303 |
+
)
|
| 304 |
+
tb_logger.writer.add_scalar(
|
| 305 |
+
"lr",
|
| 306 |
+
self.lr_scheduler.get_last_lr()[0],
|
| 307 |
+
global_step=self.effective_iter,
|
| 308 |
+
)
|
| 309 |
+
tb_logger.writer.add_scalar(
|
| 310 |
+
"n_batch_in_epoch",
|
| 311 |
+
self.n_batch_in_epoch,
|
| 312 |
+
global_step=self.effective_iter,
|
| 313 |
+
)
|
| 314 |
+
self.train_metrics.reset()
|
| 315 |
+
|
| 316 |
+
# Per-step callback
|
| 317 |
+
self._train_step_callback()
|
| 318 |
+
|
| 319 |
+
# End of training
|
| 320 |
+
if self.max_iter > 0 and self.effective_iter >= self.max_iter:
|
| 321 |
+
self.save_checkpoint(
|
| 322 |
+
ckpt_name=self._get_backup_ckpt_name(),
|
| 323 |
+
save_train_state=False,
|
| 324 |
+
)
|
| 325 |
+
logging.info("Training ended.")
|
| 326 |
+
return
|
| 327 |
+
# Time's up
|
| 328 |
+
elif t_end is not None and datetime.now() >= t_end:
|
| 329 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 330 |
+
logging.info("Time is up, training paused.")
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
torch.cuda.empty_cache()
|
| 334 |
+
# <<< Effective batch end <<<
|
| 335 |
+
|
| 336 |
+
# Epoch end
|
| 337 |
+
self.n_batch_in_epoch = 0
|
| 338 |
+
|
| 339 |
+
def encode_depth(self, depth_in):
|
| 340 |
+
# stack depth into 3-channel
|
| 341 |
+
stacked = self.stack_depth_images(depth_in)
|
| 342 |
+
# encode using VAE encoder
|
| 343 |
+
depth_latent = self.model.encode_rgb(stacked)
|
| 344 |
+
return depth_latent
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def stack_depth_images(depth_in):
|
| 348 |
+
if 4 == len(depth_in.shape):
|
| 349 |
+
stacked = depth_in.repeat(1, 3, 1, 1)
|
| 350 |
+
elif 3 == len(depth_in.shape):
|
| 351 |
+
stacked = depth_in.unsqueeze(1)
|
| 352 |
+
stacked = depth_in.repeat(1, 3, 1, 1)
|
| 353 |
+
return stacked
|
| 354 |
+
|
| 355 |
+
def _train_step_callback(self):
|
| 356 |
+
"""Executed after every iteration"""
|
| 357 |
+
# Save backup (with a larger interval, without training states)
|
| 358 |
+
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
|
| 359 |
+
self.save_checkpoint(
|
| 360 |
+
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
_is_latest_saved = False
|
| 364 |
+
# Validation
|
| 365 |
+
if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
|
| 366 |
+
self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
|
| 367 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 368 |
+
_is_latest_saved = True
|
| 369 |
+
self.validate()
|
| 370 |
+
self.in_evaluation = False
|
| 371 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 372 |
+
|
| 373 |
+
# Save training checkpoint (can be resumed)
|
| 374 |
+
if (
|
| 375 |
+
self.save_period > 0
|
| 376 |
+
and 0 == self.effective_iter % self.save_period
|
| 377 |
+
and not _is_latest_saved
|
| 378 |
+
):
|
| 379 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 380 |
+
|
| 381 |
+
# Visualization
|
| 382 |
+
if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
|
| 383 |
+
self.visualize()
|
| 384 |
+
|
| 385 |
+
def validate(self):
|
| 386 |
+
for i, val_loader in enumerate(self.val_loaders):
|
| 387 |
+
val_dataset_name = val_loader.dataset.disp_name
|
| 388 |
+
val_metric_dic = self.validate_single_dataset(
|
| 389 |
+
data_loader=val_loader, metric_tracker=self.val_metrics
|
| 390 |
+
)
|
| 391 |
+
logging.info(
|
| 392 |
+
f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}"
|
| 393 |
+
)
|
| 394 |
+
tb_logger.log_dic(
|
| 395 |
+
{f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
|
| 396 |
+
global_step=self.effective_iter,
|
| 397 |
+
)
|
| 398 |
+
# save to file
|
| 399 |
+
eval_text = eval_dic_to_text(
|
| 400 |
+
val_metrics=val_metric_dic,
|
| 401 |
+
dataset_name=val_dataset_name,
|
| 402 |
+
sample_list_path=val_loader.dataset.filename_ls_path,
|
| 403 |
+
)
|
| 404 |
+
_save_to = os.path.join(
|
| 405 |
+
self.out_dir_eval,
|
| 406 |
+
f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
|
| 407 |
+
)
|
| 408 |
+
with open(_save_to, "w+") as f:
|
| 409 |
+
f.write(eval_text)
|
| 410 |
+
|
| 411 |
+
# Update main eval metric
|
| 412 |
+
if 0 == i:
|
| 413 |
+
main_eval_metric = val_metric_dic[self.main_val_metric]
|
| 414 |
+
if (
|
| 415 |
+
"minimize" == self.main_val_metric_goal
|
| 416 |
+
and main_eval_metric < self.best_metric
|
| 417 |
+
or "maximize" == self.main_val_metric_goal
|
| 418 |
+
and main_eval_metric > self.best_metric
|
| 419 |
+
):
|
| 420 |
+
self.best_metric = main_eval_metric
|
| 421 |
+
logging.info(
|
| 422 |
+
f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
|
| 423 |
+
)
|
| 424 |
+
# Save a checkpoint
|
| 425 |
+
self.save_checkpoint(
|
| 426 |
+
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def visualize(self):
|
| 430 |
+
for val_loader in self.vis_loaders:
|
| 431 |
+
vis_dataset_name = val_loader.dataset.disp_name
|
| 432 |
+
vis_out_dir = os.path.join(
|
| 433 |
+
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
|
| 434 |
+
)
|
| 435 |
+
os.makedirs(vis_out_dir, exist_ok=True)
|
| 436 |
+
_ = self.validate_single_dataset(
|
| 437 |
+
data_loader=val_loader,
|
| 438 |
+
metric_tracker=self.val_metrics,
|
| 439 |
+
save_to_dir=vis_out_dir,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
@torch.no_grad()
|
| 443 |
+
def validate_single_dataset(
|
| 444 |
+
self,
|
| 445 |
+
data_loader: DataLoader,
|
| 446 |
+
metric_tracker: MetricTracker,
|
| 447 |
+
save_to_dir: str = None,
|
| 448 |
+
):
|
| 449 |
+
self.model.to(self.device)
|
| 450 |
+
metric_tracker.reset()
|
| 451 |
+
|
| 452 |
+
# Generate seed sequence for consistent evaluation
|
| 453 |
+
val_init_seed = self.cfg.validation.init_seed
|
| 454 |
+
val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
|
| 455 |
+
|
| 456 |
+
for i, batch in enumerate(
|
| 457 |
+
tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
|
| 458 |
+
start=1,
|
| 459 |
+
):
|
| 460 |
+
assert 1 == data_loader.batch_size
|
| 461 |
+
# Read input image
|
| 462 |
+
rgb_int = batch["rgb_int"] # [3, H, W]
|
| 463 |
+
# GT depth
|
| 464 |
+
depth_raw_ts = batch["depth_raw_linear"].squeeze()
|
| 465 |
+
depth_raw = depth_raw_ts.numpy()
|
| 466 |
+
depth_raw_ts = depth_raw_ts.to(self.device)
|
| 467 |
+
valid_mask_ts = batch["valid_mask_raw"].squeeze()
|
| 468 |
+
valid_mask = valid_mask_ts.numpy()
|
| 469 |
+
valid_mask_ts = valid_mask_ts.to(self.device)
|
| 470 |
+
|
| 471 |
+
# Predict depth
|
| 472 |
+
pipe_out: DepthMasterDepthOutput = self.model(
|
| 473 |
+
rgb_int,
|
| 474 |
+
processing_res=self.cfg.validation.processing_res,
|
| 475 |
+
match_input_res=self.cfg.validation.match_input_res,
|
| 476 |
+
batch_size=1, # use batch size 1 to increase reproducibility
|
| 477 |
+
color_map=None,
|
| 478 |
+
show_progress_bar=False,
|
| 479 |
+
resample_method=self.cfg.validation.resample_method,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
depth_pred: np.ndarray = pipe_out.depth_np.squeeze()
|
| 483 |
+
|
| 484 |
+
if "least_square" == self.cfg.eval.alignment:
|
| 485 |
+
depth_pred, scale, shift = align_depth_least_square(
|
| 486 |
+
gt_arr=depth_raw,
|
| 487 |
+
pred_arr=depth_pred,
|
| 488 |
+
valid_mask_arr=valid_mask,
|
| 489 |
+
return_scale_shift=True,
|
| 490 |
+
max_resolution=self.cfg.eval.align_max_res,
|
| 491 |
+
)
|
| 492 |
+
elif "least_square_disparity" == self.cfg.eval.alignment:
|
| 493 |
+
gt_disparity = depth_raw
|
| 494 |
+
gt_non_neg_mask = gt_disparity > 0
|
| 495 |
+
|
| 496 |
+
# LS alignment in disparity space
|
| 497 |
+
pred_non_neg_mask = depth_pred > 0
|
| 498 |
+
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
|
| 499 |
+
|
| 500 |
+
disparity_pred, scale, shift = align_depth_least_square(
|
| 501 |
+
gt_arr=gt_disparity,
|
| 502 |
+
pred_arr=depth_pred,
|
| 503 |
+
valid_mask_arr=valid_nonnegative_mask,
|
| 504 |
+
return_scale_shift=True,
|
| 505 |
+
)
|
| 506 |
+
# convert to depth
|
| 507 |
+
disparity_pred = np.clip(
|
| 508 |
+
disparity_pred, a_min=1e-3, a_max=None
|
| 509 |
+
) # avoid 0 disparity
|
| 510 |
+
depth_pred = disparity2depth(disparity_pred)
|
| 511 |
+
depth_raw_ts = disparity2depth(depth_raw_ts)
|
| 512 |
+
elif "least_square_sqrt_disp" == self.cfg.eval.alignment:
|
| 513 |
+
gt_sqrt_disp = depth_raw
|
| 514 |
+
gt_non_neg_mask = gt_sqrt_disp > 0
|
| 515 |
+
|
| 516 |
+
# LS alignment in sqrt space
|
| 517 |
+
pred_non_neg_mask = depth_pred > 0
|
| 518 |
+
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
|
| 519 |
+
depth_sqrt_disp_pred, scale, shift = align_depth_least_square(
|
| 520 |
+
gt_arr=gt_sqrt_disp,
|
| 521 |
+
pred_arr=depth_pred,
|
| 522 |
+
valid_mask_arr=valid_mask,
|
| 523 |
+
return_scale_shift=True,
|
| 524 |
+
)
|
| 525 |
+
# convert to depth
|
| 526 |
+
disparity_pred = depth_sqrt_disp_pred ** 2
|
| 527 |
+
depth_raw_ts = torch.pow(depth_raw_ts, 2)
|
| 528 |
+
# convert to depth
|
| 529 |
+
disparity_pred = np.clip(
|
| 530 |
+
disparity_pred, a_min=1e-3, a_max=None
|
| 531 |
+
) # avoid 0 disparity
|
| 532 |
+
depth_pred = disparity2depth(disparity_pred)
|
| 533 |
+
depth_raw_ts = disparity2depth(depth_raw_ts)
|
| 534 |
+
else:
|
| 535 |
+
raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
|
| 536 |
+
|
| 537 |
+
# Clip to dataset min max
|
| 538 |
+
depth_pred = np.clip(
|
| 539 |
+
depth_pred,
|
| 540 |
+
a_min=data_loader.dataset.min_depth,
|
| 541 |
+
a_max=data_loader.dataset.max_depth,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# clip to d > 0 for evaluation
|
| 545 |
+
depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
|
| 546 |
+
|
| 547 |
+
# Evaluate
|
| 548 |
+
sample_metric = []
|
| 549 |
+
depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
|
| 550 |
+
|
| 551 |
+
for met_func in self.metric_funcs:
|
| 552 |
+
_metric_name = met_func.__name__
|
| 553 |
+
_metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
|
| 554 |
+
sample_metric.append(_metric.__str__())
|
| 555 |
+
metric_tracker.update(_metric_name, _metric)
|
| 556 |
+
|
| 557 |
+
# Save as 16-bit uint png
|
| 558 |
+
if save_to_dir is not None:
|
| 559 |
+
img_name = batch["rgb_relative_path"][0].replace("/", "_")
|
| 560 |
+
png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
|
| 561 |
+
depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16)
|
| 562 |
+
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
|
| 563 |
+
|
| 564 |
+
return metric_tracker.result()
|
| 565 |
+
|
| 566 |
+
def _get_next_seed(self):
|
| 567 |
+
if 0 == len(self.global_seed_sequence):
|
| 568 |
+
self.global_seed_sequence = generate_seed_sequence(
|
| 569 |
+
initial_seed=self.seed,
|
| 570 |
+
length=self.max_iter * self.gradient_accumulation_steps,
|
| 571 |
+
)
|
| 572 |
+
logging.info(
|
| 573 |
+
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
|
| 574 |
+
)
|
| 575 |
+
return self.global_seed_sequence.pop()
|
| 576 |
+
|
| 577 |
+
def save_checkpoint(self, ckpt_name, save_train_state):
|
| 578 |
+
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
|
| 579 |
+
logging.info(f"Saving checkpoint to: {ckpt_dir}")
|
| 580 |
+
# Backup previous checkpoint
|
| 581 |
+
temp_ckpt_dir = None
|
| 582 |
+
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
|
| 583 |
+
temp_ckpt_dir = os.path.join(
|
| 584 |
+
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
|
| 585 |
+
)
|
| 586 |
+
if os.path.exists(temp_ckpt_dir):
|
| 587 |
+
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
|
| 588 |
+
os.rename(ckpt_dir, temp_ckpt_dir)
|
| 589 |
+
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
|
| 590 |
+
|
| 591 |
+
# Save UNet
|
| 592 |
+
unet_path = os.path.join(ckpt_dir, "unet")
|
| 593 |
+
self.model.unet.save_pretrained(unet_path, safe_serialization=False)
|
| 594 |
+
logging.info(f"UNet is saved to: {unet_path}")
|
| 595 |
+
|
| 596 |
+
# Save DINOv2_Adapter
|
| 597 |
+
adapter_path = os.path.join(ckpt_dir, "dinov2_adapter.pth")
|
| 598 |
+
state_dict = self.dinov2_adapter.state_dict()
|
| 599 |
+
torch.save(state_dict, adapter_path)
|
| 600 |
+
logging.info(f"dinov2_adapter is saved to: {adapter_path}")
|
| 601 |
+
|
| 602 |
+
if save_train_state:
|
| 603 |
+
state = {
|
| 604 |
+
"optimizer": self.optimizer.state_dict(),
|
| 605 |
+
"lr_scheduler": self.lr_scheduler.state_dict(),
|
| 606 |
+
"config": self.cfg,
|
| 607 |
+
"effective_iter": self.effective_iter,
|
| 608 |
+
"epoch": self.epoch,
|
| 609 |
+
"n_batch_in_epoch": self.n_batch_in_epoch,
|
| 610 |
+
"best_metric": self.best_metric,
|
| 611 |
+
"in_evaluation": self.in_evaluation,
|
| 612 |
+
"global_seed_sequence": self.global_seed_sequence,
|
| 613 |
+
}
|
| 614 |
+
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
|
| 615 |
+
torch.save(state, train_state_path)
|
| 616 |
+
# iteration indicator
|
| 617 |
+
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
|
| 618 |
+
f.close()
|
| 619 |
+
|
| 620 |
+
logging.info(f"Trainer state is saved to: {train_state_path}")
|
| 621 |
+
|
| 622 |
+
# Remove temp ckpt
|
| 623 |
+
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
|
| 624 |
+
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
|
| 625 |
+
logging.debug("Old checkpoint backup is removed.")
|
| 626 |
+
|
| 627 |
+
def load_checkpoint(
|
| 628 |
+
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
|
| 629 |
+
):
|
| 630 |
+
logging.info(f"Loading checkpoint from: {ckpt_path}")
|
| 631 |
+
# Load UNet
|
| 632 |
+
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
|
| 633 |
+
self.model.unet.load_state_dict(
|
| 634 |
+
torch.load(_model_path, map_location=self.device)
|
| 635 |
+
)
|
| 636 |
+
self.model.unet.to(self.device)
|
| 637 |
+
logging.info(f"UNet parameters are loaded from {_model_path}")
|
| 638 |
+
|
| 639 |
+
# Load DINOv2_adapter
|
| 640 |
+
_model_path = os.path.join(ckpt_path, "dinov2_adapter.pth")
|
| 641 |
+
self.dinov2_adapter.load_state_dict(
|
| 642 |
+
torch.load(_model_path, map_location=self.device)
|
| 643 |
+
)
|
| 644 |
+
self.dinov2_adapter.to(self.device)
|
| 645 |
+
logging.info(f"dinov2_adapter parameters are loaded from {_model_path}")
|
| 646 |
+
|
| 647 |
+
# Load training states
|
| 648 |
+
if load_trainer_state:
|
| 649 |
+
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
|
| 650 |
+
self.effective_iter = checkpoint["effective_iter"]
|
| 651 |
+
self.epoch = checkpoint["epoch"]
|
| 652 |
+
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
|
| 653 |
+
self.in_evaluation = checkpoint["in_evaluation"]
|
| 654 |
+
self.global_seed_sequence = checkpoint["global_seed_sequence"]
|
| 655 |
+
|
| 656 |
+
self.best_metric = checkpoint["best_metric"]
|
| 657 |
+
|
| 658 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 659 |
+
logging.info(f"optimizer state is loaded from {ckpt_path}")
|
| 660 |
+
|
| 661 |
+
if resume_lr_scheduler:
|
| 662 |
+
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 663 |
+
logging.info(f"LR scheduler state is loaded from {ckpt_path}")
|
| 664 |
+
|
| 665 |
+
logging.info(
|
| 666 |
+
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
|
| 667 |
+
)
|
| 668 |
+
return
|
| 669 |
+
|
| 670 |
+
def _get_backup_ckpt_name(self):
|
| 671 |
+
return f"iter_{self.effective_iter:06d}"
|
DepthMaster/src/trainer/trainer_s2.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# An official reimplemented version of Marigold training script.
|
| 2 |
+
# Last modified: 2024-04-29
|
| 3 |
+
#
|
| 4 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
# --------------------------------------------------------------------------
|
| 18 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 19 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 20 |
+
# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
|
| 21 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 22 |
+
# --------------------------------------------------------------------------
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from typing import List, Union
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
# from diffusers import DDPMScheduler
|
| 32 |
+
from omegaconf import OmegaConf
|
| 33 |
+
# from torch.nn import Conv2d
|
| 34 |
+
# from torch.nn.parameter import Parameter
|
| 35 |
+
from torch.optim import Adam
|
| 36 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 37 |
+
from torch.utils.data import DataLoader
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
from PIL import Image
|
| 40 |
+
|
| 41 |
+
from depthmaster.depthmaster_pipeline import DepthMasterPipeline, DepthMasterDepthOutput
|
| 42 |
+
from src.util import metric
|
| 43 |
+
from src.util.data_loader import skip_first_batches
|
| 44 |
+
from src.util.logging_util import tb_logger, eval_dic_to_text
|
| 45 |
+
from src.util.loss import get_loss
|
| 46 |
+
from src.util.lr_scheduler import IterExponential
|
| 47 |
+
from src.util.metric import MetricTracker
|
| 48 |
+
from src.util.alignment import (
|
| 49 |
+
align_depth_least_square,
|
| 50 |
+
depth2disparity,
|
| 51 |
+
disparity2depth,
|
| 52 |
+
align_depth_least_square_torch_mask,
|
| 53 |
+
align_depth_medium_mask
|
| 54 |
+
)
|
| 55 |
+
# from src.util.alignment import align_depth_least_square
|
| 56 |
+
# from src.util.alignment import align_depth_least_square
|
| 57 |
+
from src.util.seeding import generate_seed_sequence
|
| 58 |
+
import torch.nn.functional as F
|
| 59 |
+
|
| 60 |
+
class DepthMasterTrainerS2:
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
cfg: OmegaConf,
|
| 64 |
+
model: DepthMasterPipeline,
|
| 65 |
+
train_dataloader: DataLoader,
|
| 66 |
+
device,
|
| 67 |
+
base_ckpt_dir,
|
| 68 |
+
out_dir_ckpt,
|
| 69 |
+
out_dir_eval,
|
| 70 |
+
out_dir_vis,
|
| 71 |
+
accumulation_steps: int,
|
| 72 |
+
val_dataloaders: List[DataLoader] = None,
|
| 73 |
+
vis_dataloaders: List[DataLoader] = None,
|
| 74 |
+
):
|
| 75 |
+
self.cfg: OmegaConf = cfg
|
| 76 |
+
self.model: DepthMasterPipeline = model
|
| 77 |
+
self.device = device
|
| 78 |
+
self.seed: Union[int, None] = (
|
| 79 |
+
self.cfg.trainer.init_seed
|
| 80 |
+
) # used to generate seed sequence, set to `None` to train w/o seeding
|
| 81 |
+
self.out_dir_ckpt = out_dir_ckpt
|
| 82 |
+
self.out_dir_eval = out_dir_eval
|
| 83 |
+
self.out_dir_vis = out_dir_vis
|
| 84 |
+
self.train_loader: DataLoader = train_dataloader
|
| 85 |
+
self.val_loaders: List[DataLoader] = val_dataloaders
|
| 86 |
+
self.vis_loaders: List[DataLoader] = vis_dataloaders
|
| 87 |
+
self.accumulation_steps: int = accumulation_steps
|
| 88 |
+
|
| 89 |
+
# Encode empty text prompt
|
| 90 |
+
self.model.encode_empty_text()
|
| 91 |
+
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
|
| 92 |
+
|
| 93 |
+
self.model.unet.enable_xformers_memory_efficient_attention()
|
| 94 |
+
|
| 95 |
+
# Trainability
|
| 96 |
+
self.model.vae.requires_grad_(False)
|
| 97 |
+
self.model.vae.decoder.requires_grad_(False)
|
| 98 |
+
self.model.text_encoder.requires_grad_(False)
|
| 99 |
+
self.model.unet.requires_grad_(True)
|
| 100 |
+
|
| 101 |
+
# Optimizer !should be defined after input layer is adapted
|
| 102 |
+
lr = self.cfg.lr
|
| 103 |
+
self.optimizer = Adam(self.model.unet.parameters(), lr=lr)
|
| 104 |
+
|
| 105 |
+
# LR scheduler
|
| 106 |
+
lr_func = IterExponential(
|
| 107 |
+
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
|
| 108 |
+
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
|
| 109 |
+
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
|
| 110 |
+
)
|
| 111 |
+
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
|
| 112 |
+
|
| 113 |
+
# Loss
|
| 114 |
+
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
|
| 115 |
+
self.grad_loss = get_loss(loss_name=self.cfg.grad_loss.name, ** self.cfg.grad_loss.kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Eval metrics
|
| 119 |
+
self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
|
| 120 |
+
self.train_metrics = MetricTracker(*["loss", "grad_loss"])
|
| 121 |
+
self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
|
| 122 |
+
# main metric for best checkpoint saving
|
| 123 |
+
self.main_val_metric = cfg.validation.main_val_metric
|
| 124 |
+
self.main_val_metric_goal = cfg.validation.main_val_metric_goal
|
| 125 |
+
assert (
|
| 126 |
+
self.main_val_metric in cfg.eval.eval_metrics
|
| 127 |
+
), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
|
| 128 |
+
self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
|
| 129 |
+
|
| 130 |
+
# Settings
|
| 131 |
+
self.max_epoch = self.cfg.max_epoch
|
| 132 |
+
self.max_iter = self.cfg.max_iter
|
| 133 |
+
self.gradient_accumulation_steps = accumulation_steps
|
| 134 |
+
self.gt_depth_type = self.cfg.gt_depth_type
|
| 135 |
+
self.gt_mask_type = self.cfg.gt_mask_type
|
| 136 |
+
self.save_period = self.cfg.trainer.save_period
|
| 137 |
+
self.backup_period = self.cfg.trainer.backup_period
|
| 138 |
+
self.val_period = self.cfg.trainer.validation_period
|
| 139 |
+
self.vis_period = self.cfg.trainer.visualization_period
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Internal variables
|
| 143 |
+
self.epoch = 1
|
| 144 |
+
self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
|
| 145 |
+
self.effective_iter = 0 # how many times optimizer.step() is called
|
| 146 |
+
self.in_evaluation = False
|
| 147 |
+
self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def grad(self, x):
|
| 151 |
+
# x.shape : n, c, h, w
|
| 152 |
+
diff_x = x[..., 1:, 1:] - x[..., 1:, :-1]
|
| 153 |
+
diff_y = x[..., 1:, 1:] - x[..., :-1, 1:]
|
| 154 |
+
|
| 155 |
+
diff_45 = x[..., :-1, 1:] - x[..., 1:, :-1]
|
| 156 |
+
diff_135 = x[..., 1:, 1:] - x[..., :-1, :-1]
|
| 157 |
+
|
| 158 |
+
# mag = diff_x**2 + diff_y**2
|
| 159 |
+
# # angle_ratio
|
| 160 |
+
# angle = torch.atan(diff_y / (diff_x + 1e-10))
|
| 161 |
+
# result = torch.cat([mag, angle], dim=1)
|
| 162 |
+
result = torch.cat([diff_x, diff_y, diff_45, diff_135], dim=1)
|
| 163 |
+
return result
|
| 164 |
+
|
| 165 |
+
def train(self, t_end=None):
|
| 166 |
+
logging.info("Start training")
|
| 167 |
+
|
| 168 |
+
device = self.device
|
| 169 |
+
self.model.to(device)
|
| 170 |
+
|
| 171 |
+
self.visualize()
|
| 172 |
+
|
| 173 |
+
if self.in_evaluation:
|
| 174 |
+
logging.info(
|
| 175 |
+
"Last evaluation was not finished, will do evaluation before continue training."
|
| 176 |
+
)
|
| 177 |
+
self.validate()
|
| 178 |
+
|
| 179 |
+
self.train_metrics.reset()
|
| 180 |
+
accumulated_step = 0
|
| 181 |
+
|
| 182 |
+
progress_bar = tqdm(
|
| 183 |
+
range(0, self.max_iter),
|
| 184 |
+
initial=self.effective_iter,
|
| 185 |
+
desc="iter"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
for epoch in range(self.epoch, self.max_epoch + 1):
|
| 189 |
+
self.epoch = epoch
|
| 190 |
+
logging.debug(f"epoch: {self.epoch}")
|
| 191 |
+
|
| 192 |
+
# Skip previous batches when resume
|
| 193 |
+
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
|
| 194 |
+
self.model.unet.train()
|
| 195 |
+
|
| 196 |
+
# >>> With gradient accumulation >>>
|
| 197 |
+
|
| 198 |
+
# Get data
|
| 199 |
+
rgb = batch["rgb_norm"].to(device)
|
| 200 |
+
depth_gt_for_latent = batch[self.gt_depth_type].to(device)
|
| 201 |
+
|
| 202 |
+
if self.gt_mask_type is not None:
|
| 203 |
+
valid_mask_for_latent = batch[self.gt_mask_type].to(device)
|
| 204 |
+
else:
|
| 205 |
+
raise NotImplementedError
|
| 206 |
+
|
| 207 |
+
batch_size = rgb.shape[0]
|
| 208 |
+
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
# Encode image
|
| 211 |
+
rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
|
| 212 |
+
|
| 213 |
+
# Text embedding
|
| 214 |
+
text_embed = self.empty_text_embed.to(device).repeat(
|
| 215 |
+
(batch_size, 1, 1)
|
| 216 |
+
) # [B, 77, 1024]
|
| 217 |
+
|
| 218 |
+
rgb_latent = self.model.unet(
|
| 219 |
+
rgb_latent, 1, text_embed
|
| 220 |
+
).sample # [B, 4, h, w]
|
| 221 |
+
|
| 222 |
+
depth_pred = self.model.decode_depth(rgb_latent)
|
| 223 |
+
depth_gt_for_loss = depth_gt_for_latent
|
| 224 |
+
|
| 225 |
+
aligned_pred = depth_pred
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if self.gt_mask_type is not None:
|
| 229 |
+
loss = self.loss(aligned_pred[valid_mask_for_latent].float(), depth_gt_for_loss[valid_mask_for_latent].float()).mean()
|
| 230 |
+
else:
|
| 231 |
+
loss = self.loss(aligned_pred.float(), depth_gt_for_loss.float()).mean()
|
| 232 |
+
|
| 233 |
+
self.train_metrics.update("loss", loss.item())
|
| 234 |
+
|
| 235 |
+
# grad loss
|
| 236 |
+
depth_gt_for_loss[~valid_mask_for_latent] = 0
|
| 237 |
+
grad_gt = self.grad(depth_gt_for_loss)
|
| 238 |
+
aligned_pred[~valid_mask_for_latent] = 0
|
| 239 |
+
grad_pred = self.grad(aligned_pred)
|
| 240 |
+
grad_loss = self.grad_loss(grad_gt, grad_pred)
|
| 241 |
+
self.train_metrics.update(f"grad_loss", grad_loss.item())
|
| 242 |
+
loss += self.cfg.grad_loss.lamda * grad_loss
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
loss = loss / self.gradient_accumulation_steps
|
| 246 |
+
loss.backward()
|
| 247 |
+
accumulated_step += 1
|
| 248 |
+
|
| 249 |
+
self.n_batch_in_epoch += 1
|
| 250 |
+
# Practical batch end
|
| 251 |
+
|
| 252 |
+
# Perform optimization step
|
| 253 |
+
if accumulated_step >= self.gradient_accumulation_steps:
|
| 254 |
+
self.optimizer.step()
|
| 255 |
+
self.lr_scheduler.step()
|
| 256 |
+
self.optimizer.zero_grad()
|
| 257 |
+
accumulated_step = 0
|
| 258 |
+
|
| 259 |
+
self.effective_iter += 1
|
| 260 |
+
progress_bar.update(1)
|
| 261 |
+
|
| 262 |
+
# Log to tensorboard
|
| 263 |
+
accumulated_loss = self.train_metrics.result()["loss"]
|
| 264 |
+
logs = {"loss": accumulated_loss}
|
| 265 |
+
progress_bar.set_postfix(**logs)
|
| 266 |
+
tb_logger.log_dic(
|
| 267 |
+
{
|
| 268 |
+
f"train/{k}": v
|
| 269 |
+
for k, v in self.train_metrics.result().items()
|
| 270 |
+
},
|
| 271 |
+
global_step=self.effective_iter,
|
| 272 |
+
)
|
| 273 |
+
tb_logger.writer.add_scalar(
|
| 274 |
+
"lr",
|
| 275 |
+
self.lr_scheduler.get_last_lr()[0],
|
| 276 |
+
global_step=self.effective_iter,
|
| 277 |
+
)
|
| 278 |
+
tb_logger.writer.add_scalar(
|
| 279 |
+
"n_batch_in_epoch",
|
| 280 |
+
self.n_batch_in_epoch,
|
| 281 |
+
global_step=self.effective_iter,
|
| 282 |
+
)
|
| 283 |
+
self.train_metrics.reset()
|
| 284 |
+
|
| 285 |
+
# Per-step callback
|
| 286 |
+
self._train_step_callback()
|
| 287 |
+
|
| 288 |
+
# End of training
|
| 289 |
+
if self.max_iter > 0 and self.effective_iter >= self.max_iter:
|
| 290 |
+
self.save_checkpoint(
|
| 291 |
+
ckpt_name=self._get_backup_ckpt_name(),
|
| 292 |
+
save_train_state=False,
|
| 293 |
+
)
|
| 294 |
+
logging.info("Training ended.")
|
| 295 |
+
return
|
| 296 |
+
# Time's up
|
| 297 |
+
elif t_end is not None and datetime.now() >= t_end:
|
| 298 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 299 |
+
logging.info("Time is up, training paused.")
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
torch.cuda.empty_cache()
|
| 303 |
+
# <<< Effective batch end <<<
|
| 304 |
+
|
| 305 |
+
# Epoch end
|
| 306 |
+
self.n_batch_in_epoch = 0
|
| 307 |
+
|
| 308 |
+
def encode_depth(self, depth_in):
|
| 309 |
+
# stack depth into 3-channel
|
| 310 |
+
stacked = self.stack_depth_images(depth_in)
|
| 311 |
+
# encode using VAE encoder
|
| 312 |
+
depth_latent = self.model.encode_rgb(stacked)
|
| 313 |
+
return depth_latent
|
| 314 |
+
|
| 315 |
+
@staticmethod
|
| 316 |
+
def stack_depth_images(depth_in):
|
| 317 |
+
if 4 == len(depth_in.shape):
|
| 318 |
+
stacked = depth_in.repeat(1, 3, 1, 1)
|
| 319 |
+
elif 3 == len(depth_in.shape):
|
| 320 |
+
stacked = depth_in.unsqueeze(1)
|
| 321 |
+
stacked = depth_in.repeat(1, 3, 1, 1)
|
| 322 |
+
return stacked
|
| 323 |
+
|
| 324 |
+
def _train_step_callback(self):
|
| 325 |
+
"""Executed after every iteration"""
|
| 326 |
+
# Save backup (with a larger interval, without training states)
|
| 327 |
+
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
|
| 328 |
+
self.save_checkpoint(
|
| 329 |
+
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
_is_latest_saved = False
|
| 333 |
+
# Validation
|
| 334 |
+
if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
|
| 335 |
+
self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
|
| 336 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 337 |
+
_is_latest_saved = True
|
| 338 |
+
self.validate()
|
| 339 |
+
self.in_evaluation = False
|
| 340 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 341 |
+
|
| 342 |
+
# Save training checkpoint (can be resumed)
|
| 343 |
+
if (
|
| 344 |
+
self.save_period > 0
|
| 345 |
+
and 0 == self.effective_iter % self.save_period
|
| 346 |
+
and not _is_latest_saved
|
| 347 |
+
):
|
| 348 |
+
self.save_checkpoint(ckpt_name="latest", save_train_state=True)
|
| 349 |
+
|
| 350 |
+
# Visualization
|
| 351 |
+
if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
|
| 352 |
+
self.visualize()
|
| 353 |
+
|
| 354 |
+
def validate(self):
|
| 355 |
+
for i, val_loader in enumerate(self.val_loaders):
|
| 356 |
+
val_dataset_name = val_loader.dataset.disp_name
|
| 357 |
+
val_metric_dic = self.validate_single_dataset(
|
| 358 |
+
data_loader=val_loader, metric_tracker=self.val_metrics
|
| 359 |
+
)
|
| 360 |
+
logging.info(
|
| 361 |
+
f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}"
|
| 362 |
+
)
|
| 363 |
+
tb_logger.log_dic(
|
| 364 |
+
{f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
|
| 365 |
+
global_step=self.effective_iter,
|
| 366 |
+
)
|
| 367 |
+
# save to file
|
| 368 |
+
eval_text = eval_dic_to_text(
|
| 369 |
+
val_metrics=val_metric_dic,
|
| 370 |
+
dataset_name=val_dataset_name,
|
| 371 |
+
sample_list_path=val_loader.dataset.filename_ls_path,
|
| 372 |
+
)
|
| 373 |
+
_save_to = os.path.join(
|
| 374 |
+
self.out_dir_eval,
|
| 375 |
+
f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
|
| 376 |
+
)
|
| 377 |
+
with open(_save_to, "w+") as f:
|
| 378 |
+
f.write(eval_text)
|
| 379 |
+
|
| 380 |
+
# Update main eval metric
|
| 381 |
+
if 0 == i:
|
| 382 |
+
main_eval_metric = val_metric_dic[self.main_val_metric]
|
| 383 |
+
if (
|
| 384 |
+
"minimize" == self.main_val_metric_goal
|
| 385 |
+
and main_eval_metric < self.best_metric
|
| 386 |
+
or "maximize" == self.main_val_metric_goal
|
| 387 |
+
and main_eval_metric > self.best_metric
|
| 388 |
+
):
|
| 389 |
+
self.best_metric = main_eval_metric
|
| 390 |
+
logging.info(
|
| 391 |
+
f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
|
| 392 |
+
)
|
| 393 |
+
# Save a checkpoint
|
| 394 |
+
self.save_checkpoint(
|
| 395 |
+
ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
def visualize(self):
|
| 399 |
+
for val_loader in self.vis_loaders:
|
| 400 |
+
vis_dataset_name = val_loader.dataset.disp_name
|
| 401 |
+
vis_out_dir = os.path.join(
|
| 402 |
+
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
|
| 403 |
+
)
|
| 404 |
+
os.makedirs(vis_out_dir, exist_ok=True)
|
| 405 |
+
_ = self.validate_single_dataset(
|
| 406 |
+
data_loader=val_loader,
|
| 407 |
+
metric_tracker=self.val_metrics,
|
| 408 |
+
save_to_dir=vis_out_dir,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def validate_single_dataset(
|
| 413 |
+
self,
|
| 414 |
+
data_loader: DataLoader,
|
| 415 |
+
metric_tracker: MetricTracker,
|
| 416 |
+
save_to_dir: str = None,
|
| 417 |
+
):
|
| 418 |
+
self.model.to(self.device)
|
| 419 |
+
metric_tracker.reset()
|
| 420 |
+
|
| 421 |
+
# Generate seed sequence for consistent evaluation
|
| 422 |
+
val_init_seed = self.cfg.validation.init_seed
|
| 423 |
+
val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
|
| 424 |
+
|
| 425 |
+
for i, batch in enumerate(
|
| 426 |
+
tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
|
| 427 |
+
start=1,
|
| 428 |
+
):
|
| 429 |
+
assert 1 == data_loader.batch_size
|
| 430 |
+
# Read input image
|
| 431 |
+
rgb_int = batch["rgb_int"] # [3, H, W]
|
| 432 |
+
# GT depth
|
| 433 |
+
depth_raw_ts = batch["depth_raw_linear"].squeeze()
|
| 434 |
+
depth_raw = depth_raw_ts.numpy()
|
| 435 |
+
depth_raw_ts = depth_raw_ts.to(self.device)
|
| 436 |
+
valid_mask_ts = batch["valid_mask_raw"].squeeze()
|
| 437 |
+
valid_mask = valid_mask_ts.numpy()
|
| 438 |
+
valid_mask_ts = valid_mask_ts.to(self.device)
|
| 439 |
+
|
| 440 |
+
# Predict depth
|
| 441 |
+
pipe_out: DepthMasterDepthOutput = self.model(
|
| 442 |
+
rgb_int,
|
| 443 |
+
processing_res=self.cfg.validation.processing_res,
|
| 444 |
+
match_input_res=self.cfg.validation.match_input_res,
|
| 445 |
+
batch_size=1, # use batch size 1 to increase reproducibility
|
| 446 |
+
color_map=None,
|
| 447 |
+
show_progress_bar=False,
|
| 448 |
+
resample_method=self.cfg.validation.resample_method,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
depth_pred: np.ndarray = pipe_out.depth_np.squeeze()
|
| 452 |
+
|
| 453 |
+
if "least_square" == self.cfg.eval.alignment:
|
| 454 |
+
depth_pred, scale, shift = align_depth_least_square(
|
| 455 |
+
gt_arr=depth_raw,
|
| 456 |
+
pred_arr=depth_pred,
|
| 457 |
+
valid_mask_arr=valid_mask,
|
| 458 |
+
return_scale_shift=True,
|
| 459 |
+
max_resolution=self.cfg.eval.align_max_res,
|
| 460 |
+
)
|
| 461 |
+
elif "least_square_disparity" == self.cfg.eval.alignment:
|
| 462 |
+
# gt_disparity = depth_raw
|
| 463 |
+
gt_disparity = depth2disparity(depth_raw)
|
| 464 |
+
gt_non_neg_mask = gt_disparity > 0
|
| 465 |
+
|
| 466 |
+
# LS alignment in disparity space
|
| 467 |
+
pred_non_neg_mask = depth_pred > 0
|
| 468 |
+
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
|
| 469 |
+
|
| 470 |
+
disparity_pred, scale, shift = align_depth_least_square(
|
| 471 |
+
gt_arr=gt_disparity,
|
| 472 |
+
pred_arr=depth_pred,
|
| 473 |
+
valid_mask_arr=valid_nonnegative_mask,
|
| 474 |
+
return_scale_shift=True,
|
| 475 |
+
)
|
| 476 |
+
# convert to depth
|
| 477 |
+
disparity_pred = np.clip(
|
| 478 |
+
disparity_pred, a_min=1e-3, a_max=None
|
| 479 |
+
) # avoid 0 disparity
|
| 480 |
+
depth_pred = disparity2depth(disparity_pred)
|
| 481 |
+
depth_raw_ts = disparity2depth(depth_raw_ts)
|
| 482 |
+
elif "least_square_sqrt_disp" == self.cfg.eval.alignment:
|
| 483 |
+
# gt_sqrt_disp = depth_raw
|
| 484 |
+
gt_sqrt_disp = np.sqrt(depth2disparity(depth_raw))
|
| 485 |
+
gt_non_neg_mask = gt_sqrt_disp > 0
|
| 486 |
+
|
| 487 |
+
# LS alignment in sqrt space
|
| 488 |
+
pred_non_neg_mask = depth_pred > 0
|
| 489 |
+
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
|
| 490 |
+
depth_sqrt_disp_pred, scale, shift = align_depth_least_square(
|
| 491 |
+
gt_arr=gt_sqrt_disp,
|
| 492 |
+
pred_arr=depth_pred,
|
| 493 |
+
valid_mask_arr=valid_mask,
|
| 494 |
+
return_scale_shift=True,
|
| 495 |
+
)
|
| 496 |
+
# convert to depth
|
| 497 |
+
disparity_pred = depth_sqrt_disp_pred ** 2
|
| 498 |
+
depth_raw_ts = torch.pow(depth_raw_ts, 2)
|
| 499 |
+
# convert to depth
|
| 500 |
+
disparity_pred = np.clip(
|
| 501 |
+
disparity_pred, a_min=1e-3, a_max=None
|
| 502 |
+
) # avoid 0 disparity
|
| 503 |
+
depth_pred = disparity2depth(disparity_pred)
|
| 504 |
+
depth_raw_ts = disparity2depth(depth_raw_ts)
|
| 505 |
+
else:
|
| 506 |
+
raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
|
| 507 |
+
|
| 508 |
+
# Clip to dataset min max
|
| 509 |
+
depth_pred = np.clip(
|
| 510 |
+
depth_pred,
|
| 511 |
+
a_min=data_loader.dataset.min_depth,
|
| 512 |
+
a_max=data_loader.dataset.max_depth,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# clip to d > 0 for evaluation
|
| 516 |
+
depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
|
| 517 |
+
|
| 518 |
+
# Evaluate
|
| 519 |
+
sample_metric = []
|
| 520 |
+
depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
|
| 521 |
+
|
| 522 |
+
for met_func in self.metric_funcs:
|
| 523 |
+
_metric_name = met_func.__name__
|
| 524 |
+
_metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
|
| 525 |
+
sample_metric.append(_metric.__str__())
|
| 526 |
+
metric_tracker.update(_metric_name, _metric)
|
| 527 |
+
|
| 528 |
+
# Save as 16-bit uint png
|
| 529 |
+
if save_to_dir is not None:
|
| 530 |
+
img_name = batch["rgb_relative_path"][0].replace("/", "_")
|
| 531 |
+
png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
|
| 532 |
+
depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16)
|
| 533 |
+
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
|
| 534 |
+
|
| 535 |
+
return metric_tracker.result()
|
| 536 |
+
|
| 537 |
+
def _get_next_seed(self):
|
| 538 |
+
if 0 == len(self.global_seed_sequence):
|
| 539 |
+
self.global_seed_sequence = generate_seed_sequence(
|
| 540 |
+
initial_seed=self.seed,
|
| 541 |
+
length=self.max_iter * self.gradient_accumulation_steps,
|
| 542 |
+
)
|
| 543 |
+
logging.info(
|
| 544 |
+
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
|
| 545 |
+
)
|
| 546 |
+
return self.global_seed_sequence.pop()
|
| 547 |
+
|
| 548 |
+
def save_checkpoint(self, ckpt_name, save_train_state):
|
| 549 |
+
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
|
| 550 |
+
logging.info(f"Saving checkpoint to: {ckpt_dir}")
|
| 551 |
+
# Backup previous checkpoint
|
| 552 |
+
temp_ckpt_dir = None
|
| 553 |
+
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
|
| 554 |
+
temp_ckpt_dir = os.path.join(
|
| 555 |
+
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
|
| 556 |
+
)
|
| 557 |
+
if os.path.exists(temp_ckpt_dir):
|
| 558 |
+
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
|
| 559 |
+
os.rename(ckpt_dir, temp_ckpt_dir)
|
| 560 |
+
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
|
| 561 |
+
|
| 562 |
+
# Save UNet
|
| 563 |
+
unet_path = os.path.join(ckpt_dir, "unet")
|
| 564 |
+
self.model.unet.save_pretrained(unet_path, safe_serialization=False)
|
| 565 |
+
logging.info(f"UNet is saved to: {unet_path}")
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
if save_train_state:
|
| 569 |
+
state = {
|
| 570 |
+
"optimizer": self.optimizer.state_dict(),
|
| 571 |
+
"lr_scheduler": self.lr_scheduler.state_dict(),
|
| 572 |
+
"config": self.cfg,
|
| 573 |
+
"effective_iter": self.effective_iter,
|
| 574 |
+
"epoch": self.epoch,
|
| 575 |
+
"n_batch_in_epoch": self.n_batch_in_epoch,
|
| 576 |
+
"best_metric": self.best_metric,
|
| 577 |
+
"in_evaluation": self.in_evaluation,
|
| 578 |
+
"global_seed_sequence": self.global_seed_sequence,
|
| 579 |
+
}
|
| 580 |
+
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
|
| 581 |
+
torch.save(state, train_state_path)
|
| 582 |
+
# iteration indicator
|
| 583 |
+
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
|
| 584 |
+
f.close()
|
| 585 |
+
|
| 586 |
+
logging.info(f"Trainer state is saved to: {train_state_path}")
|
| 587 |
+
|
| 588 |
+
# Remove temp ckpt
|
| 589 |
+
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
|
| 590 |
+
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
|
| 591 |
+
logging.debug("Old checkpoint backup is removed.")
|
| 592 |
+
|
| 593 |
+
def load_checkpoint(
|
| 594 |
+
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
|
| 595 |
+
):
|
| 596 |
+
logging.info(f"Loading checkpoint from: {ckpt_path}")
|
| 597 |
+
# Load UNet
|
| 598 |
+
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
|
| 599 |
+
self.model.unet.load_state_dict(
|
| 600 |
+
torch.load(_model_path, map_location=self.device)
|
| 601 |
+
)
|
| 602 |
+
self.model.unet.to(self.device)
|
| 603 |
+
logging.info(f"UNet parameters are loaded from {_model_path}")
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
# Load training states
|
| 607 |
+
if load_trainer_state:
|
| 608 |
+
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
|
| 609 |
+
self.effective_iter = checkpoint["effective_iter"]
|
| 610 |
+
self.epoch = checkpoint["epoch"]
|
| 611 |
+
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
|
| 612 |
+
self.in_evaluation = checkpoint["in_evaluation"]
|
| 613 |
+
self.global_seed_sequence = checkpoint["global_seed_sequence"]
|
| 614 |
+
|
| 615 |
+
self.best_metric = checkpoint["best_metric"]
|
| 616 |
+
|
| 617 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 618 |
+
logging.info(f"optimizer state is loaded from {ckpt_path}")
|
| 619 |
+
|
| 620 |
+
if resume_lr_scheduler:
|
| 621 |
+
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 622 |
+
logging.info(f"LR scheduler state is loaded from {ckpt_path}")
|
| 623 |
+
|
| 624 |
+
logging.info(
|
| 625 |
+
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
|
| 626 |
+
)
|
| 627 |
+
return
|
| 628 |
+
|
| 629 |
+
def _get_backup_ckpt_name(self):
|
| 630 |
+
return f"iter_{self.effective_iter:06d}"
|
DepthMaster/src/util/alignment.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
def align_depth_medium_mask(
|
| 29 |
+
gt: torch.Tensor,
|
| 30 |
+
valid_mask: torch.Tensor,
|
| 31 |
+
max_resolution=None,
|
| 32 |
+
):
|
| 33 |
+
ori_shape = gt.shape[-2:] # input shape
|
| 34 |
+
batch_size = gt.shape[0]
|
| 35 |
+
# print(gt.shape)
|
| 36 |
+
|
| 37 |
+
# Downsample
|
| 38 |
+
if max_resolution is not None:
|
| 39 |
+
scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
|
| 40 |
+
if scale_factor < 1:
|
| 41 |
+
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
| 42 |
+
gt = downscaler(gt)
|
| 43 |
+
valid_mask = downscaler(valid_mask).bool()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
scale_ls = []
|
| 47 |
+
shift_ls = []
|
| 48 |
+
|
| 49 |
+
for i in range(batch_size):
|
| 50 |
+
# print('yes')
|
| 51 |
+
|
| 52 |
+
gt_masked = gt[i][valid_mask[i]]
|
| 53 |
+
shift = torch.median(gt_masked).unsqueeze(0)
|
| 54 |
+
scale = torch.mean(torch.abs(gt_masked - shift)).unsqueeze(0)
|
| 55 |
+
# print(scale)
|
| 56 |
+
|
| 57 |
+
scale_ls.append(scale)
|
| 58 |
+
shift_ls.append(shift)
|
| 59 |
+
# print(len(scale_ls))
|
| 60 |
+
|
| 61 |
+
scale = torch.concat(scale_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 62 |
+
shift = torch.concat(shift_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 63 |
+
|
| 64 |
+
return scale, shift
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def align_depth_least_square(
|
| 68 |
+
gt_arr: np.ndarray,
|
| 69 |
+
pred_arr: np.ndarray,
|
| 70 |
+
valid_mask_arr: np.ndarray,
|
| 71 |
+
return_scale_shift=True,
|
| 72 |
+
max_resolution=None,
|
| 73 |
+
):
|
| 74 |
+
ori_shape = pred_arr.shape # input shape
|
| 75 |
+
|
| 76 |
+
gt = gt_arr.squeeze() # [H, W]
|
| 77 |
+
pred = pred_arr.squeeze()
|
| 78 |
+
valid_mask = valid_mask_arr.squeeze()
|
| 79 |
+
|
| 80 |
+
# Downsample
|
| 81 |
+
if max_resolution is not None:
|
| 82 |
+
scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
|
| 83 |
+
if scale_factor < 1:
|
| 84 |
+
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
| 85 |
+
gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
|
| 86 |
+
pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
|
| 87 |
+
valid_mask = (
|
| 88 |
+
downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
|
| 89 |
+
.bool()
|
| 90 |
+
.numpy()
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
assert (
|
| 94 |
+
gt.shape == pred.shape == valid_mask.shape
|
| 95 |
+
), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
|
| 96 |
+
|
| 97 |
+
gt_masked = gt[valid_mask].reshape((-1, 1))
|
| 98 |
+
pred_masked = pred[valid_mask].reshape((-1, 1))
|
| 99 |
+
|
| 100 |
+
# numpy solver
|
| 101 |
+
_ones = np.ones_like(pred_masked)
|
| 102 |
+
A = np.concatenate([pred_masked, _ones], axis=-1)
|
| 103 |
+
X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
|
| 104 |
+
scale, shift = X
|
| 105 |
+
|
| 106 |
+
aligned_pred = pred_arr * scale + shift
|
| 107 |
+
|
| 108 |
+
# restore dimensions
|
| 109 |
+
aligned_pred = aligned_pred.reshape(ori_shape)
|
| 110 |
+
|
| 111 |
+
if return_scale_shift:
|
| 112 |
+
return aligned_pred, scale, shift
|
| 113 |
+
else:
|
| 114 |
+
return aligned_pred
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ******************** disparity space ********************
|
| 118 |
+
def depth2disparity(depth, return_mask=False):
|
| 119 |
+
if isinstance(depth, torch.Tensor):
|
| 120 |
+
disparity = torch.zeros_like(depth)
|
| 121 |
+
elif isinstance(depth, np.ndarray):
|
| 122 |
+
disparity = np.zeros_like(depth)
|
| 123 |
+
non_negtive_mask = depth > 0
|
| 124 |
+
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
|
| 125 |
+
if return_mask:
|
| 126 |
+
return disparity, non_negtive_mask
|
| 127 |
+
else:
|
| 128 |
+
return disparity
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def disparity2depth(disparity, **kwargs):
|
| 132 |
+
return depth2disparity(disparity, **kwargs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def align_depth_least_square_torch_mask(
|
| 136 |
+
gt: torch.Tensor,
|
| 137 |
+
pred: torch.Tensor,
|
| 138 |
+
valid_mask: torch.Tensor,
|
| 139 |
+
max_resolution=None,
|
| 140 |
+
):
|
| 141 |
+
ori_shape = pred.shape[-2:] # input shape
|
| 142 |
+
batch_size = gt.shape[0]
|
| 143 |
+
|
| 144 |
+
# gt = gt_arr.squeeze() # [B, H, W]
|
| 145 |
+
# pred = pred_arr.squeeze()
|
| 146 |
+
# valid_mask = valid_mask_arr.squeeze()
|
| 147 |
+
|
| 148 |
+
# Downsample
|
| 149 |
+
if max_resolution is not None:
|
| 150 |
+
scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
|
| 151 |
+
if scale_factor < 1:
|
| 152 |
+
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
| 153 |
+
gt = downscaler(gt)
|
| 154 |
+
pred = downscaler(pred)
|
| 155 |
+
valid_mask = downscaler(valid_mask).bool()
|
| 156 |
+
|
| 157 |
+
assert (
|
| 158 |
+
gt.shape == pred.shape
|
| 159 |
+
), f"{gt.shape}, {pred.shape}"
|
| 160 |
+
|
| 161 |
+
scale_ls = []
|
| 162 |
+
shift_ls = []
|
| 163 |
+
|
| 164 |
+
for i in range(batch_size):
|
| 165 |
+
|
| 166 |
+
gt_masked = gt[i][valid_mask[i]].view(-1, 1)
|
| 167 |
+
pred_masked = pred[i][valid_mask[i]].view(-1, 1)
|
| 168 |
+
|
| 169 |
+
# torch solver
|
| 170 |
+
ones = torch.ones_like(pred_masked)
|
| 171 |
+
A = torch.cat([pred_masked, ones], dim=-1)
|
| 172 |
+
X, *_ = torch.linalg.lstsq(A, gt_masked)
|
| 173 |
+
|
| 174 |
+
scale, shift = X[0, :].detach(), X[1, :].detach()
|
| 175 |
+
scale_ls.append(scale)
|
| 176 |
+
shift_ls.append(shift)
|
| 177 |
+
scale = torch.concat(scale_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 178 |
+
shift = torch.concat(shift_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 179 |
+
|
| 180 |
+
return scale, shift
|
DepthMaster/src/util/boundary_metrics.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]: # type: ignore
|
| 7 |
+
"""Find connected components in the given row and column indices.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
----
|
| 11 |
+
r (np.ndarray): Row indices.
|
| 12 |
+
c (np.ndarray): Column indices.
|
| 13 |
+
|
| 14 |
+
Yields:
|
| 15 |
+
------
|
| 16 |
+
List[int]: Indices of connected components.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
indices = [0]
|
| 20 |
+
for i in range(1, r.size):
|
| 21 |
+
if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
|
| 22 |
+
indices.append(i)
|
| 23 |
+
else:
|
| 24 |
+
yield indices
|
| 25 |
+
indices = [i]
|
| 26 |
+
yield indices
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
|
| 30 |
+
"""Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
----
|
| 34 |
+
ratio (np.ndarray): Input ratio matrix.
|
| 35 |
+
threshold (float): Threshold for NMS.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
-------
|
| 39 |
+
np.ndarray: Binary mask after applying NMS.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
mask = np.zeros_like(ratio, dtype=bool)
|
| 43 |
+
r, c = np.nonzero(ratio > threshold)
|
| 44 |
+
if len(r) == 0:
|
| 45 |
+
return mask
|
| 46 |
+
for ids in connected_component(r, c):
|
| 47 |
+
values = [ratio[r[i], c[i]] for i in ids]
|
| 48 |
+
mi = np.argmax(values)
|
| 49 |
+
mask[r[ids[mi]], c[ids[mi]]] = True
|
| 50 |
+
return mask
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
|
| 54 |
+
"""Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
----
|
| 58 |
+
ratio (np.ndarray): Input ratio matrix.
|
| 59 |
+
threshold (float): Threshold for NMS.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
-------
|
| 63 |
+
np.ndarray: Binary mask after applying NMS.
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def fgbg_depth(
|
| 70 |
+
d: np.ndarray, t: float
|
| 71 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 72 |
+
"""Find foreground-background relations between neighboring pixels.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
----
|
| 76 |
+
d (np.ndarray): Depth matrix.
|
| 77 |
+
t (float): Threshold for comparison.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
-------
|
| 81 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
|
| 82 |
+
left, top, right, and bottom foreground-background relations.
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
|
| 86 |
+
left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
|
| 87 |
+
bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
|
| 88 |
+
top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
|
| 89 |
+
return (
|
| 90 |
+
left_is_big_enough,
|
| 91 |
+
top_is_big_enough,
|
| 92 |
+
right_is_big_enough,
|
| 93 |
+
bottom_is_big_enough,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def fgbg_depth_thinned(
|
| 98 |
+
d: np.ndarray, t: float
|
| 99 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 100 |
+
"""Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
----
|
| 104 |
+
d (np.ndarray): Depth matrix.
|
| 105 |
+
t (float): Threshold for NMS.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
-------
|
| 109 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
|
| 110 |
+
left, top, right, and bottom foreground-background relations with NMS applied.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
|
| 114 |
+
left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
|
| 115 |
+
bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
|
| 116 |
+
top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
|
| 117 |
+
return (
|
| 118 |
+
left_is_big_enough,
|
| 119 |
+
top_is_big_enough,
|
| 120 |
+
right_is_big_enough,
|
| 121 |
+
bottom_is_big_enough,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def fgbg_binary_mask(
|
| 126 |
+
d: np.ndarray,
|
| 127 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 128 |
+
"""Find foreground-background relations between neighboring pixels in binary masks.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
----
|
| 132 |
+
d (np.ndarray): Binary depth matrix.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
-------
|
| 136 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
|
| 137 |
+
left, top, right, and bottom foreground-background relations in binary masks.
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
+
assert d.dtype == bool
|
| 141 |
+
right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
|
| 142 |
+
left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
|
| 143 |
+
bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
|
| 144 |
+
top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
|
| 145 |
+
return (
|
| 146 |
+
left_is_big_enough,
|
| 147 |
+
top_is_big_enough,
|
| 148 |
+
right_is_big_enough,
|
| 149 |
+
bottom_is_big_enough,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
|
| 154 |
+
"""Calculate edge recall for image matting.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
----
|
| 158 |
+
pr (np.ndarray): Predicted depth matrix.
|
| 159 |
+
gt (np.ndarray): Ground truth binary mask.
|
| 160 |
+
t (float): Threshold for NMS.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
-------
|
| 164 |
+
float: Edge recall value.
|
| 165 |
+
|
| 166 |
+
"""
|
| 167 |
+
assert gt.dtype == bool
|
| 168 |
+
ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
|
| 169 |
+
ag, bg, cg, dg = fgbg_binary_mask(gt)
|
| 170 |
+
return 0.25 * (
|
| 171 |
+
np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
|
| 172 |
+
+ np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
|
| 173 |
+
+ np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
|
| 174 |
+
+ np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def boundary_f1(
|
| 179 |
+
pr: np.ndarray,
|
| 180 |
+
gt: np.ndarray,
|
| 181 |
+
t: float,
|
| 182 |
+
return_p: bool = False,
|
| 183 |
+
return_r: bool = False,
|
| 184 |
+
) -> float:
|
| 185 |
+
"""Calculate Boundary F1 score.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
----
|
| 189 |
+
pr (np.ndarray): Predicted depth matrix.
|
| 190 |
+
gt (np.ndarray): Ground truth depth matrix.
|
| 191 |
+
t (float): Threshold for comparison.
|
| 192 |
+
return_p (bool, optional): If True, return precision. Defaults to False.
|
| 193 |
+
return_r (bool, optional): If True, return recall. Defaults to False.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
-------
|
| 197 |
+
float: Boundary F1 score, or precision, or recall depending on the flags.
|
| 198 |
+
|
| 199 |
+
"""
|
| 200 |
+
ap, bp, cp, dp = fgbg_depth(pr, t)
|
| 201 |
+
ag, bg, cg, dg = fgbg_depth(gt, t)
|
| 202 |
+
|
| 203 |
+
r = 0.25 * (
|
| 204 |
+
np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
|
| 205 |
+
+ np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
|
| 206 |
+
+ np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
|
| 207 |
+
+ np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
|
| 208 |
+
)
|
| 209 |
+
p = 0.25 * (
|
| 210 |
+
np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
|
| 211 |
+
+ np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
|
| 212 |
+
+ np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
|
| 213 |
+
+ np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
|
| 214 |
+
)
|
| 215 |
+
if r + p == 0:
|
| 216 |
+
return 0.0
|
| 217 |
+
if return_p:
|
| 218 |
+
return p
|
| 219 |
+
if return_r:
|
| 220 |
+
return r
|
| 221 |
+
return 2 * (r * p) / (r + p)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_thresholds_and_weights(
|
| 225 |
+
t_min: float, t_max: float, N: int
|
| 226 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 227 |
+
"""Generate thresholds and weights for the given range.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
----
|
| 231 |
+
t_min (float): Minimum threshold.
|
| 232 |
+
t_max (float): Maximum threshold.
|
| 233 |
+
N (int): Number of thresholds.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
-------
|
| 237 |
+
Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
thresholds = np.linspace(t_min, t_max, N)
|
| 241 |
+
weights = thresholds / thresholds.sum()
|
| 242 |
+
return thresholds, weights
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
|
| 246 |
+
"""Inverts a depth map with numerical stability.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
----
|
| 250 |
+
depth (np.ndarray): Depth map to be inverted.
|
| 251 |
+
eps (float): Minimum value to avoid division by zero (default is 1e-6).
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
-------
|
| 255 |
+
np.ndarray: Inverted depth map.
|
| 256 |
+
|
| 257 |
+
"""
|
| 258 |
+
inverse_depth = 1.0 / depth.clip(min=eps)
|
| 259 |
+
return inverse_depth
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def SI_boundary_F1(
|
| 263 |
+
predicted_depth: np.ndarray,
|
| 264 |
+
target_depth: np.ndarray,
|
| 265 |
+
t_min: float = 1.05,
|
| 266 |
+
t_max: float = 1.25,
|
| 267 |
+
N: int = 10,
|
| 268 |
+
) -> float:
|
| 269 |
+
"""Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
----
|
| 273 |
+
predicted_depth (np.ndarray): Predicted depth matrix.
|
| 274 |
+
target_depth (np.ndarray): Ground truth depth matrix.
|
| 275 |
+
t_min (float, optional): Minimum threshold. Defaults to 1.05.
|
| 276 |
+
t_max (float, optional): Maximum threshold. Defaults to 1.25.
|
| 277 |
+
N (int, optional): Number of thresholds. Defaults to 10.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
-------
|
| 281 |
+
float: Scale-Invariant Boundary F1 Score.
|
| 282 |
+
|
| 283 |
+
"""
|
| 284 |
+
assert predicted_depth.ndim == target_depth.ndim == 2
|
| 285 |
+
thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
|
| 286 |
+
f1_scores = np.array(
|
| 287 |
+
[
|
| 288 |
+
boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
|
| 289 |
+
for t in thresholds
|
| 290 |
+
]
|
| 291 |
+
)
|
| 292 |
+
return np.sum(f1_scores * weights)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def SI_boundary_Recall(
|
| 296 |
+
predicted_depth: np.ndarray,
|
| 297 |
+
target_mask: np.ndarray,
|
| 298 |
+
t_min: float = 1.05,
|
| 299 |
+
t_max: float = 1.25,
|
| 300 |
+
N: int = 10,
|
| 301 |
+
alpha_threshold: float = 0.1,
|
| 302 |
+
) -> float:
|
| 303 |
+
"""Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
----
|
| 307 |
+
predicted_depth (np.ndarray): Predicted depth matrix.
|
| 308 |
+
target_mask (np.ndarray): Ground truth binary mask.
|
| 309 |
+
t_min (float, optional): Minimum threshold. Defaults to 1.05.
|
| 310 |
+
t_max (float, optional): Maximum threshold. Defaults to 1.25.
|
| 311 |
+
N (int, optional): Number of thresholds. Defaults to 10.
|
| 312 |
+
alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
-------
|
| 316 |
+
float: Scale-Invariant Boundary Recall Score.
|
| 317 |
+
|
| 318 |
+
"""
|
| 319 |
+
assert predicted_depth.ndim == target_mask.ndim == 2
|
| 320 |
+
thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
|
| 321 |
+
thresholded_target = target_mask > alpha_threshold
|
| 322 |
+
|
| 323 |
+
recall_scores = np.array(
|
| 324 |
+
[
|
| 325 |
+
edge_recall_matting(
|
| 326 |
+
invert_depth(predicted_depth), thresholded_target, t=float(t)
|
| 327 |
+
)
|
| 328 |
+
for t in thresholds
|
| 329 |
+
]
|
| 330 |
+
)
|
| 331 |
+
weighted_recall = np.sum(recall_scores * weights)
|
| 332 |
+
return weighted_recall
|
DepthMaster/src/util/build_mlp.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
def build_mlp_(hidden_size=640, projector_dim=1024, z_dim=768):
|
| 4 |
+
return nn.Sequential(
|
| 5 |
+
nn.Linear(hidden_size, projector_dim),
|
| 6 |
+
nn.SiLU(),
|
| 7 |
+
nn.Linear(projector_dim, projector_dim),
|
| 8 |
+
nn.SiLU(),
|
| 9 |
+
nn.Linear(projector_dim, z_dim),
|
| 10 |
+
)
|
DepthMaster/src/util/config_util.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import omegaconf
|
| 26 |
+
from omegaconf import OmegaConf
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def recursive_load_config(config_path: str) -> OmegaConf:
|
| 30 |
+
conf = OmegaConf.load(config_path)
|
| 31 |
+
|
| 32 |
+
output_conf = OmegaConf.create({})
|
| 33 |
+
|
| 34 |
+
# Load base config. Later configs on the list will overwrite previous
|
| 35 |
+
base_configs = conf.get("base_config", default_value=None)
|
| 36 |
+
if base_configs is not None:
|
| 37 |
+
assert isinstance(base_configs, omegaconf.listconfig.ListConfig)
|
| 38 |
+
for _path in base_configs:
|
| 39 |
+
assert (
|
| 40 |
+
_path != config_path
|
| 41 |
+
), "Circulate merging, base_config should not include itself."
|
| 42 |
+
_base_conf = recursive_load_config(_path)
|
| 43 |
+
output_conf = OmegaConf.merge(output_conf, _base_conf)
|
| 44 |
+
|
| 45 |
+
# Merge configs and overwrite values
|
| 46 |
+
output_conf = OmegaConf.merge(output_conf, conf)
|
| 47 |
+
|
| 48 |
+
return output_conf
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def find_value_in_omegaconf(search_key, config):
|
| 52 |
+
result_list = []
|
| 53 |
+
|
| 54 |
+
if isinstance(config, omegaconf.DictConfig):
|
| 55 |
+
for key, value in config.items():
|
| 56 |
+
if key == search_key:
|
| 57 |
+
result_list.append(value)
|
| 58 |
+
elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)):
|
| 59 |
+
result_list.extend(find_value_in_omegaconf(search_key, value))
|
| 60 |
+
elif isinstance(config, omegaconf.ListConfig):
|
| 61 |
+
for item in config:
|
| 62 |
+
if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)):
|
| 63 |
+
result_list.extend(find_value_in_omegaconf(search_key, item))
|
| 64 |
+
|
| 65 |
+
return result_list
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if "__main__" == __name__:
|
| 69 |
+
conf = recursive_load_config("config/train_base.yaml")
|
| 70 |
+
print(OmegaConf.to_yaml(conf))
|
DepthMaster/src/util/data_loader.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py
|
| 2 |
+
|
| 3 |
+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
|
| 4 |
+
|
| 5 |
+
# kwargs of the DataLoader in min version 1.4.0.
|
| 6 |
+
_PYTORCH_DATALOADER_KWARGS = {
|
| 7 |
+
"batch_size": 1,
|
| 8 |
+
"shuffle": False,
|
| 9 |
+
"sampler": None,
|
| 10 |
+
"batch_sampler": None,
|
| 11 |
+
"num_workers": 0,
|
| 12 |
+
"collate_fn": None,
|
| 13 |
+
"pin_memory": False,
|
| 14 |
+
"drop_last": False,
|
| 15 |
+
"timeout": 0,
|
| 16 |
+
"worker_init_fn": None,
|
| 17 |
+
"multiprocessing_context": None,
|
| 18 |
+
"generator": None,
|
| 19 |
+
"prefetch_factor": 2,
|
| 20 |
+
"persistent_workers": False,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SkipBatchSampler(BatchSampler):
|
| 25 |
+
"""
|
| 26 |
+
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, batch_sampler, skip_batches=0):
|
| 30 |
+
self.batch_sampler = batch_sampler
|
| 31 |
+
self.skip_batches = skip_batches
|
| 32 |
+
|
| 33 |
+
def __iter__(self):
|
| 34 |
+
for index, samples in enumerate(self.batch_sampler):
|
| 35 |
+
if index >= self.skip_batches:
|
| 36 |
+
yield samples
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def total_length(self):
|
| 40 |
+
return len(self.batch_sampler)
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.batch_sampler) - self.skip_batches
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SkipDataLoader(DataLoader):
|
| 47 |
+
"""
|
| 48 |
+
Subclass of a PyTorch `DataLoader` that will skip the first batches.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
| 52 |
+
The dataset to use to build this datalaoder.
|
| 53 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
| 54 |
+
The number of batches to skip at the beginning.
|
| 55 |
+
kwargs:
|
| 56 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, dataset, skip_batches=0, **kwargs):
|
| 60 |
+
super().__init__(dataset, **kwargs)
|
| 61 |
+
self.skip_batches = skip_batches
|
| 62 |
+
|
| 63 |
+
def __iter__(self):
|
| 64 |
+
for index, batch in enumerate(super().__iter__()):
|
| 65 |
+
if index >= self.skip_batches:
|
| 66 |
+
yield batch
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Adapted from https://github.com/huggingface/accelerate
|
| 70 |
+
def skip_first_batches(dataloader, num_batches=0):
|
| 71 |
+
"""
|
| 72 |
+
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
|
| 73 |
+
"""
|
| 74 |
+
dataset = dataloader.dataset
|
| 75 |
+
sampler_is_batch_sampler = False
|
| 76 |
+
if isinstance(dataset, IterableDataset):
|
| 77 |
+
new_batch_sampler = None
|
| 78 |
+
else:
|
| 79 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
| 80 |
+
batch_sampler = (
|
| 81 |
+
dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
| 82 |
+
)
|
| 83 |
+
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
|
| 84 |
+
|
| 85 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
| 86 |
+
ignore_kwargs = [
|
| 87 |
+
"batch_size",
|
| 88 |
+
"shuffle",
|
| 89 |
+
"sampler",
|
| 90 |
+
"batch_sampler",
|
| 91 |
+
"drop_last",
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
kwargs = {
|
| 95 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
| 96 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
| 97 |
+
if k not in ignore_kwargs
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
| 101 |
+
if new_batch_sampler is None:
|
| 102 |
+
kwargs["drop_last"] = dataloader.drop_last
|
| 103 |
+
kwargs["batch_size"] = dataloader.batch_size
|
| 104 |
+
|
| 105 |
+
if new_batch_sampler is None:
|
| 106 |
+
# Need to manually skip batches in the dataloader
|
| 107 |
+
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
|
| 108 |
+
else:
|
| 109 |
+
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
|
| 110 |
+
|
| 111 |
+
return dataloader
|
DepthMaster/src/util/depth_transform.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2025-01-14
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2025 Ziyang Song, USTC. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified from the original version.
|
| 6 |
+
# Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 21 |
+
# Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
|
| 22 |
+
# More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
|
| 23 |
+
# --------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import logging
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_depth_normalizer(cfg_normalizer):
|
| 30 |
+
if cfg_normalizer is None:
|
| 31 |
+
|
| 32 |
+
def identical(x):
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
depth_transform = identical
|
| 36 |
+
|
| 37 |
+
elif "scale_shift_depth" == cfg_normalizer.type:
|
| 38 |
+
depth_transform = ScaleShiftDepthNormalizer(
|
| 39 |
+
norm_min=cfg_normalizer.norm_min,
|
| 40 |
+
norm_max=cfg_normalizer.norm_max,
|
| 41 |
+
min_max_quantile=cfg_normalizer.min_max_quantile,
|
| 42 |
+
clip=cfg_normalizer.clip,
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
return depth_transform
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DepthNormalizerBase:
|
| 50 |
+
is_absolute = None
|
| 51 |
+
far_plane_at_max = None
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
norm_min=-1.0,
|
| 56 |
+
norm_max=1.0,
|
| 57 |
+
) -> None:
|
| 58 |
+
self.norm_min = norm_min
|
| 59 |
+
self.norm_max = norm_max
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
def __call__(self, depth, valid_mask=None, clip=None):
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
def denormalize(self, depth_norm, **kwargs):
|
| 66 |
+
# For metric depth: convert prediction back to metric depth
|
| 67 |
+
# For relative depth: convert prediction to [0, 1]
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ScaleShiftDepthNormalizer(DepthNormalizerBase):
|
| 72 |
+
"""
|
| 73 |
+
Use near and far plane to linearly normalize depth,
|
| 74 |
+
i.e. d' = d * s + t,
|
| 75 |
+
where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max`
|
| 76 |
+
Near and far planes are determined by taking quantile values.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
is_absolute = False
|
| 80 |
+
far_plane_at_max = True
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True
|
| 84 |
+
) -> None:
|
| 85 |
+
self.norm_min = norm_min
|
| 86 |
+
self.norm_max = norm_max
|
| 87 |
+
self.norm_range = self.norm_max - self.norm_min
|
| 88 |
+
self.min_quantile = min_max_quantile
|
| 89 |
+
self.max_quantile = 1.0 - self.min_quantile
|
| 90 |
+
self.clip = clip
|
| 91 |
+
|
| 92 |
+
def __call__(self, depth_linear, valid_mask=None, clip=None):
|
| 93 |
+
clip = clip if clip is not None else self.clip
|
| 94 |
+
|
| 95 |
+
if valid_mask is None:
|
| 96 |
+
valid_mask = torch.ones_like(depth_linear).bool()
|
| 97 |
+
valid_mask = valid_mask & (depth_linear > 0)
|
| 98 |
+
|
| 99 |
+
# Take quantiles as min and max
|
| 100 |
+
_min, _max = torch.quantile(
|
| 101 |
+
depth_linear[valid_mask],
|
| 102 |
+
torch.tensor([self.min_quantile, self.max_quantile]),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# scale and shift
|
| 106 |
+
depth_norm_linear = (depth_linear - _min) / (
|
| 107 |
+
_max - _min
|
| 108 |
+
) * self.norm_range + self.norm_min
|
| 109 |
+
|
| 110 |
+
if clip:
|
| 111 |
+
depth_norm_linear = torch.clip(
|
| 112 |
+
depth_norm_linear, self.norm_min, self.norm_max
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return depth_norm_linear
|
| 116 |
+
|
| 117 |
+
def scale_back(self, depth_norm):
|
| 118 |
+
# scale to [0, 1]
|
| 119 |
+
depth_linear = (depth_norm - self.norm_min) / self.norm_range
|
| 120 |
+
return depth_linear
|
| 121 |
+
|
| 122 |
+
def denormalize(self, depth_norm, **kwargs):
|
| 123 |
+
logging.warning(f"{self.__class__} is not revertible without GT")
|
| 124 |
+
return self.scale_back(depth_norm=depth_norm)
|