Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +75 -0
- .gitignore +132 -0
- .gitmodules +0 -0
- .gradio/certificate.pem +31 -0
- LICENSE +22 -0
- LICENSE-VGGT +115 -0
- README.md +40 -8
- check_model_size.py +85 -0
- configs/config.yaml +50 -0
- configs/model/dpm.yaml +3 -0
- configs/visualise.yaml +13 -0
- dpm/aggregator.py +366 -0
- dpm/decoder.py +416 -0
- dpm/model.py +149 -0
- examples/videos/camel.mp4 +3 -0
- examples/videos/car.mp4 +3 -0
- examples/videos/figure1.mp4 +3 -0
- examples/videos/figure2.mp4 +3 -0
- examples/videos/figure3.mp4 +3 -0
- examples/videos/goldfish.mp4 +3 -0
- examples/videos/horse.mp4 +3 -0
- examples/videos/paragliding.mp4 +3 -0
- examples/videos/pstudio.mp4 +3 -0
- examples/videos/stroller.mp4 +3 -0
- examples/videos/swing.mp4 +3 -0
- examples/videos/tennis.mp4 +3 -0
- examples/videos/tesla.mp4 +3 -0
- gradio_demo.py +981 -0
- input_images_20260127_052216_587020/images/000000.png +3 -0
- input_images_20260127_052216_587020/images/000001.png +3 -0
- input_images_20260127_052216_587020/images/000002.png +3 -0
- input_images_20260127_052216_587020/images/000003.png +3 -0
- input_images_20260127_052216_587020/images/000004.png +3 -0
- input_images_20260127_052216_587020/images/000005.png +3 -0
- input_images_20260127_052216_587020/images/000006.png +3 -0
- input_images_20260127_052216_587020/images/000007.png +3 -0
- input_images_20260127_052216_587020/images/000008.png +3 -0
- input_images_20260127_052216_587020/images/000009.png +3 -0
- input_images_20260127_052216_587020/images/000010.png +3 -0
- input_images_20260127_052216_587020/images/000011.png +3 -0
- input_images_20260127_052216_587020/images/000012.png +3 -0
- input_images_20260127_052216_587020/images/000013.png +3 -0
- input_images_20260127_052216_587020/images/000014.png +3 -0
- input_images_20260127_052216_587020/images/000015.png +3 -0
- input_images_20260127_052216_587020/images/000016.png +3 -0
- input_images_20260127_052216_587020/images/000017.png +3 -0
- input_images_20260127_052439_748027/images/000000.png +3 -0
- input_images_20260127_052439_748027/images/000001.png +3 -0
- input_images_20260127_052439_748027/images/000002.png +3 -0
- input_images_20260127_052439_748027/images/000003.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,78 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/videos/camel.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/videos/car.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/videos/figure1.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/videos/figure2.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/videos/figure3.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/videos/goldfish.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/videos/horse.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/videos/paragliding.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/videos/pstudio.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
examples/videos/stroller.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
examples/videos/swing.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
examples/videos/tennis.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
examples/videos/tesla.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
input_images_20260127_052216_587020/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
input_images_20260127_052216_587020/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
input_images_20260127_052216_587020/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
input_images_20260127_052216_587020/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
input_images_20260127_052216_587020/images/000004.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
input_images_20260127_052216_587020/images/000005.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
input_images_20260127_052216_587020/images/000006.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
input_images_20260127_052216_587020/images/000007.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
input_images_20260127_052216_587020/images/000008.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
input_images_20260127_052216_587020/images/000009.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
input_images_20260127_052216_587020/images/000010.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
input_images_20260127_052216_587020/images/000011.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
input_images_20260127_052216_587020/images/000012.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
input_images_20260127_052216_587020/images/000013.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
input_images_20260127_052216_587020/images/000014.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
input_images_20260127_052216_587020/images/000015.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
input_images_20260127_052216_587020/images/000016.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
input_images_20260127_052216_587020/images/000017.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
input_images_20260127_052439_748027/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
input_images_20260127_052439_748027/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
input_images_20260127_052439_748027/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
input_images_20260127_052439_748027/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
input_images_20260127_052439_748027/images/000004.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
input_images_20260127_052439_748027/images/000005.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
input_images_20260127_052439_748027/images/000006.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
input_images_20260127_052439_748027/images/000007.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
input_images_20260127_052439_748027/images/000008.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
input_images_20260127_052439_748027/images/000009.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
input_images_20260127_052521_840381/images/cam00.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
input_images_20260127_052521_840381/images/cam01.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
input_images_20260127_052521_840381/images/cam18.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
input_images_20260127_052521_840381/images/cam19.png filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
input_images_20260127_053256_343581/images/cam00.png filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
input_images_20260127_053256_343581/images/cam01.png filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
input_images_20260127_053256_343581/images/cam18.png filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
input_images_20260127_053256_343581/images/cam19.png filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
input_images_20260127_053630_522657/images/cam18.png filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
input_images_20260127_053630_522657/images/cam19.png filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
input_images_20260127_161749_937508/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
input_images_20260127_161749_937508/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
input_images_20260127_161749_937508/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
input_images_20260127_161749_937508/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
input_images_20260127_162743_468054/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
input_images_20260127_162743_468054/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
input_images_20260127_162743_468054/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
input_images_20260127_162743_468054/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
input_images_20260127_163859_002158/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
input_images_20260127_163859_002158/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
input_images_20260127_163859_002158/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
input_images_20260127_163859_002158/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
input_images_20260127_170350_777007/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
input_images_20260127_170350_777007/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
input_images_20260127_170350_777007/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
input_images_20260127_170350_777007/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
input_images_20260127_170350_777007/images/000004.png filter=lfs diff=lfs merge=lfs -text
|
| 104 |
+
input_images_20260127_170350_777007/images/000005.png filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
input_images_20260127_170350_777007/images/000006.png filter=lfs diff=lfs merge=lfs -text
|
| 106 |
+
input_images_20260127_170350_777007/images/000007.png filter=lfs diff=lfs merge=lfs -text
|
| 107 |
+
input_images_20260127_170350_777007/images/000008.png filter=lfs diff=lfs merge=lfs -text
|
| 108 |
+
input_images_20260127_170350_777007/images/000009.png filter=lfs diff=lfs merge=lfs -text
|
| 109 |
+
input_images_20260127_170350_777007/images/000010.png filter=lfs diff=lfs merge=lfs -text
|
| 110 |
+
input_images_20260127_170350_777007/images/000011.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
checkpoints/
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
pip-wheel-metadata/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
target/
|
| 79 |
+
|
| 80 |
+
# Jupyter Notebook
|
| 81 |
+
.ipynb_checkpoints
|
| 82 |
+
|
| 83 |
+
# IPython
|
| 84 |
+
profile_default/
|
| 85 |
+
ipython_config.py
|
| 86 |
+
|
| 87 |
+
# pyenv
|
| 88 |
+
.python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 98 |
+
__pypackages__/
|
| 99 |
+
|
| 100 |
+
# Celery stuff
|
| 101 |
+
celerybeat-schedule
|
| 102 |
+
celerybeat.pid
|
| 103 |
+
|
| 104 |
+
# SageMath parsed files
|
| 105 |
+
*.sage.py
|
| 106 |
+
|
| 107 |
+
# Environments
|
| 108 |
+
.env
|
| 109 |
+
.venv
|
| 110 |
+
env/
|
| 111 |
+
venv/
|
| 112 |
+
ENV/
|
| 113 |
+
env.bak/
|
| 114 |
+
venv.bak/
|
| 115 |
+
|
| 116 |
+
# Spyder project settings
|
| 117 |
+
.spyderproject
|
| 118 |
+
.spyproject
|
| 119 |
+
|
| 120 |
+
# Rope project settings
|
| 121 |
+
.ropeproject
|
| 122 |
+
|
| 123 |
+
# mkdocs documentation
|
| 124 |
+
/site
|
| 125 |
+
|
| 126 |
+
# mypy
|
| 127 |
+
.mypy_cache/
|
| 128 |
+
.dmypy.json
|
| 129 |
+
dmypy.json
|
| 130 |
+
|
| 131 |
+
# Pyre type checker
|
| 132 |
+
.pyre/
|
.gitmodules
ADDED
|
File without changes
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
LICENSE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Eldar Insafutdinov, Edgar Sucar
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
LICENSE-VGGT
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VGGT License
|
| 2 |
+
|
| 3 |
+
v1 Last Updated: July 29, 2025
|
| 4 |
+
|
| 5 |
+
“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
|
| 6 |
+
|
| 7 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
| 11 |
+
Research Materials distributed by Meta.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 15 |
+
|
| 16 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 17 |
+
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 18 |
+
|
| 19 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
1. License Rights and Redistribution.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
|
| 26 |
+
|
| 27 |
+
b. Redistribution and Use.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
|
| 37 |
+
2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 41 |
+
|
| 42 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 43 |
+
|
| 44 |
+
5. Intellectual Property.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 48 |
+
|
| 49 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
|
| 50 |
+
|
| 51 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
|
| 52 |
+
|
| 53 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
Acceptable Use Policy
|
| 60 |
+
|
| 61 |
+
Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
|
| 62 |
+
|
| 63 |
+
As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.
|
| 64 |
+
|
| 65 |
+
Prohibited Uses
|
| 66 |
+
|
| 67 |
+
You agree you will not use, or allow others to use, Research Materials to:
|
| 68 |
+
|
| 69 |
+
Violate the law or others’ rights, including to:
|
| 70 |
+
Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
| 71 |
+
Violence or terrorism
|
| 72 |
+
Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
| 73 |
+
Human trafficking, exploitation, and sexual violence
|
| 74 |
+
The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
| 75 |
+
Sexual solicitation
|
| 76 |
+
Any other criminal activity
|
| 77 |
+
|
| 78 |
+
Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
| 79 |
+
|
| 80 |
+
Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
| 81 |
+
|
| 82 |
+
Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
| 83 |
+
|
| 84 |
+
Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
| 85 |
+
|
| 86 |
+
Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials
|
| 87 |
+
|
| 88 |
+
Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
| 89 |
+
|
| 90 |
+
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
|
| 91 |
+
|
| 92 |
+
Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
| 93 |
+
|
| 94 |
+
Guns and illegal weapons (including weapon development)
|
| 95 |
+
|
| 96 |
+
Illegal drugs and regulated/controlled substances
|
| 97 |
+
Operation of critical infrastructure, transportation technologies, or heavy machinery
|
| 98 |
+
|
| 99 |
+
Self-harm or harm to others, including suicide, cutting, and eating disorders
|
| 100 |
+
Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
| 101 |
+
|
| 102 |
+
3. Intentionally deceive or mislead others, including use of Research Materials related to the following:
|
| 103 |
+
|
| 104 |
+
Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
| 105 |
+
Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
| 106 |
+
|
| 107 |
+
Generating, promoting, or further distributing spam
|
| 108 |
+
|
| 109 |
+
Impersonating another individual without consent, authorization, or legal right
|
| 110 |
+
|
| 111 |
+
Representing that outputs of research materials or outputs from technology using Research Materials are human-generated
|
| 112 |
+
|
| 113 |
+
Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
| 114 |
+
|
| 115 |
+
4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
|
README.md
CHANGED
|
@@ -1,12 +1,44 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
colorFrom: indigo
|
| 5 |
-
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
---
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: vdpm
|
| 3 |
+
app_file: gradio_demo.py
|
|
|
|
|
|
|
| 4 |
sdk: gradio
|
| 5 |
+
sdk_version: 5.17.1
|
|
|
|
|
|
|
| 6 |
---
|
| 7 |
+
<div align="center">
|
| 8 |
+
<h1>V-DPM: 4D Video Reconstruction with Dynamic Point Maps</h1>
|
| 9 |
|
| 10 |
+
<a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
|
| 11 |
+
<a href="https://huggingface.co/spaces/edgarsucar/vdpm"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
|
| 12 |
+
|
| 13 |
+
**[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
[Edgar Sucar](https://edgarsucar.github.io/)\*, [Eldar Insafutdinov](https://eldar.insafutdinov.com/)\*, [Zihang Lai](https://scholar.google.com/citations?user=31eXgMYAAAAJ), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/)
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
## Setup
|
| 20 |
+
|
| 21 |
+
First, clone the repository and setup a virtual environment with [uv](https://github.com/astral-sh/uv):
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
git clone git@github.com:eldar/vdpm.git
|
| 25 |
+
cd vdpm
|
| 26 |
+
uv venv --python 3.12
|
| 27 |
+
. .venv/bin/activate
|
| 28 |
+
|
| 29 |
+
# Install PyTorch with CUDA 11.8 first
|
| 30 |
+
uv pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
|
| 31 |
+
|
| 32 |
+
# Then install remaining dependencies
|
| 33 |
+
uv pip install -r requirements.txt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Viser demo
|
| 37 |
+
```bash
|
| 38 |
+
python visualise.py ++vis.input_video=examples/videos/camel.mp4
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Gradio demo
|
| 42 |
+
```bash
|
| 43 |
+
python gradio_demo.py
|
| 44 |
+
```
|
check_model_size.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Add parent directory to path
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 7 |
+
|
| 8 |
+
def check_model_memory():
|
| 9 |
+
# Simple config object
|
| 10 |
+
class SimpleConfig:
|
| 11 |
+
class ModelConfig:
|
| 12 |
+
decoder_depth = 4
|
| 13 |
+
model = ModelConfig()
|
| 14 |
+
|
| 15 |
+
cfg = SimpleConfig()
|
| 16 |
+
|
| 17 |
+
# Import after path is set
|
| 18 |
+
from dpm.model import VDPM
|
| 19 |
+
|
| 20 |
+
# Create model on CPU first to count parameters
|
| 21 |
+
print("Creating model...")
|
| 22 |
+
model = VDPM(cfg)
|
| 23 |
+
|
| 24 |
+
# Count parameters
|
| 25 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 26 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 27 |
+
|
| 28 |
+
print(f"\n{'='*60}")
|
| 29 |
+
print(f"MODEL SIZE ANALYSIS FOR RTX 3070 Ti (8GB)")
|
| 30 |
+
print(f"{'='*60}")
|
| 31 |
+
print(f"Total parameters: {total_params:,}")
|
| 32 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 33 |
+
print(f"\nEstimated model weights memory:")
|
| 34 |
+
print(f" - FP32 (float32): {total_params * 4 / 1024**3:.2f} GB")
|
| 35 |
+
print(f" - FP16 (float16): {total_params * 2 / 1024**3:.2f} GB")
|
| 36 |
+
print(f" - BF16 (bfloat16): {total_params * 2 / 1024**3:.2f} GB")
|
| 37 |
+
print(f" - INT8 (quantized): {total_params * 1 / 1024**3:.2f} GB <-- RECOMMENDED for 8GB GPU")
|
| 38 |
+
|
| 39 |
+
# Estimate activation memory for typical input
|
| 40 |
+
batch_size = 1
|
| 41 |
+
num_frames = 5 # typical video length
|
| 42 |
+
img_size = 518
|
| 43 |
+
print(f"\nEstimated activation memory (batch={batch_size}, frames={num_frames}, img_size={img_size}):")
|
| 44 |
+
|
| 45 |
+
# Input images: [B, S, 3, H, W]
|
| 46 |
+
input_mem = batch_size * num_frames * 3 * img_size * img_size * 4 / 1024**3
|
| 47 |
+
print(f" - Input images (FP32): {input_mem:.2f} GB")
|
| 48 |
+
|
| 49 |
+
# Rough estimate for activations (can be 2-4x model size during forward pass)
|
| 50 |
+
activation_mem_estimate = total_params * 2 * 3 / 1024**3 # conservative estimate
|
| 51 |
+
print(f" - Activations (estimate): {activation_mem_estimate:.2f} GB")
|
| 52 |
+
|
| 53 |
+
# Calculate total for different precision modes
|
| 54 |
+
total_fp16 = (total_params * 2 / 1024**3) + input_mem + activation_mem_estimate
|
| 55 |
+
total_int8 = (total_params * 1 / 1024**3) + input_mem + (activation_mem_estimate * 0.6) # INT8 reduces activations too
|
| 56 |
+
|
| 57 |
+
print(f"\nTotal estimated GPU memory needed:")
|
| 58 |
+
print(f" - With FP16/BF16: {total_fp16:.2f} GB")
|
| 59 |
+
print(f" - With INT8 quantization: {total_int8:.2f} GB <-- FITS IN 8GB!")
|
| 60 |
+
print(f"Your RTX 3070 Ti has: 8 GB VRAM")
|
| 61 |
+
|
| 62 |
+
if total_int8 <= 8:
|
| 63 |
+
print(f"\n✓ With INT8 quantization, model will fit in GPU memory!")
|
| 64 |
+
print(f" Set USE_QUANTIZATION = True in gradio_demo.py")
|
| 65 |
+
elif total_fp16 > 8:
|
| 66 |
+
print(f"\n⚠️ WARNING: Even with INT8 ({total_int8:.2f} GB), memory is tight")
|
| 67 |
+
print(f" Recommendations:")
|
| 68 |
+
print(f" 1. Use INT8 quantization (USE_QUANTIZATION = True)")
|
| 69 |
+
print(f" 2. Reduce number of input frames to {num_frames} or fewer")
|
| 70 |
+
print(f" 3. Clear CUDA cache between batches")
|
| 71 |
+
else:
|
| 72 |
+
print(f"\n✓ Model should fit with FP16!")
|
| 73 |
+
|
| 74 |
+
print(f"{'='*60}\n")
|
| 75 |
+
|
| 76 |
+
# Check actual GPU memory if CUDA available
|
| 77 |
+
if torch.cuda.is_available():
|
| 78 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 79 |
+
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
| 80 |
+
print(f"Current GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
| 81 |
+
print(f"Current GPU memory cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
check_model_memory()
|
| 85 |
+
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- hydra: defaults
|
| 4 |
+
- model: dpm
|
| 5 |
+
|
| 6 |
+
config:
|
| 7 |
+
exp_name: "debug"
|
| 8 |
+
file: "config.yaml"
|
| 9 |
+
|
| 10 |
+
data_loader:
|
| 11 |
+
batch_size: 2
|
| 12 |
+
num_workers: 8
|
| 13 |
+
dynamic_batch: false
|
| 14 |
+
|
| 15 |
+
train:
|
| 16 |
+
logging: true
|
| 17 |
+
num_gpus: 4
|
| 18 |
+
amp: bfloat16
|
| 19 |
+
amp_dpt: false
|
| 20 |
+
dry_run: false
|
| 21 |
+
camera_loss_lambda: 5.0
|
| 22 |
+
|
| 23 |
+
optimiser:
|
| 24 |
+
lr: 0.00005 # absolute lr
|
| 25 |
+
blr: 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 26 |
+
start_epoch:
|
| 27 |
+
epochs: 70
|
| 28 |
+
accum_iter: 1
|
| 29 |
+
warmup_epochs: 3
|
| 30 |
+
min_lr: 1e-06
|
| 31 |
+
|
| 32 |
+
run:
|
| 33 |
+
resume: false
|
| 34 |
+
dirpath: null
|
| 35 |
+
debug: false
|
| 36 |
+
random_seed: 42
|
| 37 |
+
git_hash: null
|
| 38 |
+
log_frequency: 250
|
| 39 |
+
training_progress_bar: false
|
| 40 |
+
save_freq: 5
|
| 41 |
+
eval_freq: 1
|
| 42 |
+
keep_freq: 5
|
| 43 |
+
print_freq: 20
|
| 44 |
+
num_keep_ckpts: 5
|
| 45 |
+
# Old Dust3r params
|
| 46 |
+
world_size: -1
|
| 47 |
+
local_rank: -1
|
| 48 |
+
dist_url: "env://"
|
| 49 |
+
seed: 0
|
| 50 |
+
|
configs/model/dpm.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dpm-video
|
| 2 |
+
pretrained: /work/eldar/models/vggt/VGGT-1B.pt
|
| 3 |
+
decoder_depth: 4
|
configs/visualise.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- model: dpm
|
| 4 |
+
|
| 5 |
+
hydra:
|
| 6 |
+
output_subdir: null # Disable saving of config files.
|
| 7 |
+
job:
|
| 8 |
+
chdir: False
|
| 9 |
+
|
| 10 |
+
vis:
|
| 11 |
+
port: 8080
|
| 12 |
+
input_video:
|
| 13 |
+
|
dpm/aggregator.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE-VGGT file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.checkpoint import checkpoint
|
| 12 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
| 13 |
+
|
| 14 |
+
from vggt.layers import PatchEmbed
|
| 15 |
+
from vggt.layers.block import Block
|
| 16 |
+
from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 17 |
+
from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 22 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Aggregator(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
The Aggregator applies alternating-attention over input frames,
|
| 28 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
img_size (int): Image size in pixels.
|
| 33 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
| 34 |
+
embed_dim (int): Dimension of the token embeddings.
|
| 35 |
+
depth (int): Number of blocks.
|
| 36 |
+
num_heads (int): Number of attention heads.
|
| 37 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
| 38 |
+
num_register_tokens (int): Number of register tokens.
|
| 39 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
| 40 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
| 41 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
| 42 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
| 43 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
| 44 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
| 45 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
| 46 |
+
qk_norm (bool): Whether to apply QK normalization.
|
| 47 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
| 48 |
+
init_values (float): Init scale for layer scale.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
img_size=518,
|
| 54 |
+
patch_size=14,
|
| 55 |
+
embed_dim=1024,
|
| 56 |
+
depth=24,
|
| 57 |
+
num_heads=16,
|
| 58 |
+
mlp_ratio=4.0,
|
| 59 |
+
num_register_tokens=4,
|
| 60 |
+
block_fn=Block,
|
| 61 |
+
qkv_bias=True,
|
| 62 |
+
proj_bias=True,
|
| 63 |
+
ffn_bias=True,
|
| 64 |
+
patch_embed="dinov2_vitl14_reg",
|
| 65 |
+
aa_order=["frame", "global"],
|
| 66 |
+
aa_block_size=1,
|
| 67 |
+
qk_norm=True,
|
| 68 |
+
rope_freq=100,
|
| 69 |
+
init_values=0.01,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
| 74 |
+
|
| 75 |
+
# Initialize rotary position embedding if frequency > 0
|
| 76 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 77 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 78 |
+
|
| 79 |
+
self.frame_blocks = nn.ModuleList(
|
| 80 |
+
[
|
| 81 |
+
block_fn(
|
| 82 |
+
dim=embed_dim,
|
| 83 |
+
num_heads=num_heads,
|
| 84 |
+
mlp_ratio=mlp_ratio,
|
| 85 |
+
qkv_bias=qkv_bias,
|
| 86 |
+
proj_bias=proj_bias,
|
| 87 |
+
ffn_bias=ffn_bias,
|
| 88 |
+
init_values=init_values,
|
| 89 |
+
qk_norm=qk_norm,
|
| 90 |
+
rope=self.rope,
|
| 91 |
+
)
|
| 92 |
+
for _ in range(depth)
|
| 93 |
+
]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.global_blocks = nn.ModuleList(
|
| 97 |
+
[
|
| 98 |
+
block_fn(
|
| 99 |
+
dim=embed_dim,
|
| 100 |
+
num_heads=num_heads,
|
| 101 |
+
mlp_ratio=mlp_ratio,
|
| 102 |
+
qkv_bias=qkv_bias,
|
| 103 |
+
proj_bias=proj_bias,
|
| 104 |
+
ffn_bias=ffn_bias,
|
| 105 |
+
init_values=init_values,
|
| 106 |
+
qk_norm=qk_norm,
|
| 107 |
+
rope=self.rope,
|
| 108 |
+
)
|
| 109 |
+
for _ in range(depth)
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.depth = depth
|
| 114 |
+
self.aa_order = aa_order
|
| 115 |
+
self.patch_size = patch_size
|
| 116 |
+
self.aa_block_size = aa_block_size
|
| 117 |
+
|
| 118 |
+
# Validate that depth is divisible by aa_block_size
|
| 119 |
+
if self.depth % self.aa_block_size != 0:
|
| 120 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 121 |
+
|
| 122 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 123 |
+
|
| 124 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
| 125 |
+
# The same applies for register tokens
|
| 126 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
| 127 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
| 128 |
+
|
| 129 |
+
# The patch tokens start after the camera and register tokens
|
| 130 |
+
self.patch_start_idx = 1 + num_register_tokens
|
| 131 |
+
|
| 132 |
+
self.time_conditioning_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
| 133 |
+
self.patch_start_idx += 1
|
| 134 |
+
|
| 135 |
+
# Initialize parameters with small values
|
| 136 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
| 137 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
| 138 |
+
|
| 139 |
+
# Register normalization constants as buffers
|
| 140 |
+
for name, value in (
|
| 141 |
+
("_resnet_mean", _RESNET_MEAN),
|
| 142 |
+
("_resnet_std", _RESNET_STD),
|
| 143 |
+
):
|
| 144 |
+
self.register_buffer(
|
| 145 |
+
name,
|
| 146 |
+
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
|
| 147 |
+
persistent=False,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.use_reentrant = False # hardcoded to False
|
| 151 |
+
|
| 152 |
+
def __build_patch_embed__(
|
| 153 |
+
self,
|
| 154 |
+
patch_embed,
|
| 155 |
+
img_size,
|
| 156 |
+
patch_size,
|
| 157 |
+
num_register_tokens,
|
| 158 |
+
interpolate_antialias=True,
|
| 159 |
+
interpolate_offset=0.0,
|
| 160 |
+
block_chunks=0,
|
| 161 |
+
init_values=1.0,
|
| 162 |
+
embed_dim=1024,
|
| 163 |
+
):
|
| 164 |
+
"""
|
| 165 |
+
Build the patch embed layer. If 'conv', we use a
|
| 166 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
if "conv" in patch_embed:
|
| 170 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
| 171 |
+
else:
|
| 172 |
+
vit_models = {
|
| 173 |
+
"dinov2_vitl14_reg": vit_large,
|
| 174 |
+
"dinov2_vitb14_reg": vit_base,
|
| 175 |
+
"dinov2_vits14_reg": vit_small,
|
| 176 |
+
"dinov2_vitg2_reg": vit_giant2,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
self.patch_embed = vit_models[patch_embed](
|
| 180 |
+
img_size=img_size,
|
| 181 |
+
patch_size=patch_size,
|
| 182 |
+
num_register_tokens=num_register_tokens,
|
| 183 |
+
interpolate_antialias=interpolate_antialias,
|
| 184 |
+
interpolate_offset=interpolate_offset,
|
| 185 |
+
block_chunks=block_chunks,
|
| 186 |
+
init_values=init_values,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Disable gradient updates for mask token
|
| 190 |
+
if hasattr(self.patch_embed, "mask_token"):
|
| 191 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
images: torch.Tensor,
|
| 196 |
+
) -> Tuple[List[torch.Tensor], int]:
|
| 197 |
+
"""
|
| 198 |
+
Args:
|
| 199 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 200 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
(list[torch.Tensor], int):
|
| 204 |
+
The list of outputs from the attention blocks,
|
| 205 |
+
and the patch_start_idx indicating where patch tokens begin.
|
| 206 |
+
"""
|
| 207 |
+
B, S, C_in, H, W = images.shape
|
| 208 |
+
|
| 209 |
+
if C_in != 3:
|
| 210 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
| 211 |
+
|
| 212 |
+
# Normalize images and reshape for patch embed
|
| 213 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
| 214 |
+
|
| 215 |
+
# Reshape to [B*S, C, H, W] for patch embedding
|
| 216 |
+
images = images.view(B * S, C_in, H, W)
|
| 217 |
+
patch_tokens = self.patch_embed(images)
|
| 218 |
+
|
| 219 |
+
if isinstance(patch_tokens, dict):
|
| 220 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
| 221 |
+
|
| 222 |
+
_, P, C = patch_tokens.shape
|
| 223 |
+
|
| 224 |
+
# Expand camera and register tokens to match batch size and sequence length
|
| 225 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
| 226 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
| 227 |
+
# do something similar for time_conditioning_token
|
| 228 |
+
time_conditioning_token = slice_expand_and_flatten_single(self.time_conditioning_token, B, S)
|
| 229 |
+
# Concatenate special tokens with patch tokens
|
| 230 |
+
tokens = torch.cat([camera_token, time_conditioning_token, register_token, patch_tokens], dim=1)
|
| 231 |
+
|
| 232 |
+
pos = None
|
| 233 |
+
if self.rope is not None:
|
| 234 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
| 235 |
+
|
| 236 |
+
if self.patch_start_idx > 0:
|
| 237 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 238 |
+
# so set pos to 0 for the special tokens
|
| 239 |
+
pos = pos + 1
|
| 240 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
| 241 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 242 |
+
|
| 243 |
+
# update P because we added special tokens
|
| 244 |
+
_, P, C = tokens.shape
|
| 245 |
+
|
| 246 |
+
frame_idx = 0
|
| 247 |
+
global_idx = 0
|
| 248 |
+
output_list = []
|
| 249 |
+
|
| 250 |
+
for _ in range(self.aa_block_num):
|
| 251 |
+
for attn_type in self.aa_order:
|
| 252 |
+
if attn_type == "frame":
|
| 253 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
| 254 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
| 255 |
+
)
|
| 256 |
+
elif attn_type == "global":
|
| 257 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
| 258 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 262 |
+
|
| 263 |
+
for i in range(len(frame_intermediates)):
|
| 264 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
| 265 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
| 266 |
+
output_list.append(concat_inter)
|
| 267 |
+
|
| 268 |
+
del concat_inter
|
| 269 |
+
del frame_intermediates
|
| 270 |
+
del global_intermediates
|
| 271 |
+
return output_list, self.patch_start_idx
|
| 272 |
+
|
| 273 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
| 274 |
+
"""
|
| 275 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 276 |
+
"""
|
| 277 |
+
# If needed, reshape tokens or positions:
|
| 278 |
+
if tokens.shape != (B * S, P, C):
|
| 279 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 280 |
+
|
| 281 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 282 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 283 |
+
|
| 284 |
+
intermediates = []
|
| 285 |
+
|
| 286 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 287 |
+
for _ in range(self.aa_block_size):
|
| 288 |
+
if self.training:
|
| 289 |
+
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 290 |
+
else:
|
| 291 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
| 292 |
+
frame_idx += 1
|
| 293 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 294 |
+
|
| 295 |
+
return tokens, frame_idx, intermediates
|
| 296 |
+
|
| 297 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
| 298 |
+
"""
|
| 299 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 300 |
+
"""
|
| 301 |
+
if tokens.shape != (B, S * P, C):
|
| 302 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 303 |
+
|
| 304 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 305 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 306 |
+
|
| 307 |
+
intermediates = []
|
| 308 |
+
|
| 309 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 310 |
+
for _ in range(self.aa_block_size):
|
| 311 |
+
if self.training:
|
| 312 |
+
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 313 |
+
else:
|
| 314 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
| 315 |
+
global_idx += 1
|
| 316 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 317 |
+
|
| 318 |
+
return tokens, global_idx, intermediates
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
| 322 |
+
"""
|
| 323 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
| 324 |
+
1) Uses the first position (index=0) for the first frame only
|
| 325 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
| 326 |
+
3) Expands both to match batch size B
|
| 327 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
| 328 |
+
followed by (S-1) second-position tokens
|
| 329 |
+
5) Flattens to (B*S, X, C) for processing
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
| 336 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
| 337 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
| 338 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
| 339 |
+
# Concatenate => shape (B, S, ...)
|
| 340 |
+
combined = torch.cat([query, others], dim=1)
|
| 341 |
+
|
| 342 |
+
# Finally flatten => shape (B*S, ...)
|
| 343 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
| 344 |
+
return combined
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def slice_expand_and_flatten_single(token_tensor, B, S):
|
| 348 |
+
"""
|
| 349 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
| 350 |
+
1) Uses the first position (index=0) for the first frame only
|
| 351 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
| 352 |
+
3) Expands both to match batch size B
|
| 353 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
| 354 |
+
followed by (S-1) second-position tokens
|
| 355 |
+
5) Flattens to (B*S, X, C) for processing
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
| 362 |
+
token = token_tensor.expand(B, S, *token_tensor.shape[2:])
|
| 363 |
+
|
| 364 |
+
# Finally flatten => shape (B*S, ...)
|
| 365 |
+
token = token.view(B * S, 1, *token.shape[2:])
|
| 366 |
+
return token
|
dpm/decoder.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE-VGGT file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn, Tensor
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
from typing import List, Callable
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
from einops import repeat
|
| 15 |
+
|
| 16 |
+
from vggt.layers.block import drop_add_residual_stochastic_depth
|
| 17 |
+
from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 18 |
+
|
| 19 |
+
from vggt.layers.attention import Attention
|
| 20 |
+
from vggt.layers.drop_path import DropPath
|
| 21 |
+
from vggt.layers.layer_scale import LayerScale
|
| 22 |
+
from vggt.layers.mlp import Mlp
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ModulationOut:
|
| 29 |
+
shift: Tensor
|
| 30 |
+
scale: Tensor
|
| 31 |
+
gate: Tensor
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Modulation(nn.Module):
|
| 35 |
+
def __init__(self, dim: int, double: bool):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.is_double = double
|
| 38 |
+
self.multiplier = 6 if double else 3
|
| 39 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 40 |
+
|
| 41 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 42 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
| 43 |
+
|
| 44 |
+
return (
|
| 45 |
+
ModulationOut(*out[:3]),
|
| 46 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ConditionalBlock(nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
dim: int,
|
| 54 |
+
num_heads: int,
|
| 55 |
+
mlp_ratio: float = 4.0,
|
| 56 |
+
qkv_bias: bool = True,
|
| 57 |
+
proj_bias: bool = True,
|
| 58 |
+
ffn_bias: bool = True,
|
| 59 |
+
drop: float = 0.0,
|
| 60 |
+
attn_drop: float = 0.0,
|
| 61 |
+
init_values=None,
|
| 62 |
+
drop_path: float = 0.0,
|
| 63 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 64 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 65 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 66 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 67 |
+
qk_norm: bool = False,
|
| 68 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 69 |
+
rope=None,
|
| 70 |
+
) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.norm1 = norm_layer(dim, elementwise_affine=False)
|
| 74 |
+
self.modulation = Modulation(dim, double=False)
|
| 75 |
+
|
| 76 |
+
self.attn = attn_class(
|
| 77 |
+
dim,
|
| 78 |
+
num_heads=num_heads,
|
| 79 |
+
qkv_bias=qkv_bias,
|
| 80 |
+
proj_bias=proj_bias,
|
| 81 |
+
attn_drop=attn_drop,
|
| 82 |
+
proj_drop=drop,
|
| 83 |
+
qk_norm=qk_norm,
|
| 84 |
+
fused_attn=fused_attn,
|
| 85 |
+
rope=rope,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 89 |
+
|
| 90 |
+
self.norm2 = norm_layer(dim)
|
| 91 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 92 |
+
self.mlp = ffn_layer(
|
| 93 |
+
in_features=dim,
|
| 94 |
+
hidden_features=mlp_hidden_dim,
|
| 95 |
+
act_layer=act_layer,
|
| 96 |
+
drop=drop,
|
| 97 |
+
bias=ffn_bias,
|
| 98 |
+
)
|
| 99 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 100 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 101 |
+
|
| 102 |
+
self.sample_drop_ratio = drop_path
|
| 103 |
+
|
| 104 |
+
def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor:
|
| 105 |
+
B, S = cond.shape[:2]
|
| 106 |
+
C = x.shape[-1]
|
| 107 |
+
if is_global:
|
| 108 |
+
P = x.shape[1] // S
|
| 109 |
+
cond = cond.view(B * S, C)
|
| 110 |
+
mod, _ = self.modulation(cond)
|
| 111 |
+
|
| 112 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 113 |
+
"""
|
| 114 |
+
conditional attention following DiT implementation from Flux
|
| 115 |
+
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239
|
| 116 |
+
"""
|
| 117 |
+
def prepare_for_mod(y):
|
| 118 |
+
"""reshape to modulate the patch tokens with correct conditioning one"""
|
| 119 |
+
return y.view(B, S, P, C).view(B * S, P, C) if is_global else y
|
| 120 |
+
def restore_after_mod(y):
|
| 121 |
+
"""reshape back to global sequence"""
|
| 122 |
+
return y.view(B, S, P, C).view(B, S * P, C) if is_global else y
|
| 123 |
+
|
| 124 |
+
x = prepare_for_mod(x)
|
| 125 |
+
x = (1 + mod.scale) * self.norm1(x) + mod.shift
|
| 126 |
+
x = restore_after_mod(x)
|
| 127 |
+
|
| 128 |
+
x = self.attn(x, pos=pos)
|
| 129 |
+
|
| 130 |
+
x = prepare_for_mod(x)
|
| 131 |
+
x = mod.gate * x
|
| 132 |
+
x = restore_after_mod(x)
|
| 133 |
+
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 137 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 138 |
+
|
| 139 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 140 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 141 |
+
x = drop_add_residual_stochastic_depth(
|
| 142 |
+
x,
|
| 143 |
+
pos=pos,
|
| 144 |
+
residual_func=attn_residual_func,
|
| 145 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 146 |
+
)
|
| 147 |
+
x = drop_add_residual_stochastic_depth(
|
| 148 |
+
x,
|
| 149 |
+
residual_func=ffn_residual_func,
|
| 150 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 151 |
+
)
|
| 152 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 153 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 154 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 155 |
+
else:
|
| 156 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 157 |
+
x = x + ffn_residual_func(x)
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Decoder(nn.Module):
|
| 162 |
+
"""Attention blocks after encoder per DPT input feature
|
| 163 |
+
to generate point maps at a given time.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
cfg,
|
| 169 |
+
dim_in: int,
|
| 170 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 171 |
+
patch_size=14,
|
| 172 |
+
embed_dim=1024,
|
| 173 |
+
depth=2,
|
| 174 |
+
num_heads=16,
|
| 175 |
+
mlp_ratio=4.0,
|
| 176 |
+
block_fn=ConditionalBlock,
|
| 177 |
+
qkv_bias=True,
|
| 178 |
+
proj_bias=True,
|
| 179 |
+
ffn_bias=True,
|
| 180 |
+
aa_order=["frame", "global"],
|
| 181 |
+
aa_block_size=1,
|
| 182 |
+
qk_norm=True,
|
| 183 |
+
rope_freq=100,
|
| 184 |
+
init_values=0.01,
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.cfg = cfg
|
| 188 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 189 |
+
|
| 190 |
+
self.depth = depth
|
| 191 |
+
self.aa_order = aa_order
|
| 192 |
+
self.patch_size = patch_size
|
| 193 |
+
self.aa_block_size = aa_block_size
|
| 194 |
+
|
| 195 |
+
# Validate that depth is divisible by aa_block_size
|
| 196 |
+
if self.depth % self.aa_block_size != 0:
|
| 197 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 198 |
+
|
| 199 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 200 |
+
|
| 201 |
+
self.rope = (
|
| 202 |
+
RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 203 |
+
)
|
| 204 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 205 |
+
|
| 206 |
+
self.dim_in = dim_in
|
| 207 |
+
|
| 208 |
+
self.old_decoder = False
|
| 209 |
+
if self.old_decoder:
|
| 210 |
+
self.frame_blocks = nn.ModuleList(
|
| 211 |
+
[
|
| 212 |
+
block_fn(
|
| 213 |
+
dim=embed_dim*2,
|
| 214 |
+
num_heads=num_heads,
|
| 215 |
+
mlp_ratio=mlp_ratio,
|
| 216 |
+
qkv_bias=qkv_bias,
|
| 217 |
+
proj_bias=proj_bias,
|
| 218 |
+
ffn_bias=ffn_bias,
|
| 219 |
+
init_values=init_values,
|
| 220 |
+
qk_norm=qk_norm,
|
| 221 |
+
rope=self.rope,
|
| 222 |
+
)
|
| 223 |
+
for _ in range(depth)
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
self.global_blocks = nn.ModuleList(
|
| 227 |
+
[
|
| 228 |
+
block_fn(
|
| 229 |
+
dim=embed_dim*2,
|
| 230 |
+
num_heads=num_heads,
|
| 231 |
+
mlp_ratio=mlp_ratio,
|
| 232 |
+
qkv_bias=qkv_bias,
|
| 233 |
+
proj_bias=proj_bias,
|
| 234 |
+
ffn_bias=ffn_bias,
|
| 235 |
+
init_values=init_values,
|
| 236 |
+
qk_norm=qk_norm,
|
| 237 |
+
rope=self.rope,
|
| 238 |
+
)
|
| 239 |
+
for _ in range(depth)
|
| 240 |
+
]
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
depths = [depth]
|
| 244 |
+
self.frame_blocks = nn.ModuleList([
|
| 245 |
+
nn.ModuleList([
|
| 246 |
+
block_fn(
|
| 247 |
+
dim=embed_dim*2,
|
| 248 |
+
num_heads=num_heads,
|
| 249 |
+
mlp_ratio=mlp_ratio,
|
| 250 |
+
qkv_bias=qkv_bias,
|
| 251 |
+
proj_bias=proj_bias,
|
| 252 |
+
ffn_bias=ffn_bias,
|
| 253 |
+
init_values=init_values,
|
| 254 |
+
qk_norm=qk_norm,
|
| 255 |
+
rope=self.rope,
|
| 256 |
+
)
|
| 257 |
+
for _ in range(d)
|
| 258 |
+
])
|
| 259 |
+
for d in depths
|
| 260 |
+
])
|
| 261 |
+
|
| 262 |
+
self.global_blocks = nn.ModuleList([
|
| 263 |
+
nn.ModuleList([
|
| 264 |
+
block_fn(
|
| 265 |
+
dim=embed_dim*2,
|
| 266 |
+
num_heads=num_heads,
|
| 267 |
+
mlp_ratio=mlp_ratio,
|
| 268 |
+
qkv_bias=qkv_bias,
|
| 269 |
+
proj_bias=proj_bias,
|
| 270 |
+
ffn_bias=ffn_bias,
|
| 271 |
+
init_values=init_values,
|
| 272 |
+
qk_norm=qk_norm,
|
| 273 |
+
rope=self.rope,
|
| 274 |
+
)
|
| 275 |
+
for _ in range(d)
|
| 276 |
+
])
|
| 277 |
+
for d in depths
|
| 278 |
+
])
|
| 279 |
+
|
| 280 |
+
self.use_reentrant = False # hardcoded to False
|
| 281 |
+
|
| 282 |
+
def get_condition_tokens(
|
| 283 |
+
self,
|
| 284 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 285 |
+
cond_view_idxs: torch.Tensor
|
| 286 |
+
):
|
| 287 |
+
# Use tokens from the last block for conditioning
|
| 288 |
+
tokens_last = aggregated_tokens_list[-1] # [B S N_tok D]
|
| 289 |
+
# Extract the camera tokens
|
| 290 |
+
cond_token_idx = 1
|
| 291 |
+
camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
|
| 292 |
+
|
| 293 |
+
cond_view_idxs = cond_view_idxs.to(camera_tokens.device)
|
| 294 |
+
cond_view_idxs = repeat(
|
| 295 |
+
cond_view_idxs,
|
| 296 |
+
"b s -> b s c d",
|
| 297 |
+
c=camera_tokens.shape[2],
|
| 298 |
+
d=camera_tokens.shape[3],
|
| 299 |
+
)
|
| 300 |
+
cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs)
|
| 301 |
+
|
| 302 |
+
return cond_tokens
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
images: torch.Tensor,
|
| 307 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 308 |
+
patch_start_idx: int,
|
| 309 |
+
cond_view_idxs: torch.Tensor,
|
| 310 |
+
):
|
| 311 |
+
B, S, _, H, W = images.shape
|
| 312 |
+
|
| 313 |
+
cond_tokens = self.get_condition_tokens(
|
| 314 |
+
aggregated_tokens_list, cond_view_idxs
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
input_tokens = []
|
| 318 |
+
for k, layer_idx in enumerate(self.intermediate_layer_idx):
|
| 319 |
+
layer_tokens = aggregated_tokens_list[layer_idx].clone()
|
| 320 |
+
input_tokens.append(layer_tokens)
|
| 321 |
+
|
| 322 |
+
_, _, P, C = input_tokens[0].shape
|
| 323 |
+
|
| 324 |
+
pos = None
|
| 325 |
+
if self.rope is not None:
|
| 326 |
+
pos = self.position_getter(
|
| 327 |
+
B * S, H // self.patch_size, W // self.patch_size, device=images.device
|
| 328 |
+
)
|
| 329 |
+
if patch_start_idx > 0:
|
| 330 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 331 |
+
# so set pos to 0 for the special tokens
|
| 332 |
+
pos = pos + 1
|
| 333 |
+
pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype)
|
| 334 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 335 |
+
|
| 336 |
+
frame_idx = 0
|
| 337 |
+
global_idx = 0
|
| 338 |
+
depth = len(self.frame_blocks[0])
|
| 339 |
+
N = len(input_tokens)
|
| 340 |
+
# stack all intermediate layer tokens along batch dimension
|
| 341 |
+
# they are all processed by the same decoder
|
| 342 |
+
s_tokens = torch.cat(input_tokens)
|
| 343 |
+
s_cond_tokens = torch.cat([cond_tokens] * N, dim=0)
|
| 344 |
+
s_pos = torch.cat([pos] * N, dim=0)
|
| 345 |
+
|
| 346 |
+
# perform time conditioned attention
|
| 347 |
+
for _ in range(depth):
|
| 348 |
+
for attn_type in self.aa_order:
|
| 349 |
+
token_idx = 0
|
| 350 |
+
|
| 351 |
+
if attn_type == "frame":
|
| 352 |
+
s_tokens, frame_idx, _ = self._process_frame_attention(
|
| 353 |
+
s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx
|
| 354 |
+
)
|
| 355 |
+
elif attn_type == "global":
|
| 356 |
+
s_tokens, global_idx, _ = self._process_global_attention(
|
| 357 |
+
s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 361 |
+
processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)]
|
| 362 |
+
|
| 363 |
+
return processed
|
| 364 |
+
|
| 365 |
+
def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0):
|
| 366 |
+
"""
|
| 367 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 368 |
+
"""
|
| 369 |
+
# If needed, reshape tokens or positions:
|
| 370 |
+
if tokens.shape != (B * S, P, C):
|
| 371 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 372 |
+
|
| 373 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 374 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 375 |
+
|
| 376 |
+
intermediates = []
|
| 377 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 378 |
+
for _ in range(self.aa_block_size):
|
| 379 |
+
if self.training:
|
| 380 |
+
tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant)
|
| 381 |
+
else:
|
| 382 |
+
if self.old_decoder:
|
| 383 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens)
|
| 384 |
+
else:
|
| 385 |
+
tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens)
|
| 386 |
+
|
| 387 |
+
frame_idx += 1
|
| 388 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 389 |
+
|
| 390 |
+
return tokens, frame_idx, intermediates
|
| 391 |
+
|
| 392 |
+
def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0):
|
| 393 |
+
"""
|
| 394 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 395 |
+
"""
|
| 396 |
+
if tokens.shape != (B, S * P, C):
|
| 397 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 398 |
+
|
| 399 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 400 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 401 |
+
|
| 402 |
+
intermediates = []
|
| 403 |
+
|
| 404 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 405 |
+
for _ in range(self.aa_block_size):
|
| 406 |
+
if self.training:
|
| 407 |
+
tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant)
|
| 408 |
+
else:
|
| 409 |
+
if self.old_decoder:
|
| 410 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
|
| 411 |
+
else:
|
| 412 |
+
tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
|
| 413 |
+
global_idx += 1
|
| 414 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 415 |
+
|
| 416 |
+
return tokens, global_idx, intermediates
|
dpm/model.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from vggt.heads.camera_head import CameraHead
|
| 5 |
+
from vggt.heads.dpt_head import DPTHead
|
| 6 |
+
|
| 7 |
+
from .aggregator import Aggregator
|
| 8 |
+
from .decoder import Decoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def freeze_all_params(modules):
|
| 12 |
+
for module in modules:
|
| 13 |
+
try:
|
| 14 |
+
for n, param in module.named_parameters():
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
except AttributeError:
|
| 17 |
+
# module is directly a parameter
|
| 18 |
+
module.requires_grad = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VDPM(nn.Module):
|
| 22 |
+
def __init__(self, cfg, img_size=518, patch_size=14, embed_dim=1024):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.cfg = cfg
|
| 25 |
+
|
| 26 |
+
self.aggregator = Aggregator(
|
| 27 |
+
img_size=img_size,
|
| 28 |
+
patch_size=patch_size,
|
| 29 |
+
embed_dim=embed_dim,
|
| 30 |
+
)
|
| 31 |
+
self.decoder = Decoder(
|
| 32 |
+
cfg,
|
| 33 |
+
dim_in=2*embed_dim,
|
| 34 |
+
embed_dim=embed_dim,
|
| 35 |
+
depth=cfg.model.decoder_depth
|
| 36 |
+
)
|
| 37 |
+
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
|
| 38 |
+
|
| 39 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
| 40 |
+
self.set_freeze()
|
| 41 |
+
|
| 42 |
+
def set_freeze(self):
|
| 43 |
+
to_be_frozen = [self.aggregator.patch_embed]
|
| 44 |
+
freeze_all_params(to_be_frozen)
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
views, autocast_dpt=None
|
| 49 |
+
):
|
| 50 |
+
images = torch.stack([view["img"] for view in views], dim=1)
|
| 51 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
| 52 |
+
|
| 53 |
+
res_dynamic = dict()
|
| 54 |
+
|
| 55 |
+
if self.decoder is not None:
|
| 56 |
+
cond_view_idxs = torch.stack([view["view_idxs"][:, 1] for view in views], dim=1)
|
| 57 |
+
decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
|
| 58 |
+
|
| 59 |
+
if autocast_dpt is None:
|
| 60 |
+
autocast_dpt = torch.amp.autocast("cuda", enabled=False)
|
| 61 |
+
|
| 62 |
+
with autocast_dpt:
|
| 63 |
+
pts3d, pts3d_conf = self.point_head(
|
| 64 |
+
aggregated_tokens_list, images, patch_start_idx
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
padded_decoded_tokens = [None] * len(aggregated_tokens_list)
|
| 68 |
+
for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
|
| 69 |
+
padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
|
| 70 |
+
pts3d_dyn, pts3d_dyn_conf = self.point_head(
|
| 71 |
+
padded_decoded_tokens, images, patch_start_idx
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
res_dynamic |= {
|
| 75 |
+
"pts3d": pts3d_dyn,
|
| 76 |
+
"conf": pts3d_dyn_conf
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 80 |
+
res_dynamic |= {"pose_enc_list": pose_enc_list}
|
| 81 |
+
|
| 82 |
+
res_static = dict(
|
| 83 |
+
pts3d=pts3d,
|
| 84 |
+
conf=pts3d_conf
|
| 85 |
+
)
|
| 86 |
+
return res_static, res_dynamic
|
| 87 |
+
|
| 88 |
+
def inference(
|
| 89 |
+
self,
|
| 90 |
+
views,
|
| 91 |
+
images=None,
|
| 92 |
+
num_timesteps=None
|
| 93 |
+
):
|
| 94 |
+
autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
|
| 95 |
+
|
| 96 |
+
if images is None:
|
| 97 |
+
images = torch.stack([view["img"] for view in views], dim=1)
|
| 98 |
+
|
| 99 |
+
with autocast_amp:
|
| 100 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
| 101 |
+
S = images.shape[1]
|
| 102 |
+
|
| 103 |
+
# Determine number of timesteps to query
|
| 104 |
+
if num_timesteps is None:
|
| 105 |
+
# Default to S if not specified (legacy behavior)
|
| 106 |
+
# But if views has indices, try to infer max time
|
| 107 |
+
if views is not None and "view_idxs" in views[0]:
|
| 108 |
+
try:
|
| 109 |
+
all_idxs = torch.cat([v["view_idxs"][:, 1] for v in views])
|
| 110 |
+
num_timesteps = int(all_idxs.max().item()) + 1
|
| 111 |
+
except:
|
| 112 |
+
num_timesteps = S
|
| 113 |
+
else:
|
| 114 |
+
num_timesteps = S
|
| 115 |
+
|
| 116 |
+
predictions = dict()
|
| 117 |
+
pointmaps = []
|
| 118 |
+
ones = torch.ones(1, S, dtype=torch.int64)
|
| 119 |
+
for time_ in range(num_timesteps):
|
| 120 |
+
cond_view_idxs = ones * time_
|
| 121 |
+
|
| 122 |
+
with autocast_amp:
|
| 123 |
+
decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
|
| 124 |
+
padded_decoded_tokens = [None] * len(aggregated_tokens_list)
|
| 125 |
+
for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
|
| 126 |
+
padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
|
| 127 |
+
|
| 128 |
+
# ... existing code ...
|
| 129 |
+
|
| 130 |
+
pts3d, pts3d_conf = self.point_head(
|
| 131 |
+
padded_decoded_tokens, images, patch_start_idx
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
pointmaps.append(dict(
|
| 135 |
+
pts3d=pts3d,
|
| 136 |
+
conf=pts3d_conf
|
| 137 |
+
))
|
| 138 |
+
|
| 139 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 140 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
| 141 |
+
predictions["pose_enc_list"] = pose_enc_list
|
| 142 |
+
predictions["pointmaps"] = pointmaps
|
| 143 |
+
return predictions
|
| 144 |
+
|
| 145 |
+
def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
|
| 146 |
+
# don't load these VGGT heads as not needed
|
| 147 |
+
exclude = ["depth_head", "track_head"]
|
| 148 |
+
ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
|
| 149 |
+
return super().load_state_dict(ckpt, **kw)
|
examples/videos/camel.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3db92c240efbd1b97a466565988a9a06687fd422086656dc0a29e12c5b99b9bb
|
| 3 |
+
size 1301172
|
examples/videos/car.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd74efdb4d4d59fc17356fefa5dadd4c5b787641c98ce3172ecd8e5a180e76a6
|
| 3 |
+
size 1015132
|
examples/videos/figure1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae285726e5d247e904bb1ea7887ee96733c0beea913b421abba39150a3299cd5
|
| 3 |
+
size 465850
|
examples/videos/figure2.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b2b030dd564cffbb9b2795e7fcdf97fa50e3a518df5b71dfb3dfb36f431dfa4
|
| 3 |
+
size 516209
|
examples/videos/figure3.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a4144a53f14bd2dc671376d26ecbb42b06c9b8810e1700f21a16d3e11dfbf5c
|
| 3 |
+
size 559096
|
examples/videos/goldfish.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28912e59d0d9e6b20d26973efee4806e89e115c7f1e63aec7206384ac3d0bf78
|
| 3 |
+
size 668862
|
examples/videos/horse.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8227c7d901a936aeab6a2b41f104dd17e5544315d4cde7dac37f5787319947e7
|
| 3 |
+
size 1223145
|
examples/videos/paragliding.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acab90bf2ac105ac0a263559722503e7e7a09aed22164d83cb8b3a5ca1d1504d
|
| 3 |
+
size 814941
|
examples/videos/pstudio.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12a598cce787ce9b39c2e4abd33c28e1223b59823258287f3a6f5ffb8abe3b47
|
| 3 |
+
size 167366
|
examples/videos/stroller.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7a88fc629c994bb0067535e181226519ad0750df0d1decbddd01a3e7c5d3c92
|
| 3 |
+
size 1137267
|
examples/videos/swing.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11d0cb318514b326ac2ed94d0765e29181d58401a84d64aa57ddd4cb1e865dcc
|
| 3 |
+
size 812797
|
examples/videos/tennis.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50bee7fd257e38c66765a82412734be5abacbe30aa9e9ba04f5387b7865a380e
|
| 3 |
+
size 467310
|
examples/videos/tesla.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c74333ee25fdb87d26a5a8db080d64d51edc418bab56f3aae98792ad4dee2704
|
| 3 |
+
size 588952
|
gradio_demo.py
ADDED
|
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import matplotlib
|
| 12 |
+
import numpy as np
|
| 13 |
+
import plotly.graph_objects as go
|
| 14 |
+
import torch
|
| 15 |
+
from hydra import compose, initialize
|
| 16 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 17 |
+
|
| 18 |
+
from dpm.model import VDPM
|
| 19 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# MEMORY OPTIMIZATION SETTINGS FOR 8GB GPUs (RTX 3070 Ti, 3060 Ti, etc.)
|
| 24 |
+
# ============================================================================
|
| 25 |
+
# Model size: 1.66B parameters (~3GB FP16 weights, ~12GB total with activations)
|
| 26 |
+
#
|
| 27 |
+
# Memory reduction options (choose one):
|
| 28 |
+
# USE_HALF_PRECISION = True: FP16 model -> ~1.5GB weights, ~6-7GB total (RECOMMENDED FOR GPU)
|
| 29 |
+
# USE_QUANTIZATION = True: INT8 quantization -> CPU ONLY, not supported on CUDA
|
| 30 |
+
# Both False: FP16/BF16 inference only -> ~3GB weights, ~8-10GB total (may OOM)
|
| 31 |
+
#
|
| 32 |
+
# MAX_FRAMES: Limit input frames (5 recommended for 8GB GPUs)
|
| 33 |
+
# ============================================================================
|
| 34 |
+
|
| 35 |
+
MAX_POINTS_PER_FRAME = 50_000
|
| 36 |
+
TRAIL_LENGTH = 20
|
| 37 |
+
MAX_TRACKS = 150
|
| 38 |
+
STATIC_THRESHOLD = 0.025
|
| 39 |
+
VIDEO_SAMPLE_HZ = 1.0
|
| 40 |
+
|
| 41 |
+
# Dynamic Configuration based on Helper/Hardware
|
| 42 |
+
USE_QUANTIZATION = False
|
| 43 |
+
USE_HALF_PRECISION = True
|
| 44 |
+
MAX_FRAMES = 5 # Default for 8GB
|
| 45 |
+
|
| 46 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
+
|
| 48 |
+
if device == "cuda":
|
| 49 |
+
# Enable TF32
|
| 50 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 51 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 52 |
+
|
| 53 |
+
# Check VRAM to auto-scale MAX_FRAMES
|
| 54 |
+
vram_bytes = torch.cuda.get_device_properties(0).total_memory
|
| 55 |
+
vram_gb = vram_bytes / (1024**3)
|
| 56 |
+
|
| 57 |
+
print(f"\u2713 GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
|
| 58 |
+
|
| 59 |
+
if vram_gb > 22: # A10G (24GB), A100 (40/80GB), RTX 3090/4090 (24GB)
|
| 60 |
+
MAX_FRAMES = 80
|
| 61 |
+
print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
|
| 62 |
+
elif vram_gb > 14: # T4 (16GB), 4080 (16GB)
|
| 63 |
+
MAX_FRAMES = 16
|
| 64 |
+
print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
|
| 65 |
+
else:
|
| 66 |
+
MAX_FRAMES = 5
|
| 67 |
+
print(f" -> Low VRAM (<14GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
|
| 68 |
+
print(f"\u2713 TF32 enabled for faster matrix operations")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_cfg_from_cli() -> "omegaconf.DictConfig":
|
| 73 |
+
if GlobalHydra.instance().is_initialized():
|
| 74 |
+
GlobalHydra.instance().clear()
|
| 75 |
+
overrides = sys.argv[1:]
|
| 76 |
+
with initialize(config_path="configs"):
|
| 77 |
+
return compose(config_name="visualise", overrides=overrides)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_model(cfg) -> VDPM:
|
| 81 |
+
model = VDPM(cfg).to(device)
|
| 82 |
+
|
| 83 |
+
_URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
|
| 84 |
+
sd = torch.hub.load_state_dict_from_url(
|
| 85 |
+
_URL,
|
| 86 |
+
file_name="vdpm_model.pt",
|
| 87 |
+
progress=True
|
| 88 |
+
)
|
| 89 |
+
print(model.load_state_dict(sd, strict=True))
|
| 90 |
+
|
| 91 |
+
model.eval()
|
| 92 |
+
|
| 93 |
+
# Option 1: Use FP16 for all model weights (simple, ~2x memory reduction)
|
| 94 |
+
if USE_HALF_PRECISION and not USE_QUANTIZATION:
|
| 95 |
+
print("Converting model to FP16 precision...")
|
| 96 |
+
model = model.half()
|
| 97 |
+
print("✓ Model converted to FP16: ~2x memory reduction (3GB -> 1.5GB)")
|
| 98 |
+
|
| 99 |
+
# Option 2: Apply INT8 dynamic quantization (more aggressive, ~3-4x reduction)
|
| 100 |
+
if USE_QUANTIZATION:
|
| 101 |
+
try:
|
| 102 |
+
print("Applying INT8 dynamic quantization to reduce memory usage...")
|
| 103 |
+
# Move to CPU for quantization, then back to GPU
|
| 104 |
+
model = model.cpu()
|
| 105 |
+
model = torch.quantization.quantize_dynamic(
|
| 106 |
+
model,
|
| 107 |
+
{torch.nn.Linear, torch.nn.Conv2d}, # Quantize these layer types
|
| 108 |
+
dtype=torch.qint8
|
| 109 |
+
)
|
| 110 |
+
model = model.to(device)
|
| 111 |
+
print("✓ Model quantized: ~3x memory reduction (3GB -> 1GB)")
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"⚠️ Quantization failed: {e}")
|
| 114 |
+
print("Continuing with FP16/BF16 precision...")
|
| 115 |
+
model = model.to(device)
|
| 116 |
+
|
| 117 |
+
# Enable torch.compile for faster inference (PyTorch 2.0+)
|
| 118 |
+
# Note: Disable compile if using quantization as they may conflict
|
| 119 |
+
if not USE_QUANTIZATION:
|
| 120 |
+
try:
|
| 121 |
+
print("Compiling model with torch.compile for faster inference...")
|
| 122 |
+
model = torch.compile(model, mode="reduce-overhead")
|
| 123 |
+
print("✓ Model compilation successful")
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Warning: torch.compile not available or failed: {e}")
|
| 126 |
+
|
| 127 |
+
return model
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def require_cuda():
|
| 131 |
+
if device != "cuda":
|
| 132 |
+
raise ValueError("CUDA is not available. Check your environment.")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def gradio_file_path(file_obj):
|
| 136 |
+
if file_obj is None:
|
| 137 |
+
return None
|
| 138 |
+
if isinstance(file_obj, dict) and "name" in file_obj:
|
| 139 |
+
return file_obj["name"]
|
| 140 |
+
return file_obj
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def ensure_nhwc_images(images: np.ndarray) -> np.ndarray:
|
| 144 |
+
if images.ndim == 4 and images.shape[1] == 3:
|
| 145 |
+
return np.transpose(images, (0, 2, 3, 1))
|
| 146 |
+
return images
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def compute_scene_bounds(world_points: np.ndarray):
|
| 150 |
+
all_pts = world_points.reshape(-1, 3)
|
| 151 |
+
raw_min = all_pts.min(axis=0)
|
| 152 |
+
raw_max = all_pts.max(axis=0)
|
| 153 |
+
|
| 154 |
+
center = 0.5 * (raw_min + raw_max)
|
| 155 |
+
half_extent = 0.5 * (raw_max - raw_min) * 1.05
|
| 156 |
+
|
| 157 |
+
if np.all(half_extent < 1e-6):
|
| 158 |
+
half_extent[:] = 1.0
|
| 159 |
+
else:
|
| 160 |
+
half_extent[half_extent < 1e-6] = half_extent.max()
|
| 161 |
+
|
| 162 |
+
global_min = center - half_extent
|
| 163 |
+
global_max = center + half_extent
|
| 164 |
+
|
| 165 |
+
max_half = half_extent.max()
|
| 166 |
+
aspectratio = {
|
| 167 |
+
"x": float(half_extent[0] / max_half),
|
| 168 |
+
"y": float(half_extent[1] / max_half),
|
| 169 |
+
"z": float(half_extent[2] / max_half),
|
| 170 |
+
}
|
| 171 |
+
return global_min, global_max, aspectratio
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def stride_downsample(pts: np.ndarray, cols: np.ndarray, max_points: int):
|
| 175 |
+
n = pts.shape[0]
|
| 176 |
+
if n <= max_points:
|
| 177 |
+
return pts, cols
|
| 178 |
+
step = int(np.ceil(n / max_points))
|
| 179 |
+
idx = np.arange(0, n, step)[:max_points]
|
| 180 |
+
return pts[idx], cols[idx]
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ============================================================
|
| 184 |
+
# NEW: Single shared mask function (used by points + tracks)
|
| 185 |
+
# ============================================================
|
| 186 |
+
def compute_point_mask(
|
| 187 |
+
conf_score: np.ndarray | None,
|
| 188 |
+
cols: np.ndarray,
|
| 189 |
+
conf_thres: float,
|
| 190 |
+
mask_black_bg: bool,
|
| 191 |
+
mask_white_bg: bool,
|
| 192 |
+
) -> np.ndarray:
|
| 193 |
+
"""
|
| 194 |
+
conf_score: (N,) or None
|
| 195 |
+
cols: (N,3) uint8
|
| 196 |
+
Returns: (N,) boolean mask
|
| 197 |
+
"""
|
| 198 |
+
mask = np.ones(cols.shape[0], dtype=bool)
|
| 199 |
+
|
| 200 |
+
# confidence percentile threshold (same semantics as before)
|
| 201 |
+
if conf_score is not None and conf_thres > 0:
|
| 202 |
+
thresh = np.percentile(conf_score, conf_thres)
|
| 203 |
+
mask &= (conf_score >= thresh) & (conf_score > 1e-5)
|
| 204 |
+
|
| 205 |
+
# background masks (same as before)
|
| 206 |
+
if mask_black_bg:
|
| 207 |
+
mask &= (cols.sum(axis=1) >= 16)
|
| 208 |
+
if mask_white_bg:
|
| 209 |
+
mask &= ~((cols[:, 0] > 240) & (cols[:, 1] > 240) & (cols[:, 2] > 240))
|
| 210 |
+
|
| 211 |
+
return mask
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def sample_frame_points(
|
| 215 |
+
world_points: np.ndarray,
|
| 216 |
+
images_nhwc: np.ndarray,
|
| 217 |
+
conf: np.ndarray | None,
|
| 218 |
+
idx: int,
|
| 219 |
+
conf_thres: float,
|
| 220 |
+
mask_black_bg: bool,
|
| 221 |
+
mask_white_bg: bool,
|
| 222 |
+
max_points: int,
|
| 223 |
+
):
|
| 224 |
+
i = int(np.clip(idx, 0, world_points.shape[0] - 1))
|
| 225 |
+
pts = world_points[i].reshape(-1, 3)
|
| 226 |
+
cols = (images_nhwc[i].reshape(-1, 3) * 255).astype(np.uint8)
|
| 227 |
+
|
| 228 |
+
conf_score = conf[i].reshape(-1) if (conf is not None) else None
|
| 229 |
+
|
| 230 |
+
mask = compute_point_mask(
|
| 231 |
+
conf_score=conf_score,
|
| 232 |
+
cols=cols,
|
| 233 |
+
conf_thres=conf_thres,
|
| 234 |
+
mask_black_bg=mask_black_bg,
|
| 235 |
+
mask_white_bg=mask_white_bg,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
pts = pts[mask]
|
| 239 |
+
cols = cols[mask]
|
| 240 |
+
|
| 241 |
+
pts, cols = stride_downsample(pts, cols, max_points)
|
| 242 |
+
|
| 243 |
+
if pts.size == 0:
|
| 244 |
+
pts = np.array([[0.0, 0.0, 0.0]])
|
| 245 |
+
cols = np.array([[255, 255, 255]], dtype=np.uint8)
|
| 246 |
+
|
| 247 |
+
colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols]
|
| 248 |
+
return pts, colors_str
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# ============================================================
|
| 252 |
+
# UPDATED: prepare_tracks now applies the SAME masks as points
|
| 253 |
+
# ============================================================
|
| 254 |
+
def prepare_tracks(
|
| 255 |
+
world_points: np.ndarray,
|
| 256 |
+
images_nhwc: np.ndarray,
|
| 257 |
+
conf: np.ndarray | None,
|
| 258 |
+
conf_thres: float,
|
| 259 |
+
mask_black_bg: bool,
|
| 260 |
+
mask_white_bg: bool,
|
| 261 |
+
):
|
| 262 |
+
S, H, W, _ = world_points.shape
|
| 263 |
+
N = H * W
|
| 264 |
+
if S < 2 or N == 0:
|
| 265 |
+
return None, None, None
|
| 266 |
+
|
| 267 |
+
tracks_xyz = world_points.reshape(S, N, 3)
|
| 268 |
+
|
| 269 |
+
disp = np.linalg.norm(tracks_xyz - tracks_xyz[0:1], axis=-1)
|
| 270 |
+
dynamic_mask = disp.max(axis=0) > STATIC_THRESHOLD
|
| 271 |
+
|
| 272 |
+
# build a per-point confidence score (across time)
|
| 273 |
+
conf_score = None
|
| 274 |
+
if conf is not None:
|
| 275 |
+
conf_flat = conf.reshape(S, N)
|
| 276 |
+
conf_score = conf_flat.mean(axis=0)
|
| 277 |
+
|
| 278 |
+
# Use reference-frame colors for background masking (stable, consistent)
|
| 279 |
+
ref_cols = (images_nhwc[0].reshape(-1, 3) * 255).astype(np.uint8)
|
| 280 |
+
|
| 281 |
+
point_mask = compute_point_mask(
|
| 282 |
+
conf_score=conf_score,
|
| 283 |
+
cols=ref_cols,
|
| 284 |
+
conf_thres=conf_thres,
|
| 285 |
+
mask_black_bg=mask_black_bg,
|
| 286 |
+
mask_white_bg=mask_white_bg,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
dynamic_mask &= point_mask
|
| 290 |
+
|
| 291 |
+
idx_tracks = np.nonzero(dynamic_mask)[0]
|
| 292 |
+
if idx_tracks.size == 0:
|
| 293 |
+
return None, None, None
|
| 294 |
+
|
| 295 |
+
if idx_tracks.size > MAX_TRACKS:
|
| 296 |
+
step = int(np.ceil(idx_tracks.size / MAX_TRACKS))
|
| 297 |
+
idx_tracks = idx_tracks[::step][:MAX_TRACKS]
|
| 298 |
+
|
| 299 |
+
tracks_xyz = tracks_xyz[:, idx_tracks, :]
|
| 300 |
+
order = np.argsort(tracks_xyz[0, :, 1])
|
| 301 |
+
tracks_xyz = tracks_xyz[:, order, :]
|
| 302 |
+
|
| 303 |
+
num_tracks = tracks_xyz.shape[1]
|
| 304 |
+
cmap = matplotlib.colormaps.get_cmap("hsv")
|
| 305 |
+
norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1))
|
| 306 |
+
|
| 307 |
+
colorscale = []
|
| 308 |
+
for t in range(num_tracks):
|
| 309 |
+
r, g, b, _ = cmap(norm(t))
|
| 310 |
+
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
| 311 |
+
pos = t / max(num_tracks - 1, 1)
|
| 312 |
+
colorscale.append([pos, f"rgb({r},{g},{b})"])
|
| 313 |
+
|
| 314 |
+
track_ids = np.arange(num_tracks, dtype=float)
|
| 315 |
+
return tracks_xyz, colorscale, track_ids
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def track_segments_for_frame(tracks_xyz: np.ndarray | None, track_ids: np.ndarray | None, f: int):
|
| 319 |
+
if tracks_xyz is None or track_ids is None or f <= 0:
|
| 320 |
+
return np.array([]), np.array([]), np.array([]), np.array([])
|
| 321 |
+
|
| 322 |
+
start_t = max(0, f - TRAIL_LENGTH)
|
| 323 |
+
num_tracks = tracks_xyz.shape[1]
|
| 324 |
+
|
| 325 |
+
xs, ys, zs, cs = [], [], [], []
|
| 326 |
+
for j in range(num_tracks):
|
| 327 |
+
seg = tracks_xyz[start_t : f + 1, j, :]
|
| 328 |
+
if seg.shape[0] < 2:
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
xs.extend([seg[:, 0], np.array([np.nan])])
|
| 332 |
+
ys.extend([seg[:, 1], np.array([np.nan])])
|
| 333 |
+
zs.extend([seg[:, 2], np.array([np.nan])])
|
| 334 |
+
cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float))
|
| 335 |
+
|
| 336 |
+
x = np.concatenate(xs) if xs else np.array([])
|
| 337 |
+
y = np.concatenate(ys) if ys else np.array([])
|
| 338 |
+
z = np.concatenate(zs) if zs else np.array([])
|
| 339 |
+
c = np.concatenate(cs) if cs else np.array([])
|
| 340 |
+
|
| 341 |
+
return x, y, z, c
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def build_pointcloud_figure_update(
|
| 345 |
+
data,
|
| 346 |
+
conf_thres: float,
|
| 347 |
+
mask_black_bg: bool,
|
| 348 |
+
mask_white_bg: bool,
|
| 349 |
+
):
|
| 350 |
+
if data is None:
|
| 351 |
+
return go.Figure()
|
| 352 |
+
|
| 353 |
+
world_points = data["world_points"]
|
| 354 |
+
conf = data.get("world_points_conf")
|
| 355 |
+
images = ensure_nhwc_images(data["images"])
|
| 356 |
+
S = world_points.shape[0]
|
| 357 |
+
|
| 358 |
+
global_min, global_max, aspectratio = compute_scene_bounds(world_points)
|
| 359 |
+
|
| 360 |
+
# UPDATED: pass same masks into prepare_tracks
|
| 361 |
+
tracks_xyz, colorscale, track_ids = prepare_tracks(
|
| 362 |
+
world_points=world_points,
|
| 363 |
+
images_nhwc=images,
|
| 364 |
+
conf=conf,
|
| 365 |
+
conf_thres=conf_thres,
|
| 366 |
+
mask_black_bg=mask_black_bg,
|
| 367 |
+
mask_white_bg=mask_white_bg,
|
| 368 |
+
)
|
| 369 |
+
track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1
|
| 370 |
+
|
| 371 |
+
pts_xyz = [None] * S
|
| 372 |
+
pts_cols = [None] * S
|
| 373 |
+
trk_xyz = [None] * S
|
| 374 |
+
trk_c = [None] * S
|
| 375 |
+
|
| 376 |
+
for i in range(S):
|
| 377 |
+
pts_i, cols_i = sample_frame_points(
|
| 378 |
+
world_points=world_points,
|
| 379 |
+
images_nhwc=images,
|
| 380 |
+
conf=conf,
|
| 381 |
+
idx=i,
|
| 382 |
+
conf_thres=conf_thres,
|
| 383 |
+
mask_black_bg=mask_black_bg,
|
| 384 |
+
mask_white_bg=mask_white_bg,
|
| 385 |
+
max_points=MAX_POINTS_PER_FRAME,
|
| 386 |
+
)
|
| 387 |
+
pts_xyz[i] = pts_i
|
| 388 |
+
pts_cols[i] = cols_i
|
| 389 |
+
|
| 390 |
+
x, y, z, c = track_segments_for_frame(tracks_xyz, track_ids, f=i)
|
| 391 |
+
trk_xyz[i] = (x, y, z)
|
| 392 |
+
trk_c[i] = c
|
| 393 |
+
|
| 394 |
+
p0 = pts_xyz[0]
|
| 395 |
+
c0 = pts_cols[0]
|
| 396 |
+
x0, y0, z0 = trk_xyz[0]
|
| 397 |
+
tc0 = trk_c[0]
|
| 398 |
+
|
| 399 |
+
scene_cfg = dict(
|
| 400 |
+
xaxis=dict(
|
| 401 |
+
visible=False,
|
| 402 |
+
showbackground=False,
|
| 403 |
+
showgrid=False,
|
| 404 |
+
zeroline=False,
|
| 405 |
+
showticklabels=False,
|
| 406 |
+
range=[float(global_min[0]), float(global_max[0])],
|
| 407 |
+
),
|
| 408 |
+
yaxis=dict(
|
| 409 |
+
visible=False,
|
| 410 |
+
showbackground=False,
|
| 411 |
+
showgrid=False,
|
| 412 |
+
zeroline=False,
|
| 413 |
+
showticklabels=False,
|
| 414 |
+
range=[float(global_min[1]), float(global_max[1])],
|
| 415 |
+
),
|
| 416 |
+
zaxis=dict(
|
| 417 |
+
visible=False,
|
| 418 |
+
showbackground=False,
|
| 419 |
+
showgrid=False,
|
| 420 |
+
zeroline=False,
|
| 421 |
+
showticklabels=False,
|
| 422 |
+
range=[float(global_min[2]), float(global_max[2])],
|
| 423 |
+
),
|
| 424 |
+
aspectmode="manual",
|
| 425 |
+
aspectratio=aspectratio,
|
| 426 |
+
dragmode="orbit",
|
| 427 |
+
camera=dict(
|
| 428 |
+
eye=dict(x=0.0, y=0.0, z=-1.0),
|
| 429 |
+
center=dict(x=0.0, y=0.0, z=0.0),
|
| 430 |
+
up=dict(x=0.0, y=-1.0, z=0.0),
|
| 431 |
+
),
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
fig = go.Figure(
|
| 435 |
+
data=[
|
| 436 |
+
go.Scatter3d(
|
| 437 |
+
x=p0[:, 0],
|
| 438 |
+
y=p0[:, 1],
|
| 439 |
+
z=p0[:, 2],
|
| 440 |
+
mode="markers",
|
| 441 |
+
marker=dict(size=2, color=c0),
|
| 442 |
+
showlegend=False,
|
| 443 |
+
name="points",
|
| 444 |
+
),
|
| 445 |
+
go.Scatter3d(
|
| 446 |
+
x=x0,
|
| 447 |
+
y=y0,
|
| 448 |
+
z=z0,
|
| 449 |
+
mode="lines",
|
| 450 |
+
line=dict(
|
| 451 |
+
width=2,
|
| 452 |
+
color=tc0 if (tc0 is not None and tc0.size) else None,
|
| 453 |
+
colorscale=colorscale if colorscale is not None else None,
|
| 454 |
+
cmin=0,
|
| 455 |
+
cmax=track_cmax,
|
| 456 |
+
),
|
| 457 |
+
hoverinfo="skip",
|
| 458 |
+
showlegend=False,
|
| 459 |
+
name="tracks",
|
| 460 |
+
),
|
| 461 |
+
]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
steps = []
|
| 465 |
+
for i in range(S):
|
| 466 |
+
pi = pts_xyz[i]
|
| 467 |
+
ci = pts_cols[i]
|
| 468 |
+
xi, yi, zi = trk_xyz[i]
|
| 469 |
+
ti = trk_c[i]
|
| 470 |
+
|
| 471 |
+
steps.append(
|
| 472 |
+
dict(
|
| 473 |
+
method="update",
|
| 474 |
+
label=str(i),
|
| 475 |
+
args=[
|
| 476 |
+
{
|
| 477 |
+
"x": [pi[:, 0], xi],
|
| 478 |
+
"y": [pi[:, 1], yi],
|
| 479 |
+
"z": [pi[:, 2], zi],
|
| 480 |
+
"marker.color": [ci, None],
|
| 481 |
+
"line.color": [None, ti if (ti is not None and len(ti)) else None],
|
| 482 |
+
},
|
| 483 |
+
{},
|
| 484 |
+
],
|
| 485 |
+
)
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
sliders = [
|
| 489 |
+
dict(
|
| 490 |
+
active=0,
|
| 491 |
+
currentvalue={"prefix": "Frame: ", "visible": True, "font": {"size": 14}},
|
| 492 |
+
pad={"t": 10},
|
| 493 |
+
len=0.6,
|
| 494 |
+
x=0.2,
|
| 495 |
+
font={"size": 8},
|
| 496 |
+
steps=steps,
|
| 497 |
+
)
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
fig.update_layout(
|
| 501 |
+
margin=dict(l=0, r=0, t=30, b=0),
|
| 502 |
+
scene=scene_cfg,
|
| 503 |
+
sliders=sliders,
|
| 504 |
+
showlegend=False,
|
| 505 |
+
title="Scrub frames with the slider below",
|
| 506 |
+
uirevision="keep-camera",
|
| 507 |
+
height=700,
|
| 508 |
+
)
|
| 509 |
+
return fig
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def run_model(target_dir: str, model: VDPM, frame_id_arg=0) -> dict:
|
| 513 |
+
require_cuda()
|
| 514 |
+
|
| 515 |
+
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
|
| 516 |
+
if not image_names:
|
| 517 |
+
raise ValueError("No images found. Check your upload.")
|
| 518 |
+
|
| 519 |
+
# Load metadata for Multi-View sync (Load BEFORE slicing to respect view count)
|
| 520 |
+
meta_path = os.path.join(target_dir, "meta.json")
|
| 521 |
+
num_views = 1
|
| 522 |
+
if os.path.exists(meta_path):
|
| 523 |
+
try:
|
| 524 |
+
import json
|
| 525 |
+
with open(meta_path, 'r') as f:
|
| 526 |
+
num_views = json.load(f).get("num_views", 1)
|
| 527 |
+
except:
|
| 528 |
+
pass
|
| 529 |
+
|
| 530 |
+
# Limit frames to prevent OOM on 8GB GPUs
|
| 531 |
+
if len(image_names) > MAX_FRAMES:
|
| 532 |
+
# Round down to nearest multiple of num_views to preserve full scenes
|
| 533 |
+
limit = (MAX_FRAMES // num_views) * num_views
|
| 534 |
+
if limit == 0:
|
| 535 |
+
limit = num_views # At least one full timestep
|
| 536 |
+
print(f"⚠️ Warning: MAX_FRAMES={MAX_FRAMES} is smaller than num_views={num_views}. Processing 1 full timestep anyway (may OOM).")
|
| 537 |
+
|
| 538 |
+
print(f"⚠️ Limiting to {limit} frames ({limit // num_views} timesteps * {num_views} views) to fit in GPU memory")
|
| 539 |
+
image_names = image_names[:limit]
|
| 540 |
+
|
| 541 |
+
images = load_and_preprocess_images(image_names).to(device)
|
| 542 |
+
|
| 543 |
+
if device == "cuda":
|
| 544 |
+
print(f"GPU memory before inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
| 545 |
+
|
| 546 |
+
print(f"Running inference on {len(image_names)} images ({num_views} synchronized views)...")
|
| 547 |
+
|
| 548 |
+
# Construct 'views' dictionaries with correct time indices
|
| 549 |
+
views = []
|
| 550 |
+
for i in range(len(image_names)):
|
| 551 |
+
t_idx = i // num_views
|
| 552 |
+
views.append({
|
| 553 |
+
"img": images[i].unsqueeze(0), # (1, C, H, W)
|
| 554 |
+
"view_idxs": torch.tensor([[0, t_idx]], device=device, dtype=torch.long)
|
| 555 |
+
})
|
| 556 |
+
|
| 557 |
+
inference_start = time.time()
|
| 558 |
+
|
| 559 |
+
with torch.no_grad():
|
| 560 |
+
with torch.amp.autocast('cuda'):
|
| 561 |
+
# Pass constructed views so model uses correct time query
|
| 562 |
+
predictions = model.inference(views=views)
|
| 563 |
+
|
| 564 |
+
inference_time = time.time() - inference_start
|
| 565 |
+
print(f"✓ Inference completed in {inference_time:.2f}s ({inference_time/len(image_names):.2f}s per frame)")
|
| 566 |
+
|
| 567 |
+
# Move results to CPU immediately to free GPU memory
|
| 568 |
+
pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| 569 |
+
conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| 570 |
+
|
| 571 |
+
# Clear predictions from GPU to save memory
|
| 572 |
+
del predictions
|
| 573 |
+
if device == "cuda":
|
| 574 |
+
torch.cuda.empty_cache()
|
| 575 |
+
print(f"GPU memory after inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
| 576 |
+
|
| 577 |
+
world_points = np.concatenate(pts_list, axis=0)
|
| 578 |
+
world_points_conf = np.concatenate(conf_list, axis=0)
|
| 579 |
+
|
| 580 |
+
try:
|
| 581 |
+
frame_id = int(frame_id_arg)
|
| 582 |
+
except Exception:
|
| 583 |
+
frame_id = 0
|
| 584 |
+
|
| 585 |
+
if frame_id >= world_points.shape[0]:
|
| 586 |
+
frame_id = 0
|
| 587 |
+
|
| 588 |
+
world_points_s = world_points[:, frame_id, ::2, ::2, :]
|
| 589 |
+
single_mask = world_points_conf[frame_id, frame_id, ::2, ::2]
|
| 590 |
+
world_points_conf_s = np.tile(single_mask[np.newaxis, ...], (world_points.shape[0], 1, 1))
|
| 591 |
+
|
| 592 |
+
img_np = images.detach().cpu().numpy()
|
| 593 |
+
img_np = img_np[frame_id : frame_id + 1, :, ::2, ::2]
|
| 594 |
+
img_np = np.repeat(img_np, world_points.shape[0], axis=0)
|
| 595 |
+
|
| 596 |
+
torch.cuda.empty_cache()
|
| 597 |
+
return {
|
| 598 |
+
"world_points": world_points_s,
|
| 599 |
+
"world_points_conf": world_points_conf_s,
|
| 600 |
+
"images": img_np,
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def handle_uploads(input_video, input_images):
|
| 605 |
+
start_time = time.time()
|
| 606 |
+
gc.collect()
|
| 607 |
+
torch.cuda.empty_cache()
|
| 608 |
+
|
| 609 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 610 |
+
target_dir = f"input_images_{timestamp}"
|
| 611 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 612 |
+
|
| 613 |
+
if os.path.exists(target_dir):
|
| 614 |
+
shutil.rmtree(target_dir)
|
| 615 |
+
os.makedirs(target_dir_images, exist_ok=True)
|
| 616 |
+
|
| 617 |
+
image_paths = []
|
| 618 |
+
|
| 619 |
+
if input_images:
|
| 620 |
+
for file_obj in input_images:
|
| 621 |
+
src = gradio_file_path(file_obj)
|
| 622 |
+
if not src:
|
| 623 |
+
continue
|
| 624 |
+
dst = os.path.join(target_dir_images, os.path.basename(src))
|
| 625 |
+
shutil.copy(src, dst)
|
| 626 |
+
image_paths.append(dst)
|
| 627 |
+
|
| 628 |
+
if input_video:
|
| 629 |
+
# Check if input is a list (Gradio 4.x/5.x or file_count="multiple")
|
| 630 |
+
input_video_list = input_video if isinstance(input_video, list) else [input_video]
|
| 631 |
+
|
| 632 |
+
# Determine starting frame number based on existing images
|
| 633 |
+
existing_files = os.listdir(target_dir_images)
|
| 634 |
+
frame_num = len(existing_files)
|
| 635 |
+
|
| 636 |
+
# Modified for Interleaved/Synchronized processing
|
| 637 |
+
# 1. Open all videos
|
| 638 |
+
captures = []
|
| 639 |
+
capture_meta = []
|
| 640 |
+
for idx, vid_obj in enumerate(input_video_list):
|
| 641 |
+
video_path = gradio_file_path(vid_obj)
|
| 642 |
+
print(f"Preparing video {idx+1}/{len(input_video_list)}: {video_path}")
|
| 643 |
+
|
| 644 |
+
vs = cv2.VideoCapture(video_path)
|
| 645 |
+
fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0)
|
| 646 |
+
if fps <= 0: fps = 30.0 # Fallback
|
| 647 |
+
|
| 648 |
+
frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
|
| 649 |
+
captures.append(vs)
|
| 650 |
+
capture_meta.append({"interval": frame_interval, "name": video_path})
|
| 651 |
+
|
| 652 |
+
# 2. Step through them together
|
| 653 |
+
print("Processing videos in interleaved mode...")
|
| 654 |
+
step_count = 0
|
| 655 |
+
active_videos = True
|
| 656 |
+
|
| 657 |
+
while active_videos:
|
| 658 |
+
active_videos = False
|
| 659 |
+
for i, vs in enumerate(captures):
|
| 660 |
+
if not vs.isOpened():
|
| 661 |
+
continue
|
| 662 |
+
|
| 663 |
+
gotit, frame = vs.read()
|
| 664 |
+
if gotit:
|
| 665 |
+
active_videos = True # Keep going as long as at least one video has frames
|
| 666 |
+
|
| 667 |
+
if step_count % capture_meta[i]["interval"] == 0:
|
| 668 |
+
out_path = os.path.join(target_dir_images, f"{frame_num:06}.png")
|
| 669 |
+
cv2.imwrite(out_path, frame)
|
| 670 |
+
image_paths.append(out_path)
|
| 671 |
+
frame_num += 1
|
| 672 |
+
else:
|
| 673 |
+
vs.release()
|
| 674 |
+
|
| 675 |
+
step_count += 1
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
image_paths.sort()
|
| 679 |
+
|
| 680 |
+
# Save metadata about capture structure (num_views)
|
| 681 |
+
num_views = len(input_video_list) if input_video else 1
|
| 682 |
+
meta_path = os.path.join(target_dir, "meta.json")
|
| 683 |
+
try:
|
| 684 |
+
import json
|
| 685 |
+
with open(meta_path, 'w') as f:
|
| 686 |
+
json.dump({"num_views": num_views}, f)
|
| 687 |
+
except Exception as e:
|
| 688 |
+
print(f"Warning: could not save metadata: {e}")
|
| 689 |
+
|
| 690 |
+
print(f"Files copied to {target_dir_images}; took {time.time() - start_time:.3f} seconds")
|
| 691 |
+
return target_dir, image_paths
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def update_gallery_on_upload(input_video, input_images):
|
| 695 |
+
if not input_video and not input_images:
|
| 696 |
+
return None, None, None, None
|
| 697 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 698 |
+
return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def gradio_reconstruct(
|
| 702 |
+
target_dir,
|
| 703 |
+
conf_thres=50.0,
|
| 704 |
+
mask_black_bg=False,
|
| 705 |
+
mask_white_bg=False,
|
| 706 |
+
frame_id_val=0,
|
| 707 |
+
):
|
| 708 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 709 |
+
return None, "No valid target directory found. Please upload first.", None
|
| 710 |
+
|
| 711 |
+
gc.collect()
|
| 712 |
+
torch.cuda.empty_cache()
|
| 713 |
+
|
| 714 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 715 |
+
num_frames = len(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else 0
|
| 716 |
+
|
| 717 |
+
with torch.no_grad():
|
| 718 |
+
predictions = run_model(target_dir, model, frame_id_val)
|
| 719 |
+
|
| 720 |
+
fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg)
|
| 721 |
+
|
| 722 |
+
torch.cuda.empty_cache()
|
| 723 |
+
msg = f"Reconstruction Success ({num_frames} frames processed, showing frame {frame_id_val})."
|
| 724 |
+
return fig, msg, predictions
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def update_plot(
|
| 728 |
+
target_dir,
|
| 729 |
+
predictions,
|
| 730 |
+
conf_thres,
|
| 731 |
+
mask_black_bg,
|
| 732 |
+
mask_white_bg,
|
| 733 |
+
is_example,
|
| 734 |
+
):
|
| 735 |
+
if is_example == "True" or predictions is None:
|
| 736 |
+
return None, "No reconstruction available. Please click the Reconstruct button first."
|
| 737 |
+
|
| 738 |
+
fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg)
|
| 739 |
+
return fig, "Updated visualization with new settings. Use the slider below the plot to scrub frames."
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def clear_fields():
|
| 743 |
+
return None
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def update_log():
|
| 747 |
+
return "Loading and Reconstructing..."
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def example_pipeline(
|
| 751 |
+
input_video_ex,
|
| 752 |
+
num_images_str,
|
| 753 |
+
input_images_ex,
|
| 754 |
+
conf_thres_val,
|
| 755 |
+
mask_black_bg_val,
|
| 756 |
+
mask_white_bg_val,
|
| 757 |
+
is_example_str,
|
| 758 |
+
frame_id_val,
|
| 759 |
+
):
|
| 760 |
+
target_dir, image_paths = handle_uploads(input_video_ex, input_images_ex)
|
| 761 |
+
fig, log_msg, predictions = gradio_reconstruct(
|
| 762 |
+
target_dir,
|
| 763 |
+
conf_thres_val,
|
| 764 |
+
mask_black_bg_val,
|
| 765 |
+
mask_white_bg_val,
|
| 766 |
+
frame_id_val,
|
| 767 |
+
)
|
| 768 |
+
return fig, log_msg, target_dir, predictions, image_paths
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
colosseum_video = "examples/videos/Colosseum.mp4"
|
| 772 |
+
camel_video = "examples/videos/camel.mp4"
|
| 773 |
+
tennis_video = "examples/videos/tennis.mp4"
|
| 774 |
+
paragliding_video = "examples/videos/paragliding.mp4"
|
| 775 |
+
stroller_video = "examples/videos/stroller.mp4"
|
| 776 |
+
goldfish_video = "examples/videos/goldfish.mp4"
|
| 777 |
+
horse_video = "examples/videos/horse.mp4"
|
| 778 |
+
swing_video = "examples/videos/swing.mp4"
|
| 779 |
+
car_video = "examples/videos/car.mp4"
|
| 780 |
+
figure1_video = "examples/videos/figure1.mp4"
|
| 781 |
+
figure2_video = "examples/videos/figure2.mp4"
|
| 782 |
+
figure3_video = "examples/videos/figure3.mp4"
|
| 783 |
+
tesla_video = "examples/videos/tesla.mp4"
|
| 784 |
+
pstudio_video = "examples/videos/pstudio.mp4"
|
| 785 |
+
|
| 786 |
+
theme = gr.themes.Default(
|
| 787 |
+
primary_hue=gr.themes.colors.slate,
|
| 788 |
+
secondary_hue=gr.themes.colors.zinc,
|
| 789 |
+
neutral_hue=gr.themes.colors.slate,
|
| 790 |
+
).set(
|
| 791 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
| 792 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
| 793 |
+
body_background_fill="#FFFFFF",
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
css = """
|
| 797 |
+
.custom-log * {
|
| 798 |
+
font-style: italic;
|
| 799 |
+
font-size: 22px !important;
|
| 800 |
+
background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
|
| 801 |
+
-webkit-background-clip: text;
|
| 802 |
+
background-clip: text;
|
| 803 |
+
font-weight: bold !important;
|
| 804 |
+
color: transparent !important;
|
| 805 |
+
text-align: center !important;
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
.example-log * {
|
| 809 |
+
font-style: italic;
|
| 810 |
+
font-size: 16px !important;
|
| 811 |
+
background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
|
| 812 |
+
-webkit-background-clip: text;
|
| 813 |
+
background-clip: text;
|
| 814 |
+
color: transparent !important;
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
#my_radio .wrap {
|
| 818 |
+
display: flex;
|
| 819 |
+
flex-wrap: nowrap;
|
| 820 |
+
justify-content: center;
|
| 821 |
+
align-items: center;
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
#my_radio .wrap label {
|
| 825 |
+
display: flex;
|
| 826 |
+
width: 50%;
|
| 827 |
+
justify-content: center;
|
| 828 |
+
align-items: center;
|
| 829 |
+
margin: 0;
|
| 830 |
+
padding: 10px 0;
|
| 831 |
+
box-sizing: border-box;
|
| 832 |
+
}
|
| 833 |
+
"""
|
| 834 |
+
|
| 835 |
+
cfg = load_cfg_from_cli()
|
| 836 |
+
model = load_model(cfg)
|
| 837 |
+
|
| 838 |
+
with gr.Blocks(theme=theme, css=css) as demo:
|
| 839 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 840 |
+
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
| 841 |
+
frame_id_state = gr.Textbox(label="frame_id", visible=False, value="0")
|
| 842 |
+
|
| 843 |
+
gr.HTML(
|
| 844 |
+
"""
|
| 845 |
+
<h1>V-DPM: Video Reconstruction with Dynamic Point Maps</h1>
|
| 846 |
+
<p>
|
| 847 |
+
<a href="https://github.com/eldar/vdpm">🐙 GitHub Repository</a> |
|
| 848 |
+
<a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/">Project Page</a>
|
| 849 |
+
</p>
|
| 850 |
+
<div style="font-size: 16px; line-height: 1.5;">
|
| 851 |
+
<p>Upload a video (or multiple videos for multi-view setup) or a set of images to create a dynamic point map reconstruction of a scene or object.</p>
|
| 852 |
+
</div>
|
| 853 |
+
"""
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 857 |
+
predictions_state = gr.State(value=None)
|
| 858 |
+
|
| 859 |
+
with gr.Row():
|
| 860 |
+
with gr.Column(scale=2):
|
| 861 |
+
# Change Video input to File input to allow multiple videos for multi-view support
|
| 862 |
+
gr.Markdown("### Input")
|
| 863 |
+
input_video = gr.File(
|
| 864 |
+
label="Upload Video(s)",
|
| 865 |
+
file_count="multiple",
|
| 866 |
+
file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"],
|
| 867 |
+
interactive=True
|
| 868 |
+
)
|
| 869 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
| 870 |
+
image_gallery = gr.Gallery(
|
| 871 |
+
label="Preview",
|
| 872 |
+
columns=4,
|
| 873 |
+
height="300px",
|
| 874 |
+
show_download_button=True,
|
| 875 |
+
object_fit="contain",
|
| 876 |
+
preview=True,
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
with gr.Column(scale=5):
|
| 880 |
+
gr.Markdown("**3D Reconstruction (Point Cloud)**")
|
| 881 |
+
log_output = gr.Markdown(
|
| 882 |
+
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
|
| 883 |
+
)
|
| 884 |
+
reconstruction_output = gr.Plot(label="3D Point Cloud")
|
| 885 |
+
|
| 886 |
+
with gr.Row():
|
| 887 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 888 |
+
gr.ClearButton(
|
| 889 |
+
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
|
| 890 |
+
scale=1,
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
with gr.Row():
|
| 894 |
+
conf_thres = gr.Slider(0, 100, value=50, step=1, label="Confidence Threshold (%)")
|
| 895 |
+
with gr.Column():
|
| 896 |
+
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
| 897 |
+
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
|
| 898 |
+
|
| 899 |
+
examples = [
|
| 900 |
+
[camel_video, "17", None, 15.0, False, False, "True", "8"],
|
| 901 |
+
[horse_video, "18", None, 50.0, False, False, "True", "2"],
|
| 902 |
+
[tennis_video, "11", None, 5.0, False, False, "True", "0"],
|
| 903 |
+
[paragliding_video, "11", None, 5.0, False, False, "True", "0"],
|
| 904 |
+
[stroller_video, "17", None, 10.0, False, False, "True", "8"],
|
| 905 |
+
[goldfish_video, "11", None, 12.0, False, False, "True", "5"],
|
| 906 |
+
[swing_video, "10", None, 40.0, False, False, "True", "4"],
|
| 907 |
+
[car_video, "13", None, 15.0, False, False, "True", "7"],
|
| 908 |
+
[figure1_video, "10", None, 25.0, False, False, "True", "0"],
|
| 909 |
+
[figure2_video, "12", None, 25.0, False, False, "True", "6"],
|
| 910 |
+
[figure3_video, "13", None, 30.0, False, False, "True", "0"],
|
| 911 |
+
[tesla_video, "18", None, 20.0, False, True, "True", "0"],
|
| 912 |
+
[pstudio_video, "12", None, 0.0, False, False, "True", "6"],
|
| 913 |
+
]
|
| 914 |
+
|
| 915 |
+
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
|
| 916 |
+
gr.Examples(
|
| 917 |
+
examples=examples,
|
| 918 |
+
inputs=[
|
| 919 |
+
input_video,
|
| 920 |
+
num_images,
|
| 921 |
+
input_images,
|
| 922 |
+
conf_thres,
|
| 923 |
+
mask_black_bg,
|
| 924 |
+
mask_white_bg,
|
| 925 |
+
is_example,
|
| 926 |
+
frame_id_state,
|
| 927 |
+
],
|
| 928 |
+
outputs=[
|
| 929 |
+
reconstruction_output,
|
| 930 |
+
log_output,
|
| 931 |
+
target_dir_output,
|
| 932 |
+
predictions_state,
|
| 933 |
+
image_gallery,
|
| 934 |
+
],
|
| 935 |
+
fn=example_pipeline,
|
| 936 |
+
cache_examples=False,
|
| 937 |
+
examples_per_page=50,
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
|
| 941 |
+
fn=update_log, inputs=[], outputs=[log_output]
|
| 942 |
+
).then(
|
| 943 |
+
fn=gradio_reconstruct,
|
| 944 |
+
inputs=[
|
| 945 |
+
target_dir_output,
|
| 946 |
+
conf_thres,
|
| 947 |
+
mask_black_bg,
|
| 948 |
+
mask_white_bg,
|
| 949 |
+
frame_id_state,
|
| 950 |
+
],
|
| 951 |
+
outputs=[reconstruction_output, log_output, predictions_state],
|
| 952 |
+
).then(
|
| 953 |
+
fn=lambda: "False", inputs=[], outputs=[is_example]
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
for ctrl in (conf_thres, mask_black_bg, mask_white_bg):
|
| 957 |
+
ctrl.change(
|
| 958 |
+
fn=update_plot,
|
| 959 |
+
inputs=[
|
| 960 |
+
target_dir_output,
|
| 961 |
+
predictions_state,
|
| 962 |
+
conf_thres,
|
| 963 |
+
mask_black_bg,
|
| 964 |
+
mask_white_bg,
|
| 965 |
+
is_example,
|
| 966 |
+
],
|
| 967 |
+
outputs=[reconstruction_output, log_output],
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
input_video.change(
|
| 971 |
+
fn=update_gallery_on_upload,
|
| 972 |
+
inputs=[input_video, input_images],
|
| 973 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 974 |
+
)
|
| 975 |
+
input_images.change(
|
| 976 |
+
fn=update_gallery_on_upload,
|
| 977 |
+
inputs=[input_video, input_images],
|
| 978 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
demo.queue(max_size=20).launch(show_error=True, share=True)
|
input_images_20260127_052216_587020/images/000000.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000001.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000002.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000003.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000004.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000005.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000006.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000007.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000008.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000009.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000010.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000011.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000012.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000013.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000014.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000015.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000016.png
ADDED
|
Git LFS Details
|
input_images_20260127_052216_587020/images/000017.png
ADDED
|
Git LFS Details
|
input_images_20260127_052439_748027/images/000000.png
ADDED
|
Git LFS Details
|
input_images_20260127_052439_748027/images/000001.png
ADDED
|
Git LFS Details
|
input_images_20260127_052439_748027/images/000002.png
ADDED
|
Git LFS Details
|
input_images_20260127_052439_748027/images/000003.png
ADDED
|
Git LFS Details
|