Xuweiyi commited on
Commit
c7b663e
·
verified ·
1 Parent(s): 4c3c562

Initial SAB3R demo release

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. .gitignore +148 -0
  3. LICENSE +7 -0
  4. README.md +185 -6
  5. app.py +18 -0
  6. assets/network_architecture.png +3 -0
  7. assets/qualitative_2.jpg +3 -0
  8. assets/teaser.jpg +3 -0
  9. assets/teaser_v5.jpg +3 -0
  10. config/deepspeed.json +38 -0
  11. config/training_config.yaml +55 -0
  12. config/training_config_full.yaml +57 -0
  13. demo.py +118 -0
  14. docker/docker-compose-cpu.yml +16 -0
  15. docker/docker-compose-cuda.yml +23 -0
  16. docker/files/cpu.Dockerfile +39 -0
  17. docker/files/cuda.Dockerfile +29 -0
  18. docker/files/entrypoint.sh +8 -0
  19. docker/run.sh +68 -0
  20. dust3r/.gitignore +132 -0
  21. dust3r/LICENSE +7 -0
  22. dust3r/NOTICE +12 -0
  23. dust3r/README.md +390 -0
  24. dust3r/assets/demo.jpg +3 -0
  25. dust3r/assets/dust3r.jpg +0 -0
  26. dust3r/assets/dust3r_archi.jpg +0 -0
  27. dust3r/assets/matching.jpg +3 -0
  28. dust3r/assets/pipeline1.jpg +0 -0
  29. dust3r/croco/LICENSE +52 -0
  30. dust3r/croco/NOTICE +21 -0
  31. dust3r/croco/README.MD +124 -0
  32. dust3r/croco/assets/Chateau1.png +3 -0
  33. dust3r/croco/assets/Chateau2.png +3 -0
  34. dust3r/croco/assets/arch.jpg +0 -0
  35. dust3r/croco/croco-stereo-flow-demo.ipynb +191 -0
  36. dust3r/croco/datasets_croco/__init__.py +0 -0
  37. dust3r/croco/datasets_croco/crops/README.MD +104 -0
  38. dust3r/croco/datasets_croco/crops/extract_crops_from_images.py +159 -0
  39. dust3r/croco/datasets_croco/habitat_sim/README.MD +76 -0
  40. dust3r/croco/datasets_croco/habitat_sim/__init__.py +0 -0
  41. dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata.py +92 -0
  42. dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata_files.py +27 -0
  43. dust3r/croco/datasets_croco/habitat_sim/generate_multiview_images.py +177 -0
  44. dust3r/croco/datasets_croco/habitat_sim/multiview_habitat_sim_generator.py +390 -0
  45. dust3r/croco/datasets_croco/habitat_sim/pack_metadata_files.py +69 -0
  46. dust3r/croco/datasets_croco/habitat_sim/paths.py +129 -0
  47. dust3r/croco/datasets_croco/pairs_dataset.py +109 -0
  48. dust3r/croco/datasets_croco/transforms.py +95 -0
  49. dust3r/croco/demo.py +55 -0
  50. dust3r/croco/interactive_demo.ipynb +271 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/network_architecture.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/qualitative_2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/teaser.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/teaser_v5.jpg filter=lfs diff=lfs merge=lfs -text
40
+ dust3r/assets/demo.jpg filter=lfs diff=lfs merge=lfs -text
41
+ dust3r/assets/matching.jpg filter=lfs diff=lfs merge=lfs -text
42
+ dust3r/croco/assets/Chateau1.png filter=lfs diff=lfs merge=lfs -text
43
+ dust3r/croco/assets/Chateau2.png filter=lfs diff=lfs merge=lfs -text
44
+ dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-311/.ninja_deps filter=lfs diff=lfs merge=lfs -text
45
+ dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-311/curope.o filter=lfs diff=lfs merge=lfs -text
46
+ dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-311/kernels.o filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *gl-outputs/
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ launch_script/output/
7
+ preprocess_data/
8
+ images
9
+ launch_script/
10
+ paper/
11
+ .gradio/
12
+ semantic_extraction_ade20k*
13
+ semantic_extraction_voc*
14
+ semantic_seg_ade20k_clip*
15
+ zero_shot_semantic_seg_ade20k_inference*
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ pip-wheel-metadata/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+ db.sqlite3-journal
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105
+ __pypackages__/
106
+
107
+ # Celery stuff
108
+ celerybeat-schedule
109
+ celerybeat.pid
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
140
+
141
+ # data
142
+ data/
143
+ # checkpoints
144
+ checkpoints/
145
+ # wandb
146
+ wandb/
147
+ # outputs
148
+ outputs/
LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
README.md CHANGED
@@ -1,13 +1,192 @@
1
  ---
2
  title: SAB3R
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.13.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: SAB3R
3
+ emoji: 🌐
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.1
 
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-nc-sa-4.0
11
  ---
12
 
