Diffusers
Safetensors
zeyuren2002 commited on
Commit
4b7b610
·
verified ·
1 Parent(s): aca0d59

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. DepthMaster/ckpt/eval/.gitattributes +35 -0
  2. DepthMaster/ckpt/eval/README.md +97 -0
  3. DepthMaster/ckpt/eval/model_index.json +28 -0
  4. DepthMaster/ckpt/eval/text_encoder/config.json +25 -0
  5. DepthMaster/ckpt/eval/tokenizer/merges.txt +0 -0
  6. DepthMaster/ckpt/eval/tokenizer/special_tokens_map.json +24 -0
  7. DepthMaster/ckpt/eval/tokenizer/tokenizer_config.json +34 -0
  8. DepthMaster/ckpt/eval/tokenizer/vocab.json +0 -0
  9. DepthMaster/ckpt/eval/unet/config.json +73 -0
  10. DepthMaster/ckpt/eval/vae/config.json +30 -0
  11. DepthMaster/data_split/kitti/eigen_train_files_with_gt.txt +0 -0
  12. DepthMaster/depthmaster/modules/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
  13. DepthMaster/depthmaster/modules/__pycache__/unet_2d_condition_s2.cpython-310.pyc +0 -0
  14. DepthMaster/external_encoder/dinov2/dinov2_layers/attention.py +83 -0
  15. DepthMaster/external_encoder/dinov2/dinov2_layers/block.py +252 -0
  16. DepthMaster/external_encoder/dinov2/dinov2_layers/drop_path.py +35 -0
  17. DepthMaster/external_encoder/dinov2/dinov2_layers/layer_scale.py +28 -0
  18. DepthMaster/external_encoder/dinov2/dinov2_layers/mlp.py +41 -0
  19. DepthMaster/external_encoder/dinov2/dinov2_layers/patch_embed.py +89 -0
  20. DepthMaster/external_encoder/dinov2/dinov2_layers/swiglu_ffn.py +63 -0
  21. DepthMaster/external_encoder/dinov2/util/transform.py +160 -0
  22. DepthMaster/in_the_wild_example/input/06.jpg +0 -0
  23. DepthMaster/scripts/eval_diode.sh +13 -0
  24. DepthMaster/scripts/eval_eth3d.sh +13 -0
  25. DepthMaster/scripts/eval_hypersim.sh +13 -0
  26. DepthMaster/scripts/eval_kitti.sh +13 -0
  27. DepthMaster/scripts/eval_nyu.sh +13 -0
  28. DepthMaster/scripts/eval_scannet.sh +13 -0
  29. DepthMaster/scripts/infer.sh +10 -0
  30. DepthMaster/scripts/train_s1.sh +9 -0
  31. DepthMaster/scripts/train_s2.sh +9 -0
  32. DepthMaster/src/dataset/__init__.py +71 -0
  33. DepthMaster/src/dataset/base_depth_dataset.py +303 -0
  34. DepthMaster/src/dataset/diode_dataset.py +94 -0
  35. DepthMaster/src/dataset/eth3d_dataset.py +68 -0
  36. DepthMaster/src/dataset/hypersim_dataset.py +48 -0
  37. DepthMaster/src/dataset/kitti_dataset.py +127 -0
  38. DepthMaster/src/dataset/mixed_sampler.py +151 -0
  39. DepthMaster/src/dataset/nyu_dataset.py +64 -0
  40. DepthMaster/src/dataset/scannet_dataset.py +47 -0
  41. DepthMaster/src/dataset/vkitti_dataset.py +100 -0
  42. DepthMaster/src/trainer/__init__.py +15 -0
  43. DepthMaster/src/trainer/trainer_s1.py +671 -0
  44. DepthMaster/src/trainer/trainer_s2.py +630 -0
  45. DepthMaster/src/util/alignment.py +180 -0
  46. DepthMaster/src/util/boundary_metrics.py +332 -0
  47. DepthMaster/src/util/build_mlp.py +10 -0
  48. DepthMaster/src/util/config_util.py +70 -0
  49. DepthMaster/src/util/data_loader.py +111 -0
  50. DepthMaster/src/util/depth_transform.py +124 -0
