Spaces:
Running on Zero
Running on Zero
Initial SAB3R demo release
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +11 -0
- .gitignore +148 -0
- LICENSE +7 -0
- README.md +185 -6
- app.py +18 -0
- assets/network_architecture.png +3 -0
- assets/qualitative_2.jpg +3 -0
- assets/teaser.jpg +3 -0
- assets/teaser_v5.jpg +3 -0
- config/deepspeed.json +38 -0
- config/training_config.yaml +55 -0
- config/training_config_full.yaml +57 -0
- demo.py +118 -0
- docker/docker-compose-cpu.yml +16 -0
- docker/docker-compose-cuda.yml +23 -0
- docker/files/cpu.Dockerfile +39 -0
- docker/files/cuda.Dockerfile +29 -0
- docker/files/entrypoint.sh +8 -0
- docker/run.sh +68 -0
- dust3r/.gitignore +132 -0
- dust3r/LICENSE +7 -0
- dust3r/NOTICE +12 -0
- dust3r/README.md +390 -0
- dust3r/assets/demo.jpg +3 -0
- dust3r/assets/dust3r.jpg +0 -0
- dust3r/assets/dust3r_archi.jpg +0 -0
- dust3r/assets/matching.jpg +3 -0
- dust3r/assets/pipeline1.jpg +0 -0
- dust3r/croco/LICENSE +52 -0
- dust3r/croco/NOTICE +21 -0
- dust3r/croco/README.MD +124 -0
- dust3r/croco/assets/Chateau1.png +3 -0
- dust3r/croco/assets/Chateau2.png +3 -0
- dust3r/croco/assets/arch.jpg +0 -0
- dust3r/croco/croco-stereo-flow-demo.ipynb +191 -0
- dust3r/croco/datasets_croco/__init__.py +0 -0
- dust3r/croco/datasets_croco/crops/README.MD +104 -0
- dust3r/croco/datasets_croco/crops/extract_crops_from_images.py +159 -0
- dust3r/croco/datasets_croco/habitat_sim/README.MD +76 -0
- dust3r/croco/datasets_croco/habitat_sim/__init__.py +0 -0
- dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata.py +92 -0
- dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata_files.py +27 -0
- dust3r/croco/datasets_croco/habitat_sim/generate_multiview_images.py +177 -0
- dust3r/croco/datasets_croco/habitat_sim/multiview_habitat_sim_generator.py +390 -0
- dust3r/croco/datasets_croco/habitat_sim/pack_metadata_files.py +69 -0
- dust3r/croco/datasets_croco/habitat_sim/paths.py +129 -0
- dust3r/croco/datasets_croco/pairs_dataset.py +109 -0
- dust3r/croco/datasets_croco/transforms.py +95 -0
- dust3r/croco/demo.py +55 -0
- dust3r/croco/interactive_demo.ipynb +271 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/network_architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/qualitative_2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/teaser_v5.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
dust3r/assets/demo.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
dust3r/assets/matching.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
dust3r/croco/assets/Chateau1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
dust3r/croco/assets/Chateau2.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-311/.ninja_deps filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-311/curope.o filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-311/kernels.o filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*gl-outputs/
|
| 2 |
+
# Byte-compiled / optimized / DLL files
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*$py.class
|
| 6 |
+
launch_script/output/
|
| 7 |
+
preprocess_data/
|
| 8 |
+
images
|
| 9 |
+
launch_script/
|
| 10 |
+
paper/
|
| 11 |
+
.gradio/
|
| 12 |
+
semantic_extraction_ade20k*
|
| 13 |
+
semantic_extraction_voc*
|
| 14 |
+
semantic_seg_ade20k_clip*
|
| 15 |
+
zero_shot_semantic_seg_ade20k_inference*
|
| 16 |
+
# C extensions
|
| 17 |
+
*.so
|
| 18 |
+
|
| 19 |
+
# Distribution / packaging
|
| 20 |
+
.Python
|
| 21 |
+
build/
|
| 22 |
+
develop-eggs/
|
| 23 |
+
dist/
|
| 24 |
+
downloads/
|
| 25 |
+
eggs/
|
| 26 |
+
.eggs/
|
| 27 |
+
lib/
|
| 28 |
+
lib64/
|
| 29 |
+
parts/
|
| 30 |
+
sdist/
|
| 31 |
+
var/
|
| 32 |
+
wheels/
|
| 33 |
+
pip-wheel-metadata/
|
| 34 |
+
share/python-wheels/
|
| 35 |
+
*.egg-info/
|
| 36 |
+
.installed.cfg
|
| 37 |
+
*.egg
|
| 38 |
+
MANIFEST
|
| 39 |
+
|
| 40 |
+
# PyInstaller
|
| 41 |
+
# Usually these files are written by a python script from a template
|
| 42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 43 |
+
*.manifest
|
| 44 |
+
*.spec
|
| 45 |
+
|
| 46 |
+
# Installer logs
|
| 47 |
+
pip-log.txt
|
| 48 |
+
pip-delete-this-directory.txt
|
| 49 |
+
|
| 50 |
+
# Unit test / coverage reports
|
| 51 |
+
htmlcov/
|
| 52 |
+
.tox/
|
| 53 |
+
.nox/
|
| 54 |
+
.coverage
|
| 55 |
+
.coverage.*
|
| 56 |
+
.cache
|
| 57 |
+
nosetests.xml
|
| 58 |
+
coverage.xml
|
| 59 |
+
*.cover
|
| 60 |
+
*.py,cover
|
| 61 |
+
.hypothesis/
|
| 62 |
+
.pytest_cache/
|
| 63 |
+
|
| 64 |
+
# Translations
|
| 65 |
+
*.mo
|
| 66 |
+
*.pot
|
| 67 |
+
|
| 68 |
+
# Django stuff:
|
| 69 |
+
*.log
|
| 70 |
+
local_settings.py
|
| 71 |
+
db.sqlite3
|
| 72 |
+
db.sqlite3-journal
|
| 73 |
+
|
| 74 |
+
# Flask stuff:
|
| 75 |
+
instance/
|
| 76 |
+
.webassets-cache
|
| 77 |
+
|
| 78 |
+
# Scrapy stuff:
|
| 79 |
+
.scrapy
|
| 80 |
+
|
| 81 |
+
# Sphinx documentation
|
| 82 |
+
docs/_build/
|
| 83 |
+
|
| 84 |
+
# PyBuilder
|
| 85 |
+
target/
|
| 86 |
+
|
| 87 |
+
# Jupyter Notebook
|
| 88 |
+
.ipynb_checkpoints
|
| 89 |
+
|
| 90 |
+
# IPython
|
| 91 |
+
profile_default/
|
| 92 |
+
ipython_config.py
|
| 93 |
+
|
| 94 |
+
# pyenv
|
| 95 |
+
.python-version
|
| 96 |
+
|
| 97 |
+
# pipenv
|
| 98 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 99 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 100 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 101 |
+
# install all needed dependencies.
|
| 102 |
+
#Pipfile.lock
|
| 103 |
+
|
| 104 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 105 |
+
__pypackages__/
|
| 106 |
+
|
| 107 |
+
# Celery stuff
|
| 108 |
+
celerybeat-schedule
|
| 109 |
+
celerybeat.pid
|
| 110 |
+
|
| 111 |
+
# SageMath parsed files
|
| 112 |
+
*.sage.py
|
| 113 |
+
|
| 114 |
+
# Environments
|
| 115 |
+
.env
|
| 116 |
+
.venv
|
| 117 |
+
env/
|
| 118 |
+
venv/
|
| 119 |
+
ENV/
|
| 120 |
+
env.bak/
|
| 121 |
+
venv.bak/
|
| 122 |
+
|
| 123 |
+
# Spyder project settings
|
| 124 |
+
.spyderproject
|
| 125 |
+
.spyproject
|
| 126 |
+
|
| 127 |
+
# Rope project settings
|
| 128 |
+
.ropeproject
|
| 129 |
+
|
| 130 |
+
# mkdocs documentation
|
| 131 |
+
/site
|
| 132 |
+
|
| 133 |
+
# mypy
|
| 134 |
+
.mypy_cache/
|
| 135 |
+
.dmypy.json
|
| 136 |
+
dmypy.json
|
| 137 |
+
|
| 138 |
+
# Pyre type checker
|
| 139 |
+
.pyre/
|
| 140 |
+
|
| 141 |
+
# data
|
| 142 |
+
data/
|
| 143 |
+
# checkpoints
|
| 144 |
+
checkpoints/
|
| 145 |
+
# wandb
|
| 146 |
+
wandb/
|
| 147 |
+
# outputs
|
| 148 |
+
outputs/
|
LICENSE
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
| 2 |
+
|
| 3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
| 4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 5 |
+
|
| 6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
| 7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
README.md
CHANGED
|
@@ -1,13 +1,192 @@
|
|
| 1 |
---
|
| 2 |
title: SAB3R
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: SAB3R
|
| 3 |
+
emoji: 🌐
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: cc-by-nc-sa-4.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SAB3R: Semantic-Augmented Backbone in 3D Reconstruction
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
|
| 17 |
+
**3D-LLM/VLA Workshop @ CVPR 2025**
|
| 18 |
+
|
| 19 |
+
[**Xuweiyi Chen**](https://xuweiyichen.github.io/)<sup>*,1</sup> · [**Tian Xia**](https://tianx-ia.github.io/)<sup>*,2</sup> · [**Sihan Xu**](https://sihanxu.github.io/)<sup>2</sup> · [**Jed Jianing Yang**](https://jedyang.com/)<sup>2</sup> · [**Joyce Chai**](https://web.eecs.umich.edu/~chaijy/)<sup>2</sup> · [**Zezhou Cheng**](https://sites.google.com/site/zezhoucheng/)<sup>1</sup>
|
| 20 |
+
|
| 21 |
+
<sup>1</sup>University of Virginia · <sup>2</sup>University of Michigan
|
| 22 |
+
|
| 23 |
+
<sup>*</sup>Denotes Equal Contribution
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
[](https://www.arxiv.org/abs/2506.02112)
|
| 28 |
+
[](https://uva-computer-vision-lab.github.io/sab3r/)
|
| 29 |
+
[](https://huggingface.co/spaces/uva-cv-lab/SAB3R)
|
| 30 |
+
[](#)
|
| 31 |
+
[](https://github.com/uva-computer-vision-lab/sab3r)
|
| 32 |
+
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+

|
| 38 |
+
|
| 39 |
+
*Given an unposed input video, we show ground truth for: open-vocabulary semantic segmentation (per-pixel labels for the prompt "a black office chair"), 3D reconstruction (ground-truth point cloud), and the proposed **Map and Locate** task (open-vocabulary segmentation and point cloud). The Map and Locate task: (1) encompasses both 2D and 3D tasks, (2) bridges reconstruction and recognition, and (3) introduces practical questions in robotics and embodied AI.*
|
| 40 |
+
|
| 41 |
+
## Release Plan
|
| 42 |
+
|
| 43 |
+
- [x] Demo Release
|
| 44 |
+
- [x] Training and Inference Code Release
|
| 45 |
+
- [ ] Release Map and Locate Dataset
|
| 46 |
+
|
| 47 |
+
## Abstract
|
| 48 |
+
|
| 49 |
+
We introduce **Map and Locate**, a task that unifies open-vocabulary segmentation and 3D reconstruction from unposed videos. Our method, **SAB3R**, builds upon MASt3R and incorporates lightweight distillation from CLIP and DINOv2 to generate semantic point clouds in a single forward pass. SAB3R achieves superior performance compared to separate deployment of MASt3R and CLIP on the Map and Locate benchmark.
|
| 50 |
+
|
| 51 |
+
## Network Architecture
|
| 52 |
+
|
| 53 |
+

|
| 54 |
+
|
| 55 |
+
**SAB3R** distills dense features from CLIP and DINO into the MASt3R framework, enriching it with 2D semantic understanding. Each encoder-decoder pair operates on multi-view images, sharing weights and exchanging information to ensure consistent feature extraction across views. The model simultaneously generates depth, dense DINOv2, and dense CLIP features, which are then used for multi-view 3D reconstruction and semantic segmentation. This architecture enables SAB3R to seamlessly integrate 2D and 3D representations, achieving both geometric and semantic comprehension in a unified model.
|
| 56 |
+
|
| 57 |
+
## Installation
|
| 58 |
+
|
| 59 |
+
1. Clone the repository:
|
| 60 |
+
```bash
|
| 61 |
+
git clone --recursive https://github.com/uva-computer-vision-lab/sab3r
|
| 62 |
+
cd sab3r
|
| 63 |
+
# if you have already cloned:
|
| 64 |
+
# git submodule update --init --recursive
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
2. Create the environment:
|
| 68 |
+
```bash
|
| 69 |
+
conda create -n sab3r python=3.11 cmake=3.14.0
|
| 70 |
+
conda activate sab3r
|
| 71 |
+
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 72 |
+
pip install -r requirements.txt
|
| 73 |
+
|
| 74 |
+
# FeatUp (not on PyPI) — required for the CLIP/DINO semantic heads.
|
| 75 |
+
pip install git+https://github.com/mhamilton723/FeatUp
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
3. (Optional) Compile RoPE CUDA kernels for faster inference:
|
| 79 |
+
```bash
|
| 80 |
+
cd dust3r/croco/models/curope/
|
| 81 |
+
python setup.py build_ext --inplace
|
| 82 |
+
cd ../../../../
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
4. (Optional) Pre-download the CLIP BPE vocab (the demo will fetch it on first run):
|
| 86 |
+
```bash
|
| 87 |
+
mkdir -p ~/.cache/clip
|
| 88 |
+
cd ~/.cache/clip
|
| 89 |
+
wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Demo
|
| 93 |
+
|
| 94 |
+
The demo launches a Gradio UI for 3D reconstruction + open-vocabulary text queries.
|
| 95 |
+
|
| 96 |
+
**Checkpoint from HF Hub (default)**
|
| 97 |
+
```bash
|
| 98 |
+
python demo.py \
|
| 99 |
+
--model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric \
|
| 100 |
+
--local_network --share
|
| 101 |
+
```
|
| 102 |
+
This downloads `demo_ckpt/base/base.pt` from [`uva-cv-lab/SAB3R`](https://huggingface.co/uva-cv-lab/SAB3R) on first launch and caches it in `~/.cache/huggingface/`.
|
| 103 |
+
|
| 104 |
+
**Local checkpoint**
|
| 105 |
+
```bash
|
| 106 |
+
python demo.py \
|
| 107 |
+
--model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric \
|
| 108 |
+
--weights /path/to/your.pt \
|
| 109 |
+
--local_network --share
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
**Override the HF Hub repo / filename**
|
| 113 |
+
```bash
|
| 114 |
+
python demo.py \
|
| 115 |
+
--model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric \
|
| 116 |
+
--model_repo your-org/your-sab3r-ckpt \
|
| 117 |
+
--ckpt_filename model.pt
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
**Local dev with a checkpoint dropdown** — if you keep multiple checkpoints under a directory (one sub-directory per checkpoint, each holding `<name>.pt`), pass `--checkpoint_dir`:
|
| 121 |
+
```bash
|
| 122 |
+
python demo.py --checkpoint_dir /path/to/ckpt_root --local_network --share
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Hugging Face Spaces
|
| 126 |
+
|
| 127 |
+
This repo ships with an `app.py` entry point for Hugging Face Spaces. To deploy:
|
| 128 |
+
|
| 129 |
+
1. Create a Gradio Space at https://huggingface.co/new-space (SDK: `gradio`).
|
| 130 |
+
2. Upload the SAB3R checkpoint to a model repo (default expected: `uva-cv-lab/SAB3R`, filename `base.pt`). To use a different repo, set the `SAB3R_MODEL_REPO` / `SAB3R_CKPT_FILENAME` env vars on the Space, or edit `demo.py`'s defaults.
|
| 131 |
+
3. Push this repo (minus the `env/` conda env and the heavyweight submodule binaries) to the Space. The Space's `README.md` must begin with the YAML frontmatter below; add it to your Space copy of this README before pushing:
|
| 132 |
+
|
| 133 |
+
```yaml
|
| 134 |
+
---
|
| 135 |
+
title: SAB3R
|
| 136 |
+
emoji: 🌐
|
| 137 |
+
colorFrom: blue
|
| 138 |
+
colorTo: green
|
| 139 |
+
sdk: gradio
|
| 140 |
+
sdk_version: 4.44.1
|
| 141 |
+
app_file: app.py
|
| 142 |
+
pinned: false
|
| 143 |
+
license: cc-by-nc-sa-4.0
|
| 144 |
+
---
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
FeatUp needs to be installed from GitHub, which Spaces handles via `requirements.txt` — the one in this repo already leaves a comment pointing at the install line; uncomment it for the Space:
|
| 148 |
+
|
| 149 |
+
```txt
|
| 150 |
+
git+https://github.com/mhamilton723/FeatUp
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Training
|
| 154 |
+
|
| 155 |
+
Two canonical configs are provided under `config/`:
|
| 156 |
+
|
| 157 |
+
- `training_config.yaml` — minimal dev recipe (CLIP distillation on a Co3D subset).
|
| 158 |
+
- `training_config_full.yaml` — full paper recipe (CLIP + DINO distillation on Habitat + ScanNet++ + ARKitScenes + Co3D).
|
| 159 |
+
|
| 160 |
+
Both reference paths relative to the repo root (e.g. `./data`, `./checkpoints`, `./outputs`); override them via Hydra:
|
| 161 |
+
|
| 162 |
+
```bash
|
| 163 |
+
torchrun --nproc_per_node=8 train.py \
|
| 164 |
+
--config-name training_config_full \
|
| 165 |
+
dataset_url=/path/to/data \
|
| 166 |
+
output_url=/path/to/outputs
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
Set `WANDB_API_KEY` in your shell (do **not** commit it) if you want experiment tracking.
|
| 170 |
+
|
| 171 |
+
## Citation
|
| 172 |
+
|
| 173 |
+
```bibtex
|
| 174 |
+
@article{chen2025sab3rsemanticaugmentedbackbone3d,
|
| 175 |
+
title={SAB3R: Semantic-Augmented Backbone in 3D Reconstruction},
|
| 176 |
+
author={Xuweiyi Chen and Tian Xia and Sihan Xu and Jianing Yang and Joyce Chai and Zezhou Cheng},
|
| 177 |
+
year={2025},
|
| 178 |
+
eprint={2506.02112},
|
| 179 |
+
archivePrefix={arXiv},
|
| 180 |
+
primaryClass={cs.CV},
|
| 181 |
+
url={https://arxiv.org/abs/2506.02112},
|
| 182 |
+
}
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## Acknowledgments
|
| 186 |
+
|
| 187 |
+
This work builds upon [MASt3R](https://github.com/naver/mast3r), [DUSt3R](https://github.com/naver/dust3r) and [FeatUp](https://github.com/mhamilton723/FeatUp). We thank the original authors for their excellent work and open-source contributions.
|
| 188 |
+
|
| 189 |
+
## License
|
| 190 |
+
|
| 191 |
+
The code is distributed under the CC BY-NC-SA 4.0 License.
|
| 192 |
+
See [LICENSE](LICENSE) for more information.
|
app.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# --------------------------------------------------------
|
| 3 |
+
# Hugging Face Spaces entry point for the SAB3R demo.
|
| 4 |
+
#
|
| 5 |
+
# Spaces looks for `app.py` by default. This wrapper just sets
|
| 6 |
+
# Spaces-appropriate defaults and delegates to demo.py. For local
|
| 7 |
+
# dev, run demo.py directly (it exposes --share, --local_network, etc).
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from demo import main
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
main([
|
| 15 |
+
"--model_name", "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric",
|
| 16 |
+
"--server_name", "0.0.0.0",
|
| 17 |
+
"--server_port", os.environ.get("GRADIO_SERVER_PORT", "7860"),
|
| 18 |
+
])
|
assets/network_architecture.png
ADDED
|
Git LFS Details
|
assets/qualitative_2.jpg
ADDED
|
Git LFS Details
|
assets/teaser.jpg
ADDED
|
Git LFS Details
|
assets/teaser_v5.jpg
ADDED
|
Git LFS Details
|
config/deepspeed.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 3 |
+
"gradient_accumulation_steps": 8,
|
| 4 |
+
"steps_per_print": 1,
|
| 5 |
+
"optimizer": {
|
| 6 |
+
"type": "AdamW",
|
| 7 |
+
"params": {
|
| 8 |
+
"lr": 1e-4,
|
| 9 |
+
"betas": [0.9, 0.95],
|
| 10 |
+
"weight_decay": 0.05
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"scheduler": {
|
| 14 |
+
"type": "WarmupLR",
|
| 15 |
+
"params": {
|
| 16 |
+
"warmup_min_lr": 0.0,
|
| 17 |
+
"warmup_max_lr": 1e-4,
|
| 18 |
+
"warmup_num_steps": 4000
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"gradient_clipping": 1.0,
|
| 22 |
+
"zero_optimization": {
|
| 23 |
+
"stage": 2,
|
| 24 |
+
"offload_optimizer": {
|
| 25 |
+
"device": "cpu",
|
| 26 |
+
"pin_memory": true
|
| 27 |
+
},
|
| 28 |
+
"overlap_comm": true,
|
| 29 |
+
"contiguous_gradients": true,
|
| 30 |
+
"reduce_bucket_size": 5e7
|
| 31 |
+
},
|
| 32 |
+
"activation_checkpointing": {
|
| 33 |
+
"partition_activations": true,
|
| 34 |
+
"contiguous_memory_optimization": true,
|
| 35 |
+
"cpu_checkpointing": true,
|
| 36 |
+
"number_checkpoints": 2
|
| 37 |
+
}
|
| 38 |
+
}
|
config/training_config.yaml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 8
|
| 2 |
+
|
| 3 |
+
# Override these paths via hydra (e.g. `python train.py dataset_url=/path/to/data`)
|
| 4 |
+
# or by editing this file.
|
| 5 |
+
dataset_url: "./data"
|
| 6 |
+
code_url: "."
|
| 7 |
+
output_url: "./outputs"
|
| 8 |
+
|
| 9 |
+
model:
|
| 10 |
+
name: "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', clip_head_type='dpt', dino_head_type=None, depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True)"
|
| 11 |
+
pretrained: "checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
clip_checkpoint_weights_path: "checkpoints/maskclip_jbu_stack_cocostuff.ckpt"
|
| 15 |
+
dino_checkpoint_weights_path: "checkpoints/dinov2_jbu_stack_cocostuff.ckpt"
|
| 16 |
+
|
| 17 |
+
train_criterion: "ConfFeatLoss(Regr3D(L21, norm_mode='?avg_dis'), alpha=0.2, beta=0.4, gamma=0.4, need_clip=True, need_dino=False, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + 0.075*ConfMatchingLoss(MatchingLoss(InfoNCE(mode='proper', temperature=0.05), negatives_padding=0, blocksize=8192), alpha=10.0, confmode='mean')"
|
| 18 |
+
test_criterion: "FeatRegr3D_ScaleShiftInv(L21, norm_mode='?avg_dis', need_clip=True, need_dino=False, gt_scale=True, sky_loss_value=0, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + -1.*MatchingLoss(APLoss(nq='torch', fp=torch.float16), negatives_padding=12288)"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
dataset:
|
| 22 |
+
train: "1000 @ Co3d(split='train', ROOT='${dataset_url}/co3d_subset_processed', aug_crop='auto', aug_monocular=0.005, aug_rot90='diff', mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], n_corres=8192, nneg=0.5, transform=ColorJitter)"
|
| 23 |
+
test: "100 @ Co3d(split='test', ROOT='${dataset_url}/co3d_subset_processed', resolution=(512,384), n_corres=1024, seed=777)"
|
| 24 |
+
|
| 25 |
+
training:
|
| 26 |
+
epochs: 200
|
| 27 |
+
disable_cudnn_benchmark: true
|
| 28 |
+
print_freq: 1
|
| 29 |
+
|
| 30 |
+
saving:
|
| 31 |
+
save_freq: 10
|
| 32 |
+
keep_freq: 40
|
| 33 |
+
eval_freq: 10
|
| 34 |
+
|
| 35 |
+
output_dir: "${output_url}/sab3r"
|
| 36 |
+
|
| 37 |
+
wandb:
|
| 38 |
+
project_name: "sab3r"
|
| 39 |
+
entity: "uva-computer-vision-lab"
|
| 40 |
+
group: "sab3r"
|
| 41 |
+
|
| 42 |
+
num_workers: 8
|
| 43 |
+
dist_url: "env://"
|
| 44 |
+
distributed: True
|
| 45 |
+
|
| 46 |
+
world_size: -1
|
| 47 |
+
gpu: -1
|
| 48 |
+
rank: -1
|
| 49 |
+
dist_backend: "nccl"
|
| 50 |
+
|
| 51 |
+
resume: ""
|
| 52 |
+
start_epoch: 0
|
| 53 |
+
disable_cudnn_benchmark: True
|
| 54 |
+
amp: 1
|
| 55 |
+
deepspeed_config: "${code_url}/config/deepspeed.json"
|
config/training_config_full.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 8
|
| 2 |
+
|
| 3 |
+
# Full SAB3R training recipe (CLIP + DINO distillation on Habitat + ScanNet++ + ARKitScenes + Co3D).
|
| 4 |
+
# Override these paths via hydra (e.g. `python train.py --config-name training_config_full dataset_url=/path/to/data`)
|
| 5 |
+
# or by editing this file.
|
| 6 |
+
dataset_url: "./data"
|
| 7 |
+
code_url: "."
|
| 8 |
+
output_url: "./outputs"
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
name: "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', clip_head_type='dpt', dino_head_type='dpt', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True)"
|
| 12 |
+
pretrained: "checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
clip_checkpoint_weights_path: "checkpoints/maskclip_jbu_stack_cocostuff.ckpt"
|
| 16 |
+
dino_checkpoint_weights_path: "checkpoints/dinov2_jbu_stack_cocostuff.ckpt"
|
| 17 |
+
|
| 18 |
+
train_criterion: "ConfFeatLoss(Regr3D(L21, norm_mode='?avg_dis'), alpha=1, beta=20, gamma=4, need_clip=True, need_dino=True, need_mask=False, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + 0.75*ConfMatchingLoss(MatchingLoss(InfoNCE(mode='proper', temperature=0.05), negatives_padding=0, blocksize=8192), alpha=10.0, confmode='mean')"
|
| 19 |
+
test_criterion: "FeatRegr3D_ScaleShiftInv(L21, norm_mode='?avg_dis', need_clip=True, need_dino=True, gt_scale=True, sky_loss_value=0, clip_checkpoint_weights_path='${clip_checkpoint_weights_path}', dino_checkpoint_weights_path='${dino_checkpoint_weights_path}') + -1.*MatchingLoss(APLoss(nq='torch', fp=torch.float16), negatives_padding=12288)"
|
| 20 |
+
|
| 21 |
+
dataset:
|
| 22 |
+
train: "57_000 @ Habitat512(1_000_000, split='train', ROOT='${dataset_url}/habitat_processed', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5) + 45_600 @ ScanNetpp(split='train', ROOT='${dataset_url}/scannetpp_processed', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5) + 45_600 @ ARKitScenes(split='train', ROOT='${dataset_url}/arkitscenes_processed', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5) + 22_800 @ Co3d(split='train', ROOT='${dataset_url}/co3d_50_seqs_per_category_subset_processed', aug_crop='auto', aug_monocular=0.005, aug_rot90='diff', mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], n_corres=8192, nneg=0.5, transform=ColorJitter)"
|
| 23 |
+
test: "100 @ Co3d(split='test', ROOT='${dataset_url}/co3d_50_seqs_per_category_subset_processed', resolution=(512,384), n_corres=1024, seed=777)"
|
| 24 |
+
|
| 25 |
+
training:
|
| 26 |
+
epochs: 5
|
| 27 |
+
disable_cudnn_benchmark: true
|
| 28 |
+
print_freq: 1
|
| 29 |
+
|
| 30 |
+
saving:
|
| 31 |
+
save_freq: 1
|
| 32 |
+
keep_freq: 1
|
| 33 |
+
eval_freq: 1
|
| 34 |
+
save_steps: 5000
|
| 35 |
+
plot_steps: 200
|
| 36 |
+
|
| 37 |
+
output_dir: "${output_url}/sab3r_full"
|
| 38 |
+
|
| 39 |
+
wandb:
|
| 40 |
+
project_name: "sab3r"
|
| 41 |
+
entity: "uva-computer-vision-lab"
|
| 42 |
+
group: "sab3r_full"
|
| 43 |
+
|
| 44 |
+
num_workers: 8
|
| 45 |
+
dist_url: "env://"
|
| 46 |
+
distributed: True
|
| 47 |
+
|
| 48 |
+
world_size: -1
|
| 49 |
+
gpu: -1
|
| 50 |
+
rank: -1
|
| 51 |
+
dist_backend: "nccl"
|
| 52 |
+
|
| 53 |
+
resume: ""
|
| 54 |
+
start_epoch: 0
|
| 55 |
+
disable_cudnn_benchmark: True
|
| 56 |
+
amp: 0
|
| 57 |
+
deepspeed_config: "${code_url}/config/deepspeed.json"
|
demo.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 4 |
+
#
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# SAB3R gradio demo executable
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
import tempfile
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
|
| 16 |
+
from mast3r.demo import get_args_parser as sab3r_get_args_parser, main_demo
|
| 17 |
+
from mast3r.model import AsymmetricMASt3R # noqa: F401 (referenced via eval() below)
|
| 18 |
+
from mast3r.utils.misc import hash_md5
|
| 19 |
+
|
| 20 |
+
import mast3r.utils.path_to_dust3r # noqa: F401 (side-effect: puts vendored dust3r on sys.path)
|
| 21 |
+
from dust3r.demo import set_print_with_timestamp
|
| 22 |
+
|
| 23 |
+
import matplotlib.pyplot as pl
|
| 24 |
+
pl.ion()
|
| 25 |
+
|
| 26 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
| 27 |
+
|
| 28 |
+
inf = float("inf")
|
| 29 |
+
|
| 30 |
+
DEFAULT_MODEL_REPO = "uva-cv-lab/SAB3R"
|
| 31 |
+
DEFAULT_CKPT_FILENAME = "demo_ckpt/base/base.pt"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_args_parser():
|
| 35 |
+
parser = sab3r_get_args_parser()
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--model_repo",
|
| 38 |
+
default=os.environ.get("SAB3R_MODEL_REPO", DEFAULT_MODEL_REPO),
|
| 39 |
+
help="Hugging Face Hub repo id hosting the SAB3R checkpoint "
|
| 40 |
+
"(used only when --weights is not provided).",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--ckpt_filename",
|
| 44 |
+
default=os.environ.get("SAB3R_CKPT_FILENAME", DEFAULT_CKPT_FILENAME),
|
| 45 |
+
help="Checkpoint filename inside --model_repo.",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--checkpoint_dir",
|
| 49 |
+
default=os.environ.get("SAB3R_CHECKPOINT_DIR", None),
|
| 50 |
+
help="Optional local directory containing one sub-directory per "
|
| 51 |
+
"checkpoint (each sub-dir must hold `<name>.pt`). When provided, "
|
| 52 |
+
"the UI exposes a dropdown to switch between them. Useful for "
|
| 53 |
+
"local dev; leave unset for single-checkpoint HF Spaces deployments.",
|
| 54 |
+
)
|
| 55 |
+
return parser
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_weights(model, ckp_path, device):
|
| 59 |
+
ckp = torch.load(ckp_path, map_location='cpu')
|
| 60 |
+
if ckp_path.endswith('.pth'):
|
| 61 |
+
model.load_state_dict(ckp['model'], strict=False)
|
| 62 |
+
elif ckp_path.endswith('.pt'):
|
| 63 |
+
model.load_state_dict(ckp['module'])
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unknown checkpoint format: {ckp_path}")
|
| 66 |
+
model.to(device)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def build_model_config():
|
| 70 |
+
return (
|
| 71 |
+
"AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', "
|
| 72 |
+
"img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', "
|
| 73 |
+
"clip_head_type='dpt', dino_head_type='dpt', "
|
| 74 |
+
"depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), "
|
| 75 |
+
"enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, "
|
| 76 |
+
"dec_embed_dim=768, dec_depth=12, dec_num_heads=12, "
|
| 77 |
+
"two_confs=True, landscape_only=False)"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def resolve_weights_path(args: argparse.Namespace) -> str:
|
| 82 |
+
if args.weights:
|
| 83 |
+
return args.weights
|
| 84 |
+
print(f"[sab3r] Downloading checkpoint from HF Hub: {args.model_repo}/{args.ckpt_filename}")
|
| 85 |
+
return hf_hub_download(repo_id=args.model_repo, filename=args.ckpt_filename)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def main(argv=None):
|
| 89 |
+
parser = get_args_parser()
|
| 90 |
+
args = parser.parse_args(argv)
|
| 91 |
+
set_print_with_timestamp()
|
| 92 |
+
|
| 93 |
+
if args.server_name is not None:
|
| 94 |
+
server_name = args.server_name
|
| 95 |
+
else:
|
| 96 |
+
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
| 97 |
+
|
| 98 |
+
model = eval(build_model_config())
|
| 99 |
+
ckp_path = resolve_weights_path(args)
|
| 100 |
+
load_weights(model, ckp_path, args.device)
|
| 101 |
+
chkpt_tag = hash_md5(ckp_path)
|
| 102 |
+
|
| 103 |
+
def get_context(tmp_dir):
|
| 104 |
+
return (tempfile.TemporaryDirectory(suffix='_sab3r_gradio_demo') if tmp_dir is None
|
| 105 |
+
else nullcontext(tmp_dir))
|
| 106 |
+
|
| 107 |
+
with get_context(args.tmp_dir) as tmpdirname:
|
| 108 |
+
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
| 109 |
+
os.makedirs(cache_path, exist_ok=True)
|
| 110 |
+
main_demo(
|
| 111 |
+
cache_path, model, args.device, args.image_size, server_name, args.server_port,
|
| 112 |
+
silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache,
|
| 113 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
main()
|
docker/docker-compose-cpu.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
services:
|
| 3 |
+
mast3r-demo:
|
| 4 |
+
build:
|
| 5 |
+
context: ./files
|
| 6 |
+
dockerfile: cpu.Dockerfile
|
| 7 |
+
ports:
|
| 8 |
+
- "7860:7860"
|
| 9 |
+
volumes:
|
| 10 |
+
- ./files/checkpoints:/mast3r/checkpoints
|
| 11 |
+
environment:
|
| 12 |
+
- DEVICE=cpu
|
| 13 |
+
- MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth}
|
| 14 |
+
cap_add:
|
| 15 |
+
- IPC_LOCK
|
| 16 |
+
- SYS_RESOURCE
|
docker/docker-compose-cuda.yml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
services:
|
| 3 |
+
mast3r-demo:
|
| 4 |
+
build:
|
| 5 |
+
context: ./files
|
| 6 |
+
dockerfile: cuda.Dockerfile
|
| 7 |
+
ports:
|
| 8 |
+
- "7860:7860"
|
| 9 |
+
environment:
|
| 10 |
+
- DEVICE=cuda
|
| 11 |
+
- MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth}
|
| 12 |
+
volumes:
|
| 13 |
+
- ./files/checkpoints:/mast3r/checkpoints
|
| 14 |
+
cap_add:
|
| 15 |
+
- IPC_LOCK
|
| 16 |
+
- SYS_RESOURCE
|
| 17 |
+
deploy:
|
| 18 |
+
resources:
|
| 19 |
+
reservations:
|
| 20 |
+
devices:
|
| 21 |
+
- driver: nvidia
|
| 22 |
+
count: 1
|
| 23 |
+
capabilities: [gpu]
|
docker/files/cpu.Dockerfile
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
LABEL description="Docker container for MASt3R with dependencies installed. CPU VERSION"
|
| 4 |
+
|
| 5 |
+
ENV DEVICE="cpu"
|
| 6 |
+
ENV MODEL="MASt3R_ViTLarge_BaseDecoder_512_dpt.pth"
|
| 7 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
| 8 |
+
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
git \
|
| 11 |
+
libgl1-mesa-glx \
|
| 12 |
+
libegl1-mesa \
|
| 13 |
+
libxrandr2 \
|
| 14 |
+
libxrandr2 \
|
| 15 |
+
libxss1 \
|
| 16 |
+
libxcursor1 \
|
| 17 |
+
libxcomposite1 \
|
| 18 |
+
libasound2 \
|
| 19 |
+
libxi6 \
|
| 20 |
+
libxtst6 \
|
| 21 |
+
libglib2.0-0 \
|
| 22 |
+
&& apt-get clean \
|
| 23 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 24 |
+
|
| 25 |
+
RUN git clone --recursive https://github.com/naver/mast3r /mast3r
|
| 26 |
+
WORKDIR /mast3r/dust3r
|
| 27 |
+
|
| 28 |
+
RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
| 29 |
+
RUN pip install -r requirements.txt
|
| 30 |
+
RUN pip install -r requirements_optional.txt
|
| 31 |
+
RUN pip install opencv-python==4.8.0.74
|
| 32 |
+
|
| 33 |
+
WORKDIR /mast3r
|
| 34 |
+
RUN pip install -r requirements.txt
|
| 35 |
+
|
| 36 |
+
COPY entrypoint.sh /entrypoint.sh
|
| 37 |
+
RUN chmod +x /entrypoint.sh
|
| 38 |
+
|
| 39 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
docker/files/cuda.Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvcr.io/nvidia/pytorch:24.01-py3
|
| 2 |
+
|
| 3 |
+
LABEL description="Docker container for MASt3R with dependencies installed. CUDA VERSION"
|
| 4 |
+
ENV DEVICE="cuda"
|
| 5 |
+
ENV MODEL="MASt3R_ViTLarge_BaseDecoder_512_dpt.pth"
|
| 6 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
| 7 |
+
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
git=1:2.34.1-1ubuntu1.10 \
|
| 10 |
+
libglib2.0-0=2.72.4-0ubuntu2.2 \
|
| 11 |
+
&& apt-get clean \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
RUN git clone --recursive https://github.com/naver/mast3r /mast3r
|
| 15 |
+
WORKDIR /mast3r/dust3r
|
| 16 |
+
RUN pip install -r requirements.txt
|
| 17 |
+
RUN pip install -r requirements_optional.txt
|
| 18 |
+
RUN pip install opencv-python==4.8.0.74
|
| 19 |
+
|
| 20 |
+
WORKDIR /mast3r/dust3r/croco/models/curope/
|
| 21 |
+
RUN python setup.py build_ext --inplace
|
| 22 |
+
|
| 23 |
+
WORKDIR /mast3r
|
| 24 |
+
RUN pip install -r requirements.txt
|
| 25 |
+
|
| 26 |
+
COPY entrypoint.sh /entrypoint.sh
|
| 27 |
+
RUN chmod +x /entrypoint.sh
|
| 28 |
+
|
| 29 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
docker/files/entrypoint.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -eux
|
| 4 |
+
|
| 5 |
+
DEVICE=${DEVICE:-cuda}
|
| 6 |
+
MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_dpt.pth}
|
| 7 |
+
|
| 8 |
+
exec python3 demo.py --weights "checkpoints/$MODEL" --device "$DEVICE" --local_network "$@"
|
docker/run.sh
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -eux
|
| 4 |
+
|
| 5 |
+
# Default model name
|
| 6 |
+
model_name="MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
|
| 7 |
+
|
| 8 |
+
check_docker() {
|
| 9 |
+
if ! command -v docker &>/dev/null; then
|
| 10 |
+
echo "Docker could not be found. Please install Docker and try again."
|
| 11 |
+
exit 1
|
| 12 |
+
fi
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
download_model_checkpoint() {
|
| 16 |
+
if [ -f "./files/checkpoints/${model_name}" ]; then
|
| 17 |
+
echo "Model checkpoint ${model_name} already exists. Skipping download."
|
| 18 |
+
return
|
| 19 |
+
fi
|
| 20 |
+
echo "Downloading model checkpoint ${model_name}..."
|
| 21 |
+
wget "https://download.europe.naverlabs.com/ComputerVision/MASt3R/${model_name}" -P ./files/checkpoints
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
set_dcomp() {
|
| 25 |
+
if command -v docker-compose &>/dev/null; then
|
| 26 |
+
dcomp="docker-compose"
|
| 27 |
+
elif command -v docker &>/dev/null && docker compose version &>/dev/null; then
|
| 28 |
+
dcomp="docker compose"
|
| 29 |
+
else
|
| 30 |
+
echo "Docker Compose could not be found. Please install Docker Compose and try again."
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
run_docker() {
|
| 36 |
+
export MODEL=${model_name}
|
| 37 |
+
if [ "$with_cuda" -eq 1 ]; then
|
| 38 |
+
$dcomp -f docker-compose-cuda.yml up --build
|
| 39 |
+
else
|
| 40 |
+
$dcomp -f docker-compose-cpu.yml up --build
|
| 41 |
+
fi
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
with_cuda=0
|
| 45 |
+
for arg in "$@"; do
|
| 46 |
+
case $arg in
|
| 47 |
+
--with-cuda)
|
| 48 |
+
with_cuda=1
|
| 49 |
+
;;
|
| 50 |
+
--model_name=*)
|
| 51 |
+
model_name="${arg#*=}.pth"
|
| 52 |
+
;;
|
| 53 |
+
*)
|
| 54 |
+
echo "Unknown parameter passed: $arg"
|
| 55 |
+
exit 1
|
| 56 |
+
;;
|
| 57 |
+
esac
|
| 58 |
+
done
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
main() {
|
| 62 |
+
check_docker
|
| 63 |
+
download_model_checkpoint
|
| 64 |
+
set_dcomp
|
| 65 |
+
run_docker
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
main
|
dust3r/.gitignore
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
checkpoints/
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
pip-wheel-metadata/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
target/
|
| 79 |
+
|
| 80 |
+
# Jupyter Notebook
|
| 81 |
+
.ipynb_checkpoints
|
| 82 |
+
|
| 83 |
+
# IPython
|
| 84 |
+
profile_default/
|
| 85 |
+
ipython_config.py
|
| 86 |
+
|
| 87 |
+
# pyenv
|
| 88 |
+
.python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 98 |
+
__pypackages__/
|
| 99 |
+
|
| 100 |
+
# Celery stuff
|
| 101 |
+
celerybeat-schedule
|
| 102 |
+
celerybeat.pid
|
| 103 |
+
|
| 104 |
+
# SageMath parsed files
|
| 105 |
+
*.sage.py
|
| 106 |
+
|
| 107 |
+
# Environments
|
| 108 |
+
.env
|
| 109 |
+
.venv
|
| 110 |
+
env/
|
| 111 |
+
venv/
|
| 112 |
+
ENV/
|
| 113 |
+
env.bak/
|
| 114 |
+
venv.bak/
|
| 115 |
+
|
| 116 |
+
# Spyder project settings
|
| 117 |
+
.spyderproject
|
| 118 |
+
.spyproject
|
| 119 |
+
|
| 120 |
+
# Rope project settings
|
| 121 |
+
.ropeproject
|
| 122 |
+
|
| 123 |
+
# mkdocs documentation
|
| 124 |
+
/site
|
| 125 |
+
|
| 126 |
+
# mypy
|
| 127 |
+
.mypy_cache/
|
| 128 |
+
.dmypy.json
|
| 129 |
+
dmypy.json
|
| 130 |
+
|
| 131 |
+
# Pyre type checker
|
| 132 |
+
.pyre/
|
dust3r/LICENSE
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
| 2 |
+
|
| 3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
| 4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 5 |
+
|
| 6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
| 7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
dust3r/NOTICE
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DUSt3R
|
| 2 |
+
Copyright 2024-present NAVER Corp.
|
| 3 |
+
|
| 4 |
+
This project contains subcomponents with separate copyright notices and license terms.
|
| 5 |
+
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
|
| 6 |
+
|
| 7 |
+
====
|
| 8 |
+
|
| 9 |
+
naver/croco
|
| 10 |
+
https://github.com/naver/croco/
|
| 11 |
+
|
| 12 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0
|
dust3r/README.md
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
Official implementation of `DUSt3R: Geometric 3D Vision Made Easy`
|
| 4 |
+
[[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]
|
| 5 |
+
|
| 6 |
+
> **Make sure to also check [MASt3R](https://github.com/naver/mast3r): Our new model with a local feature head, metric pointmaps, and a more scalable global alignment!**
|
| 7 |
+
|
| 8 |
+

|
| 9 |
+
|
| 10 |
+

|
| 11 |
+
|
| 12 |
+
```bibtex
|
| 13 |
+
@inproceedings{dust3r_cvpr24,
|
| 14 |
+
title={DUSt3R: Geometric 3D Vision Made Easy},
|
| 15 |
+
author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
|
| 16 |
+
booktitle = {CVPR},
|
| 17 |
+
year = {2024}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
@misc{dust3r_arxiv23,
|
| 21 |
+
title={DUSt3R: Geometric 3D Vision Made Easy},
|
| 22 |
+
author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
|
| 23 |
+
year={2023},
|
| 24 |
+
eprint={2312.14132},
|
| 25 |
+
archivePrefix={arXiv},
|
| 26 |
+
primaryClass={cs.CV}
|
| 27 |
+
}
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Table of Contents
|
| 31 |
+
|
| 32 |
+
- [Table of Contents](#table-of-contents)
|
| 33 |
+
- [License](#license)
|
| 34 |
+
- [Get Started](#get-started)
|
| 35 |
+
- [Installation](#installation)
|
| 36 |
+
- [Checkpoints](#checkpoints)
|
| 37 |
+
- [Interactive demo](#interactive-demo)
|
| 38 |
+
- [Interactive demo with docker](#interactive-demo-with-docker)
|
| 39 |
+
- [Usage](#usage)
|
| 40 |
+
- [Training](#training)
|
| 41 |
+
- [Datasets](#datasets)
|
| 42 |
+
- [Demo](#demo)
|
| 43 |
+
- [Our Hyperparameters](#our-hyperparameters)
|
| 44 |
+
|
| 45 |
+
## License
|
| 46 |
+
|
| 47 |
+
The code is distributed under the CC BY-NC-SA 4.0 License.
|
| 48 |
+
See [LICENSE](LICENSE) for more information.
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 52 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Get Started
|
| 56 |
+
|
| 57 |
+
### Installation
|
| 58 |
+
|
| 59 |
+
1. Clone DUSt3R.
|
| 60 |
+
```bash
|
| 61 |
+
git clone --recursive https://github.com/naver/dust3r
|
| 62 |
+
cd dust3r
|
| 63 |
+
# if you have already cloned dust3r:
|
| 64 |
+
# git submodule update --init --recursive
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
2. Create the environment, here we show an example using conda.
|
| 68 |
+
```bash
|
| 69 |
+
conda create -n dust3r python=3.11 cmake=3.14.0
|
| 70 |
+
conda activate dust3r
|
| 71 |
+
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
|
| 72 |
+
pip install -r requirements.txt
|
| 73 |
+
# Optional: you can also install additional packages to:
|
| 74 |
+
# - add support for HEIC images
|
| 75 |
+
# - add pyrender, used to render depthmap in some datasets preprocessing
|
| 76 |
+
# - add required packages for visloc.py
|
| 77 |
+
pip install -r requirements_optional.txt
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
3. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
|
| 81 |
+
```bash
|
| 82 |
+
# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
|
| 83 |
+
cd croco/models/curope/
|
| 84 |
+
python setup.py build_ext --inplace
|
| 85 |
+
cd ../../../
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### Checkpoints
|
| 89 |
+
|
| 90 |
+
You can obtain the checkpoints by two ways:
|
| 91 |
+
|
| 92 |
+
1) You can use our huggingface_hub integration: the models will be downloaded automatically.
|
| 93 |
+
|
| 94 |
+
2) Otherwise, We provide several pre-trained models:
|
| 95 |
+
|
| 96 |
+
| Modelname | Training resolutions | Head | Encoder | Decoder |
|
| 97 |
+
|-------------|----------------------|------|---------|---------|
|
| 98 |
+
| [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |
|
| 99 |
+
| [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |
|
| 100 |
+
| [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |
|
| 101 |
+
|
| 102 |
+
You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)
|
| 103 |
+
|
| 104 |
+
To download a specific model, for example `DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`:
|
| 105 |
+
```bash
|
| 106 |
+
mkdir -p checkpoints/
|
| 107 |
+
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
For the checkpoints, make sure to agree to the license of all the public training datasets and base checkpoints we used, in addition to CC-BY-NC-SA 4.0. Again, see [section: Our Hyperparameters](#our-hyperparameters) for details.
|
| 111 |
+
|
| 112 |
+
### Interactive demo
|
| 113 |
+
|
| 114 |
+
In this demo, you should be able run DUSt3R on your machine to reconstruct a scene.
|
| 115 |
+
First select images that depicts the same scene.
|
| 116 |
+
|
| 117 |
+
You can adjust the global alignment schedule and its number of iterations.
|
| 118 |
+
|
| 119 |
+
> [!NOTE]
|
| 120 |
+
> If you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)
|
| 121 |
+
|
| 122 |
+
Hit "Run" and wait.
|
| 123 |
+
When the global alignment ends, the reconstruction appears.
|
| 124 |
+
Use the slider "min_conf_thr" to show or remove low confidence areas.
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
python3 demo.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt
|
| 128 |
+
|
| 129 |
+
# Use --weights to load a checkpoint from a local file, eg --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
| 130 |
+
# Use --image_size to select the correct resolution for the selected checkpoint. 512 (default) or 224
|
| 131 |
+
# Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
|
| 132 |
+
# Use --server_port to change the port, by default it will search for an available port starting at 7860
|
| 133 |
+
# Use --device to use a different device, by default it's "cuda"
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Interactive demo with docker
|
| 137 |
+
|
| 138 |
+
To run DUSt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
|
| 139 |
+
|
| 140 |
+
1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
|
| 141 |
+
|
| 142 |
+
2. **Install NVIDIA Docker Toolkit**: For GPU support, install the NVIDIA Docker toolkit from the [Nvidia website](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
| 143 |
+
|
| 144 |
+
3. **Build the Docker image and run it**: `cd` into the `./docker` directory and run the following commands:
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
cd docker
|
| 148 |
+
bash run.sh --with-cuda --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
Or if you want to run the demo without CUDA support, run the following command:
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
cd docker
|
| 155 |
+
bash run.sh --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
By default, `demo.py` is lanched with the option `--local_network`.
|
| 159 |
+
Visit `http://localhost:7860/` to access the web UI (or replace `localhost` with the machine's name to access it from the network).
|
| 160 |
+
|
| 161 |
+
`run.sh` will launch docker-compose using either the [docker-compose-cuda.yml](docker/docker-compose-cuda.yml) or [docker-compose-cpu.ym](docker/docker-compose-cpu.yml) config file, then it starts the demo using [entrypoint.sh](docker/files/entrypoint.sh).
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+

|
| 165 |
+
|
| 166 |
+
## Usage
|
| 167 |
+
|
| 168 |
+
```python
|
| 169 |
+
from dust3r.inference import inference
|
| 170 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
| 171 |
+
from dust3r.utils.image import load_images
|
| 172 |
+
from dust3r.image_pairs import make_pairs
|
| 173 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
| 174 |
+
|
| 175 |
+
if __name__ == '__main__':
|
| 176 |
+
device = 'cuda'
|
| 177 |
+
batch_size = 1
|
| 178 |
+
schedule = 'cosine'
|
| 179 |
+
lr = 0.01
|
| 180 |
+
niter = 300
|
| 181 |
+
|
| 182 |
+
model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
|
| 183 |
+
# you can put the path to a local checkpoint in model_name if needed
|
| 184 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)
|
| 185 |
+
# load_images can take a list of images or a directory
|
| 186 |
+
images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
|
| 187 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
| 188 |
+
output = inference(pairs, model, device, batch_size=batch_size)
|
| 189 |
+
|
| 190 |
+
# at this stage, you have the raw dust3r predictions
|
| 191 |
+
view1, pred1 = output['view1'], output['pred1']
|
| 192 |
+
view2, pred2 = output['view2'], output['pred2']
|
| 193 |
+
# here, view1, pred1, view2, pred2 are dicts of lists of len(2)
|
| 194 |
+
# -> because we symmetrize we have (im1, im2) and (im2, im1) pairs
|
| 195 |
+
# in each view you have:
|
| 196 |
+
# an integer image identifier: view1['idx'] and view2['idx']
|
| 197 |
+
# the img: view1['img'] and view2['img']
|
| 198 |
+
# the image shape: view1['true_shape'] and view2['true_shape']
|
| 199 |
+
# an instance string output by the dataloader: view1['instance'] and view2['instance']
|
| 200 |
+
# pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']
|
| 201 |
+
# pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']
|
| 202 |
+
# pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']
|
| 203 |
+
|
| 204 |
+
# next we'll use the global_aligner to align the predictions
|
| 205 |
+
# depending on your task, you may be fine with the raw output and not need it
|
| 206 |
+
# with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
|
| 207 |
+
# if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
|
| 208 |
+
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
| 209 |
+
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
|
| 210 |
+
|
| 211 |
+
# retrieve useful values from scene:
|
| 212 |
+
imgs = scene.imgs
|
| 213 |
+
focals = scene.get_focals()
|
| 214 |
+
poses = scene.get_im_poses()
|
| 215 |
+
pts3d = scene.get_pts3d()
|
| 216 |
+
confidence_masks = scene.get_masks()
|
| 217 |
+
|
| 218 |
+
# visualize reconstruction
|
| 219 |
+
scene.show()
|
| 220 |
+
|
| 221 |
+
# find 2D-2D matches between the two images
|
| 222 |
+
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
|
| 223 |
+
pts2d_list, pts3d_list = [], []
|
| 224 |
+
for i in range(2):
|
| 225 |
+
conf_i = confidence_masks[i].cpu().numpy()
|
| 226 |
+
pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
|
| 227 |
+
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
|
| 228 |
+
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
|
| 229 |
+
print(f'found {num_matches} matches')
|
| 230 |
+
matches_im1 = pts2d_list[1][reciprocal_in_P2]
|
| 231 |
+
matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
| 232 |
+
|
| 233 |
+
# visualize a few matches
|
| 234 |
+
import numpy as np
|
| 235 |
+
from matplotlib import pyplot as pl
|
| 236 |
+
n_viz = 10
|
| 237 |
+
match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)
|
| 238 |
+
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
|
| 239 |
+
|
| 240 |
+
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
|
| 241 |
+
img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
| 242 |
+
img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
| 243 |
+
img = np.concatenate((img0, img1), axis=1)
|
| 244 |
+
pl.figure()
|
| 245 |
+
pl.imshow(img)
|
| 246 |
+
cmap = pl.get_cmap('jet')
|
| 247 |
+
for i in range(n_viz):
|
| 248 |
+
(x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
|
| 249 |
+
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
|
| 250 |
+
pl.show(block=True)
|
| 251 |
+
|
| 252 |
+
```
|
| 253 |
+

|
| 254 |
+
|
| 255 |
+
## Training
|
| 256 |
+
|
| 257 |
+
In this section, we present a short demonstration to get started with training DUSt3R.
|
| 258 |
+
|
| 259 |
+
### Datasets
|
| 260 |
+
At this moment, we have added the following training datasets:
|
| 261 |
+
- [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE)
|
| 262 |
+
- [ARKitScenes](https://github.com/apple/ARKitScenes) - [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license)
|
| 263 |
+
- [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) - [non-commercial research and educational purposes](https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf)
|
| 264 |
+
- [BlendedMVS](https://github.com/YoYo000/BlendedMVS) - [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/)
|
| 265 |
+
- [WayMo Open dataset](https://github.com/waymo-research/waymo-open-dataset) - [Non-Commercial Use](https://waymo.com/open/terms/)
|
| 266 |
+
- [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md)
|
| 267 |
+
- [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/)
|
| 268 |
+
- [StaticThings3D](https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d)
|
| 269 |
+
- [WildRGB-D](https://github.com/wildrgbd/wildrgbd/)
|
| 270 |
+
|
| 271 |
+
For each dataset, we provide a preprocessing script in the `datasets_preprocess` directory and an archive containing the list of pairs when needed.
|
| 272 |
+
You have to download the datasets yourself from their official sources, agree to their license, download our list of pairs, and run the preprocessing script.
|
| 273 |
+
|
| 274 |
+
Links:
|
| 275 |
+
|
| 276 |
+
[ARKitScenes pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/arkitscenes_pairs.zip)
|
| 277 |
+
[ScanNet++ pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_pairs.zip)
|
| 278 |
+
[BlendedMVS pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/blendedmvs_pairs.npy)
|
| 279 |
+
[WayMo Open dataset pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz)
|
| 280 |
+
[Habitat metadata](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz)
|
| 281 |
+
[MegaDepth pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/megadepth_pairs.npz)
|
| 282 |
+
[StaticThings3D pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/staticthings_pairs.npy)
|
| 283 |
+
|
| 284 |
+
> [!NOTE]
|
| 285 |
+
> They are not strictly equivalent to what was used to train DUSt3R, but they should be close enough.
|
| 286 |
+
|
| 287 |
+
### Demo
|
| 288 |
+
For this training demo, we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.
|
| 289 |
+
The demo model will be trained for a few epochs on a very small dataset.
|
| 290 |
+
It will not be very good.
|
| 291 |
+
|
| 292 |
+
```bash
|
| 293 |
+
# download and prepare the co3d subset
|
| 294 |
+
mkdir -p data/co3d_subset
|
| 295 |
+
cd data/co3d_subset
|
| 296 |
+
git clone https://github.com/facebookresearch/co3d
|
| 297 |
+
cd co3d
|
| 298 |
+
python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset
|
| 299 |
+
rm ../*.zip
|
| 300 |
+
cd ../../..
|
| 301 |
+
|
| 302 |
+
python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed --single_sequence_subset
|
| 303 |
+
|
| 304 |
+
# download the pretrained croco v2 checkpoint
|
| 305 |
+
mkdir -p checkpoints/
|
| 306 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/
|
| 307 |
+
|
| 308 |
+
# the training of dust3r is done in 3 steps.
|
| 309 |
+
# for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters"
|
| 310 |
+
# step 1 - train dust3r for 224 resolution
|
| 311 |
+
torchrun --nproc_per_node=4 train.py \
|
| 312 |
+
--train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \
|
| 313 |
+
--test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \
|
| 314 |
+
--model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 315 |
+
--train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 316 |
+
--test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 317 |
+
--pretrained "checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
|
| 318 |
+
--lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \
|
| 319 |
+
--save_freq 1 --keep_freq 5 --eval_freq 1 \
|
| 320 |
+
--output_dir "checkpoints/dust3r_demo_224"
|
| 321 |
+
|
| 322 |
+
# step 2 - train dust3r for 512 resolution
|
| 323 |
+
torchrun --nproc_per_node=4 train.py \
|
| 324 |
+
--train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
|
| 325 |
+
--test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
|
| 326 |
+
--model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 327 |
+
--train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 328 |
+
--test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 329 |
+
--pretrained "checkpoints/dust3r_demo_224/checkpoint-best.pth" \
|
| 330 |
+
--lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \
|
| 331 |
+
--save_freq 1 --keep_freq 5 --eval_freq 1 \
|
| 332 |
+
--output_dir "checkpoints/dust3r_demo_512"
|
| 333 |
+
|
| 334 |
+
# step 3 - train dust3r for 512 resolution with dpt
|
| 335 |
+
torchrun --nproc_per_node=4 train.py \
|
| 336 |
+
--train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
|
| 337 |
+
--test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
|
| 338 |
+
--model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 339 |
+
--train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 340 |
+
--test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 341 |
+
--pretrained "checkpoints/dust3r_demo_512/checkpoint-best.pth" \
|
| 342 |
+
--lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \
|
| 343 |
+
--save_freq 1 --keep_freq 5 --eval_freq 1 --disable_cudnn_benchmark \
|
| 344 |
+
--output_dir "checkpoints/dust3r_demo_512dpt"
|
| 345 |
+
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
### Our Hyperparameters
|
| 349 |
+
|
| 350 |
+
Here are the commands we used for training the models:
|
| 351 |
+
|
| 352 |
+
```bash
|
| 353 |
+
# NOTE: ROOT path omitted for datasets
|
| 354 |
+
# 224 linear
|
| 355 |
+
torchrun --nproc_per_node 8 train.py \
|
| 356 |
+
--train_dataset=" + 100_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepth(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=224, transform=ColorJitter) " \
|
| 357 |
+
--test_dataset=" Habitat(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepth(split='val', resolution=224, seed=777) + 1_000 @ Co3d(split='test', mask_bg='rand', resolution=224, seed=777) " \
|
| 358 |
+
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 359 |
+
--test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 360 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 361 |
+
--pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
|
| 362 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \
|
| 363 |
+
--save_freq=5 --keep_freq=10 --eval_freq=1 \
|
| 364 |
+
--output_dir="checkpoints/dust3r_224"
|
| 365 |
+
|
| 366 |
+
# 512 linear
|
| 367 |
+
torchrun --nproc_per_node 8 train.py \
|
| 368 |
+
--train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
|
| 369 |
+
--test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
|
| 370 |
+
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 371 |
+
--test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 372 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 373 |
+
--pretrained="checkpoints/dust3r_224/checkpoint-best.pth" \
|
| 374 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=100 --batch_size=4 --accum_iter=2 \
|
| 375 |
+
--save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \
|
| 376 |
+
--output_dir="checkpoints/dust3r_512"
|
| 377 |
+
|
| 378 |
+
# 512 dpt
|
| 379 |
+
torchrun --nproc_per_node 8 train.py \
|
| 380 |
+
--train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
|
| 381 |
+
--test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
|
| 382 |
+
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 383 |
+
--test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 384 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 385 |
+
--pretrained="checkpoints/dust3r_512/checkpoint-best.pth" \
|
| 386 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=4 --accum_iter=2 \
|
| 387 |
+
--save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 --disable_cudnn_benchmark \
|
| 388 |
+
--output_dir="checkpoints/dust3r_512dpt"
|
| 389 |
+
|
| 390 |
+
```
|
dust3r/assets/demo.jpg
ADDED
|
Git LFS Details
|
dust3r/assets/dust3r.jpg
ADDED
|
dust3r/assets/dust3r_archi.jpg
ADDED
|
dust3r/assets/matching.jpg
ADDED
|
Git LFS Details
|
dust3r/assets/pipeline1.jpg
ADDED
|
dust3r/croco/LICENSE
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
| 2 |
+
|
| 3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
| 4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 5 |
+
|
| 6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
| 7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
|
| 11 |
+
|
| 12 |
+
***************************
|
| 13 |
+
|
| 14 |
+
NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
|
| 15 |
+
|
| 16 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
| 17 |
+
|
| 18 |
+
https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 19 |
+
|
| 20 |
+
This software in this file incorporates parts of the following software available here:
|
| 21 |
+
|
| 22 |
+
Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
|
| 23 |
+
available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
|
| 24 |
+
|
| 25 |
+
MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 26 |
+
available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
|
| 27 |
+
|
| 28 |
+
DeiT: https://github.com/facebookresearch/deit
|
| 29 |
+
available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
|
| 33 |
+
|
| 34 |
+
https://github.com/facebookresearch/mae/blob/main/LICENSE
|
| 35 |
+
|
| 36 |
+
Attribution-NonCommercial 4.0 International
|
| 37 |
+
|
| 38 |
+
***************************
|
| 39 |
+
|
| 40 |
+
NOTICE WITH RESPECT TO THE FILE: models/blocks.py
|
| 41 |
+
|
| 42 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
| 43 |
+
|
| 44 |
+
https://github.com/rwightman/pytorch-image-models
|
| 45 |
+
|
| 46 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
|
| 47 |
+
|
| 48 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
|
| 49 |
+
|
| 50 |
+
Apache License
|
| 51 |
+
Version 2.0, January 2004
|
| 52 |
+
http://www.apache.org/licenses/
|
dust3r/croco/NOTICE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CroCo
|
| 2 |
+
Copyright 2022-present NAVER Corp.
|
| 3 |
+
|
| 4 |
+
This project contains subcomponents with separate copyright notices and license terms.
|
| 5 |
+
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
|
| 6 |
+
|
| 7 |
+
====
|
| 8 |
+
|
| 9 |
+
facebookresearch/mae
|
| 10 |
+
https://github.com/facebookresearch/mae
|
| 11 |
+
|
| 12 |
+
Attribution-NonCommercial 4.0 International
|
| 13 |
+
|
| 14 |
+
====
|
| 15 |
+
|
| 16 |
+
rwightman/pytorch-image-models
|
| 17 |
+
https://github.com/rwightman/pytorch-image-models
|
| 18 |
+
|
| 19 |
+
Apache License
|
| 20 |
+
Version 2.0, January 2004
|
| 21 |
+
http://www.apache.org/licenses/
|
dust3r/croco/README.MD
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
|
| 2 |
+
|
| 3 |
+
[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
|
| 4 |
+
|
| 5 |
+
This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
```bibtex
|
| 10 |
+
@inproceedings{croco,
|
| 11 |
+
title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
|
| 12 |
+
author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
|
| 13 |
+
booktitle={{NeurIPS}},
|
| 14 |
+
year={2022}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
@inproceedings{croco_v2,
|
| 18 |
+
title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
|
| 19 |
+
author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
|
| 20 |
+
booktitle={ICCV},
|
| 21 |
+
year={2023}
|
| 22 |
+
}
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## License
|
| 26 |
+
|
| 27 |
+
The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
|
| 28 |
+
Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
|
| 29 |
+
Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
|
| 30 |
+
|
| 31 |
+
## Preparation
|
| 32 |
+
|
| 33 |
+
1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
conda create -n croco python=3.7 cmake=3.14.0
|
| 37 |
+
conda activate croco
|
| 38 |
+
conda install habitat-sim headless -c conda-forge -c aihabitat
|
| 39 |
+
conda install pytorch torchvision -c pytorch
|
| 40 |
+
conda install notebook ipykernel matplotlib
|
| 41 |
+
conda install ipywidgets widgetsnbextension
|
| 42 |
+
conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
2. Compile cuda kernels for RoPE
|
| 47 |
+
|
| 48 |
+
CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
|
| 49 |
+
```bash
|
| 50 |
+
cd models/curope/
|
| 51 |
+
python setup.py build_ext --inplace
|
| 52 |
+
cd ../../
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
|
| 56 |
+
You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
|
| 57 |
+
|
| 58 |
+
In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
|
| 59 |
+
|
| 60 |
+
3. Download pre-trained model
|
| 61 |
+
|
| 62 |
+
We provide several pre-trained models:
|
| 63 |
+
|
| 64 |
+
| modelname | pre-training data | pos. embed. | Encoder | Decoder |
|
| 65 |
+
|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
|
| 66 |
+
| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
|
| 67 |
+
| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
|
| 68 |
+
| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
|
| 69 |
+
| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
|
| 70 |
+
|
| 71 |
+
To download a specific model, i.e., the first one (`CroCo.pth`)
|
| 72 |
+
```bash
|
| 73 |
+
mkdir -p pretrained_models/
|
| 74 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Reconstruction example
|
| 78 |
+
|
| 79 |
+
Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
|
| 80 |
+
```bash
|
| 81 |
+
python demo.py
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
|
| 85 |
+
|
| 86 |
+
First download the test scene from Habitat:
|
| 87 |
+
```bash
|
| 88 |
+
python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Then, run the Notebook demo `interactive_demo.ipynb`.
|
| 92 |
+
|
| 93 |
+
In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
|
| 94 |
+

|
| 95 |
+
|
| 96 |
+
## Pre-training
|
| 97 |
+
|
| 98 |
+
### CroCo
|
| 99 |
+
|
| 100 |
+
To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
|
| 101 |
+
```
|
| 102 |
+
torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Our CroCo pre-training was launched on a single server with 4 GPUs.
|
| 106 |
+
It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
|
| 107 |
+
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
|
| 108 |
+
The first run can take a few minutes to start, to parse all available pre-training pairs.
|
| 109 |
+
|
| 110 |
+
### CroCo v2
|
| 111 |
+
|
| 112 |
+
For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
|
| 113 |
+
Then, run the following command for the largest model (ViT-L encoder, Base decoder):
|
| 114 |
+
```
|
| 115 |
+
torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
|
| 119 |
+
The largest model should take around 12 days on A100.
|
| 120 |
+
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
|
| 121 |
+
|
| 122 |
+
## Stereo matching and Optical flow downstream tasks
|
| 123 |
+
|
| 124 |
+
For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
|
dust3r/croco/assets/Chateau1.png
ADDED
|
Git LFS Details
|
dust3r/croco/assets/Chateau2.png
ADDED
|
Git LFS Details
|
dust3r/croco/assets/arch.jpg
ADDED
|
dust3r/croco/croco-stereo-flow-demo.ipynb
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "9bca0f41",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Simple inference example with CroCo-Stereo or CroCo-Flow"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "80653ef7",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
|
| 19 |
+
"# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"id": "4f033862",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": [
|
| 27 |
+
"First download the model(s) of your choice by running\n",
|
| 28 |
+
"```\n",
|
| 29 |
+
"bash stereoflow/download_model.sh crocostereo.pth\n",
|
| 30 |
+
"bash stereoflow/download_model.sh crocoflow.pth\n",
|
| 31 |
+
"```"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"id": "1fb2e392",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [],
|
| 40 |
+
"source": [
|
| 41 |
+
"import torch\n",
|
| 42 |
+
"use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
|
| 43 |
+
"device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
|
| 44 |
+
"import matplotlib.pylab as plt"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"id": "e0e25d77",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"from stereoflow.test import _load_model_and_criterion\n",
|
| 55 |
+
"from stereoflow.engine import tiled_pred\n",
|
| 56 |
+
"from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
|
| 57 |
+
"from stereoflow.datasets_flow import flowToColor\n",
|
| 58 |
+
"tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "markdown",
|
| 63 |
+
"id": "86a921f5",
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"source": [
|
| 66 |
+
"### CroCo-Stereo example"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": null,
|
| 72 |
+
"id": "64e483cb",
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
|
| 77 |
+
"image2 = np.asarray(Image.open('<path_to_right_image>'))"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"id": "f0d04303",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"id": "47dc14b5",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
|
| 98 |
+
"im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
|
| 99 |
+
"with torch.inference_mode():\n",
|
| 100 |
+
" pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
|
| 101 |
+
"pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "code",
|
| 106 |
+
"execution_count": null,
|
| 107 |
+
"id": "583b9f16",
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [],
|
| 110 |
+
"source": [
|
| 111 |
+
"plt.imshow(vis_disparity(pred))\n",
|
| 112 |
+
"plt.axis('off')"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "markdown",
|
| 117 |
+
"id": "d2df5d70",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"source": [
|
| 120 |
+
"### CroCo-Flow example"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "code",
|
| 125 |
+
"execution_count": null,
|
| 126 |
+
"id": "9ee257a7",
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
|
| 131 |
+
"image2 = np.asarray(Image.open('<path_to_second_image>'))"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "code",
|
| 136 |
+
"execution_count": null,
|
| 137 |
+
"id": "d5edccf0",
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
|
| 142 |
+
]
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"cell_type": "code",
|
| 146 |
+
"execution_count": null,
|
| 147 |
+
"id": "b19692c3",
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"outputs": [],
|
| 150 |
+
"source": [
|
| 151 |
+
"im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
|
| 152 |
+
"im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
|
| 153 |
+
"with torch.inference_mode():\n",
|
| 154 |
+
" pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
|
| 155 |
+
"pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"cell_type": "code",
|
| 160 |
+
"execution_count": null,
|
| 161 |
+
"id": "26f79db3",
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"outputs": [],
|
| 164 |
+
"source": [
|
| 165 |
+
"plt.imshow(flowToColor(pred))\n",
|
| 166 |
+
"plt.axis('off')"
|
| 167 |
+
]
|
| 168 |
+
}
|
| 169 |
+
],
|
| 170 |
+
"metadata": {
|
| 171 |
+
"kernelspec": {
|
| 172 |
+
"display_name": "Python 3 (ipykernel)",
|
| 173 |
+
"language": "python",
|
| 174 |
+
"name": "python3"
|
| 175 |
+
},
|
| 176 |
+
"language_info": {
|
| 177 |
+
"codemirror_mode": {
|
| 178 |
+
"name": "ipython",
|
| 179 |
+
"version": 3
|
| 180 |
+
},
|
| 181 |
+
"file_extension": ".py",
|
| 182 |
+
"mimetype": "text/x-python",
|
| 183 |
+
"name": "python",
|
| 184 |
+
"nbconvert_exporter": "python",
|
| 185 |
+
"pygments_lexer": "ipython3",
|
| 186 |
+
"version": "3.9.7"
|
| 187 |
+
}
|
| 188 |
+
},
|
| 189 |
+
"nbformat": 4,
|
| 190 |
+
"nbformat_minor": 5
|
| 191 |
+
}
|
dust3r/croco/datasets_croco/__init__.py
ADDED
|
File without changes
|
dust3r/croco/datasets_croco/crops/README.MD
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Generation of crops from the real datasets
|
| 2 |
+
|
| 3 |
+
The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
|
| 4 |
+
|
| 5 |
+
### Download the metadata of the crops to generate
|
| 6 |
+
|
| 7 |
+
First, download the metadata and put them in `./data/`:
|
| 8 |
+
```
|
| 9 |
+
mkdir -p data
|
| 10 |
+
cd data/
|
| 11 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
|
| 12 |
+
unzip crop_metadata.zip
|
| 13 |
+
rm crop_metadata.zip
|
| 14 |
+
cd ..
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### Prepare the original datasets
|
| 18 |
+
|
| 19 |
+
Second, download the original datasets in `./data/original_datasets/`.
|
| 20 |
+
```
|
| 21 |
+
mkdir -p data/original_datasets
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
##### ARKitScenes
|
| 25 |
+
|
| 26 |
+
Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
|
| 27 |
+
The resulting file structure should be like:
|
| 28 |
+
```
|
| 29 |
+
./data/original_datasets/ARKitScenes/
|
| 30 |
+
└───Training
|
| 31 |
+
└───40753679
|
| 32 |
+
│ │ ultrawide
|
| 33 |
+
│ │ ...
|
| 34 |
+
└───40753686
|
| 35 |
+
│
|
| 36 |
+
...
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
##### MegaDepth
|
| 40 |
+
|
| 41 |
+
Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
|
| 42 |
+
The resulting file structure should be like:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
./data/original_datasets/MegaDepth/
|
| 46 |
+
└───0000
|
| 47 |
+
│ └───images
|
| 48 |
+
│ │ │ 1000557903_87fa96b8a4_o.jpg
|
| 49 |
+
│ │ └ ...
|
| 50 |
+
│ └─── ...
|
| 51 |
+
└───0001
|
| 52 |
+
│ │
|
| 53 |
+
│ └ ...
|
| 54 |
+
└─── ...
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
##### 3DStreetView
|
| 58 |
+
|
| 59 |
+
Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
|
| 60 |
+
The resulting file structure should be like:
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
./data/original_datasets/3DStreetView/
|
| 64 |
+
└───dataset_aligned
|
| 65 |
+
│ └───0002
|
| 66 |
+
│ │ │ 0000002_0000001_0000002_0000001.jpg
|
| 67 |
+
│ │ └ ...
|
| 68 |
+
│ └─── ...
|
| 69 |
+
└───dataset_unaligned
|
| 70 |
+
│ └───0003
|
| 71 |
+
│ │ │ 0000003_0000001_0000002_0000001.jpg
|
| 72 |
+
│ │ └ ...
|
| 73 |
+
│ └─── ...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
##### IndoorVL
|
| 77 |
+
|
| 78 |
+
Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
pip install kapture
|
| 82 |
+
mkdir -p ./data/original_datasets/IndoorVL
|
| 83 |
+
cd ./data/original_datasets/IndoorVL
|
| 84 |
+
kapture_download_dataset.py update
|
| 85 |
+
kapture_download_dataset.py install "HyundaiDepartmentStore_*"
|
| 86 |
+
kapture_download_dataset.py install "GangnamStation_*"
|
| 87 |
+
cd -
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Extract the crops
|
| 91 |
+
|
| 92 |
+
Now, extract the crops for each of the dataset:
|
| 93 |
+
```
|
| 94 |
+
for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
|
| 95 |
+
do
|
| 96 |
+
python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
|
| 97 |
+
done
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
##### Note for IndoorVL
|
| 101 |
+
|
| 102 |
+
Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
|
| 103 |
+
To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
|
| 104 |
+
The impact on the performance is negligible.
|
dust3r/croco/datasets_croco/crops/extract_crops_from_images.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Extracting crops for pre-training
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import functools
|
| 13 |
+
from multiprocessing import Pool
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def arg_parser():
|
| 18 |
+
parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--crops', type=str, required=True, help='crop file')
|
| 21 |
+
parser.add_argument('--root-dir', type=str, required=True, help='root directory')
|
| 22 |
+
parser.add_argument('--output-dir', type=str, required=True, help='output directory')
|
| 23 |
+
parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
|
| 24 |
+
parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
|
| 25 |
+
parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
|
| 26 |
+
parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
|
| 27 |
+
return parser
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main(args):
|
| 31 |
+
listing_path = os.path.join(args.output_dir, 'listing.txt')
|
| 32 |
+
|
| 33 |
+
print(f'Loading list of crops ... ({args.nthread} threads)')
|
| 34 |
+
crops, num_crops_to_generate = load_crop_file(args.crops)
|
| 35 |
+
|
| 36 |
+
print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
|
| 37 |
+
num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
|
| 38 |
+
num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
|
| 39 |
+
|
| 40 |
+
jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
|
| 41 |
+
del crops
|
| 42 |
+
|
| 43 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 44 |
+
mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
|
| 45 |
+
call = functools.partial(save_image_crops, args)
|
| 46 |
+
|
| 47 |
+
print(f"Generating cropped images to {args.output_dir} ...")
|
| 48 |
+
with open(listing_path, 'w') as listing:
|
| 49 |
+
listing.write('# pair_path\n')
|
| 50 |
+
for results in tqdm(mmap(call, jobs), total=len(jobs)):
|
| 51 |
+
for path in results:
|
| 52 |
+
listing.write(f'{path}\n')
|
| 53 |
+
print('Finished writing listing to', listing_path)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_crop_file(path):
|
| 57 |
+
data = open(path).read().splitlines()
|
| 58 |
+
pairs = []
|
| 59 |
+
num_crops_to_generate = 0
|
| 60 |
+
for line in tqdm(data):
|
| 61 |
+
if line.startswith('#'):
|
| 62 |
+
continue
|
| 63 |
+
line = line.split(', ')
|
| 64 |
+
if len(line) < 8:
|
| 65 |
+
img1, img2, rotation = line
|
| 66 |
+
pairs.append((img1, img2, int(rotation), []))
|
| 67 |
+
else:
|
| 68 |
+
l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
|
| 69 |
+
rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
|
| 70 |
+
pairs[-1][-1].append((rect1, rect2))
|
| 71 |
+
num_crops_to_generate += 1
|
| 72 |
+
return pairs, num_crops_to_generate
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
|
| 76 |
+
jobs = []
|
| 77 |
+
powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
|
| 78 |
+
|
| 79 |
+
def get_path(idx):
|
| 80 |
+
idx_array = []
|
| 81 |
+
d = idx
|
| 82 |
+
for level in range(num_levels - 1):
|
| 83 |
+
idx_array.append(idx // powers[level])
|
| 84 |
+
idx = idx % powers[level]
|
| 85 |
+
idx_array.append(d)
|
| 86 |
+
return '/'.join(map(lambda x: hex(x)[2:], idx_array))
|
| 87 |
+
|
| 88 |
+
idx = 0
|
| 89 |
+
for pair_data in tqdm(pairs):
|
| 90 |
+
img1, img2, rotation, crops = pair_data
|
| 91 |
+
if -60 <= rotation and rotation <= 60:
|
| 92 |
+
rotation = 0 # most likely not a true rotation
|
| 93 |
+
paths = [get_path(idx + k) for k in range(len(crops))]
|
| 94 |
+
idx += len(crops)
|
| 95 |
+
jobs.append(((img1, img2), rotation, crops, paths))
|
| 96 |
+
return jobs
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_image(path):
|
| 100 |
+
try:
|
| 101 |
+
return Image.open(path).convert('RGB')
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print('skipping', path, e)
|
| 104 |
+
raise OSError()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def save_image_crops(args, data):
|
| 108 |
+
# load images
|
| 109 |
+
img_pair, rot, crops, paths = data
|
| 110 |
+
try:
|
| 111 |
+
img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
|
| 112 |
+
except OSError as e:
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
def area(sz):
|
| 116 |
+
return sz[0] * sz[1]
|
| 117 |
+
|
| 118 |
+
tgt_size = (args.imsize, args.imsize)
|
| 119 |
+
|
| 120 |
+
def prepare_crop(img, rect, rot=0):
|
| 121 |
+
# actual crop
|
| 122 |
+
img = img.crop(rect)
|
| 123 |
+
|
| 124 |
+
# resize to desired size
|
| 125 |
+
interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
|
| 126 |
+
img = img.resize(tgt_size, resample=interp)
|
| 127 |
+
|
| 128 |
+
# rotate the image
|
| 129 |
+
rot90 = (round(rot/90) % 4) * 90
|
| 130 |
+
if rot90 == 90:
|
| 131 |
+
img = img.transpose(Image.Transpose.ROTATE_90)
|
| 132 |
+
elif rot90 == 180:
|
| 133 |
+
img = img.transpose(Image.Transpose.ROTATE_180)
|
| 134 |
+
elif rot90 == 270:
|
| 135 |
+
img = img.transpose(Image.Transpose.ROTATE_270)
|
| 136 |
+
return img
|
| 137 |
+
|
| 138 |
+
results = []
|
| 139 |
+
for (rect1, rect2), path in zip(crops, paths):
|
| 140 |
+
crop1 = prepare_crop(img1, rect1)
|
| 141 |
+
crop2 = prepare_crop(img2, rect2, rot)
|
| 142 |
+
|
| 143 |
+
fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
|
| 144 |
+
fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
|
| 145 |
+
os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
|
| 146 |
+
|
| 147 |
+
assert not os.path.isfile(fullpath1), fullpath1
|
| 148 |
+
assert not os.path.isfile(fullpath2), fullpath2
|
| 149 |
+
crop1.save(fullpath1)
|
| 150 |
+
crop2.save(fullpath2)
|
| 151 |
+
results.append(path)
|
| 152 |
+
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == '__main__':
|
| 157 |
+
args = arg_parser().parse_args()
|
| 158 |
+
main(args)
|
| 159 |
+
|
dust3r/croco/datasets_croco/habitat_sim/README.MD
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Generation of synthetic image pairs using Habitat-Sim
|
| 2 |
+
|
| 3 |
+
These instructions allow to generate pre-training pairs from the Habitat simulator.
|
| 4 |
+
As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
|
| 5 |
+
|
| 6 |
+
### Download Habitat-Sim scenes
|
| 7 |
+
Download Habitat-Sim scenes:
|
| 8 |
+
- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
|
| 9 |
+
- We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
|
| 10 |
+
- Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
|
| 11 |
+
```
|
| 12 |
+
./data/
|
| 13 |
+
└──habitat-sim-data/
|
| 14 |
+
└──scene_datasets/
|
| 15 |
+
├──hm3d/
|
| 16 |
+
├──gibson/
|
| 17 |
+
├──habitat-test-scenes/
|
| 18 |
+
├──replica_cad_baked_lighting/
|
| 19 |
+
├──replica_cad/
|
| 20 |
+
├──ReplicaDataset/
|
| 21 |
+
└──scannet/
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Image pairs generation
|
| 25 |
+
We provide metadata to generate reproducible images pairs for pretraining and validation.
|
| 26 |
+
Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
|
| 27 |
+
|
| 28 |
+
Specifications:
|
| 29 |
+
- 256x256 resolution images, with 60 degrees field of view .
|
| 30 |
+
- Up to 1000 image pairs per scene.
|
| 31 |
+
- Number of scenes considered/number of images pairs per dataset:
|
| 32 |
+
- Scannet: 1097 scenes / 985 209 pairs
|
| 33 |
+
- HM3D:
|
| 34 |
+
- hm3d/train: 800 / 800k pairs
|
| 35 |
+
- hm3d/val: 100 scenes / 100k pairs
|
| 36 |
+
- hm3d/minival: 10 scenes / 10k pairs
|
| 37 |
+
- habitat-test-scenes: 3 scenes / 3k pairs
|
| 38 |
+
- replica_cad_baked_lighting: 13 scenes / 13k pairs
|
| 39 |
+
|
| 40 |
+
- Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
|
| 41 |
+
|
| 42 |
+
Download metadata and extract it:
|
| 43 |
+
```bash
|
| 44 |
+
mkdir -p data/habitat_release_metadata/
|
| 45 |
+
cd data/habitat_release_metadata/
|
| 46 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
|
| 47 |
+
tar -xvf multiview_habitat_metadata.tar.gz
|
| 48 |
+
cd ../..
|
| 49 |
+
# Location of the metadata
|
| 50 |
+
METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Generate image pairs from metadata:
|
| 54 |
+
- The following command will print a list of commandlines to generate image pairs for each scene:
|
| 55 |
+
```bash
|
| 56 |
+
# Target output directory
|
| 57 |
+
PAIRS_DATASET_DIR="./data/habitat_release/"
|
| 58 |
+
python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
|
| 59 |
+
```
|
| 60 |
+
- One can launch multiple of such commands in parallel e.g. using GNU Parallel:
|
| 61 |
+
```bash
|
| 62 |
+
python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Metadata generation
|
| 66 |
+
|
| 67 |
+
Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
|
| 68 |
+
```bash
|
| 69 |
+
# Print commandlines to generate image pairs from the different scenes available.
|
| 70 |
+
PAIRS_DATASET_DIR=MY_CUSTOM_PATH
|
| 71 |
+
python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
|
| 72 |
+
|
| 73 |
+
# Once a dataset is generated, pack metadata files for reproducibility.
|
| 74 |
+
METADATA_DIR=MY_CUSTON_PATH
|
| 75 |
+
python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
|
| 76 |
+
```
|
dust3r/croco/datasets_croco/habitat_sim/__init__.py
ADDED
|
File without changes
|
dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
|
| 9 |
+
from datasets.habitat_sim.paths import SCENES_DATASET
|
| 10 |
+
import argparse
|
| 11 |
+
import quaternion
|
| 12 |
+
import PIL.Image
|
| 13 |
+
import cv2
|
| 14 |
+
import json
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
def generate_multiview_images_from_metadata(metadata_filename,
|
| 18 |
+
output_dir,
|
| 19 |
+
overload_params = dict(),
|
| 20 |
+
scene_datasets_paths=None,
|
| 21 |
+
exist_ok=False):
|
| 22 |
+
"""
|
| 23 |
+
Generate images from a metadata file for reproducibility purposes.
|
| 24 |
+
"""
|
| 25 |
+
# Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
|
| 26 |
+
if scene_datasets_paths is not None:
|
| 27 |
+
scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
|
| 28 |
+
|
| 29 |
+
with open(metadata_filename, 'r') as f:
|
| 30 |
+
input_metadata = json.load(f)
|
| 31 |
+
metadata = dict()
|
| 32 |
+
for key, value in input_metadata.items():
|
| 33 |
+
# Optionally replace some paths
|
| 34 |
+
if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
|
| 35 |
+
if scene_datasets_paths is not None:
|
| 36 |
+
for dataset_label, dataset_path in scene_datasets_paths.items():
|
| 37 |
+
if value.startswith(dataset_label):
|
| 38 |
+
value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
|
| 39 |
+
break
|
| 40 |
+
metadata[key] = value
|
| 41 |
+
|
| 42 |
+
# Overload some parameters
|
| 43 |
+
for key, value in overload_params.items():
|
| 44 |
+
metadata[key] = value
|
| 45 |
+
|
| 46 |
+
generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
|
| 47 |
+
generate_depth = metadata["generate_depth"]
|
| 48 |
+
|
| 49 |
+
os.makedirs(output_dir, exist_ok=exist_ok)
|
| 50 |
+
|
| 51 |
+
generator = MultiviewHabitatSimGenerator(**generation_entries)
|
| 52 |
+
|
| 53 |
+
# Generate views
|
| 54 |
+
for idx_label, data in tqdm(metadata['multiviews'].items()):
|
| 55 |
+
positions = data["positions"]
|
| 56 |
+
orientations = data["orientations"]
|
| 57 |
+
n = len(positions)
|
| 58 |
+
for oidx in range(n):
|
| 59 |
+
observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
|
| 60 |
+
observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
|
| 61 |
+
# Color image saved using PIL
|
| 62 |
+
img = PIL.Image.fromarray(observation['color'][:,:,:3])
|
| 63 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
|
| 64 |
+
img.save(filename)
|
| 65 |
+
if generate_depth:
|
| 66 |
+
# Depth image as EXR file
|
| 67 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
|
| 68 |
+
cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
| 69 |
+
# Camera parameters
|
| 70 |
+
camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
|
| 71 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
|
| 72 |
+
with open(filename, "w") as f:
|
| 73 |
+
json.dump(camera_params, f)
|
| 74 |
+
# Save metadata
|
| 75 |
+
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
| 76 |
+
json.dump(metadata, f)
|
| 77 |
+
|
| 78 |
+
generator.close()
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
parser = argparse.ArgumentParser()
|
| 82 |
+
parser.add_argument("--metadata_filename", required=True)
|
| 83 |
+
parser.add_argument("--output_dir", required=True)
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
|
| 86 |
+
generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
|
| 87 |
+
output_dir=args.output_dir,
|
| 88 |
+
scene_datasets_paths=SCENES_DATASET,
|
| 89 |
+
overload_params=dict(),
|
| 90 |
+
exist_ok=True)
|
| 91 |
+
|
| 92 |
+
|
dust3r/croco/datasets_croco/habitat_sim/generate_from_metadata_files.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Script generating commandlines to generate image pairs from metadata files.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import glob
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument("--input_dir", required=True)
|
| 15 |
+
parser.add_argument("--output_dir", required=True)
|
| 16 |
+
parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
|
| 20 |
+
|
| 21 |
+
for metadata_filename in tqdm(input_metadata_filenames):
|
| 22 |
+
output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
|
| 23 |
+
# Do not process the scene if the metadata file already exists
|
| 24 |
+
if os.path.exists(os.path.join(output_dir, "metadata.json")):
|
| 25 |
+
continue
|
| 26 |
+
commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
|
| 27 |
+
print(commandline)
|
dust3r/croco/datasets_croco/habitat_sim/generate_multiview_images.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import argparse
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import json
|
| 10 |
+
from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
|
| 11 |
+
from datasets.habitat_sim.paths import list_scenes_available
|
| 12 |
+
import cv2
|
| 13 |
+
import quaternion
|
| 14 |
+
import shutil
|
| 15 |
+
|
| 16 |
+
def generate_multiview_images_for_scene(scene_dataset_config_file,
|
| 17 |
+
scene,
|
| 18 |
+
navmesh,
|
| 19 |
+
output_dir,
|
| 20 |
+
views_count,
|
| 21 |
+
size,
|
| 22 |
+
exist_ok=False,
|
| 23 |
+
generate_depth=False,
|
| 24 |
+
**kwargs):
|
| 25 |
+
"""
|
| 26 |
+
Generate tuples of overlapping views for a given scene.
|
| 27 |
+
generate_depth: generate depth images and camera parameters.
|
| 28 |
+
"""
|
| 29 |
+
if os.path.exists(output_dir) and not exist_ok:
|
| 30 |
+
print(f"Scene {scene}: data already generated. Ignoring generation.")
|
| 31 |
+
return
|
| 32 |
+
try:
|
| 33 |
+
print(f"Scene {scene}: {size} multiview acquisitions to generate...")
|
| 34 |
+
os.makedirs(output_dir, exist_ok=exist_ok)
|
| 35 |
+
|
| 36 |
+
metadata_filename = os.path.join(output_dir, "metadata.json")
|
| 37 |
+
|
| 38 |
+
metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
|
| 39 |
+
scene=scene,
|
| 40 |
+
navmesh=navmesh,
|
| 41 |
+
views_count=views_count,
|
| 42 |
+
size=size,
|
| 43 |
+
generate_depth=generate_depth,
|
| 44 |
+
**kwargs)
|
| 45 |
+
metadata_template["multiviews"] = dict()
|
| 46 |
+
|
| 47 |
+
if os.path.exists(metadata_filename):
|
| 48 |
+
print("Metadata file already exists:", metadata_filename)
|
| 49 |
+
print("Loading already generated metadata file...")
|
| 50 |
+
with open(metadata_filename, "r") as f:
|
| 51 |
+
metadata = json.load(f)
|
| 52 |
+
|
| 53 |
+
for key in metadata_template.keys():
|
| 54 |
+
if key != "multiviews":
|
| 55 |
+
assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
|
| 56 |
+
else:
|
| 57 |
+
print("No temporary file found. Starting generation from scratch...")
|
| 58 |
+
metadata = metadata_template
|
| 59 |
+
|
| 60 |
+
starting_id = len(metadata["multiviews"])
|
| 61 |
+
print(f"Starting generation from index {starting_id}/{size}...")
|
| 62 |
+
if starting_id >= size:
|
| 63 |
+
print("Generation already done.")
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
|
| 67 |
+
scene=scene,
|
| 68 |
+
navmesh=navmesh,
|
| 69 |
+
views_count = views_count,
|
| 70 |
+
size = size,
|
| 71 |
+
**kwargs)
|
| 72 |
+
|
| 73 |
+
for idx in tqdm(range(starting_id, size)):
|
| 74 |
+
# Generate / re-generate the observations
|
| 75 |
+
try:
|
| 76 |
+
data = generator[idx]
|
| 77 |
+
observations = data["observations"]
|
| 78 |
+
positions = data["positions"]
|
| 79 |
+
orientations = data["orientations"]
|
| 80 |
+
|
| 81 |
+
idx_label = f"{idx:08}"
|
| 82 |
+
for oidx, observation in enumerate(observations):
|
| 83 |
+
observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
|
| 84 |
+
# Color image saved using PIL
|
| 85 |
+
img = PIL.Image.fromarray(observation['color'][:,:,:3])
|
| 86 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
|
| 87 |
+
img.save(filename)
|
| 88 |
+
if generate_depth:
|
| 89 |
+
# Depth image as EXR file
|
| 90 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
|
| 91 |
+
cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
| 92 |
+
# Camera parameters
|
| 93 |
+
camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
|
| 94 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
|
| 95 |
+
with open(filename, "w") as f:
|
| 96 |
+
json.dump(camera_params, f)
|
| 97 |
+
metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
|
| 98 |
+
"orientations": orientations.tolist(),
|
| 99 |
+
"covisibility_ratios": data["covisibility_ratios"].tolist(),
|
| 100 |
+
"valid_fractions": data["valid_fractions"].tolist(),
|
| 101 |
+
"pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
|
| 102 |
+
except RecursionError:
|
| 103 |
+
print("Recursion error: unable to sample observations for this scene. We will stop there.")
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
# Regularly save a temporary metadata file, in case we need to restart the generation
|
| 107 |
+
if idx % 10 == 0:
|
| 108 |
+
with open(metadata_filename, "w") as f:
|
| 109 |
+
json.dump(metadata, f)
|
| 110 |
+
|
| 111 |
+
# Save metadata
|
| 112 |
+
with open(metadata_filename, "w") as f:
|
| 113 |
+
json.dump(metadata, f)
|
| 114 |
+
|
| 115 |
+
generator.close()
|
| 116 |
+
except NoNaviguableSpaceError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
def create_commandline(scene_data, generate_depth, exist_ok=False):
|
| 120 |
+
"""
|
| 121 |
+
Create a commandline string to generate a scene.
|
| 122 |
+
"""
|
| 123 |
+
def my_formatting(val):
|
| 124 |
+
if val is None or val == "":
|
| 125 |
+
return '""'
|
| 126 |
+
else:
|
| 127 |
+
return val
|
| 128 |
+
commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
|
| 129 |
+
--scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
|
| 130 |
+
--navmesh {my_formatting(scene_data.navmesh)}
|
| 131 |
+
--output_dir {my_formatting(scene_data.output_dir)}
|
| 132 |
+
--generate_depth {int(generate_depth)}
|
| 133 |
+
--exist_ok {int(exist_ok)}
|
| 134 |
+
"""
|
| 135 |
+
commandline = " ".join(commandline.split())
|
| 136 |
+
return commandline
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
os.umask(2)
|
| 140 |
+
|
| 141 |
+
parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
|
| 142 |
+
> python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
|
| 143 |
+
""")
|
| 144 |
+
|
| 145 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
| 146 |
+
parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
|
| 147 |
+
parser.add_argument("--scene", type=str, default="")
|
| 148 |
+
parser.add_argument("--scene_dataset_config_file", type=str, default="")
|
| 149 |
+
parser.add_argument("--navmesh", type=str, default="")
|
| 150 |
+
|
| 151 |
+
parser.add_argument("--generate_depth", type=int, default=1)
|
| 152 |
+
parser.add_argument("--exist_ok", type=int, default=0)
|
| 153 |
+
|
| 154 |
+
kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
|
| 155 |
+
|
| 156 |
+
args = parser.parse_args()
|
| 157 |
+
generate_depth=bool(args.generate_depth)
|
| 158 |
+
exist_ok = bool(args.exist_ok)
|
| 159 |
+
|
| 160 |
+
if args.list_commands:
|
| 161 |
+
# Listing scenes available...
|
| 162 |
+
scenes_data = list_scenes_available(base_output_dir=args.output_dir)
|
| 163 |
+
|
| 164 |
+
for scene_data in scenes_data:
|
| 165 |
+
print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
|
| 166 |
+
else:
|
| 167 |
+
if args.scene == "" or args.output_dir == "":
|
| 168 |
+
print("Missing scene or output dir argument!")
|
| 169 |
+
print(parser.format_help())
|
| 170 |
+
else:
|
| 171 |
+
generate_multiview_images_for_scene(scene=args.scene,
|
| 172 |
+
scene_dataset_config_file = args.scene_dataset_config_file,
|
| 173 |
+
navmesh = args.navmesh,
|
| 174 |
+
output_dir = args.output_dir,
|
| 175 |
+
exist_ok=exist_ok,
|
| 176 |
+
generate_depth=generate_depth,
|
| 177 |
+
**kwargs)
|
dust3r/croco/datasets_croco/habitat_sim/multiview_habitat_sim_generator.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import quaternion
|
| 7 |
+
import habitat_sim
|
| 8 |
+
import json
|
| 9 |
+
from sklearn.neighbors import NearestNeighbors
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
# OpenCV to habitat camera convention transformation
|
| 13 |
+
R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
|
| 14 |
+
R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
|
| 15 |
+
DEG2RAD = np.pi / 180
|
| 16 |
+
|
| 17 |
+
def compute_camera_intrinsics(height, width, hfov):
|
| 18 |
+
f = width/2 / np.tan(hfov/2 * np.pi/180)
|
| 19 |
+
cu, cv = width/2, height/2
|
| 20 |
+
return f, cu, cv
|
| 21 |
+
|
| 22 |
+
def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
|
| 23 |
+
R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
|
| 24 |
+
t_cam2world = np.asarray(camera_position)
|
| 25 |
+
return R_cam2world, t_cam2world
|
| 26 |
+
|
| 27 |
+
def compute_pointmap(depthmap, hfov):
|
| 28 |
+
""" Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
|
| 29 |
+
height, width = depthmap.shape
|
| 30 |
+
f, cu, cv = compute_camera_intrinsics(height, width, hfov)
|
| 31 |
+
# Cast depth map to point
|
| 32 |
+
z_cam = depthmap
|
| 33 |
+
u, v = np.meshgrid(range(width), range(height))
|
| 34 |
+
x_cam = (u - cu) / f * z_cam
|
| 35 |
+
y_cam = (v - cv) / f * z_cam
|
| 36 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
|
| 37 |
+
return X_cam
|
| 38 |
+
|
| 39 |
+
def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
|
| 40 |
+
"""Return a 3D point cloud corresponding to valid pixels of the depth map"""
|
| 41 |
+
R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
|
| 42 |
+
|
| 43 |
+
X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
|
| 44 |
+
valid_mask = (X_cam[:,:,2] != 0.0)
|
| 45 |
+
|
| 46 |
+
X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
|
| 47 |
+
X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
|
| 48 |
+
return X_world
|
| 49 |
+
|
| 50 |
+
def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
|
| 51 |
+
"""
|
| 52 |
+
Compute 'overlapping' metrics based on a distance threshold between two point clouds.
|
| 53 |
+
"""
|
| 54 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
|
| 55 |
+
distances, indices = nbrs.kneighbors(pointcloud1)
|
| 56 |
+
intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
|
| 57 |
+
|
| 58 |
+
data = {"intersection1": intersection1,
|
| 59 |
+
"size1": len(pointcloud1)}
|
| 60 |
+
if compute_symmetric:
|
| 61 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
|
| 62 |
+
distances, indices = nbrs.kneighbors(pointcloud2)
|
| 63 |
+
intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
|
| 64 |
+
data["intersection2"] = intersection2
|
| 65 |
+
data["size2"] = len(pointcloud2)
|
| 66 |
+
|
| 67 |
+
return data
|
| 68 |
+
|
| 69 |
+
def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
|
| 70 |
+
"""
|
| 71 |
+
Add camera parameters to the observation dictionnary produced by Habitat-Sim
|
| 72 |
+
In-place modifications.
|
| 73 |
+
"""
|
| 74 |
+
R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
|
| 75 |
+
height, width = observation['depth'].shape
|
| 76 |
+
f, cu, cv = compute_camera_intrinsics(height, width, hfov)
|
| 77 |
+
K = np.asarray([[f, 0, cu],
|
| 78 |
+
[0, f, cv],
|
| 79 |
+
[0, 0, 1.0]])
|
| 80 |
+
observation["camera_intrinsics"] = K
|
| 81 |
+
observation["t_cam2world"] = t_cam2world
|
| 82 |
+
observation["R_cam2world"] = R_cam2world
|
| 83 |
+
|
| 84 |
+
def look_at(eye, center, up, return_cam2world=True):
|
| 85 |
+
"""
|
| 86 |
+
Return camera pose looking at a given center point.
|
| 87 |
+
Analogous of gluLookAt function, using OpenCV camera convention.
|
| 88 |
+
"""
|
| 89 |
+
z = center - eye
|
| 90 |
+
z /= np.linalg.norm(z, axis=-1, keepdims=True)
|
| 91 |
+
y = -up
|
| 92 |
+
y = y - np.sum(y * z, axis=-1, keepdims=True) * z
|
| 93 |
+
y /= np.linalg.norm(y, axis=-1, keepdims=True)
|
| 94 |
+
x = np.cross(y, z, axis=-1)
|
| 95 |
+
|
| 96 |
+
if return_cam2world:
|
| 97 |
+
R = np.stack((x, y, z), axis=-1)
|
| 98 |
+
t = eye
|
| 99 |
+
else:
|
| 100 |
+
# World to camera transformation
|
| 101 |
+
# Transposed matrix
|
| 102 |
+
R = np.stack((x, y, z), axis=-2)
|
| 103 |
+
t = - np.einsum('...ij, ...j', R, eye)
|
| 104 |
+
return R, t
|
| 105 |
+
|
| 106 |
+
def look_at_for_habitat(eye, center, up, return_cam2world=True):
|
| 107 |
+
R, t = look_at(eye, center, up)
|
| 108 |
+
orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
|
| 109 |
+
return orientation, t
|
| 110 |
+
|
| 111 |
+
def generate_orientation_noise(pan_range, tilt_range, roll_range):
|
| 112 |
+
return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
|
| 113 |
+
* quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
|
| 114 |
+
* quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class NoNaviguableSpaceError(RuntimeError):
|
| 118 |
+
def __init__(self, *args):
|
| 119 |
+
super().__init__(*args)
|
| 120 |
+
|
| 121 |
+
class MultiviewHabitatSimGenerator:
|
| 122 |
+
def __init__(self,
|
| 123 |
+
scene,
|
| 124 |
+
navmesh,
|
| 125 |
+
scene_dataset_config_file,
|
| 126 |
+
resolution = (240, 320),
|
| 127 |
+
views_count=2,
|
| 128 |
+
hfov = 60,
|
| 129 |
+
gpu_id = 0,
|
| 130 |
+
size = 10000,
|
| 131 |
+
minimum_covisibility = 0.5,
|
| 132 |
+
transform = None):
|
| 133 |
+
self.scene = scene
|
| 134 |
+
self.navmesh = navmesh
|
| 135 |
+
self.scene_dataset_config_file = scene_dataset_config_file
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.views_count = views_count
|
| 138 |
+
assert(self.views_count >= 1)
|
| 139 |
+
self.hfov = hfov
|
| 140 |
+
self.gpu_id = gpu_id
|
| 141 |
+
self.size = size
|
| 142 |
+
self.transform = transform
|
| 143 |
+
|
| 144 |
+
# Noise added to camera orientation
|
| 145 |
+
self.pan_range = (-3, 3)
|
| 146 |
+
self.tilt_range = (-10, 10)
|
| 147 |
+
self.roll_range = (-5, 5)
|
| 148 |
+
|
| 149 |
+
# Height range to sample cameras
|
| 150 |
+
self.height_range = (1.2, 1.8)
|
| 151 |
+
|
| 152 |
+
# Random steps between the camera views
|
| 153 |
+
self.random_steps_count = 5
|
| 154 |
+
self.random_step_variance = 2.0
|
| 155 |
+
|
| 156 |
+
# Minimum fraction of the scene which should be valid (well defined depth)
|
| 157 |
+
self.minimum_valid_fraction = 0.7
|
| 158 |
+
|
| 159 |
+
# Distance threshold to see to select pairs
|
| 160 |
+
self.distance_threshold = 0.05
|
| 161 |
+
# Minimum IoU of a view point cloud with respect to the reference view to be kept.
|
| 162 |
+
self.minimum_covisibility = minimum_covisibility
|
| 163 |
+
|
| 164 |
+
# Maximum number of retries.
|
| 165 |
+
self.max_attempts_count = 100
|
| 166 |
+
|
| 167 |
+
self.seed = None
|
| 168 |
+
self._lazy_initialization()
|
| 169 |
+
|
| 170 |
+
def _lazy_initialization(self):
|
| 171 |
+
# Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
|
| 172 |
+
if self.seed == None:
|
| 173 |
+
# Re-seed numpy generator
|
| 174 |
+
np.random.seed()
|
| 175 |
+
self.seed = np.random.randint(2**32-1)
|
| 176 |
+
sim_cfg = habitat_sim.SimulatorConfiguration()
|
| 177 |
+
sim_cfg.scene_id = self.scene
|
| 178 |
+
if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
|
| 179 |
+
sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
|
| 180 |
+
sim_cfg.random_seed = self.seed
|
| 181 |
+
sim_cfg.load_semantic_mesh = False
|
| 182 |
+
sim_cfg.gpu_device_id = self.gpu_id
|
| 183 |
+
|
| 184 |
+
depth_sensor_spec = habitat_sim.CameraSensorSpec()
|
| 185 |
+
depth_sensor_spec.uuid = "depth"
|
| 186 |
+
depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
|
| 187 |
+
depth_sensor_spec.resolution = self.resolution
|
| 188 |
+
depth_sensor_spec.hfov = self.hfov
|
| 189 |
+
depth_sensor_spec.position = [0.0, 0.0, 0]
|
| 190 |
+
depth_sensor_spec.orientation
|
| 191 |
+
|
| 192 |
+
rgb_sensor_spec = habitat_sim.CameraSensorSpec()
|
| 193 |
+
rgb_sensor_spec.uuid = "color"
|
| 194 |
+
rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
|
| 195 |
+
rgb_sensor_spec.resolution = self.resolution
|
| 196 |
+
rgb_sensor_spec.hfov = self.hfov
|
| 197 |
+
rgb_sensor_spec.position = [0.0, 0.0, 0]
|
| 198 |
+
agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
|
| 199 |
+
|
| 200 |
+
cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
|
| 201 |
+
self.sim = habitat_sim.Simulator(cfg)
|
| 202 |
+
if self.navmesh is not None and self.navmesh != "":
|
| 203 |
+
# Use pre-computed navmesh when available (usually better than those generated automatically)
|
| 204 |
+
self.sim.pathfinder.load_nav_mesh(self.navmesh)
|
| 205 |
+
|
| 206 |
+
if not self.sim.pathfinder.is_loaded:
|
| 207 |
+
# Try to compute a navmesh
|
| 208 |
+
navmesh_settings = habitat_sim.NavMeshSettings()
|
| 209 |
+
navmesh_settings.set_defaults()
|
| 210 |
+
self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
|
| 211 |
+
|
| 212 |
+
# Ensure that the navmesh is not empty
|
| 213 |
+
if not self.sim.pathfinder.is_loaded:
|
| 214 |
+
raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
|
| 215 |
+
|
| 216 |
+
self.agent = self.sim.initialize_agent(agent_id=0)
|
| 217 |
+
|
| 218 |
+
def close(self):
|
| 219 |
+
self.sim.close()
|
| 220 |
+
|
| 221 |
+
def __del__(self):
|
| 222 |
+
self.sim.close()
|
| 223 |
+
|
| 224 |
+
def __len__(self):
|
| 225 |
+
return self.size
|
| 226 |
+
|
| 227 |
+
def sample_random_viewpoint(self):
|
| 228 |
+
""" Sample a random viewpoint using the navmesh """
|
| 229 |
+
nav_point = self.sim.pathfinder.get_random_navigable_point()
|
| 230 |
+
|
| 231 |
+
# Sample a random viewpoint height
|
| 232 |
+
viewpoint_height = np.random.uniform(*self.height_range)
|
| 233 |
+
viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
|
| 234 |
+
viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
|
| 235 |
+
return viewpoint_position, viewpoint_orientation, nav_point
|
| 236 |
+
|
| 237 |
+
def sample_other_random_viewpoint(self, observed_point, nav_point):
|
| 238 |
+
""" Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
|
| 239 |
+
other_nav_point = nav_point
|
| 240 |
+
|
| 241 |
+
walk_directions = self.random_step_variance * np.asarray([1,0,1])
|
| 242 |
+
for i in range(self.random_steps_count):
|
| 243 |
+
temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
|
| 244 |
+
# Snapping may return nan when it fails
|
| 245 |
+
if not np.isnan(temp[0]):
|
| 246 |
+
other_nav_point = temp
|
| 247 |
+
|
| 248 |
+
other_viewpoint_height = np.random.uniform(*self.height_range)
|
| 249 |
+
other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
|
| 250 |
+
|
| 251 |
+
# Set viewing direction towards the central point
|
| 252 |
+
rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
|
| 253 |
+
rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
|
| 254 |
+
return position, rotation, other_nav_point
|
| 255 |
+
|
| 256 |
+
def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
|
| 257 |
+
""" Check if a viewpoint is valid and overlaps significantly with a reference one. """
|
| 258 |
+
# Observation
|
| 259 |
+
pixels_count = self.resolution[0] * self.resolution[1]
|
| 260 |
+
valid_fraction = len(other_pointcloud) / pixels_count
|
| 261 |
+
assert valid_fraction <= 1.0 and valid_fraction >= 0.0
|
| 262 |
+
overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
|
| 263 |
+
covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
|
| 264 |
+
is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
|
| 265 |
+
return is_valid, valid_fraction, covisibility
|
| 266 |
+
|
| 267 |
+
def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
|
| 268 |
+
""" Check if a viewpoint is valid and overlaps significantly with a reference one. """
|
| 269 |
+
# Observation
|
| 270 |
+
other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
|
| 271 |
+
return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
|
| 272 |
+
|
| 273 |
+
def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
|
| 274 |
+
agent_state = habitat_sim.AgentState()
|
| 275 |
+
agent_state.position = viewpoint_position
|
| 276 |
+
agent_state.rotation = viewpoint_orientation
|
| 277 |
+
self.agent.set_state(agent_state)
|
| 278 |
+
viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
|
| 279 |
+
_append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
|
| 280 |
+
return viewpoint_observations
|
| 281 |
+
|
| 282 |
+
def __getitem__(self, useless_idx):
|
| 283 |
+
ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
|
| 284 |
+
ref_observations = self.render_viewpoint(ref_position, ref_orientation)
|
| 285 |
+
# Extract point cloud
|
| 286 |
+
ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
|
| 287 |
+
camera_position=ref_position, camera_rotation=ref_orientation)
|
| 288 |
+
|
| 289 |
+
pixels_count = self.resolution[0] * self.resolution[1]
|
| 290 |
+
ref_valid_fraction = len(ref_pointcloud) / pixels_count
|
| 291 |
+
assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
|
| 292 |
+
if ref_valid_fraction < self.minimum_valid_fraction:
|
| 293 |
+
# This should produce a recursion error at some point when something is very wrong.
|
| 294 |
+
return self[0]
|
| 295 |
+
# Pick an reference observed point in the point cloud
|
| 296 |
+
observed_point = np.mean(ref_pointcloud, axis=0)
|
| 297 |
+
|
| 298 |
+
# Add the first image as reference
|
| 299 |
+
viewpoints_observations = [ref_observations]
|
| 300 |
+
viewpoints_covisibility = [ref_valid_fraction]
|
| 301 |
+
viewpoints_positions = [ref_position]
|
| 302 |
+
viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
|
| 303 |
+
viewpoints_clouds = [ref_pointcloud]
|
| 304 |
+
viewpoints_valid_fractions = [ref_valid_fraction]
|
| 305 |
+
|
| 306 |
+
for _ in range(self.views_count - 1):
|
| 307 |
+
# Generate an other viewpoint using some dummy random walk
|
| 308 |
+
successful_sampling = False
|
| 309 |
+
for sampling_attempt in range(self.max_attempts_count):
|
| 310 |
+
position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
|
| 311 |
+
# Observation
|
| 312 |
+
other_viewpoint_observations = self.render_viewpoint(position, rotation)
|
| 313 |
+
other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
|
| 314 |
+
|
| 315 |
+
is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
|
| 316 |
+
if is_valid:
|
| 317 |
+
successful_sampling = True
|
| 318 |
+
break
|
| 319 |
+
if not successful_sampling:
|
| 320 |
+
print("WARNING: Maximum number of attempts reached.")
|
| 321 |
+
# Dirty hack, try using a novel original viewpoint
|
| 322 |
+
return self[0]
|
| 323 |
+
viewpoints_observations.append(other_viewpoint_observations)
|
| 324 |
+
viewpoints_covisibility.append(covisibility)
|
| 325 |
+
viewpoints_positions.append(position)
|
| 326 |
+
viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
|
| 327 |
+
viewpoints_clouds.append(other_pointcloud)
|
| 328 |
+
viewpoints_valid_fractions.append(valid_fraction)
|
| 329 |
+
|
| 330 |
+
# Estimate relations between all pairs of images
|
| 331 |
+
pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
|
| 332 |
+
for i in range(len(viewpoints_observations)):
|
| 333 |
+
pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
|
| 334 |
+
for j in range(i+1, len(viewpoints_observations)):
|
| 335 |
+
overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
|
| 336 |
+
pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
|
| 337 |
+
pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
|
| 338 |
+
|
| 339 |
+
# IoU is relative to the image 0
|
| 340 |
+
data = {"observations": viewpoints_observations,
|
| 341 |
+
"positions": np.asarray(viewpoints_positions),
|
| 342 |
+
"orientations": np.asarray(viewpoints_orientations),
|
| 343 |
+
"covisibility_ratios": np.asarray(viewpoints_covisibility),
|
| 344 |
+
"valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
|
| 345 |
+
"pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
if self.transform is not None:
|
| 349 |
+
data = self.transform(data)
|
| 350 |
+
return data
|
| 351 |
+
|
| 352 |
+
def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
|
| 353 |
+
"""
|
| 354 |
+
Return a list of images corresponding to a spiral trajectory from a random starting point.
|
| 355 |
+
Useful to generate nice visualisations.
|
| 356 |
+
Use an even number of half turns to get a nice "C1-continuous" loop effect
|
| 357 |
+
"""
|
| 358 |
+
ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
|
| 359 |
+
ref_observations = self.render_viewpoint(ref_position, ref_orientation)
|
| 360 |
+
ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
|
| 361 |
+
camera_position=ref_position, camera_rotation=ref_orientation)
|
| 362 |
+
pixels_count = self.resolution[0] * self.resolution[1]
|
| 363 |
+
if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
|
| 364 |
+
# Dirty hack: ensure that the valid part of the image is significant
|
| 365 |
+
return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
|
| 366 |
+
|
| 367 |
+
# Pick an observed point in the point cloud
|
| 368 |
+
observed_point = np.mean(ref_pointcloud, axis=0)
|
| 369 |
+
ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
|
| 370 |
+
|
| 371 |
+
images = []
|
| 372 |
+
is_valid = []
|
| 373 |
+
# Spiral trajectory, use_constant orientation
|
| 374 |
+
for i, alpha in enumerate(np.linspace(0, 1, images_count)):
|
| 375 |
+
r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
|
| 376 |
+
theta = alpha * half_turns * np.pi
|
| 377 |
+
x = r * np.cos(theta)
|
| 378 |
+
y = r * np.sin(theta)
|
| 379 |
+
z = 0.0
|
| 380 |
+
position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
|
| 381 |
+
if use_constant_orientation:
|
| 382 |
+
orientation = ref_orientation
|
| 383 |
+
else:
|
| 384 |
+
# trajectory looking at a mean point in front of the ref observation
|
| 385 |
+
orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
|
| 386 |
+
observations = self.render_viewpoint(position, orientation)
|
| 387 |
+
images.append(observations['color'][...,:3])
|
| 388 |
+
_is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
|
| 389 |
+
is_valid.append(_is_valid)
|
| 390 |
+
return images, np.all(is_valid)
|
dust3r/croco/datasets_croco/habitat_sim/pack_metadata_files.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
"""
|
| 4 |
+
Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import glob
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import shutil
|
| 10 |
+
import json
|
| 11 |
+
from datasets.habitat_sim.paths import *
|
| 12 |
+
import argparse
|
| 13 |
+
import collections
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("input_dir")
|
| 18 |
+
parser.add_argument("output_dir")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
input_dirname = args.input_dir
|
| 22 |
+
output_dirname = args.output_dir
|
| 23 |
+
|
| 24 |
+
input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
|
| 25 |
+
|
| 26 |
+
images_count = collections.defaultdict(lambda : 0)
|
| 27 |
+
|
| 28 |
+
os.makedirs(output_dirname)
|
| 29 |
+
for input_filename in tqdm(input_metadata_filenames):
|
| 30 |
+
# Ignore empty files
|
| 31 |
+
with open(input_filename, "r") as f:
|
| 32 |
+
original_metadata = json.load(f)
|
| 33 |
+
if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
|
| 34 |
+
print("No views in", input_filename)
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
relpath = os.path.relpath(input_filename, input_dirname)
|
| 38 |
+
print(relpath)
|
| 39 |
+
|
| 40 |
+
# Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
|
| 41 |
+
# Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
|
| 42 |
+
scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
|
| 43 |
+
metadata = dict()
|
| 44 |
+
for key, value in original_metadata.items():
|
| 45 |
+
if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
|
| 46 |
+
known_path = False
|
| 47 |
+
for dataset, dataset_path in scenes_dataset_paths.items():
|
| 48 |
+
if value.startswith(dataset_path):
|
| 49 |
+
value = os.path.join(dataset, os.path.relpath(value, dataset_path))
|
| 50 |
+
known_path = True
|
| 51 |
+
break
|
| 52 |
+
if not known_path:
|
| 53 |
+
raise KeyError("Unknown path:" + value)
|
| 54 |
+
metadata[key] = value
|
| 55 |
+
|
| 56 |
+
# Compile some general statistics while packing data
|
| 57 |
+
scene_split = metadata["scene"].split("/")
|
| 58 |
+
upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
|
| 59 |
+
images_count[upper_level] += len(metadata["multiviews"])
|
| 60 |
+
|
| 61 |
+
output_filename = os.path.join(output_dirname, relpath)
|
| 62 |
+
os.makedirs(os.path.dirname(output_filename), exist_ok=True)
|
| 63 |
+
with open(output_filename, "w") as f:
|
| 64 |
+
json.dump(metadata, f)
|
| 65 |
+
|
| 66 |
+
# Print statistics
|
| 67 |
+
print("Images count:")
|
| 68 |
+
for upper_level, count in images_count.items():
|
| 69 |
+
print(f"- {upper_level}: {count}")
|
dust3r/croco/datasets_croco/habitat_sim/paths.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Paths to Habitat-Sim scenes
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import collections
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Hardcoded path to the different scene datasets
|
| 15 |
+
SCENES_DATASET = {
|
| 16 |
+
"hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
|
| 17 |
+
"gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
|
| 18 |
+
"habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
|
| 19 |
+
"replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
|
| 20 |
+
"replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
|
| 21 |
+
"replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
|
| 22 |
+
"scannet": "./data/habitat-sim/scene_datasets/scannet/"
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
|
| 26 |
+
|
| 27 |
+
def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
|
| 28 |
+
scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
|
| 29 |
+
scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
|
| 30 |
+
navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
|
| 31 |
+
scenes_data = []
|
| 32 |
+
for idx in range(len(scenes)):
|
| 33 |
+
output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
|
| 34 |
+
# Add scene
|
| 35 |
+
data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
|
| 36 |
+
scene = scenes[idx] + ".scene_instance.json",
|
| 37 |
+
navmesh = os.path.join(base_path, navmeshes[idx]),
|
| 38 |
+
output_dir = output_dir)
|
| 39 |
+
scenes_data.append(data)
|
| 40 |
+
return scenes_data
|
| 41 |
+
|
| 42 |
+
def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
|
| 43 |
+
scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
|
| 44 |
+
scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
|
| 45 |
+
navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
|
| 46 |
+
scenes_data = []
|
| 47 |
+
for idx in range(len(scenes)):
|
| 48 |
+
output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
|
| 49 |
+
data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
|
| 50 |
+
scene = scenes[idx],
|
| 51 |
+
navmesh = "",
|
| 52 |
+
output_dir = output_dir)
|
| 53 |
+
scenes_data.append(data)
|
| 54 |
+
return scenes_data
|
| 55 |
+
|
| 56 |
+
def list_replica_scenes(base_output_dir, base_path):
|
| 57 |
+
scenes_data = []
|
| 58 |
+
for scene_id in os.listdir(base_path):
|
| 59 |
+
scene = os.path.join(base_path, scene_id, "mesh.ply")
|
| 60 |
+
navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
|
| 61 |
+
scene_dataset_config_file = ""
|
| 62 |
+
output_dir = os.path.join(base_output_dir, scene_id)
|
| 63 |
+
# Add scene only if it does not exist already, or if exist_ok
|
| 64 |
+
data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
|
| 65 |
+
scene = scene,
|
| 66 |
+
navmesh = navmesh,
|
| 67 |
+
output_dir = output_dir)
|
| 68 |
+
scenes_data.append(data)
|
| 69 |
+
return scenes_data
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def list_scenes(base_output_dir, base_path):
|
| 73 |
+
"""
|
| 74 |
+
Generic method iterating through a base_path folder to find scenes.
|
| 75 |
+
"""
|
| 76 |
+
scenes_data = []
|
| 77 |
+
for root, dirs, files in os.walk(base_path, followlinks=True):
|
| 78 |
+
folder_scenes_data = []
|
| 79 |
+
for file in files:
|
| 80 |
+
name, ext = os.path.splitext(file)
|
| 81 |
+
if ext == ".glb":
|
| 82 |
+
scene = os.path.join(root, name + ".glb")
|
| 83 |
+
navmesh = os.path.join(root, name + ".navmesh")
|
| 84 |
+
if not os.path.exists(navmesh):
|
| 85 |
+
navmesh = ""
|
| 86 |
+
relpath = os.path.relpath(root, base_path)
|
| 87 |
+
output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
|
| 88 |
+
data = SceneData(scene_dataset_config_file="",
|
| 89 |
+
scene = scene,
|
| 90 |
+
navmesh = navmesh,
|
| 91 |
+
output_dir = output_dir)
|
| 92 |
+
folder_scenes_data.append(data)
|
| 93 |
+
|
| 94 |
+
# Specific check for HM3D:
|
| 95 |
+
# When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
|
| 96 |
+
basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
|
| 97 |
+
if len(basis_scenes) != 0:
|
| 98 |
+
folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
|
| 99 |
+
|
| 100 |
+
scenes_data.extend(folder_scenes_data)
|
| 101 |
+
return scenes_data
|
| 102 |
+
|
| 103 |
+
def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
|
| 104 |
+
scenes_data = []
|
| 105 |
+
|
| 106 |
+
# HM3D
|
| 107 |
+
for split in ("minival", "train", "val", "examples"):
|
| 108 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
|
| 109 |
+
base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
|
| 110 |
+
|
| 111 |
+
# Gibson
|
| 112 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
|
| 113 |
+
base_path=scenes_dataset_paths["gibson"])
|
| 114 |
+
|
| 115 |
+
# Habitat test scenes (just a few)
|
| 116 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
|
| 117 |
+
base_path=scenes_dataset_paths["habitat-test-scenes"])
|
| 118 |
+
|
| 119 |
+
# ReplicaCAD (baked lightning)
|
| 120 |
+
scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
|
| 121 |
+
|
| 122 |
+
# ScanNet
|
| 123 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
|
| 124 |
+
base_path=scenes_dataset_paths["scannet"])
|
| 125 |
+
|
| 126 |
+
# Replica
|
| 127 |
+
list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
|
| 128 |
+
base_path=scenes_dataset_paths["replica"])
|
| 129 |
+
return scenes_data
|
dust3r/croco/datasets_croco/pairs_dataset.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from datasets.transforms import get_pair_transforms
|
| 9 |
+
|
| 10 |
+
def load_image(impath):
|
| 11 |
+
return Image.open(impath)
|
| 12 |
+
|
| 13 |
+
def load_pairs_from_cache_file(fname, root=''):
|
| 14 |
+
assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
|
| 15 |
+
with open(fname, 'r') as fid:
|
| 16 |
+
lines = fid.read().strip().splitlines()
|
| 17 |
+
pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
|
| 18 |
+
return pairs
|
| 19 |
+
|
| 20 |
+
def load_pairs_from_list_file(fname, root=''):
|
| 21 |
+
assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
|
| 22 |
+
with open(fname, 'r') as fid:
|
| 23 |
+
lines = fid.read().strip().splitlines()
|
| 24 |
+
pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
|
| 25 |
+
return pairs
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def write_cache_file(fname, pairs, root=''):
|
| 29 |
+
if len(root)>0:
|
| 30 |
+
if not root.endswith('/'): root+='/'
|
| 31 |
+
assert os.path.isdir(root)
|
| 32 |
+
s = ''
|
| 33 |
+
for im1, im2 in pairs:
|
| 34 |
+
if len(root)>0:
|
| 35 |
+
assert im1.startswith(root), im1
|
| 36 |
+
assert im2.startswith(root), im2
|
| 37 |
+
s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
|
| 38 |
+
with open(fname, 'w') as fid:
|
| 39 |
+
fid.write(s[:-1])
|
| 40 |
+
|
| 41 |
+
def parse_and_cache_all_pairs(dname, data_dir='./data/'):
|
| 42 |
+
if dname=='habitat_release':
|
| 43 |
+
dirname = os.path.join(data_dir, 'habitat_release')
|
| 44 |
+
assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
|
| 45 |
+
cache_file = os.path.join(dirname, 'pairs.txt')
|
| 46 |
+
assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
|
| 47 |
+
|
| 48 |
+
print('Parsing pairs for dataset: '+dname)
|
| 49 |
+
pairs = []
|
| 50 |
+
for root, dirs, files in os.walk(dirname):
|
| 51 |
+
if 'val' in root: continue
|
| 52 |
+
dirs.sort()
|
| 53 |
+
pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
|
| 54 |
+
print('Found {:,} pairs'.format(len(pairs)))
|
| 55 |
+
print('Writing cache to: '+cache_file)
|
| 56 |
+
write_cache_file(cache_file, pairs, root=dirname)
|
| 57 |
+
|
| 58 |
+
else:
|
| 59 |
+
raise NotImplementedError('Unknown dataset: '+dname)
|
| 60 |
+
|
| 61 |
+
def dnames_to_image_pairs(dnames, data_dir='./data/'):
|
| 62 |
+
"""
|
| 63 |
+
dnames: list of datasets with image pairs, separated by +
|
| 64 |
+
"""
|
| 65 |
+
all_pairs = []
|
| 66 |
+
for dname in dnames.split('+'):
|
| 67 |
+
if dname=='habitat_release':
|
| 68 |
+
dirname = os.path.join(data_dir, 'habitat_release')
|
| 69 |
+
assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
|
| 70 |
+
cache_file = os.path.join(dirname, 'pairs.txt')
|
| 71 |
+
assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
|
| 72 |
+
pairs = load_pairs_from_cache_file(cache_file, root=dirname)
|
| 73 |
+
elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
|
| 74 |
+
dirname = os.path.join(data_dir, dname+'_crops')
|
| 75 |
+
assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
|
| 76 |
+
list_file = os.path.join(dirname, 'listing.txt')
|
| 77 |
+
assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
|
| 78 |
+
pairs = load_pairs_from_list_file(list_file, root=dirname)
|
| 79 |
+
print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
|
| 80 |
+
all_pairs += pairs
|
| 81 |
+
if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
|
| 82 |
+
return all_pairs
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class PairsDataset(Dataset):
|
| 86 |
+
|
| 87 |
+
def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
|
| 90 |
+
self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
|
| 91 |
+
|
| 92 |
+
def __len__(self):
|
| 93 |
+
return len(self.image_pairs)
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, index):
|
| 96 |
+
im1path, im2path = self.image_pairs[index]
|
| 97 |
+
im1 = load_image(im1path)
|
| 98 |
+
im2 = load_image(im2path)
|
| 99 |
+
if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
|
| 100 |
+
return im1, im2
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__=="__main__":
|
| 104 |
+
import argparse
|
| 105 |
+
parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
|
| 106 |
+
parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
|
| 107 |
+
parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
|
| 108 |
+
args = parser.parse_args()
|
| 109 |
+
parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
|
dust3r/croco/datasets_croco/transforms.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms
|
| 6 |
+
import torchvision.transforms.functional as F
|
| 7 |
+
|
| 8 |
+
# "Pair": apply a transform on a pair
|
| 9 |
+
# "Both": apply the exact same transform to both images
|
| 10 |
+
|
| 11 |
+
class ComposePair(torchvision.transforms.Compose):
|
| 12 |
+
def __call__(self, img1, img2):
|
| 13 |
+
for t in self.transforms:
|
| 14 |
+
img1, img2 = t(img1, img2)
|
| 15 |
+
return img1, img2
|
| 16 |
+
|
| 17 |
+
class NormalizeBoth(torchvision.transforms.Normalize):
|
| 18 |
+
def forward(self, img1, img2):
|
| 19 |
+
img1 = super().forward(img1)
|
| 20 |
+
img2 = super().forward(img2)
|
| 21 |
+
return img1, img2
|
| 22 |
+
|
| 23 |
+
class ToTensorBoth(torchvision.transforms.ToTensor):
|
| 24 |
+
def __call__(self, img1, img2):
|
| 25 |
+
img1 = super().__call__(img1)
|
| 26 |
+
img2 = super().__call__(img2)
|
| 27 |
+
return img1, img2
|
| 28 |
+
|
| 29 |
+
class RandomCropPair(torchvision.transforms.RandomCrop):
|
| 30 |
+
# the crop will be intentionally different for the two images with this class
|
| 31 |
+
def forward(self, img1, img2):
|
| 32 |
+
img1 = super().forward(img1)
|
| 33 |
+
img2 = super().forward(img2)
|
| 34 |
+
return img1, img2
|
| 35 |
+
|
| 36 |
+
class ColorJitterPair(torchvision.transforms.ColorJitter):
|
| 37 |
+
# can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
|
| 38 |
+
def __init__(self, assymetric_prob, **kwargs):
|
| 39 |
+
super().__init__(**kwargs)
|
| 40 |
+
self.assymetric_prob = assymetric_prob
|
| 41 |
+
def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
|
| 42 |
+
for fn_id in fn_idx:
|
| 43 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 44 |
+
img = F.adjust_brightness(img, brightness_factor)
|
| 45 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 46 |
+
img = F.adjust_contrast(img, contrast_factor)
|
| 47 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 48 |
+
img = F.adjust_saturation(img, saturation_factor)
|
| 49 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 50 |
+
img = F.adjust_hue(img, hue_factor)
|
| 51 |
+
return img
|
| 52 |
+
|
| 53 |
+
def forward(self, img1, img2):
|
| 54 |
+
|
| 55 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
| 56 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 57 |
+
)
|
| 58 |
+
img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
|
| 59 |
+
if torch.rand(1) < self.assymetric_prob: # assymetric:
|
| 60 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
| 61 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 62 |
+
)
|
| 63 |
+
img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
|
| 64 |
+
return img1, img2
|
| 65 |
+
|
| 66 |
+
def get_pair_transforms(transform_str, totensor=True, normalize=True):
|
| 67 |
+
# transform_str is eg crop224+color
|
| 68 |
+
trfs = []
|
| 69 |
+
for s in transform_str.split('+'):
|
| 70 |
+
if s.startswith('crop'):
|
| 71 |
+
size = int(s[len('crop'):])
|
| 72 |
+
trfs.append(RandomCropPair(size))
|
| 73 |
+
elif s=='acolor':
|
| 74 |
+
trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
|
| 75 |
+
elif s=='': # if transform_str was ""
|
| 76 |
+
pass
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError('Unknown augmentation: '+s)
|
| 79 |
+
|
| 80 |
+
if totensor:
|
| 81 |
+
trfs.append( ToTensorBoth() )
|
| 82 |
+
if normalize:
|
| 83 |
+
trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
|
| 84 |
+
|
| 85 |
+
if len(trfs)==0:
|
| 86 |
+
return None
|
| 87 |
+
elif len(trfs)==1:
|
| 88 |
+
return trfs
|
| 89 |
+
else:
|
| 90 |
+
return ComposePair(trfs)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
dust3r/croco/demo.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from models.croco import CroCoNet
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torchvision.transforms
|
| 8 |
+
from torchvision.transforms import ToTensor, Normalize, Compose
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu')
|
| 12 |
+
|
| 13 |
+
# load 224x224 images and transform them to tensor
|
| 14 |
+
imagenet_mean = [0.485, 0.456, 0.406]
|
| 15 |
+
imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True)
|
| 16 |
+
imagenet_std = [0.229, 0.224, 0.225]
|
| 17 |
+
imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True)
|
| 18 |
+
trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
|
| 19 |
+
image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
|
| 20 |
+
image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
|
| 21 |
+
|
| 22 |
+
# load model
|
| 23 |
+
ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
|
| 24 |
+
model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device)
|
| 25 |
+
model.eval()
|
| 26 |
+
msg = model.load_state_dict(ckpt['model'], strict=True)
|
| 27 |
+
|
| 28 |
+
# forward
|
| 29 |
+
with torch.inference_mode():
|
| 30 |
+
out, mask, target = model(image1, image2)
|
| 31 |
+
|
| 32 |
+
# the output is normalized, thus use the mean/std of the actual image to go back to RGB space
|
| 33 |
+
patchified = model.patchify(image1)
|
| 34 |
+
mean = patchified.mean(dim=-1, keepdim=True)
|
| 35 |
+
var = patchified.var(dim=-1, keepdim=True)
|
| 36 |
+
decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean)
|
| 37 |
+
# undo imagenet normalization, prepare masked image
|
| 38 |
+
decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
|
| 39 |
+
input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
|
| 40 |
+
ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
|
| 41 |
+
image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
|
| 42 |
+
masked_input_image = ((1 - image_masks) * input_image)
|
| 43 |
+
|
| 44 |
+
# make visualization
|
| 45 |
+
visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4
|
| 46 |
+
B, C, H, W = visualization.shape
|
| 47 |
+
visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W)
|
| 48 |
+
visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
|
| 49 |
+
fname = "demo_output.png"
|
| 50 |
+
visualization.save(fname)
|
| 51 |
+
print('Visualization save in '+fname)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__=="__main__":
|
| 55 |
+
main()
|
dust3r/croco/interactive_demo.ipynb
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Interactive demo of Cross-view Completion."
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": null,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
|
| 17 |
+
"# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": null,
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [],
|
| 25 |
+
"source": [
|
| 26 |
+
"import torch\n",
|
| 27 |
+
"import numpy as np\n",
|
| 28 |
+
"from models.croco import CroCoNet\n",
|
| 29 |
+
"from ipywidgets import interact, interactive, fixed, interact_manual\n",
|
| 30 |
+
"import ipywidgets as widgets\n",
|
| 31 |
+
"import matplotlib.pyplot as plt\n",
|
| 32 |
+
"import quaternion\n",
|
| 33 |
+
"import models.masking"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": [
|
| 40 |
+
"### Load CroCo model"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": null,
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n",
|
| 50 |
+
"model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n",
|
| 51 |
+
"msg = model.load_state_dict(ckpt['model'], strict=True)\n",
|
| 52 |
+
"use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
|
| 53 |
+
"device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
|
| 54 |
+
"model = model.eval()\n",
|
| 55 |
+
"model = model.to(device=device)\n",
|
| 56 |
+
"print(msg)\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n",
|
| 59 |
+
" \"\"\"\n",
|
| 60 |
+
" Perform Cross-View completion using two input images, specified using Numpy arrays.\n",
|
| 61 |
+
" \"\"\"\n",
|
| 62 |
+
" # Replace the mask generator\n",
|
| 63 |
+
" model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
" # ImageNet-1k color normalization\n",
|
| 66 |
+
" imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n",
|
| 67 |
+
" imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n",
|
| 68 |
+
"\n",
|
| 69 |
+
" normalize_input_colors = True\n",
|
| 70 |
+
" is_output_normalized = True\n",
|
| 71 |
+
" with torch.no_grad():\n",
|
| 72 |
+
" # Cast data to torch\n",
|
| 73 |
+
" target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
|
| 74 |
+
" ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" if normalize_input_colors:\n",
|
| 77 |
+
" ref_image = (ref_image - imagenet_mean) / imagenet_std\n",
|
| 78 |
+
" target_image = (target_image - imagenet_mean) / imagenet_std\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" out, mask, _ = model(target_image, ref_image)\n",
|
| 81 |
+
" # # get target\n",
|
| 82 |
+
" if not is_output_normalized:\n",
|
| 83 |
+
" predicted_image = model.unpatchify(out)\n",
|
| 84 |
+
" else:\n",
|
| 85 |
+
" # The output only contains higher order information,\n",
|
| 86 |
+
" # we retrieve mean and standard deviation from the actual target image\n",
|
| 87 |
+
" patchified = model.patchify(target_image)\n",
|
| 88 |
+
" mean = patchified.mean(dim=-1, keepdim=True)\n",
|
| 89 |
+
" var = patchified.var(dim=-1, keepdim=True)\n",
|
| 90 |
+
" pred_renorm = out * (var + 1.e-6)**.5 + mean\n",
|
| 91 |
+
" predicted_image = model.unpatchify(pred_renorm)\n",
|
| 92 |
+
"\n",
|
| 93 |
+
" image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n",
|
| 94 |
+
" masked_target_image = (1 - image_masks) * target_image\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" if not reconstruct_unmasked_patches:\n",
|
| 97 |
+
" # Replace unmasked patches by their actual values\n",
|
| 98 |
+
" predicted_image = predicted_image * image_masks + masked_target_image\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" # Unapply color normalization\n",
|
| 101 |
+
" if normalize_input_colors:\n",
|
| 102 |
+
" predicted_image = predicted_image * imagenet_std + imagenet_mean\n",
|
| 103 |
+
" masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n",
|
| 104 |
+
" \n",
|
| 105 |
+
" # Cast to Numpy\n",
|
| 106 |
+
" masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
|
| 107 |
+
" predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
|
| 108 |
+
" return masked_target_image, predicted_image"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "markdown",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"source": [
|
| 115 |
+
"### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "code",
|
| 120 |
+
"execution_count": null,
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"outputs": [],
|
| 123 |
+
"source": [
|
| 124 |
+
"import os\n",
|
| 125 |
+
"os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n",
|
| 126 |
+
"os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n",
|
| 127 |
+
"import habitat_sim\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n",
|
| 130 |
+
"navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"sim_cfg = habitat_sim.SimulatorConfiguration()\n",
|
| 133 |
+
"if use_gpu: sim_cfg.gpu_device_id = 0\n",
|
| 134 |
+
"sim_cfg.scene_id = scene\n",
|
| 135 |
+
"sim_cfg.load_semantic_mesh = False\n",
|
| 136 |
+
"rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n",
|
| 137 |
+
"rgb_sensor_spec.uuid = \"color\"\n",
|
| 138 |
+
"rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n",
|
| 139 |
+
"rgb_sensor_spec.resolution = (224,224)\n",
|
| 140 |
+
"rgb_sensor_spec.hfov = 56.56\n",
|
| 141 |
+
"rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n",
|
| 142 |
+
"rgb_sensor_spec.orientation = [0, 0, 0]\n",
|
| 143 |
+
"agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n",
|
| 147 |
+
"sim = habitat_sim.Simulator(cfg)\n",
|
| 148 |
+
"if navmesh is not None:\n",
|
| 149 |
+
" sim.pathfinder.load_nav_mesh(navmesh)\n",
|
| 150 |
+
"agent = sim.initialize_agent(agent_id=0)\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"def sample_random_viewpoint():\n",
|
| 153 |
+
" \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n",
|
| 154 |
+
" nav_point = sim.pathfinder.get_random_navigable_point()\n",
|
| 155 |
+
" # Sample a random viewpoint height\n",
|
| 156 |
+
" viewpoint_height = np.random.uniform(1.0, 1.6)\n",
|
| 157 |
+
" viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n",
|
| 158 |
+
" viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n",
|
| 159 |
+
" return viewpoint_position, viewpoint_orientation\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"def render_viewpoint(position, orientation):\n",
|
| 162 |
+
" agent_state = habitat_sim.AgentState()\n",
|
| 163 |
+
" agent_state.position = position\n",
|
| 164 |
+
" agent_state.rotation = orientation\n",
|
| 165 |
+
" agent.set_state(agent_state)\n",
|
| 166 |
+
" viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n",
|
| 167 |
+
" image = viewpoint_observations['color'][:,:,:3]\n",
|
| 168 |
+
" image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n",
|
| 169 |
+
" return image"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "markdown",
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"source": [
|
| 176 |
+
"### Sample a random reference view"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": null,
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"outputs": [],
|
| 184 |
+
"source": [
|
| 185 |
+
"ref_position, ref_orientation = sample_random_viewpoint()\n",
|
| 186 |
+
"ref_image = render_viewpoint(ref_position, ref_orientation)\n",
|
| 187 |
+
"plt.clf()\n",
|
| 188 |
+
"fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n",
|
| 189 |
+
"axes[0,0].imshow(ref_image)\n",
|
| 190 |
+
"for ax in axes.flatten():\n",
|
| 191 |
+
" ax.set_xticks([])\n",
|
| 192 |
+
" ax.set_yticks([])"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "markdown",
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"source": [
|
| 199 |
+
"### Interactive cross-view completion using CroCo"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"reconstruct_unmasked_patches = False\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"def show_demo(masking_ratio, x, y, z, panorama, elevation):\n",
|
| 211 |
+
" R = quaternion.as_rotation_matrix(ref_orientation)\n",
|
| 212 |
+
" target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n",
|
| 213 |
+
" target_orientation = (ref_orientation\n",
|
| 214 |
+
" * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n",
|
| 215 |
+
" * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n",
|
| 216 |
+
" \n",
|
| 217 |
+
" ref_image = render_viewpoint(ref_position, ref_orientation)\n",
|
| 218 |
+
" target_image = render_viewpoint(target_position, target_orientation)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
" masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n",
|
| 223 |
+
" axes[0].imshow(ref_image)\n",
|
| 224 |
+
" axes[0].set_xlabel(\"Reference\")\n",
|
| 225 |
+
" axes[1].imshow(masked_target_image)\n",
|
| 226 |
+
" axes[1].set_xlabel(\"Masked target\")\n",
|
| 227 |
+
" axes[2].imshow(predicted_image)\n",
|
| 228 |
+
" axes[2].set_xlabel(\"Reconstruction\") \n",
|
| 229 |
+
" axes[3].imshow(target_image)\n",
|
| 230 |
+
" axes[3].set_xlabel(\"Target\")\n",
|
| 231 |
+
" for ax in axes.flatten():\n",
|
| 232 |
+
" ax.set_xticks([])\n",
|
| 233 |
+
" ax.set_yticks([])\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"interact(show_demo,\n",
|
| 236 |
+
" masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n",
|
| 237 |
+
" x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
|
| 238 |
+
" y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
|
| 239 |
+
" z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
|
| 240 |
+
" panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n",
|
| 241 |
+
" elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));"
|
| 242 |
+
]
|
| 243 |
+
}
|
| 244 |
+
],
|
| 245 |
+
"metadata": {
|
| 246 |
+
"kernelspec": {
|
| 247 |
+
"display_name": "Python 3 (ipykernel)",
|
| 248 |
+
"language": "python",
|
| 249 |
+
"name": "python3"
|
| 250 |
+
},
|
| 251 |
+
"language_info": {
|
| 252 |
+
"codemirror_mode": {
|
| 253 |
+
"name": "ipython",
|
| 254 |
+
"version": 3
|
| 255 |
+
},
|
| 256 |
+
"file_extension": ".py",
|
| 257 |
+
"mimetype": "text/x-python",
|
| 258 |
+
"name": "python",
|
| 259 |
+
"nbconvert_exporter": "python",
|
| 260 |
+
"pygments_lexer": "ipython3",
|
| 261 |
+
"version": "3.7.13"
|
| 262 |
+
},
|
| 263 |
+
"vscode": {
|
| 264 |
+
"interpreter": {
|
| 265 |
+
"hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67"
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
},
|
| 269 |
+
"nbformat": 4,
|
| 270 |
+
"nbformat_minor": 2
|
| 271 |
+
}
|