13
+ # SAB3R: Semantic-Augmented Backbone in 3D Reconstruction
14
+
15
+ <div align="center">
16
+
17
+ **3D-LLM/VLA Workshop @ CVPR 2025**
18
+
19
+ [**Xuweiyi Chen**](https://xuweiyichen.github.io/)<sup>*,1</sup> · [**Tian Xia**](https://tianx-ia.github.io/)<sup>*,2</sup> · [**Sihan Xu**](https://sihanxu.github.io/)<sup>2</sup> · [**Jed Jianing Yang**](https://jedyang.com/)<sup>2</sup> · [**Joyce Chai**](https://web.eecs.umich.edu/~chaijy/)<sup>2</sup> · [**Zezhou Cheng**](https://sites.google.com/site/zezhoucheng/)<sup>1</sup>
20
+
21
+ <sup>1</sup>University of Virginia · <sup>2</sup>University of Michigan
22
+
23
+ <sup>*</sup>Denotes Equal Contribution
24
+
25
+ ---
26
+
27
+ [![Paper](https://img.shields.io/badge/arXiv-2506.02112-b31b1b.svg)](https://www.arxiv.org/abs/2506.02112)
28
+ [![Project Page](https://img.shields.io/badge/Project-Page-orange)](https://uva-computer-vision-lab.github.io/sab3r/)
29
+ [![Demo](https://img.shields.io/badge/🤗-Live%20Demo-yellow)](https://huggingface.co/spaces/uva-cv-lab/SAB3R)
30
+ [![Data](https://img.shields.io/badge/Dataset-Coming%20Soon-blue)](#)
31
+ [![Code](https://img.shields.io/badge/GitHub-Code-green)](https://github.com/uva-computer-vision-lab/sab3r)
32
+
33
+ </div>
34
+
35
+ ---
36
+
37
+ ![SAB3R Teaser](./assets/teaser_v5.jpg)
38
+
39
+ *Given an unposed input video, we show ground truth for: open-vocabulary semantic segmentation (per-pixel labels for the prompt "a black office chair"), 3D reconstruction (ground-truth point cloud), and the proposed **Map and Locate** task (open-vocabulary segmentation and point cloud). The Map and Locate task: (1) encompasses both 2D and 3D tasks, (2) bridges reconstruction and recognition, and (3) introduces practical questions in robotics and embodied AI.*
40
+
41
+ ## Release Plan
42
+
43
+ - [x] Demo Release
44
+ - [x] Training and Inference Code Release
45
+ - [ ] Release Map and Locate Dataset
46
+
47
+ ## Abstract
48
+
49
+ We introduce **Map and Locate**, a task that unifies open-vocabulary segmentation and 3D reconstruction from unposed videos. Our method, **SAB3R**, builds upon MASt3R and incorporates lightweight distillation from CLIP and DINOv2 to generate semantic point clouds in a single forward pass. SAB3R achieves superior performance compared to separate deployment of MASt3R and CLIP on the Map and Locate benchmark.
50
+
51
+ ## Network Architecture
52
+
53
+ ![SAB3R Architecture](assets/network_architecture.png)
54
+
55
+ **SAB3R** distills dense features from CLIP and DINO into the MASt3R framework, enriching it with 2D semantic understanding. Each encoder-decoder pair operates on multi-view images, sharing weights and exchanging information to ensure consistent feature extraction across views. The model simultaneously generates depth, dense DINOv2, and dense CLIP features, which are then used for multi-view 3D reconstruction and semantic segmentation. This architecture enables SAB3R to seamlessly integrate 2D and 3D representations, achieving both geometric and semantic comprehension in a unified model.
56
+
57
+ ## Installation
58
+
59
+ 1. Clone the repository:
60
+ ```bash
61
+ git clone --recursive https://github.com/uva-computer-vision-lab/sab3r
62
+ cd sab3r
63
+ # if you have already cloned:
64
+ # git submodule update --init --recursive
65
+ ```
66
+
67
+ 2. Create the environment:
68
+ ```bash
69
+ conda create -n sab3r python=3.11 cmake=3.14.0
70
+ conda activate sab3r
71
+ conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
72
+ pip install -r requirements.txt
73
+
74
+ # FeatUp (not on PyPI) — required for the CLIP/DINO semantic heads.
75
+ pip install git+https://github.com/mhamilton723/FeatUp
76
+ ```
77
+
78
+ 3. (Optional) Compile RoPE CUDA kernels for faster inference:
79
+ ```bash
80
+ cd dust3r/croco/models/curope/
81
+ python setup.py build_ext --inplace
82
+ cd ../../../../
83
+ ```
84
+
85
+ 4. (Optional) Pre-download the CLIP BPE vocab (the demo will fetch it on first run):
86
+ ```bash
87
+ mkdir -p ~/.cache/clip
88
+ cd ~/.cache/clip
89
+ wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz
90
+ ```
91
+
92
+ ## Demo
93
+
94
+ The demo launches a Gradio UI for 3D reconstruction + open-vocabulary text queries.
95
+
96
+ **Checkpoint from HF Hub (default)**
97
+ ```bash
98
+ python demo.py \
99
+ --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric \
100
+ --local_network --share
101
+ ```
102
+ This downloads `demo_ckpt/base/base.pt` from [`uva-cv-lab/SAB3R`](https://huggingface.co/uva-cv-lab/SAB3R) on first launch and caches it in `~/.cache/huggingface/`.
103
+
104
+ **Local checkpoint**
105
+ ```bash
106
+ python demo.py \
107
+ --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric \
108
+ --weights /path/to/your.pt \
109
+ --local_network --share
110
+ ```
111
+
112
+ **Override the HF Hub repo / filename**
113
+ ```bash
114
+ python demo.py \
115
+ --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric \
116
+ --model_repo your-org/your-sab3r-ckpt \
117
+ --ckpt_filename model.pt
118
+ ```
119
+
120
+ **Local dev with a checkpoint dropdown** — if you keep multiple checkpoints under a directory (one sub-directory per checkpoint, each holding `<name>.pt`), pass `--checkpoint_dir`:
121
+ ```bash
122
+ python demo.py --checkpoint_dir /path/to/ckpt_root --local_network --share
123
+ ```
124
+
125
+ ## Hugging Face Spaces
126
+
127
+ This repo ships with an `app.py` entry point for Hugging Face Spaces. To deploy:
128
+
129
+ 1. Create a Gradio Space at https://huggingface.co/new-space (SDK: `gradio`).
130
+ 2. Upload the SAB3R checkpoint to a model repo (default expected: `uva-cv-lab/SAB3R`, filename `base.pt`). To use a different repo, set the `SAB3R_MODEL_REPO` / `SAB3R_CKPT_FILENAME` env vars on the Space, or edit `demo.py`'s defaults.
131
+ 3. Push this repo (minus the `env/` conda env and the heavyweight submodule binaries) to the Space. The Space's `README.md` must begin with the YAML frontmatter below; add it to your Space copy of this README before pushing:
132
+
133
+ ```yaml
134
+ ---
135
+ title: SAB3R
136
+ emoji: 🌐
137
+ colorFrom: blue
138
+ colorTo: green
139
+ sdk: gradio
140
+ sdk_version: 4.44.1
141
+ app_file: app.py
142
+ pinned: false
143
+ license: cc-by-nc-sa-4.0
144
+ ---
145
+ ```
146
+
147
+ FeatUp needs to be installed from GitHub, which Spaces handles via `requirements.txt` — the one in this repo already leaves a comment pointing at the install line; uncomment it for the Space:
148
+
149
+ ```txt
150
+ git+https://github.com/mhamilton723/FeatUp
151
+ ```
152
+
153
+ ## Training
154
+
155
+ Two canonical configs are provided under `config/`:
156
+
157
+ - `training_config.yaml` — minimal dev recipe (CLIP distillation on a Co3D subset).
158
+ - `training_config_full.yaml` — full paper recipe (CLIP + DINO distillation on Habitat + ScanNet++ + ARKitScenes + Co3D).
159
+
160
+ Both reference paths relative to the repo root (e.g. `./data`, `./checkpoints`, `./outputs`); override them via Hydra:
161
+
162
+ ```bash
163
+ torchrun --nproc_per_node=8 train.py \
164
+ --config-name training_config_full \
165
+ dataset_url=/path/to/data \
166
+ output_url=/path/to/outputs
167
+ ```
168
+
169
+ Set `WANDB_API_KEY` in your shell (do **not** commit it) if you want experiment tracking.
170
+
171
+ ## Citation
172
+
173
+ ```bibtex
174
+ @article{chen2025sab3rsemanticaugmentedbackbone3d,
175
+ title={SAB3R: Semantic-Augmented Backbone in 3D Reconstruction},
176
+ author={Xuweiyi Chen and Tian Xia and Sihan Xu and Jianing Yang and Joyce Chai and Zezhou Cheng},
177
+ year={2025},
178
+ eprint={2506.02112},
179
+ archivePrefix={arXiv},
180
+ primaryClass={cs.CV},
181
+ url={https://arxiv.org/abs/2506.02112},
182
+ }
183
+ ```
184
+
185
+ ## Acknowledgments
186
+
187
+ This work builds upon [MASt3R](https://github.com/naver/mast3r), [DUSt3R](https://github.com/naver/dust3r) and [FeatUp](https://github.com/mhamilton723/FeatUp). We thank the original authors for their excellent work and open-source contributions.
188
+
189
+ ## License
190
+
191
+ The code is distributed under the CC BY-NC-SA 4.0 License.
192
+ See [LICENSE](LICENSE) for more information.
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # --------------------------------------------------------
3
+ # Hugging Face Spaces entry point for the SAB3R demo.
4
+ #
5
+ # Spaces looks for `app.py` by default. This wrapper just sets
6
+ # Spaces-appropriate defaults and delegates to demo.py. For local
7
+ # dev, run demo.py directly (it exposes --share, --local_network, etc).
8
+ # --------------------------------------------------------
9
+ import os
10
+
11
+ from demo import main
12
+
13
+ if __name__ == "__main__":
14
+ main([
15
+ "--model_name", "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric",
16
+ "--server_name", "0.0.0.0",
17
+ "--server_port", os.environ.get("GRADIO_SERVER_PORT", "7860"),
18
+ ])
assets/network_architecture.png ADDED

Git LFS Details

  • SHA256: cc03b7cf5029a5a635105e68ae53058daefcb9686e2006023cbaaa1e3991db2a
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
assets/qualitative_2.jpg ADDED

Git LFS Details

  • SHA256: f3a0fabcabff6320ed89f0388f1133e248c620b7d45ceb933b7fdd445bc6c450
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
assets/teaser.jpg ADDED

Git LFS Details

  • SHA256: 50a04659d29edb419d46ddc836330e76bb8106f60eae79657972ad9cad439f55
  • Pointer size: 132 Bytes
  • Size of remote file: 2.08 MB
assets/teaser_v5.jpg ADDED

Git LFS Details

  • SHA256: c58c2ebb4a8e0b20bfb2e9a9a02f1d8660e650c2f667c89c3d508cda1034a521
  • Pointer size: 131 Bytes
  • Size of remote file: 369 kB
config/deepspeed.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": 1,
3
+ "gradient_accumulation_steps": 8,
4
+ "steps_per_print": 1,
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": 1e-4,
9
+ "betas": [0.9, 0.95],
10
+ "weight_decay": 0.05
11
+ }
12
+ },
13
+ "scheduler": {
14
+ "type": "WarmupLR",
15
+ "params": {
16
+ "warmup_min_lr": 0.0,
17
+ "warmup_max_lr": 1e-4,
18
+ "warmup_num_steps": 4000
19
+ }
20
+ },
21
+ "gradient_clipping": 1.0,
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "offload_optimizer": {
25
+ "device": "cpu",
26
+ "pin_memory": true
27
+ },
28
+ "overlap_comm": true,
29
+ "contiguous_gradients": true,
30
+ "reduce_bucket_size": 5e7
31
+ },
32
+ "activation_checkpointing": {
33
+ "partition_activations": true,
34
+ "contiguous_memory_optimization": true,
35
+ "cpu_checkpointing": true,
36
+ "number_checkpoints": 2
37
+ }
38
+ }
config/training_config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 8
2
+
3
+ # Override these paths via hydra (e.g. `python train.py dataset_url=/path/to/data`)
4
+ # or by editing this file.
5
+ dataset_url: "./data"
6
+ code_url: "."
7
+ output_url: "./outputs"
8
+
9
+ model:
10
+ name: "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', clip_head_type='dpt', dino_head_type=None, depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True)"
11
+ pretrained: "checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
12
+
13
+
14
+ clip_checkpoint_weights_path: "checkpoints/maskclip_jbu_stack_cocostuff.ckpt"
15
+ dino_checkpoint_weights_path: "checkpoints/dinov2_jbu_stack_cocostuff.ckpt"
16
+
17
+ train_criterion: "ConfFeatLoss(Regr3D(L21, norm_mode='?avg_dis'), alpha=0.2, beta=0.4, gamma=0.4, need_clip=True, need_dino=False, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + 0.075*ConfMatchingLoss(MatchingLoss(InfoNCE(mode='proper', temperature=0.05), negatives_padding=0, blocksize=8192), alpha=10.0, confmode='mean')"
18
+ test_criterion: "FeatRegr3D_ScaleShiftInv(L21, norm_mode='?avg_dis', need_clip=True, need_dino=False, gt_scale=True, sky_loss_value=0, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + -1.*MatchingLoss(APLoss(nq='torch', fp=torch.float16), negatives_padding=12288)"
19
+
20
+
21
+ dataset:
22
+ train: "1000 @ Co3d(split='train', ROOT='${dataset_url}/co3d_subset_processed', aug_crop='auto', aug_monocular=0.005, aug_rot90='diff', mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], n_corres=8192, nneg=0.5, transform=ColorJitter)"
23
+ test: "100 @ Co3d(split='test', ROOT='${dataset_url}/co3d_subset_processed', resolution=(512,384), n_corres=1024, seed=777)"
24
+
25
+ training:
26
+ epochs: 200
27
+ disable_cudnn_benchmark: true
28
+ print_freq: 1
29
+
30
+ saving:
31
+ save_freq: 10
32
+ keep_freq: 40
33
+ eval_freq: 10
34
+
35
+ output_dir: "${output_url}/sab3r"
36
+
37
+ wandb:
38
+ project_name: "sab3r"
39
+ entity: "uva-computer-vision-lab"
40
+ group: "sab3r"
41
+
42
+ num_workers: 8
43
+ dist_url: "env://"
44
+ distributed: True
45
+
46
+ world_size: -1
47
+ gpu: -1
48
+ rank: -1
49
+ dist_backend: "nccl"
50
+
51
+ resume: ""
52
+ start_epoch: 0
53
+ disable_cudnn_benchmark: True
54
+ amp: 1
55
+ deepspeed_config: "${code_url}/config/deepspeed.json"
config/training_config_full.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 8
2
+
3
+ # Full SAB3R training recipe (CLIP + DINO distillation on Habitat + ScanNet++ + ARKitScenes + Co3D).
4
+ # Override these paths via hydra (e.g. `python train.py --config-name training_config_full dataset_url=/path/to/data`)
5
+ # or by editing this file.
6
+ dataset_url: "./data"
7
+ code_url: "."
8
+ output_url: "./outputs"
9
+
10
+ model:
11
+ name: "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', clip_head_type='dpt', dino_head_type='dpt', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True)"
12
+ pretrained: "checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
13
+
14
+
15
+ clip_checkpoint_weights_path: "checkpoints/maskclip_jbu_stack_cocostuff.ckpt"
16
+ dino_checkpoint_weights_path: "checkpoints/dinov2_jbu_stack_cocostuff.ckpt"
17
+
18
+ train_criterion: "ConfFeatLoss(Regr3D(L21, norm_mode='?avg_dis'), alpha=1, beta=20, gamma=4, need_clip=True, need_dino=True, need_mask=False, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + 0.75*ConfMatchingLoss(MatchingLoss(InfoNCE(mode='proper', temperature=0.05), negatives_padding=0, blocksize=8192), alpha=10.0, confmode='mean')"
19
+ test_criterion: "FeatRegr3D_ScaleShiftInv(L21, norm_mode='?avg_dis', need_clip=True, need_dino=True, gt_scale=True, sky_loss_value=0, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + -1.*MatchingLoss(APLoss(nq='torch', fp=torch.float16), negatives_padding=12288)"
20
+
21
+ dataset:
22
+ train: "57_000 @ Habitat512(1_000_000, split='train', ROOT='${dataset_url}/habitat_processed', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5) + 45_600 @ ScanNetpp(split='train', ROOT='${dataset_url}/scannetpp_processed', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5) + 45_600 @ ARKitScenes(split='train', ROOT='${dataset_url}/arkitscenes_processed', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5) + 22_800 @ Co3d(split='train', ROOT='${dataset_url}/co3d_50_seqs_per_category_subset_processed', aug_crop='auto', aug_monocular=0.005, aug_rot90='diff', mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], n_corres=8192, nneg=0.5, transform=ColorJitter)"
23
+ test: "100 @ Co3d(split='test', ROOT='${dataset_url}/co3d_50_seqs_per_category_subset_processed', resolution=(512,384), n_corres=1024, seed=777)"
24
+
25
+ training:
26
+ epochs: 5
27
+ disable_cudnn_benchmark: true
28
+ print_freq: 1
29
+
30
+ saving:
31
+ save_freq: 1
32
+ keep_freq: 1
33
+ eval_freq: 1
34
+ save_steps: 5000
35
+ plot_steps: 200
36
+
37
+ output_dir: "${output_url}/sab3r_full"
38
+
39
+ wandb:
40
+ project_name: "sab3r"
41
+ entity: "uva-computer-vision-lab"
42
+ group: "sab3r_full"
43
+
44
+ num_workers: 8
45
+ dist_url: "env://"
46
+ distributed: True
47
+
48
+ world_size: -1
49
+ gpu: -1
50
+ rank: -1
51
+ dist_backend: "nccl"
52
+
53
+ resume: ""
54
+ start_epoch: 0
55
+ disable_cudnn_benchmark: True
56
+ amp: 0
57
+ deepspeed_config: "${code_url}/config/deepspeed.json"
demo.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # SAB3R gradio demo executable
7
+ # --------------------------------------------------------
8
+ import os
9
+ import argparse
10
+ import tempfile
11
+ from contextlib import nullcontext
12
+
13
+ import torch
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ from mast3r.demo import get_args_parser as sab3r_get_args_parser, main_demo
17
+ from mast3r.model import AsymmetricMASt3R # noqa: F401 (referenced via eval() below)
18
+ from mast3r.utils.misc import hash_md5
19
+
20
+ import mast3r.utils.path_to_dust3r # noqa: F401 (side-effect: puts vendored dust3r on sys.path)
21
+ from dust3r.demo import set_print_with_timestamp
22
+
23
+ import matplotlib.pyplot as pl
24
+ pl.ion()
25
+
26
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
27
+
28
+ inf = float("inf")
29
+
30
+ DEFAULT_MODEL_REPO = "uva-cv-lab/SAB3R"
31
+ DEFAULT_CKPT_FILENAME = "demo_ckpt/base/base.pt"
32
+
33
+
34
+ def get_args_parser():
35
+ parser = sab3r_get_args_parser()
36
+ parser.add_argument(
37
+ "--model_repo",
38
+ default=os.environ.get("SAB3R_MODEL_REPO", DEFAULT_MODEL_REPO),
39
+ help="Hugging Face Hub repo id hosting the SAB3R checkpoint "
40
+ "(used only when --weights is not provided).",
41
+ )
42
+ parser.add_argument(
43
+ "--ckpt_filename",
44
+ default=os.environ.get("SAB3R_CKPT_FILENAME", DEFAULT_CKPT_FILENAME),
45
+ help="Checkpoint filename inside --model_repo.",
46
+ )
47
+ parser.add_argument(
48
+ "--checkpoint_dir",
49
+ default=os.environ.get("SAB3R_CHECKPOINT_DIR", None),
50
+ help="Optional local directory containing one sub-directory per "
51
+ "checkpoint (each sub-dir must hold `<name>.pt`). When provided, "
52
+ "the UI exposes a dropdown to switch between them. Useful for "
53
+ "local dev; leave unset for single-checkpoint HF Spaces deployments.",
54
+ )
55
+ return parser
56
+
57
+
58
+ def load_weights(model, ckp_path, device):
59
+ ckp = torch.load(ckp_path, map_location='cpu')
60
+ if ckp_path.endswith('.pth'):
61
+ model.load_state_dict(ckp['model'], strict=False)
62
+ elif ckp_path.endswith('.pt'):
63
+ model.load_state_dict(ckp['module'])
64
+ else:
65
+ raise ValueError(f"Unknown checkpoint format: {ckp_path}")
66
+ model.to(device)
67
+
68
+
69
+ def build_model_config():
70
+ return (
71
+ "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', "
72
+ "img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', "
73
+ "clip_head_type='dpt', dino_head_type='dpt', "
74
+ "depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), "
75
+ "enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, "
76
+ "dec_embed_dim=768, dec_depth=12, dec_num_heads=12, "
77
+ "two_confs=True, landscape_only=False)"
78
+ )
79
+
80
+
81
+ def resolve_weights_path(args: argparse.Namespace) -> str:
82
+ if args.weights:
83
+ return args.weights
84
+ print(f"[sab3r] Downloading checkpoint from HF Hub: {args.model_repo}/{args.ckpt_filename}")
85
+ return hf_hub_download(repo_id=args.model_repo, filename=args.ckpt_filename)
86
+
87
+
88
+ def main(argv=None):
89
+ parser = get_args_parser()
90
+ args = parser.parse_args(argv)
91
+ set_print_with_timestamp()
92
+
93
+ if args.server_name is not None:
94
+ server_name = args.server_name
95
+ else:
96
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
97
+
98
+ model = eval(build_model_config())
99
+ ckp_path = resolve_weights_path(args)
100
+ load_weights(model, ckp_path, args.device)
101
+ chkpt_tag = hash_md5(ckp_path)
102
+
103
+ def get_context(tmp_dir):
104
+ return (tempfile.TemporaryDirectory(suffix='_sab3r_gradio_demo') if tmp_dir is None
105
+ else nullcontext(tmp_dir))
106
+
107
+ with get_context(args.tmp_dir) as tmpdirname:
108
+ cache_path = os.path.join(tmpdirname, chkpt_tag)
109
+ os.makedirs(cache_path, exist_ok=True)
110
+ main_demo(
111
+ cache_path, model, args.device, args.image_size, server_name, args.server_port,
112
+ silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache,
113
+ checkpoint_dir=args.checkpoint_dir,
114
+ )
115
+
116
+
117
+ if __name__ == '__main__':
118
+ main()
docker/docker-compose-cpu.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+ services:
3
+ mast3r-demo:
4
+ build:
5
+ context: ./files
6
+ dockerfile: cpu.Dockerfile
7
+ ports:
8
+ - "7860:7860"
9
+ volumes:
10
+ - ./files/checkpoints:/mast3r/checkpoints
11
+ environment:
12
+ - DEVICE=cpu
13
+ - MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth}
14
+ cap_add:
15
+ - IPC_LOCK
16
+ - SYS_RESOURCE
docker/docker-compose-cuda.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+ services:
3
+ mast3r-demo:
4
+ build:
5
+ context: ./files
6
+ dockerfile: cuda.Dockerfile
7
+ ports:
8
+ - "7860:7860"
9
+ environment:
10
+ - DEVICE=cuda
11
+ - MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth}
12
+ volumes:
13
+ - ./files/checkpoints:/mast3r/checkpoints
14
+ cap_add:
15
+ - IPC_LOCK
16
+ - SYS_RESOURCE
17
+ deploy:
18
+ resources:
19
+ reservations:
20
+ devices:
21
+ - driver: nvidia
22
+ count: 1
23
+ capabilities: [gpu]
docker/files/cpu.Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ LABEL description="Docker container for MASt3R with dependencies installed. CPU VERSION"
4
+
5
+ ENV DEVICE="cpu"
6
+ ENV MODEL="MASt3R_ViTLarge_BaseDecoder_512_dpt.pth"
7
+ ARG DEBIAN_FRONTEND=noninteractive
8
+
9
+ RUN apt-get update && apt-get install -y \
10
+ git \
11
+ libgl1-mesa-glx \
12
+ libegl1-mesa \
13
+ libxrandr2 \
14
+ libxrandr2 \
15
+ libxss1 \
16
+ libxcursor1 \
17
+ libxcomposite1 \
18
+ libasound2 \
19
+ libxi6 \
20
+ libxtst6 \
21
+ libglib2.0-0 \
22
+ && apt-get clean \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ RUN git clone --recursive https://github.com/naver/mast3r /mast3r
26
+ WORKDIR /mast3r/dust3r
27
+
28
+ RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
29
+ RUN pip install -r requirements.txt
30
+ RUN pip install -r requirements_optional.txt
31
+ RUN pip install opencv-python==4.8.0.74
32
+
33
+ WORKDIR /mast3r
34
+ RUN pip install -r requirements.txt
35
+
36
+ COPY entrypoint.sh /entrypoint.sh
37
+ RUN chmod +x /entrypoint.sh
38
+
39
+ ENTRYPOINT ["/entrypoint.sh"]
docker/files/cuda.Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:24.01-py3
2
+
3
+ LABEL description="Docker container for MASt3R with dependencies installed. CUDA VERSION"
4
+ ENV DEVICE="cuda"
5
+ ENV MODEL="MASt3R_ViTLarge_BaseDecoder_512_dpt.pth"
6
+ ARG DEBIAN_FRONTEND=noninteractive
7
+
8
+ RUN apt-get update && apt-get install -y \
9
+ git=1:2.34.1-1ubuntu1.10 \
10
+ libglib2.0-0=2.72.4-0ubuntu2.2 \
11
+ && apt-get clean \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ RUN git clone --recursive https://github.com/naver/mast3r /mast3r
15
+ WORKDIR /mast3r/dust3r
16
+ RUN pip install -r requirements.txt
17
+ RUN pip install -r requirements_optional.txt
18
+ RUN pip install opencv-python==4.8.0.74
19
+
20
+ WORKDIR /mast3r/dust3r/croco/models/curope/
21
+ RUN python setup.py build_ext --inplace
22
+
23
+ WORKDIR /mast3r
24
+ RUN pip install -r requirements.txt
25
+
26
+ COPY entrypoint.sh /entrypoint.sh
27
+ RUN chmod +x /entrypoint.sh
28
+
29
+ ENTRYPOINT ["/entrypoint.sh"]
docker/files/entrypoint.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eux
4
+
5
+ DEVICE=${DEVICE:-cuda}
6
+ MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_dpt.pth}
7
+
8
+ exec python3 demo.py --weights "checkpoints/$MODEL" --device "$DEVICE" --local_network "$@"
docker/run.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eux
4
+
5
+ # Default model name
6
+ model_name="MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
7
+
8
+ check_docker() {
9
+ if ! command -v docker &>/dev/null; then
10
+ echo "Docker could not be found. Please install Docker and try again."
11
+ exit 1
12
+ fi
13
+ }
14
+
15
+ download_model_checkpoint() {
16
+ if [ -f "./files/checkpoints/${model_name}" ]; then
17
+ echo "Model checkpoint ${model_name} already exists. Skipping download."
18
+ return
19
+ fi
20
+ echo "Downloading model checkpoint ${model_name}..."
21
+ wget "https://download.europe.naverlabs.com/ComputerVision/MASt3R/${model_name}" -P ./files/checkpoints
22
+ }
23
+
24
+ set_dcomp() {
25
+ if command -v docker-compose &>/dev/null; then
26
+ dcomp="docker-compose"
27
+ elif command -v docker &>/dev/null && docker compose version &>/dev/null; then
28
+ dcomp="docker compose"
29
+ else
30
+ echo "Docker Compose could not be found. Please install Docker Compose and try again."
31
+ exit 1
32
+ fi
33
+ }
34
+
35
+ run_docker() {
36
+ export MODEL=${model_name}
37
+ if [ "$with_cuda" -eq 1 ]; then
38
+ $dcomp -f docker-compose-cuda.yml up --build
39
+ else
40
+ $dcomp -f docker-compose-cpu.yml up --build
41
+ fi
42
+ }
43
+
44
+ with_cuda=0
45
+ for arg in "$@"; do
46
+ case $arg in
47
+ --with-cuda)
48
+ with_cuda=1
49
+ ;;
50
+ --model_name=*)
51
+ model_name="${arg#*=}.pth"
52
+ ;;
53
+ *)
54
+ echo "Unknown parameter passed: $arg"
55
+ exit 1
56
+ ;;
57
+ esac
58
+ done
59
+
60
+
61
+ main() {
62
+ check_docker
63
+ download_model_checkpoint
64
+ set_dcomp
65
+ run_docker
66
+ }
67
+
68
+ main
dust3r/.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ checkpoints/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
dust3r/LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
dust3r/NOTICE ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DUSt3R
2
+ Copyright 2024-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ naver/croco
10
+ https://github.com/naver/croco/
11
+
12
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0
dust3r/README.md ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![demo](assets/dust3r.jpg)
2
+
3
+ Official implementation of `DUSt3R: Geometric 3D Vision Made Easy`
4
+ [[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]
5
+
6
+ > **Make sure to also check [MASt3R](https://github.com/naver/mast3r): Our new model with a local feature head, metric pointmaps, and a more scalable global alignment!**
7
+
8
+ ![Example of reconstruction from two images](assets/pipeline1.jpg)
9
+
10
+ ![High level overview of DUSt3R capabilities](assets/dust3r_archi.jpg)
11
+
12
+ ```bibtex
13
+ @inproceedings{dust3r_cvpr24,
14
+ title={DUSt3R: Geometric 3D Vision Made Easy},
15
+ author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
16
+ booktitle = {CVPR},
17
+ year = {2024}
18
+ }
19
+
20
+ @misc{dust3r_arxiv23,
21
+ title={DUSt3R: Geometric 3D Vision Made Easy},
22
+ author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
23
+ year={2023},
24
+ eprint={2312.14132},
25
+ archivePrefix={arXiv},
26
+ primaryClass={cs.CV}
27
+ }
28
+ ```
29
+
30
+ ## Table of Contents
31
+
32
+ - [Table of Contents](#table-of-contents)
33
+ - [License](#license)
34
+ - [Get Started](#get-started)
35
+ - [Installation](#installation)
36
+ - [Checkpoints](#checkpoints)
37
+ - [Interactive demo](#interactive-demo)
38
+ - [Interactive demo with docker](#interactive-demo-with-docker)
39
+ - [Usage](#usage)
40
+ - [Training](#training)
41
+ - [Datasets](#datasets)
42
+ - [Demo](#demo)
43
+ - [Our Hyperparameters](#our-hyperparameters)
44
+
45
+ ## License
46
+
47
+ The code is distributed under the CC BY-NC-SA 4.0 License.
48
+ See [LICENSE](LICENSE) for more information.
49
+
50
+ ```python
51
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
52
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
53
+ ```
54
+
55
+ ## Get Started
56
+
57
+ ### Installation
58
+
59
+ 1. Clone DUSt3R.
60
+ ```bash
61
+ git clone --recursive https://github.com/naver/dust3r
62
+ cd dust3r
63
+ # if you have already cloned dust3r:
64
+ # git submodule update --init --recursive
65
+ ```
66
+
67
+ 2. Create the environment, here we show an example using conda.
68
+ ```bash
69
+ conda create -n dust3r python=3.11 cmake=3.14.0
70
+ conda activate dust3r
71
+ conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
72
+ pip install -r requirements.txt
73
+ # Optional: you can also install additional packages to:
74
+ # - add support for HEIC images
75
+ # - add pyrender, used to render depthmap in some datasets preprocessing
76
+ # - add required packages for visloc.py
77
+ pip install -r requirements_optional.txt
78
+ ```
79
+
80
+ 3. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
81
+ ```bash
82
+ # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
83
+ cd croco/models/curope/
84
+ python setup.py build_ext --inplace
85
+ cd ../../../
86
+ ```
87
+
88
+ ### Checkpoints
89
+
90
+ You can obtain the checkpoints by two ways:
91
+
92
+ 1) You can use our huggingface_hub integration: the models will be downloaded automatically.
93
+
94
+ 2) Otherwise, We provide several pre-trained models:
95
+
96
+ | Modelname | Training resolutions | Head | Encoder | Decoder |
97
+ |-------------|----------------------|------|---------|---------|
98
+ | [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |
99
+ | [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |
100
+ | [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |
101
+
102
+ You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)
103
+
104
+ To download a specific model, for example `DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`:
105
+ ```bash
106
+ mkdir -p checkpoints/
107
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
108
+ ```
109
+
110
+ For the checkpoints, make sure to agree to the license of all the public training datasets and base checkpoints we used, in addition to CC-BY-NC-SA 4.0. Again, see [section: Our Hyperparameters](#our-hyperparameters) for details.
111
+
112
+ ### Interactive demo
113
+
114
+ In this demo, you should be able run DUSt3R on your machine to reconstruct a scene.
115
+ First select images that depicts the same scene.
116
+
117
+ You can adjust the global alignment schedule and its number of iterations.
118
+
119
+ > [!NOTE]
120
+ > If you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)
121
+
122
+ Hit "Run" and wait.
123
+ When the global alignment ends, the reconstruction appears.
124
+ Use the slider "min_conf_thr" to show or remove low confidence areas.
125
+
126
+ ```bash
127
+ python3 demo.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt
128
+
129
+ # Use --weights to load a checkpoint from a local file, eg --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
130
+ # Use --image_size to select the correct resolution for the selected checkpoint. 512 (default) or 224
131
+ # Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
132
+ # Use --server_port to change the port, by default it will search for an available port starting at 7860
133
+ # Use --device to use a different device, by default it's "cuda"
134
+ ```
135
+
136
+ ### Interactive demo with docker
137
+
138
+ To run DUSt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
139
+
140
+ 1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
141
+
142
+ 2. **Install NVIDIA Docker Toolkit**: For GPU support, install the NVIDIA Docker toolkit from the [Nvidia website](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
143
+
144
+ 3. **Build the Docker image and run it**: `cd` into the `./docker` directory and run the following commands:
145
+
146
+ ```bash
147
+ cd docker
148
+ bash run.sh --with-cuda --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
149
+ ```
150
+
151
+ Or if you want to run the demo without CUDA support, run the following command:
152
+
153
+ ```bash
154
+ cd docker
155
+ bash run.sh --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
156
+ ```
157
+
158
+ By default, `demo.py` is lanched with the option `--local_network`.
159
+ Visit `http://localhost:7860/` to access the web UI (or replace `localhost` with the machine's name to access it from the network).
160
+
161
+ `run.sh` will launch docker-compose using either the [docker-compose-cuda.yml](docker/docker-compose-cuda.yml) or [docker-compose-cpu.ym](docker/docker-compose-cpu.yml) config file, then it starts the demo using [entrypoint.sh](docker/files/entrypoint.sh).
162
+
163
+
164
+ ![demo](assets/demo.jpg)
165
+
166
+ ## Usage
167
+
168
+ ```python
169
+ from dust3r.inference import inference
170
+ from dust3r.model import AsymmetricCroCo3DStereo
171
+ from dust3r.utils.image import load_images
172
+ from dust3r.image_pairs import make_pairs
173
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
174
+
175
+ if __name__ == '__main__':
176
+ device = 'cuda'
177
+ batch_size = 1
178
+ schedule = 'cosine'
179
+ lr = 0.01
180
+ niter = 300
181
+
182
+ model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
183
+ # you can put the path to a local checkpoint in model_name if needed
184
+ model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)
185
+ # load_images can take a list of images or a directory
186
+ images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
187
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
188
+ output = inference(pairs, model, device, batch_size=batch_size)
189
+
190
+ # at this stage, you have the raw dust3r predictions
191
+ view1, pred1 = output['view1'], output['pred1']
192
+ view2, pred2 = output['view2'], output['pred2']
193
+ # here, view1, pred1, view2, pred2 are dicts of lists of len(2)
194
+ # -> because we symmetrize we have (im1, im2) and (im2, im1) pairs
195
+ # in each view you have:
196
+ # an integer image identifier: view1['idx'] and view2['idx']
197
+ # the img: view1['img'] and view2['img']
198
+ # the image shape: view1['true_shape'] and view2['true_shape']
199
+ # an instance string output by the dataloader: view1['instance'] and view2['instance']
200
+ # pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']
201
+ # pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']
202
+ # pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']
203
+
204
+ # next we'll use the global_aligner to align the predictions
205
+ # depending on your task, you may be fine with the raw output and not need it
206
+ # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
207
+ # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
208
+ scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
209
+ loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
210
+
211
+ # retrieve useful values from scene:
212
+ imgs = scene.imgs
213
+ focals = scene.get_focals()
214
+ poses = scene.get_im_poses()
215
+ pts3d = scene.get_pts3d()
216
+ confidence_masks = scene.get_masks()
217
+
218
+ # visualize reconstruction
219
+ scene.show()
220
+
221
+ # find 2D-2D matches between the two images
222
+ from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
223
+ pts2d_list, pts3d_list = [], []
224
+ for i in range(2):
225
+ conf_i = confidence_masks[i].cpu().numpy()
226
+ pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
227
+ pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
228
+ reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
229
+ print(f'found {num_matches} matches')
230
+ matches_im1 = pts2d_list[1][reciprocal_in_P2]
231
+ matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
232
+
233
+ # visualize a few matches
234
+ import numpy as np
235
+ from matplotlib import pyplot as pl
236
+ n_viz = 10
237
+ match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)
238
+ viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
239
+
240
+ H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
241
+ img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
242
+ img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
243
+ img = np.concatenate((img0, img1), axis=1)
244
+ pl.figure()
245
+ pl.imshow(img)
246
+ cmap = pl.get_cmap('jet')
247
+ for i in range(n_viz):
248
+ (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
249
+ pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
250
+ pl.show(block=True)
251
+
252
+ ```
253
+ ![matching example on croco pair](assets/matching.jpg)
254
+
255
+ ## Training
256
+
257
+ In this section, we present a short demonstration to get started with training DUSt3R.
258
+
259
+ ### Datasets
260
+ At this moment, we have added the following training datasets:
261
+ - [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE)
262
+ - [ARKitScenes](https://github.com/apple/ARKitScenes) - [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license)
263
+ - [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) - [non-commercial research and educational purposes](https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf)
264
+ - [BlendedMVS](https://github.com/YoYo000/BlendedMVS) - [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/)
265
+ - [WayMo Open dataset](https://github.com/waymo-research/waymo-open-dataset) - [Non-Commercial Use](https://waymo.com/open/terms/)
266
+ - [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md)
267
+ - [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/)
268
+ - [StaticThings3D](https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d)
269
+ - [WildRGB-D](https://github.com/wildrgbd/wildrgbd/)
270
+
271
+ For each dataset, we provide a preprocessing script in the `datasets_preprocess` directory and an archive containing the list of pairs when needed.
272
+ You have to download the datasets yourself from their official sources, agree to their license, download our list of pairs, and run the preprocessing script.
273
+
274
+ Links:
275
+
276
+ [ARKitScenes pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/arkitscenes_pairs.zip)
277
+ [ScanNet++ pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_pairs.zip)
278
+ [BlendedMVS pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/blendedmvs_pairs.npy)
279
+ [WayMo Open dataset pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz)
280
+ [Habitat metadata](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz)
281
+ [MegaDepth pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/megadepth_pairs.npz)
282
+ [StaticThings3D pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/staticthings_pairs.npy)
283
+
284
+ > [!NOTE]
285
+ > They are not strictly equivalent to what was used to train DUSt3R, but they should be close enough.
286
+
287
+ ### Demo
288
+ For this training demo, we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.
289
+ The demo model will be trained for a few epochs on a very small dataset.
290
+ It will not be very good.
291
+
292
+ ```bash
293
+ # download and prepare the co3d subset
294
+ mkdir -p data/co3d_subset
295
+ cd data/co3d_subset
296
+ git clone https://github.com/facebookresearch/co3d
297
+ cd co3d
298
+ python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset
299
+ rm ../*.zip
300
+ cd ../../..
301
+
302
+ python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed --single_sequence_subset
303
+
304
+ # download the pretrained croco v2 checkpoint
305
+ mkdir -p checkpoints/
306
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/
307
+
308
+ # the training of dust3r is done in 3 steps.
309
+ # for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters"
310
+ # step 1 - train dust3r for 224 resolution
311
+ torchrun --nproc_per_node=4 train.py \
312
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \
313
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \
314
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
315
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
316
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
317
+ --pretrained "checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
318
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \
319
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
320
+ --output_dir "checkpoints/dust3r_demo_224"
321
+
322
+ # step 2 - train dust3r for 512 resolution
323
+ torchrun --nproc_per_node=4 train.py \
324
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
325
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
326
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
327
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
328
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
329
+ --pretrained "checkpoints/dust3r_demo_224/checkpoint-best.pth" \
330
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \
331
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
332
+ --output_dir "checkpoints/dust3r_demo_512"
333
+
334
+ # step 3 - train dust3r for 512 resolution with dpt
335
+ torchrun --nproc_per_node=4 train.py \
336
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
337
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
338
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
339
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
340
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
341
+ --pretrained "checkpoints/dust3r_demo_512/checkpoint-best.pth" \
342
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \
343
+ --save_freq 1 --keep_freq 5 --eval_freq 1 --disable_cudnn_benchmark \
344
+ --output_dir "checkpoints/dust3r_demo_512dpt"
345
+
346
+ ```
347
+
348
+ ### Our Hyperparameters
349
+
350
+ Here are the commands we used for training the models:
351
+
352
+ ```bash
353
+ # NOTE: ROOT path omitted for datasets
354
+ # 224 linear
355
+ torchrun --nproc_per_node 8 train.py \
356
+ --train_dataset=" + 100_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepth(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=224, transform=ColorJitter) " \
357
+ --test_dataset=" Habitat(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepth(split='val', resolution=224, seed=777) + 1_000 @ Co3d(split='test', mask_bg='rand', resolution=224, seed=777) " \
358
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
359
+ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
360
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
361
+ --pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
362
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \
363
+ --save_freq=5 --keep_freq=10 --eval_freq=1 \
364
+ --output_dir="checkpoints/dust3r_224"
365
+
366
+ # 512 linear
367
+ torchrun --nproc_per_node 8 train.py \
368
+ --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
369
+ --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
370
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
371
+ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
372
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
373
+ --pretrained="checkpoints/dust3r_224/checkpoint-best.pth" \
374
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=100 --batch_size=4 --accum_iter=2 \
375
+ --save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \
376
+ --output_dir="checkpoints/dust3r_512"
377
+
378
+ # 512 dpt
379
+ torchrun --nproc_per_node 8 train.py \
380
+ --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
381
+ --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
382
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
383
+ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
384
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
385
+ --pretrained="checkpoints/dust3r_512/checkpoint-best.pth" \
386
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=4 --accum_iter=2 \
387
+ --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 --disable_cudnn_benchmark \
388
+ --output_dir="checkpoints/dust3r_512dpt"
389
+
390
+ ```
dust3r/assets/demo.jpg ADDED

Git LFS Details

  • SHA256: 957a892f9033fb3e733546a202e3c07e362618c708eacf050979d4c4edd5435f
  • Pointer size: 131 Bytes
  • Size of remote file: 340 kB
dust3r/assets/dust3r.jpg ADDED
dust3r/assets/dust3r_archi.jpg ADDED
dust3r/assets/matching.jpg ADDED

Git LFS Details

  • SHA256: ecfe07fd00505045a155902c5686cc23060782a8b020f7596829fb60584a79ee
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
dust3r/assets/pipeline1.jpg ADDED
dust3r/croco/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8
+
9
+
10
+ SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11
+
12
+ ***************************
13
+
14
+ NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15
+
16
+ This software is being redistributed in a modifiled form. The original form is available here:
17
+
18
+ https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+
20
+ This software in this file incorporates parts of the following software available here:
21
+
22
+ Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23
+ available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24
+
25
+ MoCo v3: https://github.com/facebookresearch/moco-v3
26
+ available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27
+
28
+ DeiT: https://github.com/facebookresearch/deit
29
+ available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30
+
31
+
32
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33
+
34
+ https://github.com/facebookresearch/mae/blob/main/LICENSE
35
+
36
+ Attribution-NonCommercial 4.0 International
37
+
38
+ ***************************
39
+
40
+ NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41
+
42
+ This software is being redistributed in a modifiled form. The original form is available here:
43
+
44
+ https://github.com/rwightman/pytorch-image-models
45
+
46
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47
+
48
+ https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49
+
50
+ Apache License
51
+ Version 2.0, January 2004
52
+ http://www.apache.org/licenses/
dust3r/croco/NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo
2
+ Copyright 2022-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ facebookresearch/mae
10
+ https://github.com/facebookresearch/mae
11
+
12
+ Attribution-NonCommercial 4.0 International
13
+
14
+ ====
15
+
16
+ rwightman/pytorch-image-models
17
+ https://github.com/rwightman/pytorch-image-models
18
+
19
+ Apache License
20
+ Version 2.0, January 2004
21
+ http://www.apache.org/licenses/
dust3r/croco/README.MD ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
2
+
3
+ [[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
4
+
5
+ This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
6
+
7
+ ![image](assets/arch.jpg)
8
+
9
+ ```bibtex
10
+ @inproceedings{croco,
11
+ title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
12
+ author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
13
+ booktitle={{NeurIPS}},
14
+ year={2022}
15
+ }
16
+
17
+ @inproceedings{croco_v2,
18
+ title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
19
+ author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
20
+ booktitle={ICCV},
21
+ year={2023}
22
+ }
23
+ ```
24
+
25
+ ## License
26
+
27
+ The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
28
+ Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
29
+ Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
30
+
31
+ ## Preparation
32
+
33
+ 1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
34
+
35
+ ```bash
36
+ conda create -n croco python=3.7 cmake=3.14.0
37
+ conda activate croco
38
+ conda install habitat-sim headless -c conda-forge -c aihabitat
39
+ conda install pytorch torchvision -c pytorch
40
+ conda install notebook ipykernel matplotlib
41
+ conda install ipywidgets widgetsnbextension
42
+ conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
43
+
44
+ ```
45
+
46
+ 2. Compile cuda kernels for RoPE
47
+
48
+ CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
49
+ ```bash
50
+ cd models/curope/
51
+ python setup.py build_ext --inplace
52
+ cd ../../
53
+ ```
54
+
55
+ This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
56
+ You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
57
+
58
+ In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
59
+
60
+ 3. Download pre-trained model
61
+
62
+ We provide several pre-trained models:
63
+
64
+ | modelname | pre-training data | pos. embed. | Encoder | Decoder |
65
+ |------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
66
+ | [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
67
+ | [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
68
+ | [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
69
+ | [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
70
+
71
+ To download a specific model, i.e., the first one (`CroCo.pth`)
72
+ ```bash
73
+ mkdir -p pretrained_models/
74
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
75
+ ```
76
+
77
+ ## Reconstruction example
78
+
79
+ Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
80
+ ```bash
81
+ python demo.py
82
+ ```
83
+
84
+ ## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
85
+
86
+ First download the test scene from Habitat:
87
+ ```bash
88
+ python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
89
+ ```
90
+
91
+ Then, run the Notebook demo `interactive_demo.ipynb`.
92
+
93
+ In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
94
+ ![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)
95
+
96
+ ## Pre-training
97
+
98
+ ### CroCo
99
+
100
+ To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
101
+ ```
102
+ torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
103
+ ```
104
+
105
+ Our CroCo pre-training was launched on a single server with 4 GPUs.
106
+ It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
107
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
108
+ The first run can take a few minutes to start, to parse all available pre-training pairs.
109
+
110
+ ### CroCo v2
111
+
112
+ For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
113
+ Then, run the following command for the largest model (ViT-L encoder, Base decoder):
114
+ ```
115
+ torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
116
+ ```
117
+
118
+ Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
119
+ The largest model should take around 12 days on A100.
120
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
121
+
122
+ ## Stereo matching and Optical flow downstream tasks
123
+
124
+ For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
dust3r/croco/assets/Chateau1.png ADDED

Git LFS Details

  • SHA256: 71ffb8c7d77e5ced0bb3dcd2cb0db84d0e98e6ff5ffd2d02696a7156e5284857
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
dust3r/croco/assets/Chateau2.png ADDED

Git LFS Details

  • SHA256: c3a0be9e19f6b89491d692c71e3f2317c2288a898a990561d48b7667218b47c8
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
dust3r/croco/assets/arch.jpg ADDED
dust3r/croco/croco-stereo-flow-demo.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9bca0f41",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Simple inference example with CroCo-Stereo or CroCo-Flow"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "80653ef7",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
19
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "4f033862",
25
+ "metadata": {},
26
+ "source": [
27
+ "First download the model(s) of your choice by running\n",
28
+ "```\n",
29
+ "bash stereoflow/download_model.sh crocostereo.pth\n",
30
+ "bash stereoflow/download_model.sh crocoflow.pth\n",
31
+ "```"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "1fb2e392",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "import torch\n",
42
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
43
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
44
+ "import matplotlib.pylab as plt"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "e0e25d77",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "from stereoflow.test import _load_model_and_criterion\n",
55
+ "from stereoflow.engine import tiled_pred\n",
56
+ "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
57
+ "from stereoflow.datasets_flow import flowToColor\n",
58
+ "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "id": "86a921f5",
64
+ "metadata": {},
65
+ "source": [
66
+ "### CroCo-Stereo example"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "64e483cb",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
77
+ "image2 = np.asarray(Image.open('<path_to_right_image>'))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "f0d04303",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "47dc14b5",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
98
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
99
+ "with torch.inference_mode():\n",
100
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
101
+ "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "583b9f16",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "plt.imshow(vis_disparity(pred))\n",
112
+ "plt.axis('off')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "id": "d2df5d70",
118
+ "metadata": {},
119
+ "source": [
120
+ "### CroCo-Flow example"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "9ee257a7",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
131
+ "image2 = np.asarray(Image.open('<path_to_second_image>'))"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "d5edccf0",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "b19692c3",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
152
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
153
+ "with torch.inference_mode():\n",
154
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
155
+ "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "26f79db3",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "plt.imshow(flowToColor(pred))\n",
166
+ "plt.axis('off')"
167
+ ]
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "Python 3 (ipykernel)",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.9.7"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 5
191
+ }
dust3r/croco/datasets_croco/__init__.py ADDED
File without changes
dust3r/croco/datasets_croco/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
dust3r/croco/datasets_croco/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Extracting crops for pre-training
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import functools
13
+ from multiprocessing import Pool
14
+ import math
15
+
16
+
17
+ def arg_parser():
18
+ parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
19
+
20
+ parser.add_argument('--crops', type=str, required=True, help='crop file')
21
+ parser.add_argument('--root-dir', type=str, required=True, help='root directory')
22
+ parser.add_argument('--output-dir', type=str, required=True, help='output directory')
23
+ parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
24
+ parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
25
+ parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
26
+ parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
27
+ return parser
28
+
29
+
30
+ def main(args):
31
+ listing_path = os.path.join(args.output_dir, 'listing.txt')
32
+
33
+ print(f'Loading list of crops ... ({args.nthread} threads)')
34
+ crops, num_crops_to_generate = load_crop_file(args.crops)
35
+
36
+ print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
37
+ num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
38
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
39
+
40
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
41
+ del crops
42
+
43
+ os.makedirs(args.output_dir, exist_ok=True)
44
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
45
+ call = functools.partial(save_image_crops, args)
46
+
47
+ print(f"Generating cropped images to {args.output_dir} ...")
48
+ with open(listing_path, 'w') as listing:
49
+ listing.write('# pair_path\n')
50
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
51
+ for path in results:
52
+ listing.write(f'{path}\n')
53
+ print('Finished writing listing to', listing_path)
54
+
55
+
56
+ def load_crop_file(path):
57
+ data = open(path).read().splitlines()
58
+ pairs = []
59
+ num_crops_to_generate = 0
60
+ for line in tqdm(data):
61
+ if line.startswith('#'):
62
+ continue
63
+ line = line.split(', ')
64
+ if len(line) < 8:
65
+ img1, img2, rotation = line
66
+ pairs.append((img1, img2, int(rotation), []))
67
+ else:
68
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
69
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
70
+ pairs[-1][-1].append((rect1, rect2))
71
+ num_crops_to_generate += 1
72
+ return pairs, num_crops_to_generate
73
+
74
+
75
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
76
+ jobs = []
77
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
78
+
79
+ def get_path(idx):
80
+ idx_array = []
81
+ d = idx
82
+ for level in range(num_levels - 1):
83
+ idx_array.append(idx // powers[level])
84
+ idx = idx % powers[level]
85
+ idx_array.append(d)
86
+ return '/'.join(map(lambda x: hex(x)[2:], idx_array))
87
+
88
+ idx = 0
89
+ for pair_data in tqdm(pairs):
90
+ img1, img2, rotation, crops = pair_data
91
+ if -60 <= rotation and rotation <= 60:
92
+ rotation = 0 # most likely not a true rotation
93
+ paths = [get_path(idx + k) for k in range(len(crops))]
94
+ idx += len(crops)
95
+ jobs.append(((img1, img2), rotation, crops, paths))
96
+ return jobs
97
+
98
+
99
+ def load_image(path):
100
+ try:
101
+ return Image.open(path).convert('RGB')
102
+ except Exception as e:
103
+ print('skipping', path, e)
104
+ raise OSError()
105
+
106
+
107
+ def save_image_crops(args, data):
108
+ # load images
109
+ img_pair, rot, crops, paths = data
110
+ try:
111
+ img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
112
+ except OSError as e:
113
+ return []
114
+
115
+ def area(sz):
116
+ return sz[0] * sz[1]
117
+
118
+ tgt_size = (args.imsize, args.imsize)
119
+
120
+ def prepare_crop(img, rect, rot=0):
121
+ # actual crop
122
+ img = img.crop(rect)
123
+
124
+ # resize to desired size
125
+ interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
126
+ img = img.resize(tgt_size, resample=interp)
127
+
128
+ # rotate the image
129
+ rot90 = (round(rot/90) % 4) * 90
130
+ if rot90 == 90:
131
+ img = img.transpose(Image.Transpose.ROTATE_90)
132
+ elif rot90 == 180:
133
+ img = img.transpose(Image.Transpose.ROTATE_180)
134
+ elif rot90 == 270:
135
+ img = img.transpose(Image.Transpose.ROTATE_270)
136
+ return img
137
+
138
+ results = []
139
+ for (rect1, rect2), path in zip(crops, paths):
140
+ crop1 = prepare_crop(img1, rect1)
141
+ crop2 = prepare_crop(img2, rect2, rot)
142
+
143
+ fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
144
+ fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
145
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
146
+
147
+ assert not os.path.isfile(fullpath1), fullpath1
148
+ assert not os.path.isfile(fullpath2), fullpath2
149
+ crop1.save(fullpath1)
150
+ crop2.save(fullpath2)
151
+ results.append(path)
152
+
153
+ return results
154
+
155
+
156
+ if __name__ == '__main__':
157
+ args = arg_parser().parse_args()
158
+ main(args)
159
+
dust3r/croco/datasets_croco/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
dust3r/croco/datasets_croco/habitat_sim/__init__.py ADDED
File without changes
dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6
+ """
7
+ import os
8
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
9
+ from datasets.habitat_sim.paths import SCENES_DATASET
10
+ import argparse
11
+ import quaternion
12
+ import PIL.Image
13
+ import cv2
14
+ import json
15
+ from tqdm import tqdm
16
+
17
+ def generate_multiview_images_from_metadata(metadata_filename,
18
+ output_dir,
19
+ overload_params = dict(),
20
+ scene_datasets_paths=None,
21
+ exist_ok=False):
22
+ """
23
+ Generate images from a metadata file for reproducibility purposes.
24
+ """
25
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
26
+ if scene_datasets_paths is not None:
27
+ scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
28
+
29
+ with open(metadata_filename, 'r') as f:
30
+ input_metadata = json.load(f)
31
+ metadata = dict()
32
+ for key, value in input_metadata.items():
33
+ # Optionally replace some paths
34
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
35
+ if scene_datasets_paths is not None:
36
+ for dataset_label, dataset_path in scene_datasets_paths.items():
37
+ if value.startswith(dataset_label):
38
+ value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
39
+ break
40
+ metadata[key] = value
41
+
42
+ # Overload some parameters
43
+ for key, value in overload_params.items():
44
+ metadata[key] = value
45
+
46
+ generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
47
+ generate_depth = metadata["generate_depth"]
48
+
49
+ os.makedirs(output_dir, exist_ok=exist_ok)
50
+
51
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
52
+
53
+ # Generate views
54
+ for idx_label, data in tqdm(metadata['multiviews'].items()):
55
+ positions = data["positions"]
56
+ orientations = data["orientations"]
57
+ n = len(positions)
58
+ for oidx in range(n):
59
+ observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
60
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
61
+ # Color image saved using PIL
62
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
63
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
64
+ img.save(filename)
65
+ if generate_depth:
66
+ # Depth image as EXR file
67
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
68
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
69
+ # Camera parameters
70
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
71
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
72
+ with open(filename, "w") as f:
73
+ json.dump(camera_params, f)
74
+ # Save metadata
75
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
76
+ json.dump(metadata, f)
77
+
78
+ generator.close()
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--metadata_filename", required=True)
83
+ parser.add_argument("--output_dir", required=True)
84
+ args = parser.parse_args()
85
+
86
+ generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
87
+ output_dir=args.output_dir,
88
+ scene_datasets_paths=SCENES_DATASET,
89
+ overload_params=dict(),
90
+ exist_ok=True)
91
+
92
+
dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script generating commandlines to generate image pairs from metadata files.
6
+ """
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import argparse
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", required=True)
15
+ parser.add_argument("--output_dir", required=True)
16
+ parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
17
+ args = parser.parse_args()
18
+
19
+ input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
20
+
21
+ for metadata_filename in tqdm(input_metadata_filenames):
22
+ output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
23
+ # Do not process the scene if the metadata file already exists
24
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
25
+ continue
26
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
27
+ print(commandline)
dust3r/croco/datasets_croco/habitat_sim/generate_multiview_images.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import PIL.Image
8
+ import numpy as np
9
+ import json
10
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
11
+ from datasets.habitat_sim.paths import list_scenes_available
12
+ import cv2
13
+ import quaternion
14
+ import shutil
15
+
16
+ def generate_multiview_images_for_scene(scene_dataset_config_file,
17
+ scene,
18
+ navmesh,
19
+ output_dir,
20
+ views_count,
21
+ size,
22
+ exist_ok=False,
23
+ generate_depth=False,
24
+ **kwargs):
25
+ """
26
+ Generate tuples of overlapping views for a given scene.
27
+ generate_depth: generate depth images and camera parameters.
28
+ """
29
+ if os.path.exists(output_dir) and not exist_ok:
30
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
31
+ return
32
+ try:
33
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
34
+ os.makedirs(output_dir, exist_ok=exist_ok)
35
+
36
+ metadata_filename = os.path.join(output_dir, "metadata.json")
37
+
38
+ metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
39
+ scene=scene,
40
+ navmesh=navmesh,
41
+ views_count=views_count,
42
+ size=size,
43
+ generate_depth=generate_depth,
44
+ **kwargs)
45
+ metadata_template["multiviews"] = dict()
46
+
47
+ if os.path.exists(metadata_filename):
48
+ print("Metadata file already exists:", metadata_filename)
49
+ print("Loading already generated metadata file...")
50
+ with open(metadata_filename, "r") as f:
51
+ metadata = json.load(f)
52
+
53
+ for key in metadata_template.keys():
54
+ if key != "multiviews":
55
+ assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
56
+ else:
57
+ print("No temporary file found. Starting generation from scratch...")
58
+ metadata = metadata_template
59
+
60
+ starting_id = len(metadata["multiviews"])
61
+ print(f"Starting generation from index {starting_id}/{size}...")
62
+ if starting_id >= size:
63
+ print("Generation already done.")
64
+ return
65
+
66
+ generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
67
+ scene=scene,
68
+ navmesh=navmesh,
69
+ views_count = views_count,
70
+ size = size,
71
+ **kwargs)
72
+
73
+ for idx in tqdm(range(starting_id, size)):
74
+ # Generate / re-generate the observations
75
+ try:
76
+ data = generator[idx]
77
+ observations = data["observations"]
78
+ positions = data["positions"]
79
+ orientations = data["orientations"]
80
+
81
+ idx_label = f"{idx:08}"
82
+ for oidx, observation in enumerate(observations):
83
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
84
+ # Color image saved using PIL
85
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
86
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
87
+ img.save(filename)
88
+ if generate_depth:
89
+ # Depth image as EXR file
90
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
91
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
92
+ # Camera parameters
93
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
94
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
95
+ with open(filename, "w") as f:
96
+ json.dump(camera_params, f)
97
+ metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
98
+ "orientations": orientations.tolist(),
99
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
100
+ "valid_fractions": data["valid_fractions"].tolist(),
101
+ "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
102
+ except RecursionError:
103
+ print("Recursion error: unable to sample observations for this scene. We will stop there.")
104
+ break
105
+
106
+ # Regularly save a temporary metadata file, in case we need to restart the generation
107
+ if idx % 10 == 0:
108
+ with open(metadata_filename, "w") as f:
109
+ json.dump(metadata, f)
110
+
111
+ # Save metadata
112
+ with open(metadata_filename, "w") as f:
113
+ json.dump(metadata, f)
114
+
115
+ generator.close()
116
+ except NoNaviguableSpaceError:
117
+ pass
118
+
119
+ def create_commandline(scene_data, generate_depth, exist_ok=False):
120
+ """
121
+ Create a commandline string to generate a scene.
122
+ """
123
+ def my_formatting(val):
124
+ if val is None or val == "":
125
+ return '""'
126
+ else:
127
+ return val
128
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
129
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
130
+ --navmesh {my_formatting(scene_data.navmesh)}
131
+ --output_dir {my_formatting(scene_data.output_dir)}
132
+ --generate_depth {int(generate_depth)}
133
+ --exist_ok {int(exist_ok)}
134
+ """
135
+ commandline = " ".join(commandline.split())
136
+ return commandline
137
+
138
+ if __name__ == "__main__":
139
+ os.umask(2)
140
+
141
+ parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
142
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
143
+ """)
144
+
145
+ parser.add_argument("--output_dir", type=str, required=True)
146
+ parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
147
+ parser.add_argument("--scene", type=str, default="")
148
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
149
+ parser.add_argument("--navmesh", type=str, default="")
150
+
151
+ parser.add_argument("--generate_depth", type=int, default=1)
152
+ parser.add_argument("--exist_ok", type=int, default=0)
153
+
154
+ kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
155
+
156
+ args = parser.parse_args()
157
+ generate_depth=bool(args.generate_depth)
158
+ exist_ok = bool(args.exist_ok)
159
+
160
+ if args.list_commands:
161
+ # Listing scenes available...
162
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
163
+
164
+ for scene_data in scenes_data:
165
+ print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
166
+ else:
167
+ if args.scene == "" or args.output_dir == "":
168
+ print("Missing scene or output dir argument!")
169
+ print(parser.format_help())
170
+ else:
171
+ generate_multiview_images_for_scene(scene=args.scene,
172
+ scene_dataset_config_file = args.scene_dataset_config_file,
173
+ navmesh = args.navmesh,
174
+ output_dir = args.output_dir,
175
+ exist_ok=exist_ok,
176
+ generate_depth=generate_depth,
177
+ **kwargs)
dust3r/croco/datasets_croco/habitat_sim/multiview_habitat_sim_generator.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import numpy as np
6
+ import quaternion
7
+ import habitat_sim
8
+ import json
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import cv2
11
+
12
+ # OpenCV to habitat camera convention transformation
13
+ R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
14
+ R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
15
+ DEG2RAD = np.pi / 180
16
+
17
+ def compute_camera_intrinsics(height, width, hfov):
18
+ f = width/2 / np.tan(hfov/2 * np.pi/180)
19
+ cu, cv = width/2, height/2
20
+ return f, cu, cv
21
+
22
+ def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
23
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
24
+ t_cam2world = np.asarray(camera_position)
25
+ return R_cam2world, t_cam2world
26
+
27
+ def compute_pointmap(depthmap, hfov):
28
+ """ Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
29
+ height, width = depthmap.shape
30
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
31
+ # Cast depth map to point
32
+ z_cam = depthmap
33
+ u, v = np.meshgrid(range(width), range(height))
34
+ x_cam = (u - cu) / f * z_cam
35
+ y_cam = (v - cv) / f * z_cam
36
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
37
+ return X_cam
38
+
39
+ def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
40
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
41
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
42
+
43
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
44
+ valid_mask = (X_cam[:,:,2] != 0.0)
45
+
46
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
47
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
48
+ return X_world
49
+
50
+ def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
51
+ """
52
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
53
+ """
54
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
55
+ distances, indices = nbrs.kneighbors(pointcloud1)
56
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
57
+
58
+ data = {"intersection1": intersection1,
59
+ "size1": len(pointcloud1)}
60
+ if compute_symmetric:
61
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
62
+ distances, indices = nbrs.kneighbors(pointcloud2)
63
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
64
+ data["intersection2"] = intersection2
65
+ data["size2"] = len(pointcloud2)
66
+
67
+ return data
68
+
69
+ def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
70
+ """
71
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
72
+ In-place modifications.
73
+ """
74
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
75
+ height, width = observation['depth'].shape
76
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
77
+ K = np.asarray([[f, 0, cu],
78
+ [0, f, cv],
79
+ [0, 0, 1.0]])
80
+ observation["camera_intrinsics"] = K
81
+ observation["t_cam2world"] = t_cam2world
82
+ observation["R_cam2world"] = R_cam2world
83
+
84
+ def look_at(eye, center, up, return_cam2world=True):
85
+ """
86
+ Return camera pose looking at a given center point.
87
+ Analogous of gluLookAt function, using OpenCV camera convention.
88
+ """
89
+ z = center - eye
90
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
91
+ y = -up
92
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
93
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
94
+ x = np.cross(y, z, axis=-1)
95
+
96
+ if return_cam2world:
97
+ R = np.stack((x, y, z), axis=-1)
98
+ t = eye
99
+ else:
100
+ # World to camera transformation
101
+ # Transposed matrix
102
+ R = np.stack((x, y, z), axis=-2)
103
+ t = - np.einsum('...ij, ...j', R, eye)
104
+ return R, t
105
+
106
+ def look_at_for_habitat(eye, center, up, return_cam2world=True):
107
+ R, t = look_at(eye, center, up)
108
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
109
+ return orientation, t
110
+
111
+ def generate_orientation_noise(pan_range, tilt_range, roll_range):
112
+ return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
113
+ * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
114
+ * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
115
+
116
+
117
+ class NoNaviguableSpaceError(RuntimeError):
118
+ def __init__(self, *args):
119
+ super().__init__(*args)
120
+
121
+ class MultiviewHabitatSimGenerator:
122
+ def __init__(self,
123
+ scene,
124
+ navmesh,
125
+ scene_dataset_config_file,
126
+ resolution = (240, 320),
127
+ views_count=2,
128
+ hfov = 60,
129
+ gpu_id = 0,
130
+ size = 10000,
131
+ minimum_covisibility = 0.5,
132
+ transform = None):
133
+ self.scene = scene
134
+ self.navmesh = navmesh
135
+ self.scene_dataset_config_file = scene_dataset_config_file
136
+ self.resolution = resolution
137
+ self.views_count = views_count
138
+ assert(self.views_count >= 1)
139
+ self.hfov = hfov
140
+ self.gpu_id = gpu_id
141
+ self.size = size
142
+ self.transform = transform
143
+
144
+ # Noise added to camera orientation
145
+ self.pan_range = (-3, 3)
146
+ self.tilt_range = (-10, 10)
147
+ self.roll_range = (-5, 5)
148
+
149
+ # Height range to sample cameras
150
+ self.height_range = (1.2, 1.8)
151
+
152
+ # Random steps between the camera views
153
+ self.random_steps_count = 5
154
+ self.random_step_variance = 2.0
155
+
156
+ # Minimum fraction of the scene which should be valid (well defined depth)
157
+ self.minimum_valid_fraction = 0.7
158
+
159
+ # Distance threshold to see to select pairs
160
+ self.distance_threshold = 0.05
161
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
162
+ self.minimum_covisibility = minimum_covisibility
163
+
164
+ # Maximum number of retries.
165
+ self.max_attempts_count = 100
166
+
167
+ self.seed = None
168
+ self._lazy_initialization()
169
+
170
+ def _lazy_initialization(self):
171
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
172
+ if self.seed == None:
173
+ # Re-seed numpy generator
174
+ np.random.seed()
175
+ self.seed = np.random.randint(2**32-1)
176
+ sim_cfg = habitat_sim.SimulatorConfiguration()
177
+ sim_cfg.scene_id = self.scene
178
+ if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
179
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
180
+ sim_cfg.random_seed = self.seed
181
+ sim_cfg.load_semantic_mesh = False
182
+ sim_cfg.gpu_device_id = self.gpu_id
183
+
184
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
185
+ depth_sensor_spec.uuid = "depth"
186
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
187
+ depth_sensor_spec.resolution = self.resolution
188
+ depth_sensor_spec.hfov = self.hfov
189
+ depth_sensor_spec.position = [0.0, 0.0, 0]
190
+ depth_sensor_spec.orientation
191
+
192
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
193
+ rgb_sensor_spec.uuid = "color"
194
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
195
+ rgb_sensor_spec.resolution = self.resolution
196
+ rgb_sensor_spec.hfov = self.hfov
197
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
198
+ agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
199
+
200
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
201
+ self.sim = habitat_sim.Simulator(cfg)
202
+ if self.navmesh is not None and self.navmesh != "":
203
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
204
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
205
+
206
+ if not self.sim.pathfinder.is_loaded:
207
+ # Try to compute a navmesh
208
+ navmesh_settings = habitat_sim.NavMeshSettings()
209
+ navmesh_settings.set_defaults()
210
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
211
+
212
+ # Ensure that the navmesh is not empty
213
+ if not self.sim.pathfinder.is_loaded:
214
+ raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
215
+
216
+ self.agent = self.sim.initialize_agent(agent_id=0)
217
+
218
+ def close(self):
219
+ self.sim.close()
220
+
221
+ def __del__(self):
222
+ self.sim.close()
223
+
224
+ def __len__(self):
225
+ return self.size
226
+
227
+ def sample_random_viewpoint(self):
228
+ """ Sample a random viewpoint using the navmesh """
229
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
230
+
231
+ # Sample a random viewpoint height
232
+ viewpoint_height = np.random.uniform(*self.height_range)
233
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
234
+ viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
235
+ return viewpoint_position, viewpoint_orientation, nav_point
236
+
237
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
238
+ """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
239
+ other_nav_point = nav_point
240
+
241
+ walk_directions = self.random_step_variance * np.asarray([1,0,1])
242
+ for i in range(self.random_steps_count):
243
+ temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
244
+ # Snapping may return nan when it fails
245
+ if not np.isnan(temp[0]):
246
+ other_nav_point = temp
247
+
248
+ other_viewpoint_height = np.random.uniform(*self.height_range)
249
+ other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
250
+
251
+ # Set viewing direction towards the central point
252
+ rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
253
+ rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
254
+ return position, rotation, other_nav_point
255
+
256
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
257
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
258
+ # Observation
259
+ pixels_count = self.resolution[0] * self.resolution[1]
260
+ valid_fraction = len(other_pointcloud) / pixels_count
261
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
262
+ overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
263
+ covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
264
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
265
+ return is_valid, valid_fraction, covisibility
266
+
267
+ def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
268
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
269
+ # Observation
270
+ other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
271
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
272
+
273
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
274
+ agent_state = habitat_sim.AgentState()
275
+ agent_state.position = viewpoint_position
276
+ agent_state.rotation = viewpoint_orientation
277
+ self.agent.set_state(agent_state)
278
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
279
+ _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
280
+ return viewpoint_observations
281
+
282
+ def __getitem__(self, useless_idx):
283
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
284
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
285
+ # Extract point cloud
286
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
287
+ camera_position=ref_position, camera_rotation=ref_orientation)
288
+
289
+ pixels_count = self.resolution[0] * self.resolution[1]
290
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
291
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
292
+ if ref_valid_fraction < self.minimum_valid_fraction:
293
+ # This should produce a recursion error at some point when something is very wrong.
294
+ return self[0]
295
+ # Pick an reference observed point in the point cloud
296
+ observed_point = np.mean(ref_pointcloud, axis=0)
297
+
298
+ # Add the first image as reference
299
+ viewpoints_observations = [ref_observations]
300
+ viewpoints_covisibility = [ref_valid_fraction]
301
+ viewpoints_positions = [ref_position]
302
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
303
+ viewpoints_clouds = [ref_pointcloud]
304
+ viewpoints_valid_fractions = [ref_valid_fraction]
305
+
306
+ for _ in range(self.views_count - 1):
307
+ # Generate an other viewpoint using some dummy random walk
308
+ successful_sampling = False
309
+ for sampling_attempt in range(self.max_attempts_count):
310
+ position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
311
+ # Observation
312
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
313
+ other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
314
+
315
+ is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
316
+ if is_valid:
317
+ successful_sampling = True
318
+ break
319
+ if not successful_sampling:
320
+ print("WARNING: Maximum number of attempts reached.")
321
+ # Dirty hack, try using a novel original viewpoint
322
+ return self[0]
323
+ viewpoints_observations.append(other_viewpoint_observations)
324
+ viewpoints_covisibility.append(covisibility)
325
+ viewpoints_positions.append(position)
326
+ viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
327
+ viewpoints_clouds.append(other_pointcloud)
328
+ viewpoints_valid_fractions.append(valid_fraction)
329
+
330
+ # Estimate relations between all pairs of images
331
+ pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
332
+ for i in range(len(viewpoints_observations)):
333
+ pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
334
+ for j in range(i+1, len(viewpoints_observations)):
335
+ overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
336
+ pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
337
+ pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
338
+
339
+ # IoU is relative to the image 0
340
+ data = {"observations": viewpoints_observations,
341
+ "positions": np.asarray(viewpoints_positions),
342
+ "orientations": np.asarray(viewpoints_orientations),
343
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
344
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
345
+ "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
346
+ }
347
+
348
+ if self.transform is not None:
349
+ data = self.transform(data)
350
+ return data
351
+
352
+ def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
353
+ """
354
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
355
+ Useful to generate nice visualisations.
356
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
357
+ """
358
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
359
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
360
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
361
+ camera_position=ref_position, camera_rotation=ref_orientation)
362
+ pixels_count = self.resolution[0] * self.resolution[1]
363
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
364
+ # Dirty hack: ensure that the valid part of the image is significant
365
+ return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
366
+
367
+ # Pick an observed point in the point cloud
368
+ observed_point = np.mean(ref_pointcloud, axis=0)
369
+ ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
370
+
371
+ images = []
372
+ is_valid = []
373
+ # Spiral trajectory, use_constant orientation
374
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
375
+ r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
376
+ theta = alpha * half_turns * np.pi
377
+ x = r * np.cos(theta)
378
+ y = r * np.sin(theta)
379
+ z = 0.0
380
+ position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
381
+ if use_constant_orientation:
382
+ orientation = ref_orientation
383
+ else:
384
+ # trajectory looking at a mean point in front of the ref observation
385
+ orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
386
+ observations = self.render_viewpoint(position, orientation)
387
+ images.append(observations['color'][...,:3])
388
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
389
+ is_valid.append(_is_valid)
390
+ return images, np.all(is_valid)
dust3r/croco/datasets_croco/habitat_sim/pack_metadata_files.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ """
4
+ Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5
+ """
6
+ import os
7
+ import glob
8
+ from tqdm import tqdm
9
+ import shutil
10
+ import json
11
+ from datasets.habitat_sim.paths import *
12
+ import argparse
13
+ import collections
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input_dir")
18
+ parser.add_argument("output_dir")
19
+ args = parser.parse_args()
20
+
21
+ input_dirname = args.input_dir
22
+ output_dirname = args.output_dir
23
+
24
+ input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
25
+
26
+ images_count = collections.defaultdict(lambda : 0)
27
+
28
+ os.makedirs(output_dirname)
29
+ for input_filename in tqdm(input_metadata_filenames):
30
+ # Ignore empty files
31
+ with open(input_filename, "r") as f:
32
+ original_metadata = json.load(f)
33
+ if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
34
+ print("No views in", input_filename)
35
+ continue
36
+
37
+ relpath = os.path.relpath(input_filename, input_dirname)
38
+ print(relpath)
39
+
40
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
41
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
42
+ scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
43
+ metadata = dict()
44
+ for key, value in original_metadata.items():
45
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
46
+ known_path = False
47
+ for dataset, dataset_path in scenes_dataset_paths.items():
48
+ if value.startswith(dataset_path):
49
+ value = os.path.join(dataset, os.path.relpath(value, dataset_path))
50
+ known_path = True
51
+ break
52
+ if not known_path:
53
+ raise KeyError("Unknown path:" + value)
54
+ metadata[key] = value
55
+
56
+ # Compile some general statistics while packing data
57
+ scene_split = metadata["scene"].split("/")
58
+ upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
59
+ images_count[upper_level] += len(metadata["multiviews"])
60
+
61
+ output_filename = os.path.join(output_dirname, relpath)
62
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
63
+ with open(output_filename, "w") as f:
64
+ json.dump(metadata, f)
65
+
66
+ # Print statistics
67
+ print("Images count:")
68
+ for upper_level, count in images_count.items():
69
+ print(f"- {upper_level}: {count}")
dust3r/croco/datasets_croco/habitat_sim/paths.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Paths to Habitat-Sim scenes
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import collections
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Hardcoded path to the different scene datasets
15
+ SCENES_DATASET = {
16
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/"
23
+ }
24
+
25
+ SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
26
+
27
+ def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
28
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
29
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
30
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
31
+ scenes_data = []
32
+ for idx in range(len(scenes)):
33
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
34
+ # Add scene
35
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
36
+ scene = scenes[idx] + ".scene_instance.json",
37
+ navmesh = os.path.join(base_path, navmeshes[idx]),
38
+ output_dir = output_dir)
39
+ scenes_data.append(data)
40
+ return scenes_data
41
+
42
+ def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
43
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
44
+ scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
45
+ navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
46
+ scenes_data = []
47
+ for idx in range(len(scenes)):
48
+ output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
49
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
50
+ scene = scenes[idx],
51
+ navmesh = "",
52
+ output_dir = output_dir)
53
+ scenes_data.append(data)
54
+ return scenes_data
55
+
56
+ def list_replica_scenes(base_output_dir, base_path):
57
+ scenes_data = []
58
+ for scene_id in os.listdir(base_path):
59
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
60
+ navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
61
+ scene_dataset_config_file = ""
62
+ output_dir = os.path.join(base_output_dir, scene_id)
63
+ # Add scene only if it does not exist already, or if exist_ok
64
+ data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
65
+ scene = scene,
66
+ navmesh = navmesh,
67
+ output_dir = output_dir)
68
+ scenes_data.append(data)
69
+ return scenes_data
70
+
71
+
72
+ def list_scenes(base_output_dir, base_path):
73
+ """
74
+ Generic method iterating through a base_path folder to find scenes.
75
+ """
76
+ scenes_data = []
77
+ for root, dirs, files in os.walk(base_path, followlinks=True):
78
+ folder_scenes_data = []
79
+ for file in files:
80
+ name, ext = os.path.splitext(file)
81
+ if ext == ".glb":
82
+ scene = os.path.join(root, name + ".glb")
83
+ navmesh = os.path.join(root, name + ".navmesh")
84
+ if not os.path.exists(navmesh):
85
+ navmesh = ""
86
+ relpath = os.path.relpath(root, base_path)
87
+ output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
88
+ data = SceneData(scene_dataset_config_file="",
89
+ scene = scene,
90
+ navmesh = navmesh,
91
+ output_dir = output_dir)
92
+ folder_scenes_data.append(data)
93
+
94
+ # Specific check for HM3D:
95
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
96
+ basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
97
+ if len(basis_scenes) != 0:
98
+ folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
99
+
100
+ scenes_data.extend(folder_scenes_data)
101
+ return scenes_data
102
+
103
+ def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
104
+ scenes_data = []
105
+
106
+ # HM3D
107
+ for split in ("minival", "train", "val", "examples"):
108
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
109
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
110
+
111
+ # Gibson
112
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
113
+ base_path=scenes_dataset_paths["gibson"])
114
+
115
+ # Habitat test scenes (just a few)
116
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
117
+ base_path=scenes_dataset_paths["habitat-test-scenes"])
118
+
119
+ # ReplicaCAD (baked lightning)
120
+ scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
121
+
122
+ # ScanNet
123
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
124
+ base_path=scenes_dataset_paths["scannet"])
125
+
126
+ # Replica
127
+ list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
128
+ base_path=scenes_dataset_paths["replica"])
129
+ return scenes_data
dust3r/croco/datasets_croco/pairs_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ from datasets.transforms import get_pair_transforms
9
+
10
+ def load_image(impath):
11
+ return Image.open(impath)
12
+
13
+ def load_pairs_from_cache_file(fname, root=''):
14
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
15
+ with open(fname, 'r') as fid:
16
+ lines = fid.read().strip().splitlines()
17
+ pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
18
+ return pairs
19
+
20
+ def load_pairs_from_list_file(fname, root=''):
21
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
22
+ with open(fname, 'r') as fid:
23
+ lines = fid.read().strip().splitlines()
24
+ pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
25
+ return pairs
26
+
27
+
28
+ def write_cache_file(fname, pairs, root=''):
29
+ if len(root)>0:
30
+ if not root.endswith('/'): root+='/'
31
+ assert os.path.isdir(root)
32
+ s = ''
33
+ for im1, im2 in pairs:
34
+ if len(root)>0:
35
+ assert im1.startswith(root), im1
36
+ assert im2.startswith(root), im2
37
+ s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
38
+ with open(fname, 'w') as fid:
39
+ fid.write(s[:-1])
40
+
41
+ def parse_and_cache_all_pairs(dname, data_dir='./data/'):
42
+ if dname=='habitat_release':
43
+ dirname = os.path.join(data_dir, 'habitat_release')
44
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
45
+ cache_file = os.path.join(dirname, 'pairs.txt')
46
+ assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
47
+
48
+ print('Parsing pairs for dataset: '+dname)
49
+ pairs = []
50
+ for root, dirs, files in os.walk(dirname):
51
+ if 'val' in root: continue
52
+ dirs.sort()
53
+ pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
54
+ print('Found {:,} pairs'.format(len(pairs)))
55
+ print('Writing cache to: '+cache_file)
56
+ write_cache_file(cache_file, pairs, root=dirname)
57
+
58
+ else:
59
+ raise NotImplementedError('Unknown dataset: '+dname)
60
+
61
+ def dnames_to_image_pairs(dnames, data_dir='./data/'):
62
+ """
63
+ dnames: list of datasets with image pairs, separated by +
64
+ """
65
+ all_pairs = []
66
+ for dname in dnames.split('+'):
67
+ if dname=='habitat_release':
68
+ dirname = os.path.join(data_dir, 'habitat_release')
69
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
70
+ cache_file = os.path.join(dirname, 'pairs.txt')
71
+ assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
72
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
73
+ elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
74
+ dirname = os.path.join(data_dir, dname+'_crops')
75
+ assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
76
+ list_file = os.path.join(dirname, 'listing.txt')
77
+ assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
78
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
79
+ print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
80
+ all_pairs += pairs
81
+ if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
82
+ return all_pairs
83
+
84
+
85
+ class PairsDataset(Dataset):
86
+
87
+ def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
88
+ super().__init__()
89
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
90
+ self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
91
+
92
+ def __len__(self):
93
+ return len(self.image_pairs)
94
+
95
+ def __getitem__(self, index):
96
+ im1path, im2path = self.image_pairs[index]
97
+ im1 = load_image(im1path)
98
+ im2 = load_image(im2path)
99
+ if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
100
+ return im1, im2
101
+
102
+
103
+ if __name__=="__main__":
104
+ import argparse
105
+ parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
106
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
107
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
108
+ args = parser.parse_args()
109
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
dust3r/croco/datasets_croco/transforms.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ import torchvision.transforms
6
+ import torchvision.transforms.functional as F
7
+
8
+ # "Pair": apply a transform on a pair
9
+ # "Both": apply the exact same transform to both images
10
+
11
+ class ComposePair(torchvision.transforms.Compose):
12
+ def __call__(self, img1, img2):
13
+ for t in self.transforms:
14
+ img1, img2 = t(img1, img2)
15
+ return img1, img2
16
+
17
+ class NormalizeBoth(torchvision.transforms.Normalize):
18
+ def forward(self, img1, img2):
19
+ img1 = super().forward(img1)
20
+ img2 = super().forward(img2)
21
+ return img1, img2
22
+
23
+ class ToTensorBoth(torchvision.transforms.ToTensor):
24
+ def __call__(self, img1, img2):
25
+ img1 = super().__call__(img1)
26
+ img2 = super().__call__(img2)
27
+ return img1, img2
28
+
29
+ class RandomCropPair(torchvision.transforms.RandomCrop):
30
+ # the crop will be intentionally different for the two images with this class
31
+ def forward(self, img1, img2):
32
+ img1 = super().forward(img1)
33
+ img2 = super().forward(img2)
34
+ return img1, img2
35
+
36
+ class ColorJitterPair(torchvision.transforms.ColorJitter):
37
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
38
+ def __init__(self, assymetric_prob, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.assymetric_prob = assymetric_prob
41
+ def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
42
+ for fn_id in fn_idx:
43
+ if fn_id == 0 and brightness_factor is not None:
44
+ img = F.adjust_brightness(img, brightness_factor)
45
+ elif fn_id == 1 and contrast_factor is not None:
46
+ img = F.adjust_contrast(img, contrast_factor)
47
+ elif fn_id == 2 and saturation_factor is not None:
48
+ img = F.adjust_saturation(img, saturation_factor)
49
+ elif fn_id == 3 and hue_factor is not None:
50
+ img = F.adjust_hue(img, hue_factor)
51
+ return img
52
+
53
+ def forward(self, img1, img2):
54
+
55
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
56
+ self.brightness, self.contrast, self.saturation, self.hue
57
+ )
58
+ img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
59
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
60
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
61
+ self.brightness, self.contrast, self.saturation, self.hue
62
+ )
63
+ img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
64
+ return img1, img2
65
+
66
+ def get_pair_transforms(transform_str, totensor=True, normalize=True):
67
+ # transform_str is eg crop224+color
68
+ trfs = []
69
+ for s in transform_str.split('+'):
70
+ if s.startswith('crop'):
71
+ size = int(s[len('crop'):])
72
+ trfs.append(RandomCropPair(size))
73
+ elif s=='acolor':
74
+ trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
75
+ elif s=='': # if transform_str was ""
76
+ pass
77
+ else:
78
+ raise NotImplementedError('Unknown augmentation: '+s)
79
+
80
+ if totensor:
81
+ trfs.append( ToTensorBoth() )
82
+ if normalize:
83
+ trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
84
+
85
+ if len(trfs)==0:
86
+ return None
87
+ elif len(trfs)==1:
88
+ return trfs
89
+ else:
90
+ return ComposePair(trfs)
91
+
92
+
93
+
94
+
95
+
dust3r/croco/demo.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ from models.croco import CroCoNet
6
+ from PIL import Image
7
+ import torchvision.transforms
8
+ from torchvision.transforms import ToTensor, Normalize, Compose
9
+
10
+ def main():
11
+ device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu')
12
+
13
+ # load 224x224 images and transform them to tensor
14
+ imagenet_mean = [0.485, 0.456, 0.406]
15
+ imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True)
16
+ imagenet_std = [0.229, 0.224, 0.225]
17
+ imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True)
18
+ trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
19
+ image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
20
+ image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
21
+
22
+ # load model
23
+ ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
24
+ model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device)
25
+ model.eval()
26
+ msg = model.load_state_dict(ckpt['model'], strict=True)
27
+
28
+ # forward
29
+ with torch.inference_mode():
30
+ out, mask, target = model(image1, image2)
31
+
32
+ # the output is normalized, thus use the mean/std of the actual image to go back to RGB space
33
+ patchified = model.patchify(image1)
34
+ mean = patchified.mean(dim=-1, keepdim=True)
35
+ var = patchified.var(dim=-1, keepdim=True)
36
+ decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean)
37
+ # undo imagenet normalization, prepare masked image
38
+ decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
39
+ input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
40
+ ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
41
+ image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
42
+ masked_input_image = ((1 - image_masks) * input_image)
43
+
44
+ # make visualization
45
+ visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4
46
+ B, C, H, W = visualization.shape
47
+ visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W)
48
+ visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
49
+ fname = "demo_output.png"
50
+ visualization.save(fname)
51
+ print('Visualization save in '+fname)
52
+
53
+
54
+ if __name__=="__main__":
55
+ main()
dust3r/croco/interactive_demo.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Interactive demo of Cross-view Completion."
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
17
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import torch\n",
27
+ "import numpy as np\n",
28
+ "from models.croco import CroCoNet\n",
29
+ "from ipywidgets import interact, interactive, fixed, interact_manual\n",
30
+ "import ipywidgets as widgets\n",
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import quaternion\n",
33
+ "import models.masking"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "### Load CroCo model"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n",
50
+ "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n",
51
+ "msg = model.load_state_dict(ckpt['model'], strict=True)\n",
52
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
53
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
54
+ "model = model.eval()\n",
55
+ "model = model.to(device=device)\n",
56
+ "print(msg)\n",
57
+ "\n",
58
+ "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n",
59
+ " \"\"\"\n",
60
+ " Perform Cross-View completion using two input images, specified using Numpy arrays.\n",
61
+ " \"\"\"\n",
62
+ " # Replace the mask generator\n",
63
+ " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n",
64
+ "\n",
65
+ " # ImageNet-1k color normalization\n",
66
+ " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n",
67
+ " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n",
68
+ "\n",
69
+ " normalize_input_colors = True\n",
70
+ " is_output_normalized = True\n",
71
+ " with torch.no_grad():\n",
72
+ " # Cast data to torch\n",
73
+ " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
74
+ " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
75
+ "\n",
76
+ " if normalize_input_colors:\n",
77
+ " ref_image = (ref_image - imagenet_mean) / imagenet_std\n",
78
+ " target_image = (target_image - imagenet_mean) / imagenet_std\n",
79
+ "\n",
80
+ " out, mask, _ = model(target_image, ref_image)\n",
81
+ " # # get target\n",
82
+ " if not is_output_normalized:\n",
83
+ " predicted_image = model.unpatchify(out)\n",
84
+ " else:\n",
85
+ " # The output only contains higher order information,\n",
86
+ " # we retrieve mean and standard deviation from the actual target image\n",
87
+ " patchified = model.patchify(target_image)\n",
88
+ " mean = patchified.mean(dim=-1, keepdim=True)\n",
89
+ " var = patchified.var(dim=-1, keepdim=True)\n",
90
+ " pred_renorm = out * (var + 1.e-6)**.5 + mean\n",
91
+ " predicted_image = model.unpatchify(pred_renorm)\n",
92
+ "\n",
93
+ " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n",
94
+ " masked_target_image = (1 - image_masks) * target_image\n",
95
+ " \n",
96
+ " if not reconstruct_unmasked_patches:\n",
97
+ " # Replace unmasked patches by their actual values\n",
98
+ " predicted_image = predicted_image * image_masks + masked_target_image\n",
99
+ "\n",
100
+ " # Unapply color normalization\n",
101
+ " if normalize_input_colors:\n",
102
+ " predicted_image = predicted_image * imagenet_std + imagenet_mean\n",
103
+ " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n",
104
+ " \n",
105
+ " # Cast to Numpy\n",
106
+ " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
107
+ " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
108
+ " return masked_target_image, predicted_image"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {},
114
+ "source": [
115
+ "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "import os\n",
125
+ "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n",
126
+ "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n",
127
+ "import habitat_sim\n",
128
+ "\n",
129
+ "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n",
130
+ "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n",
131
+ "\n",
132
+ "sim_cfg = habitat_sim.SimulatorConfiguration()\n",
133
+ "if use_gpu: sim_cfg.gpu_device_id = 0\n",
134
+ "sim_cfg.scene_id = scene\n",
135
+ "sim_cfg.load_semantic_mesh = False\n",
136
+ "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n",
137
+ "rgb_sensor_spec.uuid = \"color\"\n",
138
+ "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n",
139
+ "rgb_sensor_spec.resolution = (224,224)\n",
140
+ "rgb_sensor_spec.hfov = 56.56\n",
141
+ "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n",
142
+ "rgb_sensor_spec.orientation = [0, 0, 0]\n",
143
+ "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n",
144
+ "\n",
145
+ "\n",
146
+ "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n",
147
+ "sim = habitat_sim.Simulator(cfg)\n",
148
+ "if navmesh is not None:\n",
149
+ " sim.pathfinder.load_nav_mesh(navmesh)\n",
150
+ "agent = sim.initialize_agent(agent_id=0)\n",
151
+ "\n",
152
+ "def sample_random_viewpoint():\n",
153
+ " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n",
154
+ " nav_point = sim.pathfinder.get_random_navigable_point()\n",
155
+ " # Sample a random viewpoint height\n",
156
+ " viewpoint_height = np.random.uniform(1.0, 1.6)\n",
157
+ " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n",
158
+ " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n",
159
+ " return viewpoint_position, viewpoint_orientation\n",
160
+ "\n",
161
+ "def render_viewpoint(position, orientation):\n",
162
+ " agent_state = habitat_sim.AgentState()\n",
163
+ " agent_state.position = position\n",
164
+ " agent_state.rotation = orientation\n",
165
+ " agent.set_state(agent_state)\n",
166
+ " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n",
167
+ " image = viewpoint_observations['color'][:,:,:3]\n",
168
+ " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n",
169
+ " return image"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Sample a random reference view"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "ref_position, ref_orientation = sample_random_viewpoint()\n",
186
+ "ref_image = render_viewpoint(ref_position, ref_orientation)\n",
187
+ "plt.clf()\n",
188
+ "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n",
189
+ "axes[0,0].imshow(ref_image)\n",
190
+ "for ax in axes.flatten():\n",
191
+ " ax.set_xticks([])\n",
192
+ " ax.set_yticks([])"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {},
198
+ "source": [
199
+ "### Interactive cross-view completion using CroCo"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "reconstruct_unmasked_patches = False\n",
209
+ "\n",
210
+ "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n",
211
+ " R = quaternion.as_rotation_matrix(ref_orientation)\n",
212
+ " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n",
213
+ " target_orientation = (ref_orientation\n",
214
+ " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n",
215
+ " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n",
216
+ " \n",
217
+ " ref_image = render_viewpoint(ref_position, ref_orientation)\n",
218
+ " target_image = render_viewpoint(target_position, target_orientation)\n",
219
+ "\n",
220
+ " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n",
221
+ "\n",
222
+ " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n",
223
+ " axes[0].imshow(ref_image)\n",
224
+ " axes[0].set_xlabel(\"Reference\")\n",
225
+ " axes[1].imshow(masked_target_image)\n",
226
+ " axes[1].set_xlabel(\"Masked target\")\n",
227
+ " axes[2].imshow(predicted_image)\n",
228
+ " axes[2].set_xlabel(\"Reconstruction\") \n",
229
+ " axes[3].imshow(target_image)\n",
230
+ " axes[3].set_xlabel(\"Target\")\n",
231
+ " for ax in axes.flatten():\n",
232
+ " ax.set_xticks([])\n",
233
+ " ax.set_yticks([])\n",
234
+ "\n",
235
+ "interact(show_demo,\n",
236
+ " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n",
237
+ " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
238
+ " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
239
+ " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
240
+ " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n",
241
+ " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));"
242
+ ]
243
+ }
244
+ ],
245
+ "metadata": {
246
+ "kernelspec": {
247
+ "display_name": "Python 3 (ipykernel)",
248
+ "language": "python",
249
+ "name": "python3"
250
+ },
251
+ "language_info": {
252
+ "codemirror_mode": {
253
+ "name": "ipython",
254
+ "version": 3
255
+ },
256
+ "file_extension": ".py",
257
+ "mimetype": "text/x-python",
258
+ "name": "python",
259
+ "nbconvert_exporter": "python",
260
+ "pygments_lexer": "ipython3",
261
+ "version": "3.7.13"
262
+ },
263
+ "vscode": {
264
+ "interpreter": {
265
+ "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67"
266
+ }
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 2
271
+ }