DepthMaster/ckpt/eval/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
DepthMaster/ckpt/eval/README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - stabilityai/stable-diffusion-2
7
+ pipeline_tag: depth-estimation
8
+ ---
9
+ <!-- # DepthMaster: Taming Diffusion Models for Monocular Depth Estimation
10
+
11
+
12
+ This repository represents the official implementation of the paper titled "DepthMaster: Taming Diffusion Models for Monocular Depth Estimation". -->
13
+
14
+ <!-- [![Website](doc/badges/badge-website.svg)](https://marigoldmonodepth.github.io)
15
+ [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2312.02145) -->
16
+
17
+ <!-- [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) -->
18
+
19
+ <h1 align="center"><strong>DepthMaster: Taming Diffusion Models for Monocular Depth Estimation</strong></h1>
20
+ <p align="center">
21
+ <a href="https://indu1ge.github.io/ziyangsong">Ziyang Song*</a>,
22
+ <a href="https://orcid.org/0009-0001-6677-0572">Zerong Wang*</a>,
23
+ <a href="https://orcid.org/0000-0001-7817-0665">Bo Li</a>,
24
+ <a href="https://orcid.org/0009-0007-1175-5918">Hao Zhang</a>,
25
+ <a href="https://ruijiezhu94.github.io/ruijiezhu/">Ruijie Zhu</a>,
26
+ <a href="https://orcid.org/0009-0004-3280-8490">Li Liu</a>,
27
+ <a href="https://pengtaojiang.github.io/">Peng-Tao Jiang†</a>,
28
+ <a href="http://staff.ustc.edu.cn/~tzzhang/">Tianzhu Zhang†</a>,
29
+ <br>
30
+ *Equal Contribution, †Corresponding Author
31
+ <br>
32
+ University of Science and Technology of China, vivo Mobile Communication Co., Ltd.
33
+ <br>
34
+ <b>Arxiv 2025</b>
35
+ </p>
36
+ <!-- [Ziyang Song*](https://indu1ge.github.io/ziyangsong),
37
+ [Zerong Wang*](),
38
+ [Bo Li](https://orcid.org/0000-0001-7817-0665),
39
+ [Hao Zhang](https://orcid.org/0009-0007-1175-5918),
40
+ [Ruijie Zhu](https://ruijiezhu94.github.io/ruijiezhu/),
41
+ [Li Liu](https://orcid.org/0009-0004-3280-8490)
42
+ [Tianzhu Zhang](http://staff.ustc.edu.cn/~tzzhang/)
43
+ [Peng-Tao Jiang](https://pengtaojiang.github.io/) -->
44
+
45
+
46
+
47
+ <div align="center">
48
+ <a href='https://arxiv.org/abs/2501.02576'>
49
+ <img src='https://img.shields.io/badge/Paper-arXiv-red'>
50
+ </a>
51
+ <a href='https://indu1ge.github.io/DepthMaster_page/'>
52
+ <img src='https://img.shields.io/badge/Project-Page-Green'>
53
+ </a>
54
+ <a href='https://github.com/indu1ge/DepthMaster'>
55
+ <img src='https://img.shields.io/badge/GitHub-Repository-blue?logo=github'>
56
+ </a>
57
+ <a href='https://www.apache.org/licenses/LICENSE-2.0'>
58
+ <img src='https://img.shields.io/badge/License-Apache--2.0-929292'>
59
+ </a>
60
+ </div>
61
+
62
+
63
+
64
+ <!-- We present Marigold, a diffusion model, and associated fine-tuning protocol for monocular depth estimation. Its core principle is to leverage the rich visual knowledge stored in modern generative image models. Our model, derived from Stable Diffusion and fine-tuned with synthetic data, can zero-shot transfer to unseen data, offering state-of-the-art monocular depth estimation results. -->
65
+
66
+
67
+ ![teaser](assets/framework.png)
68
+
69
+ <!-- >We present DepthMaster, a tamed single-step diffusion model designed to enhance the generalization and detail preservation abilities of depth estimation models. Through feature alignment, we effectively prevent the overfitting to texture details. By adaptively enhance -->
70
+ >We present DepthMaster, a tamed single-step diffusion model that customizes generative features in diffusion models to suit the discriminative depth estimation task. We introduce a Feature Alignment module to mitigate overfitting to texture and a Fourier Enhancement module to refine fine-grained details. DepthMaster exhibits state-of-the-art zero-shot performance and superior detail preservation ability, surpassing
71
+ other diffusion-based methods across various datasets.
72
+
73
+
74
+ ## 🎓 Citation
75
+
76
+ Please cite our paper:
77
+
78
+ ```bibtex
79
+ @article{song2025depthmaster,
80
+ title={DepthMaster: Taming Diffusion Models for Monocular Depth Estimation},
81
+ author={Song, Ziyang and Wang, Zerong and Li, Bo and Zhang, Hao and Zhu, Ruijie and Liu, Li and Jiang, Peng-Tao and Zhang, Tianzhu},
82
+ journal={arXiv preprint arXiv:2501.02576},
83
+ year={2025}
84
+ }
85
+ ```
86
+
87
+ ## Acknowledgements
88
+
89
+ The code is based on [Marigold](https://github.com/prs-eth/Marigold).
90
+
91
+ ## 🎫 License
92
+
93
+ This work is licensed under the Apache License, Version 2.0 (as defined in the [LICENSE](LICENSE.txt)).
94
+
95
+ By downloading and using the code and model you agree to the terms in the [LICENSsE](LICENSE.txt).
96
+
97
+ [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0)
DepthMaster/ckpt/eval/model_index.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name":"MarigoldPipeline",
3
+ "_diffusers_version":"0.24.0",
4
+ "scale_invariant": true,
5
+ "shift_invariant": true,
6
+ "default_denoising_steps": 10,
7
+ "default_processing_resolution": 768,
8
+ "unet":[
9
+ "diffusers",
10
+ "UNet2DConditionModel"
11
+ ],
12
+ "vae":[
13
+ "diffusers",
14
+ "AutoencoderKL"
15
+ ],
16
+ "scheduler":[
17
+ "diffusers",
18
+ "DDIMScheduler"
19
+ ],
20
+ "text_encoder":[
21
+ "transformers",
22
+ "CLIPTextModel"
23
+ ],
24
+ "tokenizer":[
25
+ "transformers",
26
+ "CLIPTokenizer"
27
+ ]
28
+ }
DepthMaster/ckpt/eval/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "hf-models/stable-diffusion-v2-768x768/text_encoder",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.25.0.dev0",
24
+ "vocab_size": 49408
25
+ }
DepthMaster/ckpt/eval/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
DepthMaster/ckpt/eval/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
DepthMaster/ckpt/eval/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "hf-models/stable-diffusion-v2-768x768/tokenizer",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
DepthMaster/ckpt/eval/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
DepthMaster/ckpt/eval/unet/config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.31.0",
4
+ "_name_or_path": "/data/vjuicefs_ai_camera_jgroup/11169299/Marigold_rgb2d/log/depth_preprocess/rgb2disp_bs4_sqrt_disp_cos1e-3_0.85/checkpoint/iter_014000/unet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20,
13
+ 20
14
+ ],
15
+ "attention_type": "default",
16
+ "block_out_channels": [
17
+ 320,
18
+ 640,
19
+ 1280,
20
+ 1280
21
+ ],
22
+ "center_input_sample": false,
23
+ "class_embed_type": null,
24
+ "class_embeddings_concat": false,
25
+ "conv_in_kernel": 3,
26
+ "conv_out_kernel": 3,
27
+ "cross_attention_dim": 1024,
28
+ "cross_attention_norm": null,
29
+ "down_block_types": [
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "DownBlock2D"
34
+ ],
35
+ "downsample_padding": 1,
36
+ "dropout": 0.0,
37
+ "dual_cross_attention": false,
38
+ "encoder_hid_dim": null,
39
+ "encoder_hid_dim_type": null,
40
+ "flip_sin_to_cos": true,
41
+ "freq_shift": 0,
42
+ "in_channels": 4,
43
+ "layers_per_block": 2,
44
+ "mid_block_only_cross_attention": null,
45
+ "mid_block_scale_factor": 1,
46
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
47
+ "norm_eps": 1e-05,
48
+ "norm_num_groups": 32,
49
+ "num_attention_heads": null,
50
+ "num_class_embeds": null,
51
+ "only_cross_attention": false,
52
+ "out_channels": 4,
53
+ "projection_class_embeddings_input_dim": null,
54
+ "resnet_out_scale_factor": 1.0,
55
+ "resnet_skip_time_act": false,
56
+ "resnet_time_scale_shift": "default",
57
+ "reverse_transformer_layers_per_block": null,
58
+ "sample_size": 96,
59
+ "time_cond_proj_dim": null,
60
+ "time_embedding_act_fn": null,
61
+ "time_embedding_dim": null,
62
+ "time_embedding_type": "positional",
63
+ "timestep_post_act": null,
64
+ "transformer_layers_per_block": 1,
65
+ "up_block_types": [
66
+ "UpBlock2D",
67
+ "CrossAttnUpBlock2D",
68
+ "CrossAttnUpBlock2D",
69
+ "CrossAttnUpBlock2D"
70
+ ],
71
+ "upcast_attention": false,
72
+ "use_linear_projection": true
73
+ }
DepthMaster/ckpt/eval/vae/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.8.0",
4
+ "_name_or_path": "hf-models/stable-diffusion-v2-768x768/vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 768,
24
+ "up_block_types": [
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D"
29
+ ]
30
+ }
DepthMaster/data_split/kitti/eigen_train_files_with_gt.txt ADDED
The diff for this file is too large to render. See raw diff
 
DepthMaster/depthmaster/modules/__pycache__/unet_2d_blocks.cpython-310.pyc ADDED
Binary file (67.3 kB). View file
 
DepthMaster/depthmaster/modules/__pycache__/unet_2d_condition_s2.cpython-310.pyc ADDED
Binary file (40.9 kB). View file
 
DepthMaster/external_encoder/dinov2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
DepthMaster/external_encoder/dinov2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
DepthMaster/external_encoder/dinov2/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
DepthMaster/external_encoder/dinov2/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
DepthMaster/external_encoder/dinov2/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
DepthMaster/external_encoder/dinov2/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
DepthMaster/external_encoder/dinov2/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
DepthMaster/external_encoder/dinov2/util/transform.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Resize(object):
7
+ """Resize sample to given size (width, height).
8
+ """
9
+
10
+ def __init__(
11
+ self,
12
+ width,
13
+ height,
14
+ resize_target=True,
15
+ keep_aspect_ratio=False,
16
+ ensure_multiple_of=1,
17
+ resize_method="lower_bound",
18
+ image_interpolation_method=cv2.INTER_AREA,
19
+ ):
20
+ """Init.
21
+
22
+ Args:
23
+ width (int): desired output width
24
+ height (int): desired output height
25
+ resize_target (bool, optional):
26
+ True: Resize the full sample (image, mask, target).
27
+ False: Resize image only.
28
+ Defaults to True.
29
+ keep_aspect_ratio (bool, optional):
30
+ True: Keep the aspect ratio of the input sample.
31
+ Output sample might not have the given width and height, and
32
+ resize behaviour depends on the parameter 'resize_method'.
33
+ Defaults to False.
34
+ ensure_multiple_of (int, optional):
35
+ Output width and height is constrained to be multiple of this parameter.
36
+ Defaults to 1.
37
+ resize_method (str, optional):
38
+ "lower_bound": Output will be at least as large as the given size.
39
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
40
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
41
+ Defaults to "lower_bound".
42
+ """
43
+ self.__width = width
44
+ self.__height = height
45
+
46
+ self.__resize_target = resize_target
47
+ self.__keep_aspect_ratio = keep_aspect_ratio
48
+ self.__multiple_of = ensure_multiple_of
49
+ self.__resize_method = resize_method
50
+ self.__image_interpolation_method = image_interpolation_method
51
+
52
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
53
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
54
+
55
+ if max_val is not None and y > max_val:
56
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
57
+
58
+ if y < min_val:
59
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
60
+
61
+ return y
62
+
63
+ def get_size(self, width, height):
64
+ # determine new height and width
65
+ scale_height = self.__height / height
66
+ scale_width = self.__width / width
67
+
68
+ if self.__keep_aspect_ratio:
69
+ if self.__resize_method == "lower_bound":
70
+ # scale such that output size is lower bound
71
+ if scale_width > scale_height:
72
+ # fit width
73
+ scale_height = scale_width
74
+ else:
75
+ # fit height
76
+ scale_width = scale_height
77
+ elif self.__resize_method == "upper_bound":
78
+ # scale such that output size is upper bound
79
+ if scale_width < scale_height:
80
+ # fit width
81
+ scale_height = scale_width
82
+ else:
83
+ # fit height
84
+ scale_width = scale_height
85
+ elif self.__resize_method == "minimal":
86
+ # scale as least as possbile
87
+ if abs(1 - scale_width) < abs(1 - scale_height):
88
+ # fit width
89
+ scale_height = scale_width
90
+ else:
91
+ # fit height
92
+ scale_width = scale_height
93
+ else:
94
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
95
+
96
+ if self.__resize_method == "lower_bound":
97
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
98
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
99
+ elif self.__resize_method == "upper_bound":
100
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
101
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
102
+ elif self.__resize_method == "minimal":
103
+ new_height = self.constrain_to_multiple_of(scale_height * height)
104
+ new_width = self.constrain_to_multiple_of(scale_width * width)
105
+ else:
106
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
107
+
108
+ return (new_width, new_height)
109
+
110
+ def __call__(self, sample):
111
+ width, height = self.get_size(sample["image"].shape[-1], sample["image"].shape[-2])
112
+
113
+ # resize sample
114
+ # sample["image"] = cv2.resize(sample["image"], dsize=(width, height), interpolation=cv2.INTER_NEAREST)
115
+ sample["image"] = F.interpolate(sample["image"], size=(height, width), mode='bilinear', align_corners=False)
116
+
117
+ if self.__resize_target:
118
+ if "depth" in sample:
119
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
120
+
121
+ if "mask" in sample:
122
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
123
+
124
+ return sample, (height, width)
125
+
126
+
127
+ class NormalizeImage(object):
128
+ """Normlize image by given mean and std.
129
+ """
130
+
131
+ def __init__(self, mean, std):
132
+ self.__mean = mean
133
+ self.__std = std
134
+
135
+ def __call__(self, sample):
136
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
137
+
138
+ return sample
139
+
140
+
141
+ class PrepareForNet(object):
142
+ """Prepare sample for usage as network input.
143
+ """
144
+
145
+ def __init__(self):
146
+ pass
147
+
148
+ def __call__(self, sample):
149
+ image = np.transpose(sample["image"], (2, 0, 1))
150
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
151
+
152
+ if "depth" in sample:
153
+ depth = sample["depth"].astype(np.float32)
154
+ sample["depth"] = np.ascontiguousarray(depth)
155
+
156
+ if "mask" in sample:
157
+ sample["mask"] = sample["mask"].astype(np.float32)
158
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
159
+
160
+ return sample
DepthMaster/in_the_wild_example/input/06.jpg ADDED
DepthMaster/scripts/eval_diode.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python evaluate.py \
7
+ --base_data_dir path/to/basedata \
8
+ --dataset_config config/dataset/data_diode_all.yaml \
9
+ --alignment least_square_sqrt_disp \
10
+ --output_dir output/diode/final \
11
+ --checkpoint ckpt/eval \
12
+ --processing_res 640 \
13
+ --seed 1234 \
DepthMaster/scripts/eval_eth3d.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python evaluate.py \
7
+ --base_data_dir path/to/basedata \
8
+ --dataset_config config/dataset/data_eth3d.yaml \
9
+ --alignment least_square_sqrt_disp \
10
+ --output_dir output/eth3d/final \
11
+ --checkpoint ckpt/eval \
12
+ --processing_res 756 \
13
+ --seed 1234 \
DepthMaster/scripts/eval_hypersim.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python evaluate.py \
7
+ --base_data_dir path/to/basedata \
8
+ --dataset_config config/dataset/data_hypersim_test.yaml \
9
+ --alignment least_square_sqrt_disp \
10
+ --output_dir output/hypersim/final \
11
+ --checkpoint ckpt/eval \
12
+ --processing_res 0 \
13
+ --seed 1234 \
DepthMaster/scripts/eval_kitti.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python evaluate.py \
7
+ --base_data_dir path/to/basedata \
8
+ --dataset_config config/dataset/data_kitti_eigen_test.yaml \
9
+ --alignment least_square_sqrt_disp \
10
+ --output_dir output/kitti/final \
11
+ --checkpoint ckpt/eval \
12
+ --processing_res 0 \
13
+ --seed 1234 \
DepthMaster/scripts/eval_nyu.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python evaluate.py \
7
+ --base_data_dir path/to/basedata \
8
+ --dataset_config config/dataset/data_nyu_test.yaml \
9
+ --alignment least_square_sqrt_disp \
10
+ --output_dir output/nyu/final1 \
11
+ --checkpoint ckpt/eval \
12
+ --processing_res 0 \
13
+ --seed 1234 \
DepthMaster/scripts/eval_scannet.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python evaluate.py \
7
+ --base_data_dir path/to/basedata \
8
+ --dataset_config config/dataset/data_scannet_val.yaml \
9
+ --alignment least_square_sqrt_disp \
10
+ --output_dir output/scannet/final \
11
+ --checkpoint ckpt/eval \
12
+ --processing_res 0 \
13
+ --seed 1234 \
DepthMaster/scripts/infer.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+
5
+ export CUDA_VISIBLE_DEVICES=5
6
+ python run.py \
7
+ --checkpoint ckpt/eval \
8
+ --processing_res 768 \
9
+ --input_rgb_dir in_the_wild_example/input \
10
+ --output_dir in_the_wild_example/output/final \
DepthMaster/scripts/train_s1.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ BASE_DATA_DIR="path/to/basedata"
2
+ BASE_CKPT_DIR="path/to/sd2_ckpt"
3
+
4
+ export CUDA_VISIBLE_DEVICES=3
5
+ python train_s1.py --config config/train_s1.yaml \
6
+ --base_data_dir $BASE_DATA_DIR \
7
+ --base_ckpt_dir $BASE_CKPT_DIR \
8
+ --output_dir log/stage1_bs8 \
9
+ --no_wandb \
DepthMaster/scripts/train_s2.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ BASE_DATA_DIR="/zhdd/dataset"
2
+ BASE_CKPT_DIR="ori_ckpt"
3
+
4
+ export CUDA_VISIBLE_DEVICES=2
5
+ python train_s2.py --config config/train_s2.yaml \
6
+ --base_data_dir $BASE_DATA_DIR \
7
+ --base_ckpt_dir $BASE_CKPT_DIR \
8
+ --output_dir log/stage2 \
9
+ --no_wandb \
DepthMaster/src/dataset/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import os
26
+
27
+ from .base_depth_dataset import BaseDepthDataset, get_pred_name, DatasetMode # noqa: F401
28
+ from .diode_dataset import DIODEDataset
29
+ from .eth3d_dataset import ETH3DDataset
30
+ from .hypersim_dataset import HypersimDataset
31
+ from .kitti_dataset import KITTIDataset
32
+ from .nyu_dataset import NYUDataset
33
+ from .scannet_dataset import ScanNetDataset
34
+ from .vkitti_dataset import VirtualKITTIDataset
35
+
36
+
37
+ dataset_name_class_dict = {
38
+ "hypersim": HypersimDataset,
39
+ "vkitti": VirtualKITTIDataset,
40
+ "nyu_v2": NYUDataset,
41
+ "kitti": KITTIDataset,
42
+ "eth3d": ETH3DDataset,
43
+ "diode": DIODEDataset,
44
+ "scannet": ScanNetDataset,
45
+ }
46
+
47
+
48
+ def get_dataset(
49
+ cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs
50
+ ) -> BaseDepthDataset:
51
+ if "mixed" == cfg_data_split.name:
52
+ assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets."
53
+ dataset_ls = [
54
+ get_dataset(_cfg, base_data_dir, mode, **kwargs)
55
+ for _cfg in cfg_data_split.dataset_list
56
+ ]
57
+ return dataset_ls
58
+ elif cfg_data_split.name in dataset_name_class_dict.keys():
59
+ dataset_class = dataset_name_class_dict[cfg_data_split.name]
60
+ dataset = dataset_class(
61
+ mode=mode,
62
+ filename_ls_path=cfg_data_split.filenames,
63
+ dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir),
64
+ # dataset_tom_dir=os.path.join(base_data_dir, cfg_data_split.tom_dir),
65
+ **cfg_data_split,
66
+ **kwargs,
67
+ )
68
+ else:
69
+ raise NotImplementedError
70
+
71
+ return dataset
DepthMaster/src/dataset/base_depth_dataset.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import io
26
+ import os
27
+ import random
28
+ import tarfile
29
+ from enum import Enum
30
+ from typing import Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from torch.utils.data import Dataset
36
+ from torchvision.transforms import InterpolationMode, Resize, RandomResizedCrop
37
+
38
+ from src.util.depth_transform import DepthNormalizerBase
39
+ from src.util.alignment import depth2disparity
40
+
41
+
42
+ class DatasetMode(Enum):
43
+ RGB_ONLY = "rgb_only"
44
+ EVAL = "evaluate"
45
+ TRAIN = "train"
46
+
47
+
48
+ class DepthFileNameMode(Enum):
49
+ """Prediction file naming modes"""
50
+
51
+ id = 1 # id.png
52
+ rgb_id = 2 # rgb_id.png
53
+ i_d_rgb = 3 # i_d_1_rgb.png
54
+ rgb_i_d = 4
55
+
56
+
57
+ def read_image_from_tar(tar_obj, img_rel_path):
58
+ image = tar_obj.extractfile("./" + img_rel_path)
59
+ image = image.read()
60
+ image = Image.open(io.BytesIO(image))
61
+
62
+
63
+ class BaseDepthDataset(Dataset):
64
+ def __init__(
65
+ self,
66
+ mode: DatasetMode,
67
+ filename_ls_path: str,
68
+ dataset_dir: str,
69
+ disp_name: str,
70
+ min_depth: float,
71
+ max_depth: float,
72
+ has_filled_depth: bool,
73
+ has_egde_mask: bool,
74
+ name_mode: DepthFileNameMode,
75
+ depth_transform: Union[DepthNormalizerBase, None] = None,
76
+ augmentation_args: dict = None,
77
+ resize_to_hw=None,
78
+ move_invalid_to_far_plane: bool = True,
79
+ rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1],
80
+ **kwargs,
81
+ ) -> None:
82
+ super().__init__()
83
+ self.mode = mode
84
+ # dataset info
85
+ self.filename_ls_path = filename_ls_path
86
+ self.dataset_dir = dataset_dir
87
+ assert os.path.exists(
88
+ self.dataset_dir
89
+ ), f"Dataset does not exist at: {self.dataset_dir}"
90
+ self.disp_name = disp_name
91
+ self.has_filled_depth = has_filled_depth
92
+ self.has_egde_mask = has_egde_mask
93
+ self.name_mode: DepthFileNameMode = name_mode
94
+ self.min_depth = min_depth
95
+ self.max_depth = max_depth
96
+
97
+ # training arguments
98
+ self.depth_transform: DepthNormalizerBase = depth_transform
99
+ self.augm_args = augmentation_args
100
+ self.resize_to_hw = resize_to_hw
101
+ self.rgb_transform = rgb_transform
102
+ self.move_invalid_to_far_plane = move_invalid_to_far_plane
103
+
104
+ # Load filenames
105
+ with open(self.filename_ls_path, "r") as f:
106
+ self.filenames = [
107
+ s.split() for s in f.readlines()
108
+ ] # [['rgb.png', 'depth.tif'], [], ...]
109
+
110
+ # Tar dataset
111
+ self.tar_obj = None
112
+ self.is_tar = (
113
+ True
114
+ if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir)
115
+ else False
116
+ )
117
+
118
+ def __len__(self):
119
+ return len(self.filenames)
120
+
121
+ def __getitem__(self, index):
122
+ rasters, other = self._get_data_item(index)
123
+ if DatasetMode.TRAIN == self.mode:
124
+ rasters = self._training_preprocess(rasters)
125
+ # merge
126
+ outputs = rasters
127
+ outputs.update(other)
128
+ return outputs
129
+
130
+ def _get_data_item(self, index):
131
+ rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index)
132
+
133
+ rasters = {}
134
+
135
+ # RGB data
136
+ rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
137
+
138
+ # Depth data
139
+ if DatasetMode.RGB_ONLY != self.mode:
140
+ # load data
141
+ depth_data = self._load_depth_data(
142
+ depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path
143
+ )
144
+ rasters.update(depth_data)
145
+ # valid mask
146
+ rasters["valid_mask_raw"] = self._get_valid_mask(
147
+ rasters["depth_raw_linear"]
148
+ ).clone()
149
+ rasters["valid_mask_filled"] = self._get_valid_mask(
150
+ rasters["depth_filled_linear"]
151
+ ).clone()
152
+
153
+ if DatasetMode.TRAIN == self.mode:
154
+ # depth2disp
155
+ rasters["depth_raw_linear"] = depth2disparity(rasters["depth_raw_linear"]).clone()
156
+ if self.has_filled_depth:
157
+ rasters["depth_filled_linear"] = depth2disparity(rasters["depth_filled_linear"]).clone()
158
+
159
+ # sqrt(x)
160
+ rasters["depth_raw_linear"] = torch.sqrt(rasters["depth_raw_linear"]).clone()
161
+ if self.has_filled_depth:
162
+ rasters["depth_filled_linear"] = torch.sqrt(rasters["depth_filled_linear"]).clone()
163
+
164
+ other = {"index": index, "rgb_relative_path": rgb_rel_path}
165
+
166
+ return rasters, other
167
+
168
+ def _load_rgb_data(self, rgb_rel_path):
169
+ # Read RGB data
170
+ rgb = self._read_rgb_file(rgb_rel_path)
171
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
172
+
173
+ outputs = {
174
+ "rgb_int": torch.from_numpy(rgb).int(),
175
+ "rgb_norm": torch.from_numpy(rgb_norm).float(),
176
+ }
177
+ return outputs
178
+
179
+ def _load_depth_data(self, depth_rel_path, filled_rel_path):
180
+ # Read depth data
181
+ outputs = {}
182
+ depth_raw = self._read_depth_file(depth_rel_path).squeeze()
183
+ depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W]
184
+ outputs["depth_raw_linear"] = depth_raw_linear.clone()
185
+
186
+ if self.has_filled_depth:
187
+ depth_filled = self._read_depth_file(filled_rel_path).squeeze()
188
+ depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0)
189
+ outputs["depth_filled_linear"] = depth_filled_linear
190
+ else:
191
+ outputs["depth_filled_linear"] = depth_raw_linear.clone()
192
+
193
+ return outputs
194
+
195
+ def _get_data_path(self, index):
196
+ filename_line = self.filenames[index]
197
+
198
+ # Get data path
199
+ rgb_rel_path = filename_line[0]
200
+
201
+ depth_rel_path, filled_rel_path = None, None
202
+ if DatasetMode.RGB_ONLY != self.mode:
203
+ depth_rel_path = filename_line[1]
204
+ if self.has_filled_depth:
205
+ filled_rel_path = filename_line[2]
206
+ return rgb_rel_path, depth_rel_path, filled_rel_path
207
+
208
+ def _read_image(self, img_rel_path) -> np.ndarray:
209
+ if self.is_tar:
210
+ if self.tar_obj is None:
211
+ self.tar_obj = tarfile.open(self.dataset_dir)
212
+ image_to_read = self.tar_obj.extractfile("./" + img_rel_path)
213
+ image_to_read = image_to_read.read()
214
+ image_to_read = io.BytesIO(image_to_read)
215
+ else:
216
+ image_to_read = os.path.join(self.dataset_dir, img_rel_path)
217
+ image = Image.open(image_to_read) # [H, W, rgb]
218
+ image = np.asarray(image)
219
+ return image
220
+
221
+ def _read_rgb_file(self, rel_path) -> np.ndarray:
222
+ rgb = self._read_image(rel_path)
223
+ rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W]
224
+ return rgb
225
+
226
+ def _read_depth_file(self, rel_path):
227
+ depth_in = self._read_image(rel_path)
228
+ # Replace code below to decode depth according to dataset definition
229
+ depth_decoded = depth_in
230
+
231
+ return depth_decoded
232
+
233
+ def _get_valid_mask(self, depth: torch.Tensor):
234
+ valid_mask = torch.logical_and(
235
+ (depth > self.min_depth), (depth < self.max_depth)
236
+ ).bool()
237
+ return valid_mask
238
+
239
+ def _training_preprocess(self, rasters):
240
+ # Augmentation
241
+ if self.augm_args is not None:
242
+ rasters = self._augment_data(rasters)
243
+
244
+ # Normalization
245
+ rasters["depth_raw_norm"] = self.depth_transform(
246
+ rasters["depth_raw_linear"], rasters["valid_mask_raw"]
247
+ ).clone()
248
+ rasters["depth_filled_norm"] = self.depth_transform(
249
+ rasters["depth_filled_linear"], rasters["valid_mask_filled"]
250
+ ).clone()
251
+
252
+ # Set invalid pixel to far plane
253
+ if self.move_invalid_to_far_plane:
254
+ if self.depth_transform.far_plane_at_max:
255
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
256
+ self.depth_transform.norm_max
257
+ )
258
+ else:
259
+ rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = (
260
+ self.depth_transform.norm_min
261
+ )
262
+
263
+ # Resize
264
+ if self.resize_to_hw is not None:
265
+ resize_transform = Resize(
266
+ size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT
267
+ )
268
+ rasters = {k: resize_transform(v) for k, v in rasters.items()}
269
+
270
+ # # randomresizedcrop
271
+ # resizedcrop = RandomResizedCrop(size=self.resize_to_hw, scale=(0.9, 1), ratio=())
272
+
273
+ return rasters
274
+
275
+ def _augment_data(self, rasters_dict):
276
+ # lr flipping
277
+ lr_flip_p = self.augm_args.lr_flip_p
278
+ if random.random() < lr_flip_p:
279
+ rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()}
280
+
281
+ return rasters_dict
282
+
283
+ def __del__(self):
284
+ if hasattr(self, "tar_obj") and self.tar_obj is not None:
285
+ self.tar_obj.close()
286
+ self.tar_obj = None
287
+
288
+
289
+ def get_pred_name(rgb_basename, name_mode, suffix=".png"):
290
+ if DepthFileNameMode.rgb_id == name_mode:
291
+ pred_basename = "pred_" + rgb_basename.split("_")[1]
292
+ elif DepthFileNameMode.i_d_rgb == name_mode:
293
+ pred_basename = rgb_basename.replace("_rgb.", "_pred.")
294
+ elif DepthFileNameMode.id == name_mode:
295
+ pred_basename = "pred_" + rgb_basename
296
+ elif DepthFileNameMode.rgb_i_d == name_mode:
297
+ pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:])
298
+ else:
299
+ raise NotImplementedError
300
+ # change suffix
301
+ pred_basename = os.path.splitext(pred_basename)[0] + suffix
302
+
303
+ return pred_basename
DepthMaster/src/dataset/diode_dataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import os
26
+ import tarfile
27
+ from io import BytesIO
28
+
29
+ import numpy as np
30
+ import torch
31
+
32
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode, DatasetMode
33
+
34
+
35
+ class DIODEDataset(BaseDepthDataset):
36
+ def __init__(
37
+ self,
38
+ **kwargs,
39
+ ) -> None:
40
+ super().__init__(
41
+ # DIODE data parameter
42
+ min_depth=0.6,
43
+ max_depth=350,
44
+ has_filled_depth=False,
45
+ has_egde_mask=False,
46
+ name_mode=DepthFileNameMode.id,
47
+ **kwargs,
48
+ )
49
+
50
+ def _read_npy_file(self, rel_path):
51
+ if self.is_tar:
52
+ if self.tar_obj is None:
53
+ self.tar_obj = tarfile.open(self.dataset_dir)
54
+ fileobj = self.tar_obj.extractfile("./" + rel_path)
55
+ npy_path_or_content = BytesIO(fileobj.read())
56
+ else:
57
+ npy_path_or_content = os.path.join(self.dataset_dir, rel_path)
58
+ data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :]
59
+ return data
60
+
61
+ def _read_depth_file(self, rel_path):
62
+ depth = self._read_npy_file(rel_path)
63
+ return depth
64
+
65
+ def _get_data_path(self, index):
66
+ return self.filenames[index]
67
+
68
+ def _get_data_item(self, index):
69
+ # Special: depth mask is read from data
70
+
71
+ rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index)
72
+
73
+ rasters = {}
74
+
75
+ # RGB data
76
+ rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
77
+
78
+ # Depth data
79
+ if DatasetMode.RGB_ONLY != self.mode:
80
+ # load data
81
+ depth_data = self._load_depth_data(
82
+ depth_rel_path=depth_rel_path, filled_rel_path=None
83
+ )
84
+ rasters.update(depth_data)
85
+
86
+ # valid mask
87
+ mask = self._read_npy_file(mask_rel_path).astype(bool)
88
+ mask = torch.from_numpy(mask).bool()
89
+ rasters["valid_mask_raw"] = mask.clone()
90
+ rasters["valid_mask_filled"] = mask.clone()
91
+
92
+ other = {"index": index, "rgb_relative_path": rgb_rel_path}
93
+
94
+ return rasters, other
DepthMaster/src/dataset/eth3d_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import torch
26
+ import tarfile
27
+ import os
28
+ import numpy as np
29
+
30
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
31
+
32
+
33
+ class ETH3DDataset(BaseDepthDataset):
34
+ HEIGHT, WIDTH = 4032, 6048
35
+
36
+ def __init__(
37
+ self,
38
+ **kwargs,
39
+ ) -> None:
40
+ super().__init__(
41
+ # ETH3D data parameter
42
+ min_depth=1e-5,
43
+ max_depth=torch.inf,
44
+ has_filled_depth=False,
45
+ has_egde_mask=False,
46
+ name_mode=DepthFileNameMode.id,
47
+ **kwargs,
48
+ )
49
+
50
+ def _read_depth_file(self, rel_path):
51
+ # Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats
52
+ if self.is_tar:
53
+ if self.tar_obj is None:
54
+ self.tar_obj = tarfile.open(self.dataset_dir)
55
+ binary_data = self.tar_obj.extractfile("./" + rel_path)
56
+ binary_data = binary_data.read()
57
+
58
+ else:
59
+ depth_path = os.path.join(self.dataset_dir, rel_path)
60
+ with open(depth_path, "rb") as file:
61
+ binary_data = file.read()
62
+ # Convert the binary data to a numpy array of 32-bit floats
63
+ depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy()
64
+
65
+ depth_decoded[depth_decoded == torch.inf] = 0.0
66
+
67
+ depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH))
68
+ return depth_decoded
DepthMaster/src/dataset/hypersim_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
27
+
28
+
29
+ class HypersimDataset(BaseDepthDataset):
30
+ def __init__(
31
+ self,
32
+ **kwargs,
33
+ ) -> None:
34
+ super().__init__(
35
+ # Hypersim data parameter
36
+ min_depth=1e-5,
37
+ max_depth=65.0,
38
+ has_filled_depth=False,
39
+ has_egde_mask=False,
40
+ name_mode=DepthFileNameMode.rgb_i_d,
41
+ **kwargs,
42
+ )
43
+
44
+ def _read_depth_file(self, rel_path):
45
+ depth_in = self._read_image(rel_path)
46
+ # Decode Hypersim depth
47
+ depth_decoded = depth_in / 1000.0
48
+ return depth_decoded
DepthMaster/src/dataset/kitti_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import torch
26
+
27
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
28
+
29
+
30
+ class KITTIDataset(BaseDepthDataset):
31
+ def __init__(
32
+ self,
33
+ kitti_bm_crop, # Crop to KITTI benchmark size
34
+ valid_mask_crop, # Evaluation mask. [None, garg or eigen]
35
+ **kwargs,
36
+ ) -> None:
37
+ super().__init__(
38
+ # KITTI data parameter
39
+ min_depth=1e-5,
40
+ max_depth=80,
41
+ has_filled_depth=False,
42
+ has_egde_mask=False,
43
+ name_mode=DepthFileNameMode.id,
44
+ **kwargs,
45
+ )
46
+ self.kitti_bm_crop = kitti_bm_crop
47
+ self.valid_mask_crop = valid_mask_crop
48
+ assert self.valid_mask_crop in [
49
+ None,
50
+ "garg", # set evaluation mask according to Garg ECCV16
51
+ "eigen", # set evaluation mask according to Eigen NIPS14
52
+ ], f"Unknown crop type: {self.valid_mask_crop}"
53
+
54
+ # Filter out empty depth
55
+ self.filenames = [f for f in self.filenames if "None" != f[1]]
56
+
57
+ def _read_depth_file(self, rel_path):
58
+ depth_in = self._read_image(rel_path)
59
+ # Decode KITTI depth
60
+ depth_decoded = depth_in / 256.0
61
+ return depth_decoded
62
+
63
+ def _load_rgb_data(self, rgb_rel_path):
64
+ rgb_data = super()._load_rgb_data(rgb_rel_path)
65
+ if self.kitti_bm_crop:
66
+ rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()}
67
+ return rgb_data
68
+
69
+ def _load_depth_data(self, depth_rel_path, filled_rel_path):
70
+ depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path)
71
+ if self.kitti_bm_crop:
72
+ depth_data = {
73
+ k: self.kitti_benchmark_crop(v) for k, v in depth_data.items()
74
+ }
75
+ return depth_data
76
+
77
+ @staticmethod
78
+ def kitti_benchmark_crop(input_img):
79
+ """
80
+ Crop images to KITTI benchmark size
81
+ Args:
82
+ `input_img` (torch.Tensor): Input image to be cropped.
83
+
84
+ Returns:
85
+ torch.Tensor:Cropped image.
86
+ """
87
+ KB_CROP_HEIGHT = 352
88
+ KB_CROP_WIDTH = 1216
89
+
90
+ height, width = input_img.shape[-2:]
91
+ top_margin = int(height - KB_CROP_HEIGHT)
92
+ left_margin = int((width - KB_CROP_WIDTH) / 2)
93
+ if 2 == len(input_img.shape):
94
+ out = input_img[
95
+ top_margin : top_margin + KB_CROP_HEIGHT,
96
+ left_margin : left_margin + KB_CROP_WIDTH,
97
+ ]
98
+ elif 3 == len(input_img.shape):
99
+ out = input_img[
100
+ :,
101
+ top_margin : top_margin + KB_CROP_HEIGHT,
102
+ left_margin : left_margin + KB_CROP_WIDTH,
103
+ ]
104
+ return out
105
+
106
+ def _get_valid_mask(self, depth: torch.Tensor):
107
+ # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py
108
+ valid_mask = super()._get_valid_mask(depth) # [1, H, W]
109
+
110
+ if self.valid_mask_crop is not None:
111
+ eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
112
+ gt_height, gt_width = eval_mask.shape
113
+
114
+ if "garg" == self.valid_mask_crop:
115
+ eval_mask[
116
+ int(0.40810811 * gt_height) : int(0.99189189 * gt_height),
117
+ int(0.03594771 * gt_width) : int(0.96405229 * gt_width),
118
+ ] = 1
119
+ elif "eigen" == self.valid_mask_crop:
120
+ eval_mask[
121
+ int(0.3324324 * gt_height) : int(0.91351351 * gt_height),
122
+ int(0.0359477 * gt_width) : int(0.96405229 * gt_width),
123
+ ] = 1
124
+
125
+ eval_mask.reshape(valid_mask.shape)
126
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
127
+ return valid_mask
DepthMaster/src/dataset/mixed_sampler.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import torch
26
+ from torch.utils.data import (
27
+ BatchSampler,
28
+ RandomSampler,
29
+ SequentialSampler,
30
+ )
31
+
32
+
33
+ class MixedBatchSampler(BatchSampler):
34
+ """Sample one batch from a selected dataset with given probability.
35
+ Compatible with datasets at different resolution
36
+ """
37
+
38
+ def __init__(
39
+ self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None
40
+ ):
41
+ self.base_sampler = None
42
+ self.batch_size = batch_size
43
+ self.shuffle = shuffle
44
+ self.drop_last = drop_last
45
+ self.generator = generator
46
+
47
+ self.src_dataset_ls = src_dataset_ls
48
+ self.n_dataset = len(self.src_dataset_ls)
49
+
50
+ # Dataset length
51
+ self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
52
+ self.cum_dataset_length = [
53
+ sum(self.dataset_length[:i]) for i in range(self.n_dataset)
54
+ ] # cumulative dataset length
55
+
56
+ # BatchSamplers for each source dataset
57
+ if self.shuffle:
58
+ self.src_batch_samplers = [
59
+ BatchSampler(
60
+ sampler=RandomSampler(
61
+ ds, replacement=False, generator=self.generator
62
+ ),
63
+ batch_size=self.batch_size,
64
+ drop_last=self.drop_last,
65
+ )
66
+ for ds in self.src_dataset_ls
67
+ ]
68
+ else:
69
+ self.src_batch_samplers = [
70
+ BatchSampler(
71
+ sampler=SequentialSampler(ds),
72
+ batch_size=self.batch_size,
73
+ drop_last=self.drop_last,
74
+ )
75
+ for ds in self.src_dataset_ls
76
+ ]
77
+ self.raw_batches = [
78
+ list(bs) for bs in self.src_batch_samplers
79
+ ] # index in original dataset
80
+ self.n_batches = [len(b) for b in self.raw_batches]
81
+ self.n_total_batch = sum(self.n_batches)
82
+
83
+ # sampling probability
84
+ if prob is None:
85
+ # if not given, decide by dataset length
86
+ self.prob = torch.tensor(self.n_batches) / self.n_total_batch
87
+ else:
88
+ self.prob = torch.as_tensor(prob)
89
+
90
+ def __iter__(self):
91
+ """_summary_
92
+
93
+ Yields:
94
+ list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls
95
+ """
96
+ for _ in range(self.n_total_batch):
97
+ idx_ds = torch.multinomial(
98
+ self.prob, 1, replacement=True, generator=self.generator
99
+ ).item()
100
+ # if batch list is empty, generate new list
101
+ if 0 == len(self.raw_batches[idx_ds]):
102
+ self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
103
+ # get a batch from list
104
+ batch_raw = self.raw_batches[idx_ds].pop()
105
+ # shift by cumulative dataset length
106
+ shift = self.cum_dataset_length[idx_ds]
107
+ batch = [n + shift for n in batch_raw]
108
+
109
+ yield batch
110
+
111
+ def __len__(self):
112
+ return self.n_total_batch
113
+
114
+
115
+ # Unit test
116
+ if "__main__" == __name__:
117
+ from torch.utils.data import ConcatDataset, DataLoader, Dataset
118
+
119
+ class SimpleDataset(Dataset):
120
+ def __init__(self, start, len) -> None:
121
+ super().__init__()
122
+ self.start = start
123
+ self.len = len
124
+
125
+ def __len__(self):
126
+ return self.len
127
+
128
+ def __getitem__(self, index):
129
+ return self.start + index
130
+
131
+ dataset_1 = SimpleDataset(0, 10)
132
+ dataset_2 = SimpleDataset(200, 20)
133
+ dataset_3 = SimpleDataset(1000, 50)
134
+
135
+ concat_dataset = ConcatDataset(
136
+ [dataset_1, dataset_2, dataset_3]
137
+ ) # will directly concatenate
138
+
139
+ mixed_sampler = MixedBatchSampler(
140
+ src_dataset_ls=[dataset_1, dataset_2, dataset_3],
141
+ batch_size=4,
142
+ drop_last=True,
143
+ shuffle=False,
144
+ prob=[0.6, 0.3, 0.1],
145
+ generator=torch.Generator().manual_seed(0),
146
+ )
147
+
148
+ loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler)
149
+
150
+ for d in loader:
151
+ print(d)
DepthMaster/src/dataset/nyu_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import torch
26
+
27
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
28
+
29
+
30
+ class NYUDataset(BaseDepthDataset):
31
+ def __init__(
32
+ self,
33
+ eigen_valid_mask: bool,
34
+ **kwargs,
35
+ ) -> None:
36
+ super().__init__(
37
+ # NYUv2 dataset parameter
38
+ min_depth=1e-3,
39
+ max_depth=10.0,
40
+ has_filled_depth=True,
41
+ has_egde_mask=False,
42
+ name_mode=DepthFileNameMode.rgb_id,
43
+ **kwargs,
44
+ )
45
+
46
+ self.eigen_valid_mask = eigen_valid_mask
47
+
48
+ def _read_depth_file(self, rel_path):
49
+ depth_in = self._read_image(rel_path)
50
+ # Decode NYU depth
51
+ depth_decoded = depth_in / 1000.0
52
+ return depth_decoded
53
+
54
+ def _get_valid_mask(self, depth: torch.Tensor):
55
+ valid_mask = super()._get_valid_mask(depth)
56
+
57
+ # Eigen crop for evaluation
58
+ if self.eigen_valid_mask:
59
+ eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
60
+ eval_mask[45:471, 41:601] = 1
61
+ eval_mask.reshape(valid_mask.shape)
62
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
63
+
64
+ return valid_mask
DepthMaster/src/dataset/scannet_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
26
+
27
+
28
+ class ScanNetDataset(BaseDepthDataset):
29
+ def __init__(
30
+ self,
31
+ **kwargs,
32
+ ) -> None:
33
+ super().__init__(
34
+ # ScanNet data parameter
35
+ min_depth=1e-3,
36
+ max_depth=10,
37
+ has_filled_depth=False,
38
+ has_egde_mask=False,
39
+ name_mode=DepthFileNameMode.id,
40
+ **kwargs,
41
+ )
42
+
43
+ def _read_depth_file(self, rel_path):
44
+ depth_in = self._read_image(rel_path)
45
+ # Decode ScanNet depth
46
+ depth_decoded = depth_in / 1000.0
47
+ return depth_decoded
DepthMaster/src/dataset/vkitti_dataset.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+ import torch
25
+
26
+ from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode
27
+ from .kitti_dataset import KITTIDataset
28
+
29
+
30
+ class VirtualKITTIDataset(BaseDepthDataset):
31
+ def __init__(
32
+ self,
33
+ kitti_bm_crop, # Crop to KITTI benchmark size
34
+ valid_mask_crop, # Evaluation mask. [None, garg or eigen]
35
+ **kwargs,
36
+ ) -> None:
37
+ super().__init__(
38
+ # virtual KITTI data parameter
39
+ min_depth=1e-5,
40
+ max_depth=80, # 655.35
41
+ has_filled_depth=False,
42
+ has_egde_mask=False,
43
+ name_mode=DepthFileNameMode.id,
44
+ **kwargs,
45
+ )
46
+ self.kitti_bm_crop = kitti_bm_crop
47
+ self.valid_mask_crop = valid_mask_crop
48
+ assert self.valid_mask_crop in [
49
+ None,
50
+ "garg", # set evaluation mask according to Garg ECCV16
51
+ "eigen", # set evaluation mask according to Eigen NIPS14
52
+ ], f"Unknown crop type: {self.valid_mask_crop}"
53
+
54
+ # Filter out empty depth
55
+ self.filenames = [f for f in self.filenames if "None" != f[1]]
56
+
57
+ def _read_depth_file(self, rel_path):
58
+ depth_in = self._read_image(rel_path)
59
+ # Decode vKITTI depth
60
+ depth_decoded = depth_in / 100.0
61
+ return depth_decoded
62
+
63
+ def _load_rgb_data(self, rgb_rel_path):
64
+ rgb_data = super()._load_rgb_data(rgb_rel_path)
65
+ if self.kitti_bm_crop:
66
+ rgb_data = {
67
+ k: KITTIDataset.kitti_benchmark_crop(v) for k, v in rgb_data.items()
68
+ }
69
+ return rgb_data
70
+
71
+ def _load_depth_data(self, depth_rel_path, filled_rel_path):
72
+ depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path)
73
+ if self.kitti_bm_crop:
74
+ depth_data = {
75
+ k: KITTIDataset.kitti_benchmark_crop(v) for k, v in depth_data.items()
76
+ }
77
+ return depth_data
78
+
79
+ def _get_valid_mask(self, depth: torch.Tensor):
80
+ # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py
81
+ valid_mask = super()._get_valid_mask(depth) # [1, H, W]
82
+
83
+ if self.valid_mask_crop is not None:
84
+ eval_mask = torch.zeros_like(valid_mask.squeeze()).bool()
85
+ gt_height, gt_width = eval_mask.shape
86
+
87
+ if "garg" == self.valid_mask_crop:
88
+ eval_mask[
89
+ int(0.40810811 * gt_height) : int(0.99189189 * gt_height),
90
+ int(0.03594771 * gt_width) : int(0.96405229 * gt_width),
91
+ ] = 1
92
+ elif "eigen" == self.valid_mask_crop:
93
+ eval_mask[
94
+ int(0.3324324 * gt_height) : int(0.91351351 * gt_height),
95
+ int(0.0359477 * gt_width) : int(0.96405229 * gt_width),
96
+ ] = 1
97
+
98
+ eval_mask.reshape(valid_mask.shape)
99
+ valid_mask = torch.logical_and(valid_mask, eval_mask)
100
+ return valid_mask
DepthMaster/src/trainer/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-05-17
3
+
4
+ from .trainer_s1 import DepthMasterTrainerS1
5
+ from .trainer_s2 import DepthMasterTrainerS2
6
+
7
+
8
+ trainer_cls_name_dict = {
9
+ "DepthMasterTrainerS1": DepthMasterTrainerS1,
10
+ "DepthMasterTrainerS2": DepthMasterTrainerS2,
11
+ }
12
+
13
+
14
+ def get_trainer_cls(trainer_name):
15
+ return trainer_cls_name_dict[trainer_name]
DepthMaster/src/trainer/trainer_s1.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-07-13
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+ import logging
25
+ import os
26
+ import random
27
+ import shutil
28
+ from datetime import datetime
29
+ from typing import List, Union
30
+
31
+ import numpy as np
32
+ import torch
33
+ from omegaconf import OmegaConf
34
+ from torch.optim import Adam
35
+ from torch.optim.lr_scheduler import LambdaLR
36
+ from torch.utils.data import DataLoader
37
+ from tqdm import tqdm
38
+ from PIL import Image
39
+ import torch.nn.functional as F
40
+
41
+ from depthmaster import DepthMasterPipeline, DepthMasterDepthOutput
42
+ from src.util import metric
43
+ from src.util.data_loader import skip_first_batches
44
+ from src.util.logging_util import tb_logger, eval_dic_to_text
45
+ from src.util.loss import get_loss, SSIM
46
+ from src.util.lr_scheduler import IterExponential
47
+ from src.util.metric import MetricTracker
48
+ from src.util.alignment import (
49
+ align_depth_least_square,
50
+ depth2disparity,
51
+ disparity2depth,
52
+ )
53
+ from src.util.seeding import generate_seed_sequence
54
+ from src.util.build_mlp import build_mlp_
55
+ from torchvision.transforms import Normalize
56
+ from external_encoder.dinov2.dinov2 import DINOv2
57
+
58
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
59
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
60
+
61
+ class DepthMasterTrainerS1:
62
+ def __init__(
63
+ self,
64
+ cfg: OmegaConf,
65
+ model: DepthMasterPipeline,
66
+ train_dataloader: DataLoader,
67
+ device,
68
+ base_ckpt_dir,
69
+ out_dir_ckpt,
70
+ out_dir_eval,
71
+ out_dir_vis,
72
+ accumulation_steps: int,
73
+ val_dataloaders: List[DataLoader] = None,
74
+ vis_dataloaders: List[DataLoader] = None,
75
+ ):
76
+ self.cfg: OmegaConf = cfg
77
+ self.model: DepthMasterPipeline = model
78
+ self.device = device
79
+ self.seed: Union[int, None] = (
80
+ self.cfg.trainer.init_seed
81
+ ) # used to generate seed sequence, set to `None` to train w/o seeding
82
+ self.out_dir_ckpt = out_dir_ckpt
83
+ self.out_dir_eval = out_dir_eval
84
+ self.out_dir_vis = out_dir_vis
85
+ self.train_loader: DataLoader = train_dataloader
86
+ self.val_loaders: List[DataLoader] = val_dataloaders
87
+ self.vis_loaders: List[DataLoader] = vis_dataloaders
88
+ self.accumulation_steps: int = accumulation_steps
89
+
90
+
91
+ # Encode empty text prompt
92
+ self.model.encode_empty_text()
93
+ self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
94
+ self.model.unet.enable_xformers_memory_efficient_attention()
95
+
96
+ # Initialize DINOv2 encoder
97
+ self.dinov2_encoder = DINOv2(model_name='vitg')
98
+ dinov2_encoder_dict = self.dinov2_encoder.state_dict()
99
+ pretrained_ckpt_dict = torch.load(f'checkpoints/depth_anything_v2_vitg.pth', map_location='cpu')
100
+ pretrained_dict = {k.replace('pretrained.', ''): v for k, v in pretrained_ckpt_dict.items() if k.replace('pretrained.', '') in dinov2_encoder_dict}
101
+ self.dinov2_encoder.load_state_dict(pretrained_dict)
102
+ del self.dinov2_encoder.head
103
+ self.dinov2_encoder.head = torch.nn.Identity()
104
+ self.dinov2_encoder.eval()
105
+
106
+
107
+ # Initialize adapter to align the feat dimension of SD and DINOv2
108
+ self.dinov2_adapter = build_mlp_(hidden_size=1280, projector_dim=1536, z_dim=1536)
109
+
110
+
111
+ # Trainability
112
+ self.dinov2_adapter.requires_grad_(True)
113
+ self.dinov2_encoder.requires_grad_(False)
114
+ self.model.vae.requires_grad_(False)
115
+ self.model.text_encoder.requires_grad_(False)
116
+ self.model.unet.requires_grad_(True)
117
+
118
+
119
+ # Optimizer !should be defined after input layer is adapted
120
+ lr = self.cfg.lr
121
+ self.optimizer = Adam([
122
+ {'params': self.model.unet.parameters(), 'lr': lr},
123
+ {'params': self.dinov2_adapter.parameters(), 'lr': lr}
124
+ ])
125
+
126
+ # LR scheduler
127
+ lr_func = IterExponential(
128
+ total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
129
+ final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
130
+ warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
131
+ )
132
+ self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
133
+
134
+ # Loss
135
+ self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
136
+
137
+ # Eval metrics
138
+ self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
139
+ self.train_metrics = MetricTracker(*["loss", "feat_align_loss"])
140
+ self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
141
+ # main metric for best checkpoint saving
142
+ self.main_val_metric = cfg.validation.main_val_metric
143
+ self.main_val_metric_goal = cfg.validation.main_val_metric_goal
144
+ assert (
145
+ self.main_val_metric in cfg.eval.eval_metrics
146
+ ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
147
+ self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
148
+
149
+ # Settings
150
+ self.max_epoch = self.cfg.max_epoch
151
+ self.max_iter = self.cfg.max_iter
152
+ self.gradient_accumulation_steps = accumulation_steps
153
+ self.gt_depth_type = self.cfg.gt_depth_type
154
+ self.gt_mask_type = self.cfg.gt_mask_type
155
+ self.save_period = self.cfg.trainer.save_period
156
+ self.backup_period = self.cfg.trainer.backup_period
157
+ self.val_period = self.cfg.trainer.validation_period
158
+ self.vis_period = self.cfg.trainer.visualization_period
159
+
160
+ # Internal variables
161
+ self.epoch = 1
162
+ self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
163
+ self.effective_iter = 0 # how many times optimizer.step() is called
164
+ self.in_evaluation = False
165
+ self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
166
+
167
+
168
+ def train(self, t_end=None):
169
+ logging.info("Start training")
170
+
171
+ device = self.device
172
+ self.model.to(device)
173
+ self.dinov2_encoder.to(device)
174
+ self.dinov2_adapter.to(device)
175
+ self.visualize()
176
+
177
+ if self.in_evaluation:
178
+ logging.info(
179
+ "Last evaluation was not finished, will do evaluation before continue training."
180
+ )
181
+ self.validate()
182
+
183
+ self.train_metrics.reset()
184
+ accumulated_step = 0
185
+
186
+ progress_bar = tqdm(
187
+ range(0, self.max_iter),
188
+ initial=self.effective_iter,
189
+ desc="iter"
190
+ )
191
+
192
+ for epoch in range(self.epoch, self.max_epoch + 1):
193
+ self.epoch = epoch
194
+ logging.debug(f"epoch: {self.epoch}")
195
+
196
+ # Skip previous batches when resume
197
+ for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
198
+ self.model.unet.train()
199
+ self.dinov2_adapter.train()
200
+
201
+ # >>> With gradient accumulation >>>
202
+
203
+ # Get data
204
+ rgb = batch["rgb_norm"].to(device)
205
+ depth_gt_for_latent = batch[self.gt_depth_type].to(device)
206
+
207
+ if self.gt_mask_type is not None:
208
+ valid_mask_for_latent = batch[self.gt_mask_type].to(device)
209
+ invalid_mask = ~valid_mask_for_latent
210
+ valid_mask_down = ~torch.max_pool2d(
211
+ invalid_mask.float(), 8, 8
212
+ ).bool()
213
+ valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1))
214
+ else:
215
+ raise NotImplementedError
216
+
217
+ batch_size = rgb.shape[0]
218
+
219
+ with torch.no_grad():
220
+ # Encode image
221
+ rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
222
+ # Encode GT depth
223
+ gt_depth_latent = self.encode_depth(
224
+ depth_gt_for_latent
225
+ ) # [B, 4, h, w]
226
+ # DINOv2 feat
227
+ dinov2_input_rgb = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(rgb)
228
+ dinov2_input_rgb = F.interpolate(dinov2_input_rgb, scale_factor=0.875, mode='bicubic')
229
+ dinov2_z = self.dinov2_encoder.forward_features(dinov2_input_rgb)['x_norm_patchtokens']
230
+
231
+ # Text embedding
232
+ text_embed = self.empty_text_embed.to(device).repeat(
233
+ (batch_size, 1, 1)
234
+ ) # [B, 77, 1024]
235
+
236
+ # Predict the noise residual
237
+ rgb_latent = self.model.unet(
238
+ rgb_latent, 1, text_embed
239
+ ) # [B, 4, h, w]
240
+
241
+ feat_16 = rgb_latent.feat_64
242
+ rgb_latent = rgb_latent.sample
243
+
244
+ if self.gt_mask_type is not None:
245
+ loss = self.loss(
246
+ rgb_latent[valid_mask_down].float(),
247
+ gt_depth_latent[valid_mask_down].float(),
248
+ ).mean()
249
+ else:
250
+ loss = self.loss(rgb_latent.float(), gt_depth_latent.float()).mean()
251
+
252
+ self.train_metrics.update("loss", loss.item())
253
+
254
+ # feat align loss
255
+ b, c, h, w = feat_16.shape
256
+ _, _, H, W = rgb_latent.shape
257
+ # update dinov2_adapter
258
+ unet_16_feat_aligned = self.dinov2_adapter(feat_16.permute(0, 2, 3, 1).reshape(batch_size, -1, c))
259
+ if torch.isnan(rgb_latent).any():
260
+ logging.warning("model_pred contains NaN.")
261
+
262
+ dinov2_z = dinov2_z.reshape(b, int(H/2), int(W/2), -1).permute(0, 3, 1, 2)
263
+ dinov2_z = F.interpolate(dinov2_z, size=(h, w), mode='bicubic').permute(0, 2, 3, 1).reshape(b, h*w, -1)
264
+
265
+ # kl loss
266
+ unet_16_feat_aligned = F.softmax(unet_16_feat_aligned, dim=-1)
267
+ dinov2_z = F.softmax(dinov2_z, dim=-1)
268
+
269
+ loss_feat_align = F.kl_div(unet_16_feat_aligned.log(), dinov2_z)
270
+
271
+ self.train_metrics.update("feat_align_loss", loss_feat_align)
272
+
273
+ loss += self.cfg.loss_feat_align.lamda * loss_feat_align
274
+
275
+
276
+ loss = loss / self.gradient_accumulation_steps
277
+ loss.backward()
278
+ accumulated_step += 1
279
+
280
+ self.n_batch_in_epoch += 1
281
+ # Practical batch end
282
+
283
+ # Perform optimization step
284
+ if accumulated_step >= self.gradient_accumulation_steps:
285
+ self.optimizer.step()
286
+ self.lr_scheduler.step()
287
+ self.optimizer.zero_grad()
288
+ accumulated_step = 0
289
+
290
+ self.effective_iter += 1
291
+ progress_bar.update(1)
292
+
293
+ # Log to tensorboard
294
+ accumulated_loss = self.train_metrics.result()["loss"]
295
+ logs = {"loss": accumulated_loss}
296
+ progress_bar.set_postfix(**logs)
297
+ tb_logger.log_dic(
298
+ {
299
+ f"train/{k}": v
300
+ for k, v in self.train_metrics.result().items()
301
+ },
302
+ global_step=self.effective_iter,
303
+ )
304
+ tb_logger.writer.add_scalar(
305
+ "lr",
306
+ self.lr_scheduler.get_last_lr()[0],
307
+ global_step=self.effective_iter,
308
+ )
309
+ tb_logger.writer.add_scalar(
310
+ "n_batch_in_epoch",
311
+ self.n_batch_in_epoch,
312
+ global_step=self.effective_iter,
313
+ )
314
+ self.train_metrics.reset()
315
+
316
+ # Per-step callback
317
+ self._train_step_callback()
318
+
319
+ # End of training
320
+ if self.max_iter > 0 and self.effective_iter >= self.max_iter:
321
+ self.save_checkpoint(
322
+ ckpt_name=self._get_backup_ckpt_name(),
323
+ save_train_state=False,
324
+ )
325
+ logging.info("Training ended.")
326
+ return
327
+ # Time's up
328
+ elif t_end is not None and datetime.now() >= t_end:
329
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
330
+ logging.info("Time is up, training paused.")
331
+ return
332
+
333
+ torch.cuda.empty_cache()
334
+ # <<< Effective batch end <<<
335
+
336
+ # Epoch end
337
+ self.n_batch_in_epoch = 0
338
+
339
+ def encode_depth(self, depth_in):
340
+ # stack depth into 3-channel
341
+ stacked = self.stack_depth_images(depth_in)
342
+ # encode using VAE encoder
343
+ depth_latent = self.model.encode_rgb(stacked)
344
+ return depth_latent
345
+
346
+ @staticmethod
347
+ def stack_depth_images(depth_in):
348
+ if 4 == len(depth_in.shape):
349
+ stacked = depth_in.repeat(1, 3, 1, 1)
350
+ elif 3 == len(depth_in.shape):
351
+ stacked = depth_in.unsqueeze(1)
352
+ stacked = depth_in.repeat(1, 3, 1, 1)
353
+ return stacked
354
+
355
+ def _train_step_callback(self):
356
+ """Executed after every iteration"""
357
+ # Save backup (with a larger interval, without training states)
358
+ if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
359
+ self.save_checkpoint(
360
+ ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
361
+ )
362
+
363
+ _is_latest_saved = False
364
+ # Validation
365
+ if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
366
+ self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
367
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
368
+ _is_latest_saved = True
369
+ self.validate()
370
+ self.in_evaluation = False
371
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
372
+
373
+ # Save training checkpoint (can be resumed)
374
+ if (
375
+ self.save_period > 0
376
+ and 0 == self.effective_iter % self.save_period
377
+ and not _is_latest_saved
378
+ ):
379
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
380
+
381
+ # Visualization
382
+ if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
383
+ self.visualize()
384
+
385
+ def validate(self):
386
+ for i, val_loader in enumerate(self.val_loaders):
387
+ val_dataset_name = val_loader.dataset.disp_name
388
+ val_metric_dic = self.validate_single_dataset(
389
+ data_loader=val_loader, metric_tracker=self.val_metrics
390
+ )
391
+ logging.info(
392
+ f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}"
393
+ )
394
+ tb_logger.log_dic(
395
+ {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
396
+ global_step=self.effective_iter,
397
+ )
398
+ # save to file
399
+ eval_text = eval_dic_to_text(
400
+ val_metrics=val_metric_dic,
401
+ dataset_name=val_dataset_name,
402
+ sample_list_path=val_loader.dataset.filename_ls_path,
403
+ )
404
+ _save_to = os.path.join(
405
+ self.out_dir_eval,
406
+ f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
407
+ )
408
+ with open(_save_to, "w+") as f:
409
+ f.write(eval_text)
410
+
411
+ # Update main eval metric
412
+ if 0 == i:
413
+ main_eval_metric = val_metric_dic[self.main_val_metric]
414
+ if (
415
+ "minimize" == self.main_val_metric_goal
416
+ and main_eval_metric < self.best_metric
417
+ or "maximize" == self.main_val_metric_goal
418
+ and main_eval_metric > self.best_metric
419
+ ):
420
+ self.best_metric = main_eval_metric
421
+ logging.info(
422
+ f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
423
+ )
424
+ # Save a checkpoint
425
+ self.save_checkpoint(
426
+ ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
427
+ )
428
+
429
+ def visualize(self):
430
+ for val_loader in self.vis_loaders:
431
+ vis_dataset_name = val_loader.dataset.disp_name
432
+ vis_out_dir = os.path.join(
433
+ self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
434
+ )
435
+ os.makedirs(vis_out_dir, exist_ok=True)
436
+ _ = self.validate_single_dataset(
437
+ data_loader=val_loader,
438
+ metric_tracker=self.val_metrics,
439
+ save_to_dir=vis_out_dir,
440
+ )
441
+
442
+ @torch.no_grad()
443
+ def validate_single_dataset(
444
+ self,
445
+ data_loader: DataLoader,
446
+ metric_tracker: MetricTracker,
447
+ save_to_dir: str = None,
448
+ ):
449
+ self.model.to(self.device)
450
+ metric_tracker.reset()
451
+
452
+ # Generate seed sequence for consistent evaluation
453
+ val_init_seed = self.cfg.validation.init_seed
454
+ val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
455
+
456
+ for i, batch in enumerate(
457
+ tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
458
+ start=1,
459
+ ):
460
+ assert 1 == data_loader.batch_size
461
+ # Read input image
462
+ rgb_int = batch["rgb_int"] # [3, H, W]
463
+ # GT depth
464
+ depth_raw_ts = batch["depth_raw_linear"].squeeze()
465
+ depth_raw = depth_raw_ts.numpy()
466
+ depth_raw_ts = depth_raw_ts.to(self.device)
467
+ valid_mask_ts = batch["valid_mask_raw"].squeeze()
468
+ valid_mask = valid_mask_ts.numpy()
469
+ valid_mask_ts = valid_mask_ts.to(self.device)
470
+
471
+ # Predict depth
472
+ pipe_out: DepthMasterDepthOutput = self.model(
473
+ rgb_int,
474
+ processing_res=self.cfg.validation.processing_res,
475
+ match_input_res=self.cfg.validation.match_input_res,
476
+ batch_size=1, # use batch size 1 to increase reproducibility
477
+ color_map=None,
478
+ show_progress_bar=False,
479
+ resample_method=self.cfg.validation.resample_method,
480
+ )
481
+
482
+ depth_pred: np.ndarray = pipe_out.depth_np.squeeze()
483
+
484
+ if "least_square" == self.cfg.eval.alignment:
485
+ depth_pred, scale, shift = align_depth_least_square(
486
+ gt_arr=depth_raw,
487
+ pred_arr=depth_pred,
488
+ valid_mask_arr=valid_mask,
489
+ return_scale_shift=True,
490
+ max_resolution=self.cfg.eval.align_max_res,
491
+ )
492
+ elif "least_square_disparity" == self.cfg.eval.alignment:
493
+ gt_disparity = depth_raw
494
+ gt_non_neg_mask = gt_disparity > 0
495
+
496
+ # LS alignment in disparity space
497
+ pred_non_neg_mask = depth_pred > 0
498
+ valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
499
+
500
+ disparity_pred, scale, shift = align_depth_least_square(
501
+ gt_arr=gt_disparity,
502
+ pred_arr=depth_pred,
503
+ valid_mask_arr=valid_nonnegative_mask,
504
+ return_scale_shift=True,
505
+ )
506
+ # convert to depth
507
+ disparity_pred = np.clip(
508
+ disparity_pred, a_min=1e-3, a_max=None
509
+ ) # avoid 0 disparity
510
+ depth_pred = disparity2depth(disparity_pred)
511
+ depth_raw_ts = disparity2depth(depth_raw_ts)
512
+ elif "least_square_sqrt_disp" == self.cfg.eval.alignment:
513
+ gt_sqrt_disp = depth_raw
514
+ gt_non_neg_mask = gt_sqrt_disp > 0
515
+
516
+ # LS alignment in sqrt space
517
+ pred_non_neg_mask = depth_pred > 0
518
+ valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
519
+ depth_sqrt_disp_pred, scale, shift = align_depth_least_square(
520
+ gt_arr=gt_sqrt_disp,
521
+ pred_arr=depth_pred,
522
+ valid_mask_arr=valid_mask,
523
+ return_scale_shift=True,
524
+ )
525
+ # convert to depth
526
+ disparity_pred = depth_sqrt_disp_pred ** 2
527
+ depth_raw_ts = torch.pow(depth_raw_ts, 2)
528
+ # convert to depth
529
+ disparity_pred = np.clip(
530
+ disparity_pred, a_min=1e-3, a_max=None
531
+ ) # avoid 0 disparity
532
+ depth_pred = disparity2depth(disparity_pred)
533
+ depth_raw_ts = disparity2depth(depth_raw_ts)
534
+ else:
535
+ raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
536
+
537
+ # Clip to dataset min max
538
+ depth_pred = np.clip(
539
+ depth_pred,
540
+ a_min=data_loader.dataset.min_depth,
541
+ a_max=data_loader.dataset.max_depth,
542
+ )
543
+
544
+ # clip to d > 0 for evaluation
545
+ depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
546
+
547
+ # Evaluate
548
+ sample_metric = []
549
+ depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
550
+
551
+ for met_func in self.metric_funcs:
552
+ _metric_name = met_func.__name__
553
+ _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
554
+ sample_metric.append(_metric.__str__())
555
+ metric_tracker.update(_metric_name, _metric)
556
+
557
+ # Save as 16-bit uint png
558
+ if save_to_dir is not None:
559
+ img_name = batch["rgb_relative_path"][0].replace("/", "_")
560
+ png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
561
+ depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16)
562
+ Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
563
+
564
+ return metric_tracker.result()
565
+
566
+ def _get_next_seed(self):
567
+ if 0 == len(self.global_seed_sequence):
568
+ self.global_seed_sequence = generate_seed_sequence(
569
+ initial_seed=self.seed,
570
+ length=self.max_iter * self.gradient_accumulation_steps,
571
+ )
572
+ logging.info(
573
+ f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
574
+ )
575
+ return self.global_seed_sequence.pop()
576
+
577
+ def save_checkpoint(self, ckpt_name, save_train_state):
578
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
579
+ logging.info(f"Saving checkpoint to: {ckpt_dir}")
580
+ # Backup previous checkpoint
581
+ temp_ckpt_dir = None
582
+ if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
583
+ temp_ckpt_dir = os.path.join(
584
+ os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
585
+ )
586
+ if os.path.exists(temp_ckpt_dir):
587
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
588
+ os.rename(ckpt_dir, temp_ckpt_dir)
589
+ logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
590
+
591
+ # Save UNet
592
+ unet_path = os.path.join(ckpt_dir, "unet")
593
+ self.model.unet.save_pretrained(unet_path, safe_serialization=False)
594
+ logging.info(f"UNet is saved to: {unet_path}")
595
+
596
+ # Save DINOv2_Adapter
597
+ adapter_path = os.path.join(ckpt_dir, "dinov2_adapter.pth")
598
+ state_dict = self.dinov2_adapter.state_dict()
599
+ torch.save(state_dict, adapter_path)
600
+ logging.info(f"dinov2_adapter is saved to: {adapter_path}")
601
+
602
+ if save_train_state:
603
+ state = {
604
+ "optimizer": self.optimizer.state_dict(),
605
+ "lr_scheduler": self.lr_scheduler.state_dict(),
606
+ "config": self.cfg,
607
+ "effective_iter": self.effective_iter,
608
+ "epoch": self.epoch,
609
+ "n_batch_in_epoch": self.n_batch_in_epoch,
610
+ "best_metric": self.best_metric,
611
+ "in_evaluation": self.in_evaluation,
612
+ "global_seed_sequence": self.global_seed_sequence,
613
+ }
614
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
615
+ torch.save(state, train_state_path)
616
+ # iteration indicator
617
+ f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
618
+ f.close()
619
+
620
+ logging.info(f"Trainer state is saved to: {train_state_path}")
621
+
622
+ # Remove temp ckpt
623
+ if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
624
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
625
+ logging.debug("Old checkpoint backup is removed.")
626
+
627
+ def load_checkpoint(
628
+ self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
629
+ ):
630
+ logging.info(f"Loading checkpoint from: {ckpt_path}")
631
+ # Load UNet
632
+ _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
633
+ self.model.unet.load_state_dict(
634
+ torch.load(_model_path, map_location=self.device)
635
+ )
636
+ self.model.unet.to(self.device)
637
+ logging.info(f"UNet parameters are loaded from {_model_path}")
638
+
639
+ # Load DINOv2_adapter
640
+ _model_path = os.path.join(ckpt_path, "dinov2_adapter.pth")
641
+ self.dinov2_adapter.load_state_dict(
642
+ torch.load(_model_path, map_location=self.device)
643
+ )
644
+ self.dinov2_adapter.to(self.device)
645
+ logging.info(f"dinov2_adapter parameters are loaded from {_model_path}")
646
+
647
+ # Load training states
648
+ if load_trainer_state:
649
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
650
+ self.effective_iter = checkpoint["effective_iter"]
651
+ self.epoch = checkpoint["epoch"]
652
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
653
+ self.in_evaluation = checkpoint["in_evaluation"]
654
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
655
+
656
+ self.best_metric = checkpoint["best_metric"]
657
+
658
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
659
+ logging.info(f"optimizer state is loaded from {ckpt_path}")
660
+
661
+ if resume_lr_scheduler:
662
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
663
+ logging.info(f"LR scheduler state is loaded from {ckpt_path}")
664
+
665
+ logging.info(
666
+ f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
667
+ )
668
+ return
669
+
670
+ def _get_backup_ckpt_name(self):
671
+ return f"iter_{self.effective_iter:06d}"
DepthMaster/src/trainer/trainer_s2.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An official reimplemented version of Marigold training script.
2
+ # Last modified: 2024-04-29
3
+ #
4
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # --------------------------------------------------------------------------
18
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
19
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
20
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
21
+ # More information about the method can be found at https://marigoldmonodepth.github.io
22
+ # --------------------------------------------------------------------------
23
+ import logging
24
+ import os
25
+ import shutil
26
+ from datetime import datetime
27
+ from typing import List, Union
28
+
29
+ import numpy as np
30
+ import torch
31
+ # from diffusers import DDPMScheduler
32
+ from omegaconf import OmegaConf
33
+ # from torch.nn import Conv2d
34
+ # from torch.nn.parameter import Parameter
35
+ from torch.optim import Adam
36
+ from torch.optim.lr_scheduler import LambdaLR
37
+ from torch.utils.data import DataLoader
38
+ from tqdm import tqdm
39
+ from PIL import Image
40
+
41
+ from depthmaster.depthmaster_pipeline import DepthMasterPipeline, DepthMasterDepthOutput
42
+ from src.util import metric
43
+ from src.util.data_loader import skip_first_batches
44
+ from src.util.logging_util import tb_logger, eval_dic_to_text
45
+ from src.util.loss import get_loss
46
+ from src.util.lr_scheduler import IterExponential
47
+ from src.util.metric import MetricTracker
48
+ from src.util.alignment import (
49
+ align_depth_least_square,
50
+ depth2disparity,
51
+ disparity2depth,
52
+ align_depth_least_square_torch_mask,
53
+ align_depth_medium_mask
54
+ )
55
+ # from src.util.alignment import align_depth_least_square
56
+ # from src.util.alignment import align_depth_least_square
57
+ from src.util.seeding import generate_seed_sequence
58
+ import torch.nn.functional as F
59
+
60
+ class DepthMasterTrainerS2:
61
+ def __init__(
62
+ self,
63
+ cfg: OmegaConf,
64
+ model: DepthMasterPipeline,
65
+ train_dataloader: DataLoader,
66
+ device,
67
+ base_ckpt_dir,
68
+ out_dir_ckpt,
69
+ out_dir_eval,
70
+ out_dir_vis,
71
+ accumulation_steps: int,
72
+ val_dataloaders: List[DataLoader] = None,
73
+ vis_dataloaders: List[DataLoader] = None,
74
+ ):
75
+ self.cfg: OmegaConf = cfg
76
+ self.model: DepthMasterPipeline = model
77
+ self.device = device
78
+ self.seed: Union[int, None] = (
79
+ self.cfg.trainer.init_seed
80
+ ) # used to generate seed sequence, set to `None` to train w/o seeding
81
+ self.out_dir_ckpt = out_dir_ckpt
82
+ self.out_dir_eval = out_dir_eval
83
+ self.out_dir_vis = out_dir_vis
84
+ self.train_loader: DataLoader = train_dataloader
85
+ self.val_loaders: List[DataLoader] = val_dataloaders
86
+ self.vis_loaders: List[DataLoader] = vis_dataloaders
87
+ self.accumulation_steps: int = accumulation_steps
88
+
89
+ # Encode empty text prompt
90
+ self.model.encode_empty_text()
91
+ self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
92
+
93
+ self.model.unet.enable_xformers_memory_efficient_attention()
94
+
95
+ # Trainability
96
+ self.model.vae.requires_grad_(False)
97
+ self.model.vae.decoder.requires_grad_(False)
98
+ self.model.text_encoder.requires_grad_(False)
99
+ self.model.unet.requires_grad_(True)
100
+
101
+ # Optimizer !should be defined after input layer is adapted
102
+ lr = self.cfg.lr
103
+ self.optimizer = Adam(self.model.unet.parameters(), lr=lr)
104
+
105
+ # LR scheduler
106
+ lr_func = IterExponential(
107
+ total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
108
+ final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
109
+ warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
110
+ )
111
+ self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
112
+
113
+ # Loss
114
+ self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
115
+ self.grad_loss = get_loss(loss_name=self.cfg.grad_loss.name, ** self.cfg.grad_loss.kwargs)
116
+
117
+
118
+ # Eval metrics
119
+ self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics]
120
+ self.train_metrics = MetricTracker(*["loss", "grad_loss"])
121
+ self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs])
122
+ # main metric for best checkpoint saving
123
+ self.main_val_metric = cfg.validation.main_val_metric
124
+ self.main_val_metric_goal = cfg.validation.main_val_metric_goal
125
+ assert (
126
+ self.main_val_metric in cfg.eval.eval_metrics
127
+ ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics."
128
+ self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8
129
+
130
+ # Settings
131
+ self.max_epoch = self.cfg.max_epoch
132
+ self.max_iter = self.cfg.max_iter
133
+ self.gradient_accumulation_steps = accumulation_steps
134
+ self.gt_depth_type = self.cfg.gt_depth_type
135
+ self.gt_mask_type = self.cfg.gt_mask_type
136
+ self.save_period = self.cfg.trainer.save_period
137
+ self.backup_period = self.cfg.trainer.backup_period
138
+ self.val_period = self.cfg.trainer.validation_period
139
+ self.vis_period = self.cfg.trainer.visualization_period
140
+
141
+
142
+ # Internal variables
143
+ self.epoch = 1
144
+ self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
145
+ self.effective_iter = 0 # how many times optimizer.step() is called
146
+ self.in_evaluation = False
147
+ self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
148
+
149
+
150
+ def grad(self, x):
151
+ # x.shape : n, c, h, w
152
+ diff_x = x[..., 1:, 1:] - x[..., 1:, :-1]
153
+ diff_y = x[..., 1:, 1:] - x[..., :-1, 1:]
154
+
155
+ diff_45 = x[..., :-1, 1:] - x[..., 1:, :-1]
156
+ diff_135 = x[..., 1:, 1:] - x[..., :-1, :-1]
157
+
158
+ # mag = diff_x**2 + diff_y**2
159
+ # # angle_ratio
160
+ # angle = torch.atan(diff_y / (diff_x + 1e-10))
161
+ # result = torch.cat([mag, angle], dim=1)
162
+ result = torch.cat([diff_x, diff_y, diff_45, diff_135], dim=1)
163
+ return result
164
+
165
+ def train(self, t_end=None):
166
+ logging.info("Start training")
167
+
168
+ device = self.device
169
+ self.model.to(device)
170
+
171
+ self.visualize()
172
+
173
+ if self.in_evaluation:
174
+ logging.info(
175
+ "Last evaluation was not finished, will do evaluation before continue training."
176
+ )
177
+ self.validate()
178
+
179
+ self.train_metrics.reset()
180
+ accumulated_step = 0
181
+
182
+ progress_bar = tqdm(
183
+ range(0, self.max_iter),
184
+ initial=self.effective_iter,
185
+ desc="iter"
186
+ )
187
+
188
+ for epoch in range(self.epoch, self.max_epoch + 1):
189
+ self.epoch = epoch
190
+ logging.debug(f"epoch: {self.epoch}")
191
+
192
+ # Skip previous batches when resume
193
+ for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
194
+ self.model.unet.train()
195
+
196
+ # >>> With gradient accumulation >>>
197
+
198
+ # Get data
199
+ rgb = batch["rgb_norm"].to(device)
200
+ depth_gt_for_latent = batch[self.gt_depth_type].to(device)
201
+
202
+ if self.gt_mask_type is not None:
203
+ valid_mask_for_latent = batch[self.gt_mask_type].to(device)
204
+ else:
205
+ raise NotImplementedError
206
+
207
+ batch_size = rgb.shape[0]
208
+
209
+ with torch.no_grad():
210
+ # Encode image
211
+ rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
212
+
213
+ # Text embedding
214
+ text_embed = self.empty_text_embed.to(device).repeat(
215
+ (batch_size, 1, 1)
216
+ ) # [B, 77, 1024]
217
+
218
+ rgb_latent = self.model.unet(
219
+ rgb_latent, 1, text_embed
220
+ ).sample # [B, 4, h, w]
221
+
222
+ depth_pred = self.model.decode_depth(rgb_latent)
223
+ depth_gt_for_loss = depth_gt_for_latent
224
+
225
+ aligned_pred = depth_pred
226
+
227
+
228
+ if self.gt_mask_type is not None:
229
+ loss = self.loss(aligned_pred[valid_mask_for_latent].float(), depth_gt_for_loss[valid_mask_for_latent].float()).mean()
230
+ else:
231
+ loss = self.loss(aligned_pred.float(), depth_gt_for_loss.float()).mean()
232
+
233
+ self.train_metrics.update("loss", loss.item())
234
+
235
+ # grad loss
236
+ depth_gt_for_loss[~valid_mask_for_latent] = 0
237
+ grad_gt = self.grad(depth_gt_for_loss)
238
+ aligned_pred[~valid_mask_for_latent] = 0
239
+ grad_pred = self.grad(aligned_pred)
240
+ grad_loss = self.grad_loss(grad_gt, grad_pred)
241
+ self.train_metrics.update(f"grad_loss", grad_loss.item())
242
+ loss += self.cfg.grad_loss.lamda * grad_loss
243
+
244
+
245
+ loss = loss / self.gradient_accumulation_steps
246
+ loss.backward()
247
+ accumulated_step += 1
248
+
249
+ self.n_batch_in_epoch += 1
250
+ # Practical batch end
251
+
252
+ # Perform optimization step
253
+ if accumulated_step >= self.gradient_accumulation_steps:
254
+ self.optimizer.step()
255
+ self.lr_scheduler.step()
256
+ self.optimizer.zero_grad()
257
+ accumulated_step = 0
258
+
259
+ self.effective_iter += 1
260
+ progress_bar.update(1)
261
+
262
+ # Log to tensorboard
263
+ accumulated_loss = self.train_metrics.result()["loss"]
264
+ logs = {"loss": accumulated_loss}
265
+ progress_bar.set_postfix(**logs)
266
+ tb_logger.log_dic(
267
+ {
268
+ f"train/{k}": v
269
+ for k, v in self.train_metrics.result().items()
270
+ },
271
+ global_step=self.effective_iter,
272
+ )
273
+ tb_logger.writer.add_scalar(
274
+ "lr",
275
+ self.lr_scheduler.get_last_lr()[0],
276
+ global_step=self.effective_iter,
277
+ )
278
+ tb_logger.writer.add_scalar(
279
+ "n_batch_in_epoch",
280
+ self.n_batch_in_epoch,
281
+ global_step=self.effective_iter,
282
+ )
283
+ self.train_metrics.reset()
284
+
285
+ # Per-step callback
286
+ self._train_step_callback()
287
+
288
+ # End of training
289
+ if self.max_iter > 0 and self.effective_iter >= self.max_iter:
290
+ self.save_checkpoint(
291
+ ckpt_name=self._get_backup_ckpt_name(),
292
+ save_train_state=False,
293
+ )
294
+ logging.info("Training ended.")
295
+ return
296
+ # Time's up
297
+ elif t_end is not None and datetime.now() >= t_end:
298
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
299
+ logging.info("Time is up, training paused.")
300
+ return
301
+
302
+ torch.cuda.empty_cache()
303
+ # <<< Effective batch end <<<
304
+
305
+ # Epoch end
306
+ self.n_batch_in_epoch = 0
307
+
308
+ def encode_depth(self, depth_in):
309
+ # stack depth into 3-channel
310
+ stacked = self.stack_depth_images(depth_in)
311
+ # encode using VAE encoder
312
+ depth_latent = self.model.encode_rgb(stacked)
313
+ return depth_latent
314
+
315
+ @staticmethod
316
+ def stack_depth_images(depth_in):
317
+ if 4 == len(depth_in.shape):
318
+ stacked = depth_in.repeat(1, 3, 1, 1)
319
+ elif 3 == len(depth_in.shape):
320
+ stacked = depth_in.unsqueeze(1)
321
+ stacked = depth_in.repeat(1, 3, 1, 1)
322
+ return stacked
323
+
324
+ def _train_step_callback(self):
325
+ """Executed after every iteration"""
326
+ # Save backup (with a larger interval, without training states)
327
+ if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
328
+ self.save_checkpoint(
329
+ ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
330
+ )
331
+
332
+ _is_latest_saved = False
333
+ # Validation
334
+ if self.val_period > 0 and 0 == self.effective_iter % self.val_period:
335
+ self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished
336
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
337
+ _is_latest_saved = True
338
+ self.validate()
339
+ self.in_evaluation = False
340
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
341
+
342
+ # Save training checkpoint (can be resumed)
343
+ if (
344
+ self.save_period > 0
345
+ and 0 == self.effective_iter % self.save_period
346
+ and not _is_latest_saved
347
+ ):
348
+ self.save_checkpoint(ckpt_name="latest", save_train_state=True)
349
+
350
+ # Visualization
351
+ if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period:
352
+ self.visualize()
353
+
354
+ def validate(self):
355
+ for i, val_loader in enumerate(self.val_loaders):
356
+ val_dataset_name = val_loader.dataset.disp_name
357
+ val_metric_dic = self.validate_single_dataset(
358
+ data_loader=val_loader, metric_tracker=self.val_metrics
359
+ )
360
+ logging.info(
361
+ f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}"
362
+ )
363
+ tb_logger.log_dic(
364
+ {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()},
365
+ global_step=self.effective_iter,
366
+ )
367
+ # save to file
368
+ eval_text = eval_dic_to_text(
369
+ val_metrics=val_metric_dic,
370
+ dataset_name=val_dataset_name,
371
+ sample_list_path=val_loader.dataset.filename_ls_path,
372
+ )
373
+ _save_to = os.path.join(
374
+ self.out_dir_eval,
375
+ f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt",
376
+ )
377
+ with open(_save_to, "w+") as f:
378
+ f.write(eval_text)
379
+
380
+ # Update main eval metric
381
+ if 0 == i:
382
+ main_eval_metric = val_metric_dic[self.main_val_metric]
383
+ if (
384
+ "minimize" == self.main_val_metric_goal
385
+ and main_eval_metric < self.best_metric
386
+ or "maximize" == self.main_val_metric_goal
387
+ and main_eval_metric > self.best_metric
388
+ ):
389
+ self.best_metric = main_eval_metric
390
+ logging.info(
391
+ f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}"
392
+ )
393
+ # Save a checkpoint
394
+ self.save_checkpoint(
395
+ ckpt_name=self._get_backup_ckpt_name(), save_train_state=False
396
+ )
397
+
398
+ def visualize(self):
399
+ for val_loader in self.vis_loaders:
400
+ vis_dataset_name = val_loader.dataset.disp_name
401
+ vis_out_dir = os.path.join(
402
+ self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
403
+ )
404
+ os.makedirs(vis_out_dir, exist_ok=True)
405
+ _ = self.validate_single_dataset(
406
+ data_loader=val_loader,
407
+ metric_tracker=self.val_metrics,
408
+ save_to_dir=vis_out_dir,
409
+ )
410
+
411
+ @torch.no_grad()
412
+ def validate_single_dataset(
413
+ self,
414
+ data_loader: DataLoader,
415
+ metric_tracker: MetricTracker,
416
+ save_to_dir: str = None,
417
+ ):
418
+ self.model.to(self.device)
419
+ metric_tracker.reset()
420
+
421
+ # Generate seed sequence for consistent evaluation
422
+ val_init_seed = self.cfg.validation.init_seed
423
+ val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader))
424
+
425
+ for i, batch in enumerate(
426
+ tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"),
427
+ start=1,
428
+ ):
429
+ assert 1 == data_loader.batch_size
430
+ # Read input image
431
+ rgb_int = batch["rgb_int"] # [3, H, W]
432
+ # GT depth
433
+ depth_raw_ts = batch["depth_raw_linear"].squeeze()
434
+ depth_raw = depth_raw_ts.numpy()
435
+ depth_raw_ts = depth_raw_ts.to(self.device)
436
+ valid_mask_ts = batch["valid_mask_raw"].squeeze()
437
+ valid_mask = valid_mask_ts.numpy()
438
+ valid_mask_ts = valid_mask_ts.to(self.device)
439
+
440
+ # Predict depth
441
+ pipe_out: DepthMasterDepthOutput = self.model(
442
+ rgb_int,
443
+ processing_res=self.cfg.validation.processing_res,
444
+ match_input_res=self.cfg.validation.match_input_res,
445
+ batch_size=1, # use batch size 1 to increase reproducibility
446
+ color_map=None,
447
+ show_progress_bar=False,
448
+ resample_method=self.cfg.validation.resample_method,
449
+ )
450
+
451
+ depth_pred: np.ndarray = pipe_out.depth_np.squeeze()
452
+
453
+ if "least_square" == self.cfg.eval.alignment:
454
+ depth_pred, scale, shift = align_depth_least_square(
455
+ gt_arr=depth_raw,
456
+ pred_arr=depth_pred,
457
+ valid_mask_arr=valid_mask,
458
+ return_scale_shift=True,
459
+ max_resolution=self.cfg.eval.align_max_res,
460
+ )
461
+ elif "least_square_disparity" == self.cfg.eval.alignment:
462
+ # gt_disparity = depth_raw
463
+ gt_disparity = depth2disparity(depth_raw)
464
+ gt_non_neg_mask = gt_disparity > 0
465
+
466
+ # LS alignment in disparity space
467
+ pred_non_neg_mask = depth_pred > 0
468
+ valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
469
+
470
+ disparity_pred, scale, shift = align_depth_least_square(
471
+ gt_arr=gt_disparity,
472
+ pred_arr=depth_pred,
473
+ valid_mask_arr=valid_nonnegative_mask,
474
+ return_scale_shift=True,
475
+ )
476
+ # convert to depth
477
+ disparity_pred = np.clip(
478
+ disparity_pred, a_min=1e-3, a_max=None
479
+ ) # avoid 0 disparity
480
+ depth_pred = disparity2depth(disparity_pred)
481
+ depth_raw_ts = disparity2depth(depth_raw_ts)
482
+ elif "least_square_sqrt_disp" == self.cfg.eval.alignment:
483
+ # gt_sqrt_disp = depth_raw
484
+ gt_sqrt_disp = np.sqrt(depth2disparity(depth_raw))
485
+ gt_non_neg_mask = gt_sqrt_disp > 0
486
+
487
+ # LS alignment in sqrt space
488
+ pred_non_neg_mask = depth_pred > 0
489
+ valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask
490
+ depth_sqrt_disp_pred, scale, shift = align_depth_least_square(
491
+ gt_arr=gt_sqrt_disp,
492
+ pred_arr=depth_pred,
493
+ valid_mask_arr=valid_mask,
494
+ return_scale_shift=True,
495
+ )
496
+ # convert to depth
497
+ disparity_pred = depth_sqrt_disp_pred ** 2
498
+ depth_raw_ts = torch.pow(depth_raw_ts, 2)
499
+ # convert to depth
500
+ disparity_pred = np.clip(
501
+ disparity_pred, a_min=1e-3, a_max=None
502
+ ) # avoid 0 disparity
503
+ depth_pred = disparity2depth(disparity_pred)
504
+ depth_raw_ts = disparity2depth(depth_raw_ts)
505
+ else:
506
+ raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}")
507
+
508
+ # Clip to dataset min max
509
+ depth_pred = np.clip(
510
+ depth_pred,
511
+ a_min=data_loader.dataset.min_depth,
512
+ a_max=data_loader.dataset.max_depth,
513
+ )
514
+
515
+ # clip to d > 0 for evaluation
516
+ depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None)
517
+
518
+ # Evaluate
519
+ sample_metric = []
520
+ depth_pred_ts = torch.from_numpy(depth_pred).to(self.device)
521
+
522
+ for met_func in self.metric_funcs:
523
+ _metric_name = met_func.__name__
524
+ _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item()
525
+ sample_metric.append(_metric.__str__())
526
+ metric_tracker.update(_metric_name, _metric)
527
+
528
+ # Save as 16-bit uint png
529
+ if save_to_dir is not None:
530
+ img_name = batch["rgb_relative_path"][0].replace("/", "_")
531
+ png_save_path = os.path.join(save_to_dir, f"{img_name}.png")
532
+ depth_to_save = (pipe_out.depth_np.squeeze() * 65535.0).astype(np.uint16)
533
+ Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
534
+
535
+ return metric_tracker.result()
536
+
537
+ def _get_next_seed(self):
538
+ if 0 == len(self.global_seed_sequence):
539
+ self.global_seed_sequence = generate_seed_sequence(
540
+ initial_seed=self.seed,
541
+ length=self.max_iter * self.gradient_accumulation_steps,
542
+ )
543
+ logging.info(
544
+ f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
545
+ )
546
+ return self.global_seed_sequence.pop()
547
+
548
+ def save_checkpoint(self, ckpt_name, save_train_state):
549
+ ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
550
+ logging.info(f"Saving checkpoint to: {ckpt_dir}")
551
+ # Backup previous checkpoint
552
+ temp_ckpt_dir = None
553
+ if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
554
+ temp_ckpt_dir = os.path.join(
555
+ os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
556
+ )
557
+ if os.path.exists(temp_ckpt_dir):
558
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
559
+ os.rename(ckpt_dir, temp_ckpt_dir)
560
+ logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
561
+
562
+ # Save UNet
563
+ unet_path = os.path.join(ckpt_dir, "unet")
564
+ self.model.unet.save_pretrained(unet_path, safe_serialization=False)
565
+ logging.info(f"UNet is saved to: {unet_path}")
566
+
567
+
568
+ if save_train_state:
569
+ state = {
570
+ "optimizer": self.optimizer.state_dict(),
571
+ "lr_scheduler": self.lr_scheduler.state_dict(),
572
+ "config": self.cfg,
573
+ "effective_iter": self.effective_iter,
574
+ "epoch": self.epoch,
575
+ "n_batch_in_epoch": self.n_batch_in_epoch,
576
+ "best_metric": self.best_metric,
577
+ "in_evaluation": self.in_evaluation,
578
+ "global_seed_sequence": self.global_seed_sequence,
579
+ }
580
+ train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
581
+ torch.save(state, train_state_path)
582
+ # iteration indicator
583
+ f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
584
+ f.close()
585
+
586
+ logging.info(f"Trainer state is saved to: {train_state_path}")
587
+
588
+ # Remove temp ckpt
589
+ if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
590
+ shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
591
+ logging.debug("Old checkpoint backup is removed.")
592
+
593
+ def load_checkpoint(
594
+ self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
595
+ ):
596
+ logging.info(f"Loading checkpoint from: {ckpt_path}")
597
+ # Load UNet
598
+ _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
599
+ self.model.unet.load_state_dict(
600
+ torch.load(_model_path, map_location=self.device)
601
+ )
602
+ self.model.unet.to(self.device)
603
+ logging.info(f"UNet parameters are loaded from {_model_path}")
604
+
605
+
606
+ # Load training states
607
+ if load_trainer_state:
608
+ checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
609
+ self.effective_iter = checkpoint["effective_iter"]
610
+ self.epoch = checkpoint["epoch"]
611
+ self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
612
+ self.in_evaluation = checkpoint["in_evaluation"]
613
+ self.global_seed_sequence = checkpoint["global_seed_sequence"]
614
+
615
+ self.best_metric = checkpoint["best_metric"]
616
+
617
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
618
+ logging.info(f"optimizer state is loaded from {ckpt_path}")
619
+
620
+ if resume_lr_scheduler:
621
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
622
+ logging.info(f"LR scheduler state is loaded from {ckpt_path}")
623
+
624
+ logging.info(
625
+ f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
626
+ )
627
+ return
628
+
629
+ def _get_backup_ckpt_name(self):
630
+ return f"iter_{self.effective_iter:06d}"
DepthMaster/src/util/alignment.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ def align_depth_medium_mask(
29
+ gt: torch.Tensor,
30
+ valid_mask: torch.Tensor,
31
+ max_resolution=None,
32
+ ):
33
+ ori_shape = gt.shape[-2:] # input shape
34
+ batch_size = gt.shape[0]
35
+ # print(gt.shape)
36
+
37
+ # Downsample
38
+ if max_resolution is not None:
39
+ scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
40
+ if scale_factor < 1:
41
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
42
+ gt = downscaler(gt)
43
+ valid_mask = downscaler(valid_mask).bool()
44
+
45
+
46
+ scale_ls = []
47
+ shift_ls = []
48
+
49
+ for i in range(batch_size):
50
+ # print('yes')
51
+
52
+ gt_masked = gt[i][valid_mask[i]]
53
+ shift = torch.median(gt_masked).unsqueeze(0)
54
+ scale = torch.mean(torch.abs(gt_masked - shift)).unsqueeze(0)
55
+ # print(scale)
56
+
57
+ scale_ls.append(scale)
58
+ shift_ls.append(shift)
59
+ # print(len(scale_ls))
60
+
61
+ scale = torch.concat(scale_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
62
+ shift = torch.concat(shift_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
63
+
64
+ return scale, shift
65
+
66
+
67
+ def align_depth_least_square(
68
+ gt_arr: np.ndarray,
69
+ pred_arr: np.ndarray,
70
+ valid_mask_arr: np.ndarray,
71
+ return_scale_shift=True,
72
+ max_resolution=None,
73
+ ):
74
+ ori_shape = pred_arr.shape # input shape
75
+
76
+ gt = gt_arr.squeeze() # [H, W]
77
+ pred = pred_arr.squeeze()
78
+ valid_mask = valid_mask_arr.squeeze()
79
+
80
+ # Downsample
81
+ if max_resolution is not None:
82
+ scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
83
+ if scale_factor < 1:
84
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
85
+ gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
86
+ pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
87
+ valid_mask = (
88
+ downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
89
+ .bool()
90
+ .numpy()
91
+ )
92
+
93
+ assert (
94
+ gt.shape == pred.shape == valid_mask.shape
95
+ ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
96
+
97
+ gt_masked = gt[valid_mask].reshape((-1, 1))
98
+ pred_masked = pred[valid_mask].reshape((-1, 1))
99
+
100
+ # numpy solver
101
+ _ones = np.ones_like(pred_masked)
102
+ A = np.concatenate([pred_masked, _ones], axis=-1)
103
+ X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
104
+ scale, shift = X
105
+
106
+ aligned_pred = pred_arr * scale + shift
107
+
108
+ # restore dimensions
109
+ aligned_pred = aligned_pred.reshape(ori_shape)
110
+
111
+ if return_scale_shift:
112
+ return aligned_pred, scale, shift
113
+ else:
114
+ return aligned_pred
115
+
116
+
117
+ # ******************** disparity space ********************
118
+ def depth2disparity(depth, return_mask=False):
119
+ if isinstance(depth, torch.Tensor):
120
+ disparity = torch.zeros_like(depth)
121
+ elif isinstance(depth, np.ndarray):
122
+ disparity = np.zeros_like(depth)
123
+ non_negtive_mask = depth > 0
124
+ disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
125
+ if return_mask:
126
+ return disparity, non_negtive_mask
127
+ else:
128
+ return disparity
129
+
130
+
131
+ def disparity2depth(disparity, **kwargs):
132
+ return depth2disparity(disparity, **kwargs)
133
+
134
+
135
+ def align_depth_least_square_torch_mask(
136
+ gt: torch.Tensor,
137
+ pred: torch.Tensor,
138
+ valid_mask: torch.Tensor,
139
+ max_resolution=None,
140
+ ):
141
+ ori_shape = pred.shape[-2:] # input shape
142
+ batch_size = gt.shape[0]
143
+
144
+ # gt = gt_arr.squeeze() # [B, H, W]
145
+ # pred = pred_arr.squeeze()
146
+ # valid_mask = valid_mask_arr.squeeze()
147
+
148
+ # Downsample
149
+ if max_resolution is not None:
150
+ scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
151
+ if scale_factor < 1:
152
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
153
+ gt = downscaler(gt)
154
+ pred = downscaler(pred)
155
+ valid_mask = downscaler(valid_mask).bool()
156
+
157
+ assert (
158
+ gt.shape == pred.shape
159
+ ), f"{gt.shape}, {pred.shape}"
160
+
161
+ scale_ls = []
162
+ shift_ls = []
163
+
164
+ for i in range(batch_size):
165
+
166
+ gt_masked = gt[i][valid_mask[i]].view(-1, 1)
167
+ pred_masked = pred[i][valid_mask[i]].view(-1, 1)
168
+
169
+ # torch solver
170
+ ones = torch.ones_like(pred_masked)
171
+ A = torch.cat([pred_masked, ones], dim=-1)
172
+ X, *_ = torch.linalg.lstsq(A, gt_masked)
173
+
174
+ scale, shift = X[0, :].detach(), X[1, :].detach()
175
+ scale_ls.append(scale)
176
+ shift_ls.append(shift)
177
+ scale = torch.concat(scale_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
178
+ shift = torch.concat(shift_ls, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
179
+
180
+ return scale, shift
DepthMaster/src/util/boundary_metrics.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]: # type: ignore
7
+ """Find connected components in the given row and column indices.
8
+
9
+ Args:
10
+ ----
11
+ r (np.ndarray): Row indices.
12
+ c (np.ndarray): Column indices.
13
+
14
+ Yields:
15
+ ------
16
+ List[int]: Indices of connected components.
17
+
18
+ """
19
+ indices = [0]
20
+ for i in range(1, r.size):
21
+ if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
22
+ indices.append(i)
23
+ else:
24
+ yield indices
25
+ indices = [i]
26
+ yield indices
27
+
28
+
29
+ def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
30
+ """Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
31
+
32
+ Args:
33
+ ----
34
+ ratio (np.ndarray): Input ratio matrix.
35
+ threshold (float): Threshold for NMS.
36
+
37
+ Returns:
38
+ -------
39
+ np.ndarray: Binary mask after applying NMS.
40
+
41
+ """
42
+ mask = np.zeros_like(ratio, dtype=bool)
43
+ r, c = np.nonzero(ratio > threshold)
44
+ if len(r) == 0:
45
+ return mask
46
+ for ids in connected_component(r, c):
47
+ values = [ratio[r[i], c[i]] for i in ids]
48
+ mi = np.argmax(values)
49
+ mask[r[ids[mi]], c[ids[mi]]] = True
50
+ return mask
51
+
52
+
53
+ def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
54
+ """Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
55
+
56
+ Args:
57
+ ----
58
+ ratio (np.ndarray): Input ratio matrix.
59
+ threshold (float): Threshold for NMS.
60
+
61
+ Returns:
62
+ -------
63
+ np.ndarray: Binary mask after applying NMS.
64
+
65
+ """
66
+ return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
67
+
68
+
69
+ def fgbg_depth(
70
+ d: np.ndarray, t: float
71
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
72
+ """Find foreground-background relations between neighboring pixels.
73
+
74
+ Args:
75
+ ----
76
+ d (np.ndarray): Depth matrix.
77
+ t (float): Threshold for comparison.
78
+
79
+ Returns:
80
+ -------
81
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
82
+ left, top, right, and bottom foreground-background relations.
83
+
84
+ """
85
+ right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
86
+ left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
87
+ bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
88
+ top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
89
+ return (
90
+ left_is_big_enough,
91
+ top_is_big_enough,
92
+ right_is_big_enough,
93
+ bottom_is_big_enough,
94
+ )
95
+
96
+
97
+ def fgbg_depth_thinned(
98
+ d: np.ndarray, t: float
99
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
100
+ """Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
101
+
102
+ Args:
103
+ ----
104
+ d (np.ndarray): Depth matrix.
105
+ t (float): Threshold for NMS.
106
+
107
+ Returns:
108
+ -------
109
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
110
+ left, top, right, and bottom foreground-background relations with NMS applied.
111
+
112
+ """
113
+ right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
114
+ left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
115
+ bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
116
+ top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
117
+ return (
118
+ left_is_big_enough,
119
+ top_is_big_enough,
120
+ right_is_big_enough,
121
+ bottom_is_big_enough,
122
+ )
123
+
124
+
125
+ def fgbg_binary_mask(
126
+ d: np.ndarray,
127
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
128
+ """Find foreground-background relations between neighboring pixels in binary masks.
129
+
130
+ Args:
131
+ ----
132
+ d (np.ndarray): Binary depth matrix.
133
+
134
+ Returns:
135
+ -------
136
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
137
+ left, top, right, and bottom foreground-background relations in binary masks.
138
+
139
+ """
140
+ assert d.dtype == bool
141
+ right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
142
+ left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
143
+ bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
144
+ top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
145
+ return (
146
+ left_is_big_enough,
147
+ top_is_big_enough,
148
+ right_is_big_enough,
149
+ bottom_is_big_enough,
150
+ )
151
+
152
+
153
+ def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
154
+ """Calculate edge recall for image matting.
155
+
156
+ Args:
157
+ ----
158
+ pr (np.ndarray): Predicted depth matrix.
159
+ gt (np.ndarray): Ground truth binary mask.
160
+ t (float): Threshold for NMS.
161
+
162
+ Returns:
163
+ -------
164
+ float: Edge recall value.
165
+
166
+ """
167
+ assert gt.dtype == bool
168
+ ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
169
+ ag, bg, cg, dg = fgbg_binary_mask(gt)
170
+ return 0.25 * (
171
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
172
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
173
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
174
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
175
+ )
176
+
177
+
178
+ def boundary_f1(
179
+ pr: np.ndarray,
180
+ gt: np.ndarray,
181
+ t: float,
182
+ return_p: bool = False,
183
+ return_r: bool = False,
184
+ ) -> float:
185
+ """Calculate Boundary F1 score.
186
+
187
+ Args:
188
+ ----
189
+ pr (np.ndarray): Predicted depth matrix.
190
+ gt (np.ndarray): Ground truth depth matrix.
191
+ t (float): Threshold for comparison.
192
+ return_p (bool, optional): If True, return precision. Defaults to False.
193
+ return_r (bool, optional): If True, return recall. Defaults to False.
194
+
195
+ Returns:
196
+ -------
197
+ float: Boundary F1 score, or precision, or recall depending on the flags.
198
+
199
+ """
200
+ ap, bp, cp, dp = fgbg_depth(pr, t)
201
+ ag, bg, cg, dg = fgbg_depth(gt, t)
202
+
203
+ r = 0.25 * (
204
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
205
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
206
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
207
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
208
+ )
209
+ p = 0.25 * (
210
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
211
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
212
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
213
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
214
+ )
215
+ if r + p == 0:
216
+ return 0.0
217
+ if return_p:
218
+ return p
219
+ if return_r:
220
+ return r
221
+ return 2 * (r * p) / (r + p)
222
+
223
+
224
+ def get_thresholds_and_weights(
225
+ t_min: float, t_max: float, N: int
226
+ ) -> Tuple[np.ndarray, np.ndarray]:
227
+ """Generate thresholds and weights for the given range.
228
+
229
+ Args:
230
+ ----
231
+ t_min (float): Minimum threshold.
232
+ t_max (float): Maximum threshold.
233
+ N (int): Number of thresholds.
234
+
235
+ Returns:
236
+ -------
237
+ Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
238
+
239
+ """
240
+ thresholds = np.linspace(t_min, t_max, N)
241
+ weights = thresholds / thresholds.sum()
242
+ return thresholds, weights
243
+
244
+
245
+ def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
246
+ """Inverts a depth map with numerical stability.
247
+
248
+ Args:
249
+ ----
250
+ depth (np.ndarray): Depth map to be inverted.
251
+ eps (float): Minimum value to avoid division by zero (default is 1e-6).
252
+
253
+ Returns:
254
+ -------
255
+ np.ndarray: Inverted depth map.
256
+
257
+ """
258
+ inverse_depth = 1.0 / depth.clip(min=eps)
259
+ return inverse_depth
260
+
261
+
262
+ def SI_boundary_F1(
263
+ predicted_depth: np.ndarray,
264
+ target_depth: np.ndarray,
265
+ t_min: float = 1.05,
266
+ t_max: float = 1.25,
267
+ N: int = 10,
268
+ ) -> float:
269
+ """Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
270
+
271
+ Args:
272
+ ----
273
+ predicted_depth (np.ndarray): Predicted depth matrix.
274
+ target_depth (np.ndarray): Ground truth depth matrix.
275
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
276
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
277
+ N (int, optional): Number of thresholds. Defaults to 10.
278
+
279
+ Returns:
280
+ -------
281
+ float: Scale-Invariant Boundary F1 Score.
282
+
283
+ """
284
+ assert predicted_depth.ndim == target_depth.ndim == 2
285
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
286
+ f1_scores = np.array(
287
+ [
288
+ boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
289
+ for t in thresholds
290
+ ]
291
+ )
292
+ return np.sum(f1_scores * weights)
293
+
294
+
295
+ def SI_boundary_Recall(
296
+ predicted_depth: np.ndarray,
297
+ target_mask: np.ndarray,
298
+ t_min: float = 1.05,
299
+ t_max: float = 1.25,
300
+ N: int = 10,
301
+ alpha_threshold: float = 0.1,
302
+ ) -> float:
303
+ """Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
304
+
305
+ Args:
306
+ ----
307
+ predicted_depth (np.ndarray): Predicted depth matrix.
308
+ target_mask (np.ndarray): Ground truth binary mask.
309
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
310
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
311
+ N (int, optional): Number of thresholds. Defaults to 10.
312
+ alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
313
+
314
+ Returns:
315
+ -------
316
+ float: Scale-Invariant Boundary Recall Score.
317
+
318
+ """
319
+ assert predicted_depth.ndim == target_mask.ndim == 2
320
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
321
+ thresholded_target = target_mask > alpha_threshold
322
+
323
+ recall_scores = np.array(
324
+ [
325
+ edge_recall_matting(
326
+ invert_depth(predicted_depth), thresholded_target, t=float(t)
327
+ )
328
+ for t in thresholds
329
+ ]
330
+ )
331
+ weighted_recall = np.sum(recall_scores * weights)
332
+ return weighted_recall
DepthMaster/src/util/build_mlp.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ def build_mlp_(hidden_size=640, projector_dim=1024, z_dim=768):
4
+ return nn.Sequential(
5
+ nn.Linear(hidden_size, projector_dim),
6
+ nn.SiLU(),
7
+ nn.Linear(projector_dim, projector_dim),
8
+ nn.SiLU(),
9
+ nn.Linear(projector_dim, z_dim),
10
+ )
DepthMaster/src/util/config_util.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import omegaconf
26
+ from omegaconf import OmegaConf
27
+
28
+
29
+ def recursive_load_config(config_path: str) -> OmegaConf:
30
+ conf = OmegaConf.load(config_path)
31
+
32
+ output_conf = OmegaConf.create({})
33
+
34
+ # Load base config. Later configs on the list will overwrite previous
35
+ base_configs = conf.get("base_config", default_value=None)
36
+ if base_configs is not None:
37
+ assert isinstance(base_configs, omegaconf.listconfig.ListConfig)
38
+ for _path in base_configs:
39
+ assert (
40
+ _path != config_path
41
+ ), "Circulate merging, base_config should not include itself."
42
+ _base_conf = recursive_load_config(_path)
43
+ output_conf = OmegaConf.merge(output_conf, _base_conf)
44
+
45
+ # Merge configs and overwrite values
46
+ output_conf = OmegaConf.merge(output_conf, conf)
47
+
48
+ return output_conf
49
+
50
+
51
+ def find_value_in_omegaconf(search_key, config):
52
+ result_list = []
53
+
54
+ if isinstance(config, omegaconf.DictConfig):
55
+ for key, value in config.items():
56
+ if key == search_key:
57
+ result_list.append(value)
58
+ elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)):
59
+ result_list.extend(find_value_in_omegaconf(search_key, value))
60
+ elif isinstance(config, omegaconf.ListConfig):
61
+ for item in config:
62
+ if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)):
63
+ result_list.extend(find_value_in_omegaconf(search_key, item))
64
+
65
+ return result_list
66
+
67
+
68
+ if "__main__" == __name__:
69
+ conf = recursive_load_config("config/train_base.yaml")
70
+ print(OmegaConf.to_yaml(conf))
DepthMaster/src/util/data_loader.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py
2
+
3
+ from torch.utils.data import BatchSampler, DataLoader, IterableDataset
4
+
5
+ # kwargs of the DataLoader in min version 1.4.0.
6
+ _PYTORCH_DATALOADER_KWARGS = {
7
+ "batch_size": 1,
8
+ "shuffle": False,
9
+ "sampler": None,
10
+ "batch_sampler": None,
11
+ "num_workers": 0,
12
+ "collate_fn": None,
13
+ "pin_memory": False,
14
+ "drop_last": False,
15
+ "timeout": 0,
16
+ "worker_init_fn": None,
17
+ "multiprocessing_context": None,
18
+ "generator": None,
19
+ "prefetch_factor": 2,
20
+ "persistent_workers": False,
21
+ }
22
+
23
+
24
+ class SkipBatchSampler(BatchSampler):
25
+ """
26
+ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
27
+ """
28
+
29
+ def __init__(self, batch_sampler, skip_batches=0):
30
+ self.batch_sampler = batch_sampler
31
+ self.skip_batches = skip_batches
32
+
33
+ def __iter__(self):
34
+ for index, samples in enumerate(self.batch_sampler):
35
+ if index >= self.skip_batches:
36
+ yield samples
37
+
38
+ @property
39
+ def total_length(self):
40
+ return len(self.batch_sampler)
41
+
42
+ def __len__(self):
43
+ return len(self.batch_sampler) - self.skip_batches
44
+
45
+
46
+ class SkipDataLoader(DataLoader):
47
+ """
48
+ Subclass of a PyTorch `DataLoader` that will skip the first batches.
49
+
50
+ Args:
51
+ dataset (`torch.utils.data.dataset.Dataset`):
52
+ The dataset to use to build this datalaoder.
53
+ skip_batches (`int`, *optional*, defaults to 0):
54
+ The number of batches to skip at the beginning.
55
+ kwargs:
56
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
57
+ """
58
+
59
+ def __init__(self, dataset, skip_batches=0, **kwargs):
60
+ super().__init__(dataset, **kwargs)
61
+ self.skip_batches = skip_batches
62
+
63
+ def __iter__(self):
64
+ for index, batch in enumerate(super().__iter__()):
65
+ if index >= self.skip_batches:
66
+ yield batch
67
+
68
+
69
+ # Adapted from https://github.com/huggingface/accelerate
70
+ def skip_first_batches(dataloader, num_batches=0):
71
+ """
72
+ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
73
+ """
74
+ dataset = dataloader.dataset
75
+ sampler_is_batch_sampler = False
76
+ if isinstance(dataset, IterableDataset):
77
+ new_batch_sampler = None
78
+ else:
79
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
80
+ batch_sampler = (
81
+ dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
82
+ )
83
+ new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
84
+
85
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
86
+ ignore_kwargs = [
87
+ "batch_size",
88
+ "shuffle",
89
+ "sampler",
90
+ "batch_sampler",
91
+ "drop_last",
92
+ ]
93
+
94
+ kwargs = {
95
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
96
+ for k in _PYTORCH_DATALOADER_KWARGS
97
+ if k not in ignore_kwargs
98
+ }
99
+
100
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
101
+ if new_batch_sampler is None:
102
+ kwargs["drop_last"] = dataloader.drop_last
103
+ kwargs["batch_size"] = dataloader.batch_size
104
+
105
+ if new_batch_sampler is None:
106
+ # Need to manually skip batches in the dataloader
107
+ dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
108
+ else:
109
+ dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
110
+
111
+ return dataloader
DepthMaster/src/util/depth_transform.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+ import torch
26
+ import logging
27
+
28
+
29
+ def get_depth_normalizer(cfg_normalizer):
30
+ if cfg_normalizer is None:
31
+
32
+ def identical(x):
33
+ return x
34
+
35
+ depth_transform = identical
36
+
37
+ elif "scale_shift_depth" == cfg_normalizer.type:
38
+ depth_transform = ScaleShiftDepthNormalizer(
39
+ norm_min=cfg_normalizer.norm_min,
40
+ norm_max=cfg_normalizer.norm_max,
41
+ min_max_quantile=cfg_normalizer.min_max_quantile,
42
+ clip=cfg_normalizer.clip,
43
+ )
44
+ else:
45
+ raise NotImplementedError
46
+ return depth_transform
47
+
48
+
49
+ class DepthNormalizerBase:
50
+ is_absolute = None
51
+ far_plane_at_max = None
52
+
53
+ def __init__(
54
+ self,
55
+ norm_min=-1.0,
56
+ norm_max=1.0,
57
+ ) -> None:
58
+ self.norm_min = norm_min
59
+ self.norm_max = norm_max
60
+ raise NotImplementedError
61
+
62
+ def __call__(self, depth, valid_mask=None, clip=None):
63
+ raise NotImplementedError
64
+
65
+ def denormalize(self, depth_norm, **kwargs):
66
+ # For metric depth: convert prediction back to metric depth
67
+ # For relative depth: convert prediction to [0, 1]
68
+ raise NotImplementedError
69
+
70
+
71
+ class ScaleShiftDepthNormalizer(DepthNormalizerBase):
72
+ """
73
+ Use near and far plane to linearly normalize depth,
74
+ i.e. d' = d * s + t,
75
+ where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max`
76
+ Near and far planes are determined by taking quantile values.
77
+ """
78
+
79
+ is_absolute = False
80
+ far_plane_at_max = True
81
+
82
+ def __init__(
83
+ self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True
84
+ ) -> None:
85
+ self.norm_min = norm_min
86
+ self.norm_max = norm_max
87
+ self.norm_range = self.norm_max - self.norm_min
88
+ self.min_quantile = min_max_quantile
89
+ self.max_quantile = 1.0 - self.min_quantile
90
+ self.clip = clip
91
+
92
+ def __call__(self, depth_linear, valid_mask=None, clip=None):
93
+ clip = clip if clip is not None else self.clip
94
+
95
+ if valid_mask is None:
96
+ valid_mask = torch.ones_like(depth_linear).bool()
97
+ valid_mask = valid_mask & (depth_linear > 0)
98
+
99
+ # Take quantiles as min and max
100
+ _min, _max = torch.quantile(
101
+ depth_linear[valid_mask],
102
+ torch.tensor([self.min_quantile, self.max_quantile]),
103
+ )
104
+
105
+ # scale and shift
106
+ depth_norm_linear = (depth_linear - _min) / (
107
+ _max - _min
108
+ ) * self.norm_range + self.norm_min
109
+
110
+ if clip:
111
+ depth_norm_linear = torch.clip(
112
+ depth_norm_linear, self.norm_min, self.norm_max
113
+ )
114
+
115
+ return depth_norm_linear
116
+
117
+ def scale_back(self, depth_norm):
118
+ # scale to [0, 1]
119
+ depth_linear = (depth_norm - self.norm_min) / self.norm_range
120
+ return depth_linear
121
+
122
+ def denormalize(self, depth_norm, **kwargs):
123
+ logging.warning(f"{self.__class__} is not revertible without GT")
124
+ return self.scale_back(depth_norm=depth_norm)