dxm21 commited on
Commit
b678162
·
verified ·
1 Parent(s): bff78c9

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +75 -0
  2. .gitignore +132 -0
  3. .gitmodules +0 -0
  4. .gradio/certificate.pem +31 -0
  5. LICENSE +22 -0
  6. LICENSE-VGGT +115 -0
  7. README.md +40 -8
  8. check_model_size.py +85 -0
  9. configs/config.yaml +50 -0
  10. configs/model/dpm.yaml +3 -0
  11. configs/visualise.yaml +13 -0
  12. dpm/aggregator.py +366 -0
  13. dpm/decoder.py +416 -0
  14. dpm/model.py +149 -0
  15. examples/videos/camel.mp4 +3 -0
  16. examples/videos/car.mp4 +3 -0
  17. examples/videos/figure1.mp4 +3 -0
  18. examples/videos/figure2.mp4 +3 -0
  19. examples/videos/figure3.mp4 +3 -0
  20. examples/videos/goldfish.mp4 +3 -0
  21. examples/videos/horse.mp4 +3 -0
  22. examples/videos/paragliding.mp4 +3 -0
  23. examples/videos/pstudio.mp4 +3 -0
  24. examples/videos/stroller.mp4 +3 -0
  25. examples/videos/swing.mp4 +3 -0
  26. examples/videos/tennis.mp4 +3 -0
  27. examples/videos/tesla.mp4 +3 -0
  28. gradio_demo.py +981 -0
  29. input_images_20260127_052216_587020/images/000000.png +3 -0
  30. input_images_20260127_052216_587020/images/000001.png +3 -0
  31. input_images_20260127_052216_587020/images/000002.png +3 -0
  32. input_images_20260127_052216_587020/images/000003.png +3 -0
  33. input_images_20260127_052216_587020/images/000004.png +3 -0
  34. input_images_20260127_052216_587020/images/000005.png +3 -0
  35. input_images_20260127_052216_587020/images/000006.png +3 -0
  36. input_images_20260127_052216_587020/images/000007.png +3 -0
  37. input_images_20260127_052216_587020/images/000008.png +3 -0
  38. input_images_20260127_052216_587020/images/000009.png +3 -0
  39. input_images_20260127_052216_587020/images/000010.png +3 -0
  40. input_images_20260127_052216_587020/images/000011.png +3 -0
  41. input_images_20260127_052216_587020/images/000012.png +3 -0
  42. input_images_20260127_052216_587020/images/000013.png +3 -0
  43. input_images_20260127_052216_587020/images/000014.png +3 -0
  44. input_images_20260127_052216_587020/images/000015.png +3 -0
  45. input_images_20260127_052216_587020/images/000016.png +3 -0
  46. input_images_20260127_052216_587020/images/000017.png +3 -0
  47. input_images_20260127_052439_748027/images/000000.png +3 -0
  48. input_images_20260127_052439_748027/images/000001.png +3 -0
  49. input_images_20260127_052439_748027/images/000002.png +3 -0
  50. input_images_20260127_052439_748027/images/000003.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,78 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/videos/camel.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/videos/car.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/videos/figure1.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/videos/figure2.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ examples/videos/figure3.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ examples/videos/goldfish.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ examples/videos/horse.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ examples/videos/paragliding.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ examples/videos/pstudio.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ examples/videos/stroller.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ examples/videos/swing.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ examples/videos/tennis.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ examples/videos/tesla.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ input_images_20260127_052216_587020/images/000000.png filter=lfs diff=lfs merge=lfs -text
50
+ input_images_20260127_052216_587020/images/000001.png filter=lfs diff=lfs merge=lfs -text
51
+ input_images_20260127_052216_587020/images/000002.png filter=lfs diff=lfs merge=lfs -text
52
+ input_images_20260127_052216_587020/images/000003.png filter=lfs diff=lfs merge=lfs -text
53
+ input_images_20260127_052216_587020/images/000004.png filter=lfs diff=lfs merge=lfs -text
54
+ input_images_20260127_052216_587020/images/000005.png filter=lfs diff=lfs merge=lfs -text
55
+ input_images_20260127_052216_587020/images/000006.png filter=lfs diff=lfs merge=lfs -text
56
+ input_images_20260127_052216_587020/images/000007.png filter=lfs diff=lfs merge=lfs -text
57
+ input_images_20260127_052216_587020/images/000008.png filter=lfs diff=lfs merge=lfs -text
58
+ input_images_20260127_052216_587020/images/000009.png filter=lfs diff=lfs merge=lfs -text
59
+ input_images_20260127_052216_587020/images/000010.png filter=lfs diff=lfs merge=lfs -text
60
+ input_images_20260127_052216_587020/images/000011.png filter=lfs diff=lfs merge=lfs -text
61
+ input_images_20260127_052216_587020/images/000012.png filter=lfs diff=lfs merge=lfs -text
62
+ input_images_20260127_052216_587020/images/000013.png filter=lfs diff=lfs merge=lfs -text
63
+ input_images_20260127_052216_587020/images/000014.png filter=lfs diff=lfs merge=lfs -text
64
+ input_images_20260127_052216_587020/images/000015.png filter=lfs diff=lfs merge=lfs -text
65
+ input_images_20260127_052216_587020/images/000016.png filter=lfs diff=lfs merge=lfs -text
66
+ input_images_20260127_052216_587020/images/000017.png filter=lfs diff=lfs merge=lfs -text
67
+ input_images_20260127_052439_748027/images/000000.png filter=lfs diff=lfs merge=lfs -text
68
+ input_images_20260127_052439_748027/images/000001.png filter=lfs diff=lfs merge=lfs -text
69
+ input_images_20260127_052439_748027/images/000002.png filter=lfs diff=lfs merge=lfs -text
70
+ input_images_20260127_052439_748027/images/000003.png filter=lfs diff=lfs merge=lfs -text
71
+ input_images_20260127_052439_748027/images/000004.png filter=lfs diff=lfs merge=lfs -text
72
+ input_images_20260127_052439_748027/images/000005.png filter=lfs diff=lfs merge=lfs -text
73
+ input_images_20260127_052439_748027/images/000006.png filter=lfs diff=lfs merge=lfs -text
74
+ input_images_20260127_052439_748027/images/000007.png filter=lfs diff=lfs merge=lfs -text
75
+ input_images_20260127_052439_748027/images/000008.png filter=lfs diff=lfs merge=lfs -text
76
+ input_images_20260127_052439_748027/images/000009.png filter=lfs diff=lfs merge=lfs -text
77
+ input_images_20260127_052521_840381/images/cam00.png filter=lfs diff=lfs merge=lfs -text
78
+ input_images_20260127_052521_840381/images/cam01.png filter=lfs diff=lfs merge=lfs -text
79
+ input_images_20260127_052521_840381/images/cam18.png filter=lfs diff=lfs merge=lfs -text
80
+ input_images_20260127_052521_840381/images/cam19.png filter=lfs diff=lfs merge=lfs -text
81
+ input_images_20260127_053256_343581/images/cam00.png filter=lfs diff=lfs merge=lfs -text
82
+ input_images_20260127_053256_343581/images/cam01.png filter=lfs diff=lfs merge=lfs -text
83
+ input_images_20260127_053256_343581/images/cam18.png filter=lfs diff=lfs merge=lfs -text
84
+ input_images_20260127_053256_343581/images/cam19.png filter=lfs diff=lfs merge=lfs -text
85
+ input_images_20260127_053630_522657/images/cam18.png filter=lfs diff=lfs merge=lfs -text
86
+ input_images_20260127_053630_522657/images/cam19.png filter=lfs diff=lfs merge=lfs -text
87
+ input_images_20260127_161749_937508/images/000000.png filter=lfs diff=lfs merge=lfs -text
88
+ input_images_20260127_161749_937508/images/000001.png filter=lfs diff=lfs merge=lfs -text
89
+ input_images_20260127_161749_937508/images/000002.png filter=lfs diff=lfs merge=lfs -text
90
+ input_images_20260127_161749_937508/images/000003.png filter=lfs diff=lfs merge=lfs -text
91
+ input_images_20260127_162743_468054/images/000000.png filter=lfs diff=lfs merge=lfs -text
92
+ input_images_20260127_162743_468054/images/000001.png filter=lfs diff=lfs merge=lfs -text
93
+ input_images_20260127_162743_468054/images/000002.png filter=lfs diff=lfs merge=lfs -text
94
+ input_images_20260127_162743_468054/images/000003.png filter=lfs diff=lfs merge=lfs -text
95
+ input_images_20260127_163859_002158/images/000000.png filter=lfs diff=lfs merge=lfs -text
96
+ input_images_20260127_163859_002158/images/000001.png filter=lfs diff=lfs merge=lfs -text
97
+ input_images_20260127_163859_002158/images/000002.png filter=lfs diff=lfs merge=lfs -text
98
+ input_images_20260127_163859_002158/images/000003.png filter=lfs diff=lfs merge=lfs -text
99
+ input_images_20260127_170350_777007/images/000000.png filter=lfs diff=lfs merge=lfs -text
100
+ input_images_20260127_170350_777007/images/000001.png filter=lfs diff=lfs merge=lfs -text
101
+ input_images_20260127_170350_777007/images/000002.png filter=lfs diff=lfs merge=lfs -text
102
+ input_images_20260127_170350_777007/images/000003.png filter=lfs diff=lfs merge=lfs -text
103
+ input_images_20260127_170350_777007/images/000004.png filter=lfs diff=lfs merge=lfs -text
104
+ input_images_20260127_170350_777007/images/000005.png filter=lfs diff=lfs merge=lfs -text
105
+ input_images_20260127_170350_777007/images/000006.png filter=lfs diff=lfs merge=lfs -text
106
+ input_images_20260127_170350_777007/images/000007.png filter=lfs diff=lfs merge=lfs -text
107
+ input_images_20260127_170350_777007/images/000008.png filter=lfs diff=lfs merge=lfs -text
108
+ input_images_20260127_170350_777007/images/000009.png filter=lfs diff=lfs merge=lfs -text
109
+ input_images_20260127_170350_777007/images/000010.png filter=lfs diff=lfs merge=lfs -text
110
+ input_images_20260127_170350_777007/images/000011.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ checkpoints/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
.gitmodules ADDED
File without changes
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Eldar Insafutdinov, Edgar Sucar
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
LICENSE-VGGT ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGT License
2
+
3
+ v1 Last Updated: July 29, 2025
4
+
5
+ “Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
6
+
7
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
8
+
9
+
10
+ “Documentation” means the specifications, manuals and documentation accompanying
11
+ Research Materials distributed by Meta.
12
+
13
+
14
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
18
+
19
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
20
+
21
+
22
+ 1. License Rights and Redistribution.
23
+
24
+
25
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
26
+
27
+ b. Redistribution and Use.
28
+
29
+
30
+ i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
31
+
32
+
33
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
34
+
35
+
36
+ iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
37
+ 2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
38
+
39
+
40
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
41
+
42
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
43
+
44
+ 5. Intellectual Property.
45
+
46
+
47
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
48
+
49
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
50
+
51
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
52
+
53
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
54
+
55
+
56
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
57
+
58
+
59
+ Acceptable Use Policy
60
+
61
+ Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
62
+
63
+ As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.
64
+
65
+ Prohibited Uses
66
+
67
+ You agree you will not use, or allow others to use, Research Materials to:
68
+
69
+ Violate the law or others’ rights, including to:
70
+ Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
71
+ Violence or terrorism
72
+ Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
73
+ Human trafficking, exploitation, and sexual violence
74
+ The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
75
+ Sexual solicitation
76
+ Any other criminal activity
77
+
78
+ Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
79
+
80
+ Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
81
+
82
+ Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
83
+
84
+ Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
85
+
86
+ Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials
87
+
88
+ Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
89
+
90
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
91
+
92
+ Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
93
+
94
+ Guns and illegal weapons (including weapon development)
95
+
96
+ Illegal drugs and regulated/controlled substances
97
+ Operation of critical infrastructure, transportation technologies, or heavy machinery
98
+
99
+ Self-harm or harm to others, including suicide, cutting, and eating disorders
100
+ Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
101
+
102
+ 3. Intentionally deceive or mislead others, including use of Research Materials related to the following:
103
+
104
+ Generating, promoting, or furthering fraud or the creation or promotion of disinformation
105
+ Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
106
+
107
+ Generating, promoting, or further distributing spam
108
+
109
+ Impersonating another individual without consent, authorization, or legal right
110
+
111
+ Representing that outputs of research materials or outputs from technology using Research Materials are human-generated
112
+
113
+ Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
114
+
115
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
README.md CHANGED
@@ -1,12 +1,44 @@
1
  ---
2
- title: Vdpm
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.4.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: vdpm
3
+ app_file: gradio_demo.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.17.1
 
 
6
  ---
7
+ <div align="center">
8
+ <h1>V-DPM: 4D Video Reconstruction with Dynamic Point Maps</h1>
9
 
10
+ <a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
11
+ <a href="https://huggingface.co/spaces/edgarsucar/vdpm"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
12
+
13
+ **[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**
14
+
15
+
16
+ [Edgar Sucar](https://edgarsucar.github.io/)\*, [Eldar Insafutdinov](https://eldar.insafutdinov.com/)\*, [Zihang Lai](https://scholar.google.com/citations?user=31eXgMYAAAAJ), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/)
17
+ </div>
18
+
19
+ ## Setup
20
+
21
+ First, clone the repository and setup a virtual environment with [uv](https://github.com/astral-sh/uv):
22
+
23
+ ```bash
24
+ git clone git@github.com:eldar/vdpm.git
25
+ cd vdpm
26
+ uv venv --python 3.12
27
+ . .venv/bin/activate
28
+
29
+ # Install PyTorch with CUDA 11.8 first
30
+ uv pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
31
+
32
+ # Then install remaining dependencies
33
+ uv pip install -r requirements.txt
34
+ ```
35
+
36
+ ## Viser demo
37
+ ```bash
38
+ python visualise.py ++vis.input_video=examples/videos/camel.mp4
39
+ ```
40
+
41
+ ## Gradio demo
42
+ ```bash
43
+ python gradio_demo.py
44
+ ```
check_model_size.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # Add parent directory to path
6
+ sys.path.insert(0, str(Path(__file__).parent))
7
+
8
+ def check_model_memory():
9
+ # Simple config object
10
+ class SimpleConfig:
11
+ class ModelConfig:
12
+ decoder_depth = 4
13
+ model = ModelConfig()
14
+
15
+ cfg = SimpleConfig()
16
+
17
+ # Import after path is set
18
+ from dpm.model import VDPM
19
+
20
+ # Create model on CPU first to count parameters
21
+ print("Creating model...")
22
+ model = VDPM(cfg)
23
+
24
+ # Count parameters
25
+ total_params = sum(p.numel() for p in model.parameters())
26
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
27
+
28
+ print(f"\n{'='*60}")
29
+ print(f"MODEL SIZE ANALYSIS FOR RTX 3070 Ti (8GB)")
30
+ print(f"{'='*60}")
31
+ print(f"Total parameters: {total_params:,}")
32
+ print(f"Trainable parameters: {trainable_params:,}")
33
+ print(f"\nEstimated model weights memory:")
34
+ print(f" - FP32 (float32): {total_params * 4 / 1024**3:.2f} GB")
35
+ print(f" - FP16 (float16): {total_params * 2 / 1024**3:.2f} GB")
36
+ print(f" - BF16 (bfloat16): {total_params * 2 / 1024**3:.2f} GB")
37
+ print(f" - INT8 (quantized): {total_params * 1 / 1024**3:.2f} GB <-- RECOMMENDED for 8GB GPU")
38
+
39
+ # Estimate activation memory for typical input
40
+ batch_size = 1
41
+ num_frames = 5 # typical video length
42
+ img_size = 518
43
+ print(f"\nEstimated activation memory (batch={batch_size}, frames={num_frames}, img_size={img_size}):")
44
+
45
+ # Input images: [B, S, 3, H, W]
46
+ input_mem = batch_size * num_frames * 3 * img_size * img_size * 4 / 1024**3
47
+ print(f" - Input images (FP32): {input_mem:.2f} GB")
48
+
49
+ # Rough estimate for activations (can be 2-4x model size during forward pass)
50
+ activation_mem_estimate = total_params * 2 * 3 / 1024**3 # conservative estimate
51
+ print(f" - Activations (estimate): {activation_mem_estimate:.2f} GB")
52
+
53
+ # Calculate total for different precision modes
54
+ total_fp16 = (total_params * 2 / 1024**3) + input_mem + activation_mem_estimate
55
+ total_int8 = (total_params * 1 / 1024**3) + input_mem + (activation_mem_estimate * 0.6) # INT8 reduces activations too
56
+
57
+ print(f"\nTotal estimated GPU memory needed:")
58
+ print(f" - With FP16/BF16: {total_fp16:.2f} GB")
59
+ print(f" - With INT8 quantization: {total_int8:.2f} GB <-- FITS IN 8GB!")
60
+ print(f"Your RTX 3070 Ti has: 8 GB VRAM")
61
+
62
+ if total_int8 <= 8:
63
+ print(f"\n✓ With INT8 quantization, model will fit in GPU memory!")
64
+ print(f" Set USE_QUANTIZATION = True in gradio_demo.py")
65
+ elif total_fp16 > 8:
66
+ print(f"\n⚠️ WARNING: Even with INT8 ({total_int8:.2f} GB), memory is tight")
67
+ print(f" Recommendations:")
68
+ print(f" 1. Use INT8 quantization (USE_QUANTIZATION = True)")
69
+ print(f" 2. Reduce number of input frames to {num_frames} or fewer")
70
+ print(f" 3. Clear CUDA cache between batches")
71
+ else:
72
+ print(f"\n✓ Model should fit with FP16!")
73
+
74
+ print(f"{'='*60}\n")
75
+
76
+ # Check actual GPU memory if CUDA available
77
+ if torch.cuda.is_available():
78
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
79
+ print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
80
+ print(f"Current GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
81
+ print(f"Current GPU memory cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
82
+
83
+ if __name__ == "__main__":
84
+ check_model_memory()
85
+
configs/config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - hydra: defaults
4
+ - model: dpm
5
+
6
+ config:
7
+ exp_name: "debug"
8
+ file: "config.yaml"
9
+
10
+ data_loader:
11
+ batch_size: 2
12
+ num_workers: 8
13
+ dynamic_batch: false
14
+
15
+ train:
16
+ logging: true
17
+ num_gpus: 4
18
+ amp: bfloat16
19
+ amp_dpt: false
20
+ dry_run: false
21
+ camera_loss_lambda: 5.0
22
+
23
+ optimiser:
24
+ lr: 0.00005 # absolute lr
25
+ blr: 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
26
+ start_epoch:
27
+ epochs: 70
28
+ accum_iter: 1
29
+ warmup_epochs: 3
30
+ min_lr: 1e-06
31
+
32
+ run:
33
+ resume: false
34
+ dirpath: null
35
+ debug: false
36
+ random_seed: 42
37
+ git_hash: null
38
+ log_frequency: 250
39
+ training_progress_bar: false
40
+ save_freq: 5
41
+ eval_freq: 1
42
+ keep_freq: 5
43
+ print_freq: 20
44
+ num_keep_ckpts: 5
45
+ # Old Dust3r params
46
+ world_size: -1
47
+ local_rank: -1
48
+ dist_url: "env://"
49
+ seed: 0
50
+
configs/model/dpm.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: dpm-video
2
+ pretrained: /work/eldar/models/vggt/VGGT-1B.pt
3
+ decoder_depth: 4
configs/visualise.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - model: dpm
4
+
5
+ hydra:
6
+ output_subdir: null # Disable saving of config files.
7
+ job:
8
+ chdir: False
9
+
10
+ vis:
11
+ port: 8080
12
+ input_video:
13
+
dpm/aggregator.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE-VGGT file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
+ from typing import Optional, Tuple, Union, List, Dict, Any
13
+
14
+ from vggt.layers import PatchEmbed
15
+ from vggt.layers.block import Block
16
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
17
+ from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
22
+ _RESNET_STD = [0.229, 0.224, 0.225]
23
+
24
+
25
+ class Aggregator(nn.Module):
26
+ """
27
+ The Aggregator applies alternating-attention over input frames,
28
+ as described in VGGT: Visual Geometry Grounded Transformer.
29
+
30
+
31
+ Args:
32
+ img_size (int): Image size in pixels.
33
+ patch_size (int): Size of each patch for PatchEmbed.
34
+ embed_dim (int): Dimension of the token embeddings.
35
+ depth (int): Number of blocks.
36
+ num_heads (int): Number of attention heads.
37
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
38
+ num_register_tokens (int): Number of register tokens.
39
+ block_fn (nn.Module): The block type used for attention (Block by default).
40
+ qkv_bias (bool): Whether to include bias in QKV projections.
41
+ proj_bias (bool): Whether to include bias in the output projection.
42
+ ffn_bias (bool): Whether to include bias in MLP layers.
43
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
44
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
45
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
46
+ qk_norm (bool): Whether to apply QK normalization.
47
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
48
+ init_values (float): Init scale for layer scale.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ img_size=518,
54
+ patch_size=14,
55
+ embed_dim=1024,
56
+ depth=24,
57
+ num_heads=16,
58
+ mlp_ratio=4.0,
59
+ num_register_tokens=4,
60
+ block_fn=Block,
61
+ qkv_bias=True,
62
+ proj_bias=True,
63
+ ffn_bias=True,
64
+ patch_embed="dinov2_vitl14_reg",
65
+ aa_order=["frame", "global"],
66
+ aa_block_size=1,
67
+ qk_norm=True,
68
+ rope_freq=100,
69
+ init_values=0.01,
70
+ ):
71
+ super().__init__()
72
+
73
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
74
+
75
+ # Initialize rotary position embedding if frequency > 0
76
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
77
+ self.position_getter = PositionGetter() if self.rope is not None else None
78
+
79
+ self.frame_blocks = nn.ModuleList(
80
+ [
81
+ block_fn(
82
+ dim=embed_dim,
83
+ num_heads=num_heads,
84
+ mlp_ratio=mlp_ratio,
85
+ qkv_bias=qkv_bias,
86
+ proj_bias=proj_bias,
87
+ ffn_bias=ffn_bias,
88
+ init_values=init_values,
89
+ qk_norm=qk_norm,
90
+ rope=self.rope,
91
+ )
92
+ for _ in range(depth)
93
+ ]
94
+ )
95
+
96
+ self.global_blocks = nn.ModuleList(
97
+ [
98
+ block_fn(
99
+ dim=embed_dim,
100
+ num_heads=num_heads,
101
+ mlp_ratio=mlp_ratio,
102
+ qkv_bias=qkv_bias,
103
+ proj_bias=proj_bias,
104
+ ffn_bias=ffn_bias,
105
+ init_values=init_values,
106
+ qk_norm=qk_norm,
107
+ rope=self.rope,
108
+ )
109
+ for _ in range(depth)
110
+ ]
111
+ )
112
+
113
+ self.depth = depth
114
+ self.aa_order = aa_order
115
+ self.patch_size = patch_size
116
+ self.aa_block_size = aa_block_size
117
+
118
+ # Validate that depth is divisible by aa_block_size
119
+ if self.depth % self.aa_block_size != 0:
120
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
121
+
122
+ self.aa_block_num = self.depth // self.aa_block_size
123
+
124
+ # Note: We have two camera tokens, one for the first frame and one for the rest
125
+ # The same applies for register tokens
126
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
127
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
128
+
129
+ # The patch tokens start after the camera and register tokens
130
+ self.patch_start_idx = 1 + num_register_tokens
131
+
132
+ self.time_conditioning_token = nn.Parameter(torch.randn(1, 1, embed_dim))
133
+ self.patch_start_idx += 1
134
+
135
+ # Initialize parameters with small values
136
+ nn.init.normal_(self.camera_token, std=1e-6)
137
+ nn.init.normal_(self.register_token, std=1e-6)
138
+
139
+ # Register normalization constants as buffers
140
+ for name, value in (
141
+ ("_resnet_mean", _RESNET_MEAN),
142
+ ("_resnet_std", _RESNET_STD),
143
+ ):
144
+ self.register_buffer(
145
+ name,
146
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
147
+ persistent=False,
148
+ )
149
+
150
+ self.use_reentrant = False # hardcoded to False
151
+
152
+ def __build_patch_embed__(
153
+ self,
154
+ patch_embed,
155
+ img_size,
156
+ patch_size,
157
+ num_register_tokens,
158
+ interpolate_antialias=True,
159
+ interpolate_offset=0.0,
160
+ block_chunks=0,
161
+ init_values=1.0,
162
+ embed_dim=1024,
163
+ ):
164
+ """
165
+ Build the patch embed layer. If 'conv', we use a
166
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
167
+ """
168
+
169
+ if "conv" in patch_embed:
170
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
171
+ else:
172
+ vit_models = {
173
+ "dinov2_vitl14_reg": vit_large,
174
+ "dinov2_vitb14_reg": vit_base,
175
+ "dinov2_vits14_reg": vit_small,
176
+ "dinov2_vitg2_reg": vit_giant2,
177
+ }
178
+
179
+ self.patch_embed = vit_models[patch_embed](
180
+ img_size=img_size,
181
+ patch_size=patch_size,
182
+ num_register_tokens=num_register_tokens,
183
+ interpolate_antialias=interpolate_antialias,
184
+ interpolate_offset=interpolate_offset,
185
+ block_chunks=block_chunks,
186
+ init_values=init_values,
187
+ )
188
+
189
+ # Disable gradient updates for mask token
190
+ if hasattr(self.patch_embed, "mask_token"):
191
+ self.patch_embed.mask_token.requires_grad_(False)
192
+
193
+ def forward(
194
+ self,
195
+ images: torch.Tensor,
196
+ ) -> Tuple[List[torch.Tensor], int]:
197
+ """
198
+ Args:
199
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
200
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
201
+
202
+ Returns:
203
+ (list[torch.Tensor], int):
204
+ The list of outputs from the attention blocks,
205
+ and the patch_start_idx indicating where patch tokens begin.
206
+ """
207
+ B, S, C_in, H, W = images.shape
208
+
209
+ if C_in != 3:
210
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
211
+
212
+ # Normalize images and reshape for patch embed
213
+ images = (images - self._resnet_mean) / self._resnet_std
214
+
215
+ # Reshape to [B*S, C, H, W] for patch embedding
216
+ images = images.view(B * S, C_in, H, W)
217
+ patch_tokens = self.patch_embed(images)
218
+
219
+ if isinstance(patch_tokens, dict):
220
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
221
+
222
+ _, P, C = patch_tokens.shape
223
+
224
+ # Expand camera and register tokens to match batch size and sequence length
225
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
226
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
227
+ # do something similar for time_conditioning_token
228
+ time_conditioning_token = slice_expand_and_flatten_single(self.time_conditioning_token, B, S)
229
+ # Concatenate special tokens with patch tokens
230
+ tokens = torch.cat([camera_token, time_conditioning_token, register_token, patch_tokens], dim=1)
231
+
232
+ pos = None
233
+ if self.rope is not None:
234
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
235
+
236
+ if self.patch_start_idx > 0:
237
+ # do not use position embedding for special tokens (camera and register tokens)
238
+ # so set pos to 0 for the special tokens
239
+ pos = pos + 1
240
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
241
+ pos = torch.cat([pos_special, pos], dim=1)
242
+
243
+ # update P because we added special tokens
244
+ _, P, C = tokens.shape
245
+
246
+ frame_idx = 0
247
+ global_idx = 0
248
+ output_list = []
249
+
250
+ for _ in range(self.aa_block_num):
251
+ for attn_type in self.aa_order:
252
+ if attn_type == "frame":
253
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
254
+ tokens, B, S, P, C, frame_idx, pos=pos
255
+ )
256
+ elif attn_type == "global":
257
+ tokens, global_idx, global_intermediates = self._process_global_attention(
258
+ tokens, B, S, P, C, global_idx, pos=pos
259
+ )
260
+ else:
261
+ raise ValueError(f"Unknown attention type: {attn_type}")
262
+
263
+ for i in range(len(frame_intermediates)):
264
+ # concat frame and global intermediates, [B x S x P x 2C]
265
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
266
+ output_list.append(concat_inter)
267
+
268
+ del concat_inter
269
+ del frame_intermediates
270
+ del global_intermediates
271
+ return output_list, self.patch_start_idx
272
+
273
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
274
+ """
275
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
276
+ """
277
+ # If needed, reshape tokens or positions:
278
+ if tokens.shape != (B * S, P, C):
279
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
280
+
281
+ if pos is not None and pos.shape != (B * S, P, 2):
282
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
283
+
284
+ intermediates = []
285
+
286
+ # by default, self.aa_block_size=1, which processes one block at a time
287
+ for _ in range(self.aa_block_size):
288
+ if self.training:
289
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
290
+ else:
291
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
292
+ frame_idx += 1
293
+ intermediates.append(tokens.view(B, S, P, C))
294
+
295
+ return tokens, frame_idx, intermediates
296
+
297
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
298
+ """
299
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
300
+ """
301
+ if tokens.shape != (B, S * P, C):
302
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
303
+
304
+ if pos is not None and pos.shape != (B, S * P, 2):
305
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
306
+
307
+ intermediates = []
308
+
309
+ # by default, self.aa_block_size=1, which processes one block at a time
310
+ for _ in range(self.aa_block_size):
311
+ if self.training:
312
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
313
+ else:
314
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
315
+ global_idx += 1
316
+ intermediates.append(tokens.view(B, S, P, C))
317
+
318
+ return tokens, global_idx, intermediates
319
+
320
+
321
+ def slice_expand_and_flatten(token_tensor, B, S):
322
+ """
323
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
324
+ 1) Uses the first position (index=0) for the first frame only
325
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
326
+ 3) Expands both to match batch size B
327
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
328
+ followed by (S-1) second-position tokens
329
+ 5) Flattens to (B*S, X, C) for processing
330
+
331
+ Returns:
332
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
333
+ """
334
+
335
+ # Slice out the "query" tokens => shape (1, 1, ...)
336
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
337
+ # Slice out the "other" tokens => shape (1, S-1, ...)
338
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
339
+ # Concatenate => shape (B, S, ...)
340
+ combined = torch.cat([query, others], dim=1)
341
+
342
+ # Finally flatten => shape (B*S, ...)
343
+ combined = combined.view(B * S, *combined.shape[2:])
344
+ return combined
345
+
346
+
347
+ def slice_expand_and_flatten_single(token_tensor, B, S):
348
+ """
349
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
350
+ 1) Uses the first position (index=0) for the first frame only
351
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
352
+ 3) Expands both to match batch size B
353
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
354
+ followed by (S-1) second-position tokens
355
+ 5) Flattens to (B*S, X, C) for processing
356
+
357
+ Returns:
358
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
359
+ """
360
+
361
+ # Slice out the "query" tokens => shape (1, 1, ...)
362
+ token = token_tensor.expand(B, S, *token_tensor.shape[2:])
363
+
364
+ # Finally flatten => shape (B*S, ...)
365
+ token = token.view(B * S, 1, *token.shape[2:])
366
+ return token
dpm/decoder.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE-VGGT file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ from torch import nn, Tensor
10
+ from torch.utils.checkpoint import checkpoint
11
+ from typing import List, Callable
12
+ from dataclasses import dataclass
13
+
14
+ from einops import repeat
15
+
16
+ from vggt.layers.block import drop_add_residual_stochastic_depth
17
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
18
+
19
+ from vggt.layers.attention import Attention
20
+ from vggt.layers.drop_path import DropPath
21
+ from vggt.layers.layer_scale import LayerScale
22
+ from vggt.layers.mlp import Mlp
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class ModulationOut:
29
+ shift: Tensor
30
+ scale: Tensor
31
+ gate: Tensor
32
+
33
+
34
+ class Modulation(nn.Module):
35
+ def __init__(self, dim: int, double: bool):
36
+ super().__init__()
37
+ self.is_double = double
38
+ self.multiplier = 6 if double else 3
39
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
40
+
41
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
42
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
43
+
44
+ return (
45
+ ModulationOut(*out[:3]),
46
+ ModulationOut(*out[3:]) if self.is_double else None,
47
+ )
48
+
49
+
50
+ class ConditionalBlock(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ num_heads: int,
55
+ mlp_ratio: float = 4.0,
56
+ qkv_bias: bool = True,
57
+ proj_bias: bool = True,
58
+ ffn_bias: bool = True,
59
+ drop: float = 0.0,
60
+ attn_drop: float = 0.0,
61
+ init_values=None,
62
+ drop_path: float = 0.0,
63
+ act_layer: Callable[..., nn.Module] = nn.GELU,
64
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
65
+ attn_class: Callable[..., nn.Module] = Attention,
66
+ ffn_layer: Callable[..., nn.Module] = Mlp,
67
+ qk_norm: bool = False,
68
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
69
+ rope=None,
70
+ ) -> None:
71
+ super().__init__()
72
+
73
+ self.norm1 = norm_layer(dim, elementwise_affine=False)
74
+ self.modulation = Modulation(dim, double=False)
75
+
76
+ self.attn = attn_class(
77
+ dim,
78
+ num_heads=num_heads,
79
+ qkv_bias=qkv_bias,
80
+ proj_bias=proj_bias,
81
+ attn_drop=attn_drop,
82
+ proj_drop=drop,
83
+ qk_norm=qk_norm,
84
+ fused_attn=fused_attn,
85
+ rope=rope,
86
+ )
87
+
88
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
89
+
90
+ self.norm2 = norm_layer(dim)
91
+ mlp_hidden_dim = int(dim * mlp_ratio)
92
+ self.mlp = ffn_layer(
93
+ in_features=dim,
94
+ hidden_features=mlp_hidden_dim,
95
+ act_layer=act_layer,
96
+ drop=drop,
97
+ bias=ffn_bias,
98
+ )
99
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
100
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
101
+
102
+ self.sample_drop_ratio = drop_path
103
+
104
+ def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor:
105
+ B, S = cond.shape[:2]
106
+ C = x.shape[-1]
107
+ if is_global:
108
+ P = x.shape[1] // S
109
+ cond = cond.view(B * S, C)
110
+ mod, _ = self.modulation(cond)
111
+
112
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
113
+ """
114
+ conditional attention following DiT implementation from Flux
115
+ https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239
116
+ """
117
+ def prepare_for_mod(y):
118
+ """reshape to modulate the patch tokens with correct conditioning one"""
119
+ return y.view(B, S, P, C).view(B * S, P, C) if is_global else y
120
+ def restore_after_mod(y):
121
+ """reshape back to global sequence"""
122
+ return y.view(B, S, P, C).view(B, S * P, C) if is_global else y
123
+
124
+ x = prepare_for_mod(x)
125
+ x = (1 + mod.scale) * self.norm1(x) + mod.shift
126
+ x = restore_after_mod(x)
127
+
128
+ x = self.attn(x, pos=pos)
129
+
130
+ x = prepare_for_mod(x)
131
+ x = mod.gate * x
132
+ x = restore_after_mod(x)
133
+
134
+ return x
135
+
136
+ def ffn_residual_func(x: Tensor) -> Tensor:
137
+ return self.ls2(self.mlp(self.norm2(x)))
138
+
139
+ if self.training and self.sample_drop_ratio > 0.1:
140
+ # the overhead is compensated only for a drop path rate larger than 0.1
141
+ x = drop_add_residual_stochastic_depth(
142
+ x,
143
+ pos=pos,
144
+ residual_func=attn_residual_func,
145
+ sample_drop_ratio=self.sample_drop_ratio,
146
+ )
147
+ x = drop_add_residual_stochastic_depth(
148
+ x,
149
+ residual_func=ffn_residual_func,
150
+ sample_drop_ratio=self.sample_drop_ratio,
151
+ )
152
+ elif self.training and self.sample_drop_ratio > 0.0:
153
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
154
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
155
+ else:
156
+ x = x + attn_residual_func(x, pos=pos)
157
+ x = x + ffn_residual_func(x)
158
+ return x
159
+
160
+
161
+ class Decoder(nn.Module):
162
+ """Attention blocks after encoder per DPT input feature
163
+ to generate point maps at a given time.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ cfg,
169
+ dim_in: int,
170
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
171
+ patch_size=14,
172
+ embed_dim=1024,
173
+ depth=2,
174
+ num_heads=16,
175
+ mlp_ratio=4.0,
176
+ block_fn=ConditionalBlock,
177
+ qkv_bias=True,
178
+ proj_bias=True,
179
+ ffn_bias=True,
180
+ aa_order=["frame", "global"],
181
+ aa_block_size=1,
182
+ qk_norm=True,
183
+ rope_freq=100,
184
+ init_values=0.01,
185
+ ):
186
+ super().__init__()
187
+ self.cfg = cfg
188
+ self.intermediate_layer_idx = intermediate_layer_idx
189
+
190
+ self.depth = depth
191
+ self.aa_order = aa_order
192
+ self.patch_size = patch_size
193
+ self.aa_block_size = aa_block_size
194
+
195
+ # Validate that depth is divisible by aa_block_size
196
+ if self.depth % self.aa_block_size != 0:
197
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
198
+
199
+ self.aa_block_num = self.depth // self.aa_block_size
200
+
201
+ self.rope = (
202
+ RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
203
+ )
204
+ self.position_getter = PositionGetter() if self.rope is not None else None
205
+
206
+ self.dim_in = dim_in
207
+
208
+ self.old_decoder = False
209
+ if self.old_decoder:
210
+ self.frame_blocks = nn.ModuleList(
211
+ [
212
+ block_fn(
213
+ dim=embed_dim*2,
214
+ num_heads=num_heads,
215
+ mlp_ratio=mlp_ratio,
216
+ qkv_bias=qkv_bias,
217
+ proj_bias=proj_bias,
218
+ ffn_bias=ffn_bias,
219
+ init_values=init_values,
220
+ qk_norm=qk_norm,
221
+ rope=self.rope,
222
+ )
223
+ for _ in range(depth)
224
+ ]
225
+ )
226
+ self.global_blocks = nn.ModuleList(
227
+ [
228
+ block_fn(
229
+ dim=embed_dim*2,
230
+ num_heads=num_heads,
231
+ mlp_ratio=mlp_ratio,
232
+ qkv_bias=qkv_bias,
233
+ proj_bias=proj_bias,
234
+ ffn_bias=ffn_bias,
235
+ init_values=init_values,
236
+ qk_norm=qk_norm,
237
+ rope=self.rope,
238
+ )
239
+ for _ in range(depth)
240
+ ]
241
+ )
242
+ else:
243
+ depths = [depth]
244
+ self.frame_blocks = nn.ModuleList([
245
+ nn.ModuleList([
246
+ block_fn(
247
+ dim=embed_dim*2,
248
+ num_heads=num_heads,
249
+ mlp_ratio=mlp_ratio,
250
+ qkv_bias=qkv_bias,
251
+ proj_bias=proj_bias,
252
+ ffn_bias=ffn_bias,
253
+ init_values=init_values,
254
+ qk_norm=qk_norm,
255
+ rope=self.rope,
256
+ )
257
+ for _ in range(d)
258
+ ])
259
+ for d in depths
260
+ ])
261
+
262
+ self.global_blocks = nn.ModuleList([
263
+ nn.ModuleList([
264
+ block_fn(
265
+ dim=embed_dim*2,
266
+ num_heads=num_heads,
267
+ mlp_ratio=mlp_ratio,
268
+ qkv_bias=qkv_bias,
269
+ proj_bias=proj_bias,
270
+ ffn_bias=ffn_bias,
271
+ init_values=init_values,
272
+ qk_norm=qk_norm,
273
+ rope=self.rope,
274
+ )
275
+ for _ in range(d)
276
+ ])
277
+ for d in depths
278
+ ])
279
+
280
+ self.use_reentrant = False # hardcoded to False
281
+
282
+ def get_condition_tokens(
283
+ self,
284
+ aggregated_tokens_list: List[torch.Tensor],
285
+ cond_view_idxs: torch.Tensor
286
+ ):
287
+ # Use tokens from the last block for conditioning
288
+ tokens_last = aggregated_tokens_list[-1] # [B S N_tok D]
289
+ # Extract the camera tokens
290
+ cond_token_idx = 1
291
+ camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
292
+
293
+ cond_view_idxs = cond_view_idxs.to(camera_tokens.device)
294
+ cond_view_idxs = repeat(
295
+ cond_view_idxs,
296
+ "b s -> b s c d",
297
+ c=camera_tokens.shape[2],
298
+ d=camera_tokens.shape[3],
299
+ )
300
+ cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs)
301
+
302
+ return cond_tokens
303
+
304
+ def forward(
305
+ self,
306
+ images: torch.Tensor,
307
+ aggregated_tokens_list: List[torch.Tensor],
308
+ patch_start_idx: int,
309
+ cond_view_idxs: torch.Tensor,
310
+ ):
311
+ B, S, _, H, W = images.shape
312
+
313
+ cond_tokens = self.get_condition_tokens(
314
+ aggregated_tokens_list, cond_view_idxs
315
+ )
316
+
317
+ input_tokens = []
318
+ for k, layer_idx in enumerate(self.intermediate_layer_idx):
319
+ layer_tokens = aggregated_tokens_list[layer_idx].clone()
320
+ input_tokens.append(layer_tokens)
321
+
322
+ _, _, P, C = input_tokens[0].shape
323
+
324
+ pos = None
325
+ if self.rope is not None:
326
+ pos = self.position_getter(
327
+ B * S, H // self.patch_size, W // self.patch_size, device=images.device
328
+ )
329
+ if patch_start_idx > 0:
330
+ # do not use position embedding for special tokens (camera and register tokens)
331
+ # so set pos to 0 for the special tokens
332
+ pos = pos + 1
333
+ pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype)
334
+ pos = torch.cat([pos_special, pos], dim=1)
335
+
336
+ frame_idx = 0
337
+ global_idx = 0
338
+ depth = len(self.frame_blocks[0])
339
+ N = len(input_tokens)
340
+ # stack all intermediate layer tokens along batch dimension
341
+ # they are all processed by the same decoder
342
+ s_tokens = torch.cat(input_tokens)
343
+ s_cond_tokens = torch.cat([cond_tokens] * N, dim=0)
344
+ s_pos = torch.cat([pos] * N, dim=0)
345
+
346
+ # perform time conditioned attention
347
+ for _ in range(depth):
348
+ for attn_type in self.aa_order:
349
+ token_idx = 0
350
+
351
+ if attn_type == "frame":
352
+ s_tokens, frame_idx, _ = self._process_frame_attention(
353
+ s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx
354
+ )
355
+ elif attn_type == "global":
356
+ s_tokens, global_idx, _ = self._process_global_attention(
357
+ s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx
358
+ )
359
+ else:
360
+ raise ValueError(f"Unknown attention type: {attn_type}")
361
+ processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)]
362
+
363
+ return processed
364
+
365
+ def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0):
366
+ """
367
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
368
+ """
369
+ # If needed, reshape tokens or positions:
370
+ if tokens.shape != (B * S, P, C):
371
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
372
+
373
+ if pos is not None and pos.shape != (B * S, P, 2):
374
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
375
+
376
+ intermediates = []
377
+ # by default, self.aa_block_size=1, which processes one block at a time
378
+ for _ in range(self.aa_block_size):
379
+ if self.training:
380
+ tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant)
381
+ else:
382
+ if self.old_decoder:
383
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens)
384
+ else:
385
+ tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens)
386
+
387
+ frame_idx += 1
388
+ intermediates.append(tokens.view(B, S, P, C))
389
+
390
+ return tokens, frame_idx, intermediates
391
+
392
+ def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0):
393
+ """
394
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
395
+ """
396
+ if tokens.shape != (B, S * P, C):
397
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
398
+
399
+ if pos is not None and pos.shape != (B, S * P, 2):
400
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
401
+
402
+ intermediates = []
403
+
404
+ # by default, self.aa_block_size=1, which processes one block at a time
405
+ for _ in range(self.aa_block_size):
406
+ if self.training:
407
+ tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant)
408
+ else:
409
+ if self.old_decoder:
410
+ tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
411
+ else:
412
+ tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
413
+ global_idx += 1
414
+ intermediates.append(tokens.view(B, S, P, C))
415
+
416
+ return tokens, global_idx, intermediates
dpm/model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from vggt.heads.camera_head import CameraHead
5
+ from vggt.heads.dpt_head import DPTHead
6
+
7
+ from .aggregator import Aggregator
8
+ from .decoder import Decoder
9
+
10
+
11
+ def freeze_all_params(modules):
12
+ for module in modules:
13
+ try:
14
+ for n, param in module.named_parameters():
15
+ param.requires_grad = False
16
+ except AttributeError:
17
+ # module is directly a parameter
18
+ module.requires_grad = False
19
+
20
+
21
+ class VDPM(nn.Module):
22
+ def __init__(self, cfg, img_size=518, patch_size=14, embed_dim=1024):
23
+ super().__init__()
24
+ self.cfg = cfg
25
+
26
+ self.aggregator = Aggregator(
27
+ img_size=img_size,
28
+ patch_size=patch_size,
29
+ embed_dim=embed_dim,
30
+ )
31
+ self.decoder = Decoder(
32
+ cfg,
33
+ dim_in=2*embed_dim,
34
+ embed_dim=embed_dim,
35
+ depth=cfg.model.decoder_depth
36
+ )
37
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
38
+
39
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
40
+ self.set_freeze()
41
+
42
+ def set_freeze(self):
43
+ to_be_frozen = [self.aggregator.patch_embed]
44
+ freeze_all_params(to_be_frozen)
45
+
46
+ def forward(
47
+ self,
48
+ views, autocast_dpt=None
49
+ ):
50
+ images = torch.stack([view["img"] for view in views], dim=1)
51
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
52
+
53
+ res_dynamic = dict()
54
+
55
+ if self.decoder is not None:
56
+ cond_view_idxs = torch.stack([view["view_idxs"][:, 1] for view in views], dim=1)
57
+ decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
58
+
59
+ if autocast_dpt is None:
60
+ autocast_dpt = torch.amp.autocast("cuda", enabled=False)
61
+
62
+ with autocast_dpt:
63
+ pts3d, pts3d_conf = self.point_head(
64
+ aggregated_tokens_list, images, patch_start_idx
65
+ )
66
+
67
+ padded_decoded_tokens = [None] * len(aggregated_tokens_list)
68
+ for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
69
+ padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
70
+ pts3d_dyn, pts3d_dyn_conf = self.point_head(
71
+ padded_decoded_tokens, images, patch_start_idx
72
+ )
73
+
74
+ res_dynamic |= {
75
+ "pts3d": pts3d_dyn,
76
+ "conf": pts3d_dyn_conf
77
+ }
78
+
79
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
80
+ res_dynamic |= {"pose_enc_list": pose_enc_list}
81
+
82
+ res_static = dict(
83
+ pts3d=pts3d,
84
+ conf=pts3d_conf
85
+ )
86
+ return res_static, res_dynamic
87
+
88
+ def inference(
89
+ self,
90
+ views,
91
+ images=None,
92
+ num_timesteps=None
93
+ ):
94
+ autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
95
+
96
+ if images is None:
97
+ images = torch.stack([view["img"] for view in views], dim=1)
98
+
99
+ with autocast_amp:
100
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
101
+ S = images.shape[1]
102
+
103
+ # Determine number of timesteps to query
104
+ if num_timesteps is None:
105
+ # Default to S if not specified (legacy behavior)
106
+ # But if views has indices, try to infer max time
107
+ if views is not None and "view_idxs" in views[0]:
108
+ try:
109
+ all_idxs = torch.cat([v["view_idxs"][:, 1] for v in views])
110
+ num_timesteps = int(all_idxs.max().item()) + 1
111
+ except:
112
+ num_timesteps = S
113
+ else:
114
+ num_timesteps = S
115
+
116
+ predictions = dict()
117
+ pointmaps = []
118
+ ones = torch.ones(1, S, dtype=torch.int64)
119
+ for time_ in range(num_timesteps):
120
+ cond_view_idxs = ones * time_
121
+
122
+ with autocast_amp:
123
+ decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
124
+ padded_decoded_tokens = [None] * len(aggregated_tokens_list)
125
+ for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
126
+ padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
127
+
128
+ # ... existing code ...
129
+
130
+ pts3d, pts3d_conf = self.point_head(
131
+ padded_decoded_tokens, images, patch_start_idx
132
+ )
133
+
134
+ pointmaps.append(dict(
135
+ pts3d=pts3d,
136
+ conf=pts3d_conf
137
+ ))
138
+
139
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
140
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
141
+ predictions["pose_enc_list"] = pose_enc_list
142
+ predictions["pointmaps"] = pointmaps
143
+ return predictions
144
+
145
+ def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
146
+ # don't load these VGGT heads as not needed
147
+ exclude = ["depth_head", "track_head"]
148
+ ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
149
+ return super().load_state_dict(ckpt, **kw)
examples/videos/camel.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db92c240efbd1b97a466565988a9a06687fd422086656dc0a29e12c5b99b9bb
3
+ size 1301172
examples/videos/car.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd74efdb4d4d59fc17356fefa5dadd4c5b787641c98ce3172ecd8e5a180e76a6
3
+ size 1015132
examples/videos/figure1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae285726e5d247e904bb1ea7887ee96733c0beea913b421abba39150a3299cd5
3
+ size 465850
examples/videos/figure2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b2b030dd564cffbb9b2795e7fcdf97fa50e3a518df5b71dfb3dfb36f431dfa4
3
+ size 516209
examples/videos/figure3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a4144a53f14bd2dc671376d26ecbb42b06c9b8810e1700f21a16d3e11dfbf5c
3
+ size 559096
examples/videos/goldfish.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28912e59d0d9e6b20d26973efee4806e89e115c7f1e63aec7206384ac3d0bf78
3
+ size 668862
examples/videos/horse.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8227c7d901a936aeab6a2b41f104dd17e5544315d4cde7dac37f5787319947e7
3
+ size 1223145
examples/videos/paragliding.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acab90bf2ac105ac0a263559722503e7e7a09aed22164d83cb8b3a5ca1d1504d
3
+ size 814941
examples/videos/pstudio.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12a598cce787ce9b39c2e4abd33c28e1223b59823258287f3a6f5ffb8abe3b47
3
+ size 167366
examples/videos/stroller.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7a88fc629c994bb0067535e181226519ad0750df0d1decbddd01a3e7c5d3c92
3
+ size 1137267
examples/videos/swing.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11d0cb318514b326ac2ed94d0765e29181d58401a84d64aa57ddd4cb1e865dcc
3
+ size 812797
examples/videos/tennis.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50bee7fd257e38c66765a82412734be5abacbe30aa9e9ba04f5387b7865a380e
3
+ size 467310
examples/videos/tesla.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c74333ee25fdb87d26a5a8db080d64d51edc418bab56f3aae98792ad4dee2704
3
+ size 588952
gradio_demo.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import glob
3
+ import os
4
+ import shutil
5
+ import sys
6
+ import time
7
+ from datetime import datetime
8
+
9
+ import cv2
10
+ import gradio as gr
11
+ import matplotlib
12
+ import numpy as np
13
+ import plotly.graph_objects as go
14
+ import torch
15
+ from hydra import compose, initialize
16
+ from hydra.core.global_hydra import GlobalHydra
17
+
18
+ from dpm.model import VDPM
19
+ from vggt.utils.load_fn import load_and_preprocess_images
20
+
21
+
22
+ # ============================================================================
23
+ # MEMORY OPTIMIZATION SETTINGS FOR 8GB GPUs (RTX 3070 Ti, 3060 Ti, etc.)
24
+ # ============================================================================
25
+ # Model size: 1.66B parameters (~3GB FP16 weights, ~12GB total with activations)
26
+ #
27
+ # Memory reduction options (choose one):
28
+ # USE_HALF_PRECISION = True: FP16 model -> ~1.5GB weights, ~6-7GB total (RECOMMENDED FOR GPU)
29
+ # USE_QUANTIZATION = True: INT8 quantization -> CPU ONLY, not supported on CUDA
30
+ # Both False: FP16/BF16 inference only -> ~3GB weights, ~8-10GB total (may OOM)
31
+ #
32
+ # MAX_FRAMES: Limit input frames (5 recommended for 8GB GPUs)
33
+ # ============================================================================
34
+
35
+ MAX_POINTS_PER_FRAME = 50_000
36
+ TRAIL_LENGTH = 20
37
+ MAX_TRACKS = 150
38
+ STATIC_THRESHOLD = 0.025
39
+ VIDEO_SAMPLE_HZ = 1.0
40
+
41
+ # Dynamic Configuration based on Helper/Hardware
42
+ USE_QUANTIZATION = False
43
+ USE_HALF_PRECISION = True
44
+ MAX_FRAMES = 5 # Default for 8GB
45
+
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ if device == "cuda":
49
+ # Enable TF32
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+
53
+ # Check VRAM to auto-scale MAX_FRAMES
54
+ vram_bytes = torch.cuda.get_device_properties(0).total_memory
55
+ vram_gb = vram_bytes / (1024**3)
56
+
57
+ print(f"\u2713 GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
58
+
59
+ if vram_gb > 22: # A10G (24GB), A100 (40/80GB), RTX 3090/4090 (24GB)
60
+ MAX_FRAMES = 80
61
+ print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
62
+ elif vram_gb > 14: # T4 (16GB), 4080 (16GB)
63
+ MAX_FRAMES = 16
64
+ print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
65
+ else:
66
+ MAX_FRAMES = 5
67
+ print(f" -> Low VRAM (<14GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
68
+ print(f"\u2713 TF32 enabled for faster matrix operations")
69
+
70
+
71
+
72
+ def load_cfg_from_cli() -> "omegaconf.DictConfig":
73
+ if GlobalHydra.instance().is_initialized():
74
+ GlobalHydra.instance().clear()
75
+ overrides = sys.argv[1:]
76
+ with initialize(config_path="configs"):
77
+ return compose(config_name="visualise", overrides=overrides)
78
+
79
+
80
+ def load_model(cfg) -> VDPM:
81
+ model = VDPM(cfg).to(device)
82
+
83
+ _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
84
+ sd = torch.hub.load_state_dict_from_url(
85
+ _URL,
86
+ file_name="vdpm_model.pt",
87
+ progress=True
88
+ )
89
+ print(model.load_state_dict(sd, strict=True))
90
+
91
+ model.eval()
92
+
93
+ # Option 1: Use FP16 for all model weights (simple, ~2x memory reduction)
94
+ if USE_HALF_PRECISION and not USE_QUANTIZATION:
95
+ print("Converting model to FP16 precision...")
96
+ model = model.half()
97
+ print("✓ Model converted to FP16: ~2x memory reduction (3GB -> 1.5GB)")
98
+
99
+ # Option 2: Apply INT8 dynamic quantization (more aggressive, ~3-4x reduction)
100
+ if USE_QUANTIZATION:
101
+ try:
102
+ print("Applying INT8 dynamic quantization to reduce memory usage...")
103
+ # Move to CPU for quantization, then back to GPU
104
+ model = model.cpu()
105
+ model = torch.quantization.quantize_dynamic(
106
+ model,
107
+ {torch.nn.Linear, torch.nn.Conv2d}, # Quantize these layer types
108
+ dtype=torch.qint8
109
+ )
110
+ model = model.to(device)
111
+ print("✓ Model quantized: ~3x memory reduction (3GB -> 1GB)")
112
+ except Exception as e:
113
+ print(f"⚠️ Quantization failed: {e}")
114
+ print("Continuing with FP16/BF16 precision...")
115
+ model = model.to(device)
116
+
117
+ # Enable torch.compile for faster inference (PyTorch 2.0+)
118
+ # Note: Disable compile if using quantization as they may conflict
119
+ if not USE_QUANTIZATION:
120
+ try:
121
+ print("Compiling model with torch.compile for faster inference...")
122
+ model = torch.compile(model, mode="reduce-overhead")
123
+ print("✓ Model compilation successful")
124
+ except Exception as e:
125
+ print(f"Warning: torch.compile not available or failed: {e}")
126
+
127
+ return model
128
+
129
+
130
+ def require_cuda():
131
+ if device != "cuda":
132
+ raise ValueError("CUDA is not available. Check your environment.")
133
+
134
+
135
+ def gradio_file_path(file_obj):
136
+ if file_obj is None:
137
+ return None
138
+ if isinstance(file_obj, dict) and "name" in file_obj:
139
+ return file_obj["name"]
140
+ return file_obj
141
+
142
+
143
+ def ensure_nhwc_images(images: np.ndarray) -> np.ndarray:
144
+ if images.ndim == 4 and images.shape[1] == 3:
145
+ return np.transpose(images, (0, 2, 3, 1))
146
+ return images
147
+
148
+
149
+ def compute_scene_bounds(world_points: np.ndarray):
150
+ all_pts = world_points.reshape(-1, 3)
151
+ raw_min = all_pts.min(axis=0)
152
+ raw_max = all_pts.max(axis=0)
153
+
154
+ center = 0.5 * (raw_min + raw_max)
155
+ half_extent = 0.5 * (raw_max - raw_min) * 1.05
156
+
157
+ if np.all(half_extent < 1e-6):
158
+ half_extent[:] = 1.0
159
+ else:
160
+ half_extent[half_extent < 1e-6] = half_extent.max()
161
+
162
+ global_min = center - half_extent
163
+ global_max = center + half_extent
164
+
165
+ max_half = half_extent.max()
166
+ aspectratio = {
167
+ "x": float(half_extent[0] / max_half),
168
+ "y": float(half_extent[1] / max_half),
169
+ "z": float(half_extent[2] / max_half),
170
+ }
171
+ return global_min, global_max, aspectratio
172
+
173
+
174
+ def stride_downsample(pts: np.ndarray, cols: np.ndarray, max_points: int):
175
+ n = pts.shape[0]
176
+ if n <= max_points:
177
+ return pts, cols
178
+ step = int(np.ceil(n / max_points))
179
+ idx = np.arange(0, n, step)[:max_points]
180
+ return pts[idx], cols[idx]
181
+
182
+
183
+ # ============================================================
184
+ # NEW: Single shared mask function (used by points + tracks)
185
+ # ============================================================
186
+ def compute_point_mask(
187
+ conf_score: np.ndarray | None,
188
+ cols: np.ndarray,
189
+ conf_thres: float,
190
+ mask_black_bg: bool,
191
+ mask_white_bg: bool,
192
+ ) -> np.ndarray:
193
+ """
194
+ conf_score: (N,) or None
195
+ cols: (N,3) uint8
196
+ Returns: (N,) boolean mask
197
+ """
198
+ mask = np.ones(cols.shape[0], dtype=bool)
199
+
200
+ # confidence percentile threshold (same semantics as before)
201
+ if conf_score is not None and conf_thres > 0:
202
+ thresh = np.percentile(conf_score, conf_thres)
203
+ mask &= (conf_score >= thresh) & (conf_score > 1e-5)
204
+
205
+ # background masks (same as before)
206
+ if mask_black_bg:
207
+ mask &= (cols.sum(axis=1) >= 16)
208
+ if mask_white_bg:
209
+ mask &= ~((cols[:, 0] > 240) & (cols[:, 1] > 240) & (cols[:, 2] > 240))
210
+
211
+ return mask
212
+
213
+
214
+ def sample_frame_points(
215
+ world_points: np.ndarray,
216
+ images_nhwc: np.ndarray,
217
+ conf: np.ndarray | None,
218
+ idx: int,
219
+ conf_thres: float,
220
+ mask_black_bg: bool,
221
+ mask_white_bg: bool,
222
+ max_points: int,
223
+ ):
224
+ i = int(np.clip(idx, 0, world_points.shape[0] - 1))
225
+ pts = world_points[i].reshape(-1, 3)
226
+ cols = (images_nhwc[i].reshape(-1, 3) * 255).astype(np.uint8)
227
+
228
+ conf_score = conf[i].reshape(-1) if (conf is not None) else None
229
+
230
+ mask = compute_point_mask(
231
+ conf_score=conf_score,
232
+ cols=cols,
233
+ conf_thres=conf_thres,
234
+ mask_black_bg=mask_black_bg,
235
+ mask_white_bg=mask_white_bg,
236
+ )
237
+
238
+ pts = pts[mask]
239
+ cols = cols[mask]
240
+
241
+ pts, cols = stride_downsample(pts, cols, max_points)
242
+
243
+ if pts.size == 0:
244
+ pts = np.array([[0.0, 0.0, 0.0]])
245
+ cols = np.array([[255, 255, 255]], dtype=np.uint8)
246
+
247
+ colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols]
248
+ return pts, colors_str
249
+
250
+
251
+ # ============================================================
252
+ # UPDATED: prepare_tracks now applies the SAME masks as points
253
+ # ============================================================
254
+ def prepare_tracks(
255
+ world_points: np.ndarray,
256
+ images_nhwc: np.ndarray,
257
+ conf: np.ndarray | None,
258
+ conf_thres: float,
259
+ mask_black_bg: bool,
260
+ mask_white_bg: bool,
261
+ ):
262
+ S, H, W, _ = world_points.shape
263
+ N = H * W
264
+ if S < 2 or N == 0:
265
+ return None, None, None
266
+
267
+ tracks_xyz = world_points.reshape(S, N, 3)
268
+
269
+ disp = np.linalg.norm(tracks_xyz - tracks_xyz[0:1], axis=-1)
270
+ dynamic_mask = disp.max(axis=0) > STATIC_THRESHOLD
271
+
272
+ # build a per-point confidence score (across time)
273
+ conf_score = None
274
+ if conf is not None:
275
+ conf_flat = conf.reshape(S, N)
276
+ conf_score = conf_flat.mean(axis=0)
277
+
278
+ # Use reference-frame colors for background masking (stable, consistent)
279
+ ref_cols = (images_nhwc[0].reshape(-1, 3) * 255).astype(np.uint8)
280
+
281
+ point_mask = compute_point_mask(
282
+ conf_score=conf_score,
283
+ cols=ref_cols,
284
+ conf_thres=conf_thres,
285
+ mask_black_bg=mask_black_bg,
286
+ mask_white_bg=mask_white_bg,
287
+ )
288
+
289
+ dynamic_mask &= point_mask
290
+
291
+ idx_tracks = np.nonzero(dynamic_mask)[0]
292
+ if idx_tracks.size == 0:
293
+ return None, None, None
294
+
295
+ if idx_tracks.size > MAX_TRACKS:
296
+ step = int(np.ceil(idx_tracks.size / MAX_TRACKS))
297
+ idx_tracks = idx_tracks[::step][:MAX_TRACKS]
298
+
299
+ tracks_xyz = tracks_xyz[:, idx_tracks, :]
300
+ order = np.argsort(tracks_xyz[0, :, 1])
301
+ tracks_xyz = tracks_xyz[:, order, :]
302
+
303
+ num_tracks = tracks_xyz.shape[1]
304
+ cmap = matplotlib.colormaps.get_cmap("hsv")
305
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1))
306
+
307
+ colorscale = []
308
+ for t in range(num_tracks):
309
+ r, g, b, _ = cmap(norm(t))
310
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
311
+ pos = t / max(num_tracks - 1, 1)
312
+ colorscale.append([pos, f"rgb({r},{g},{b})"])
313
+
314
+ track_ids = np.arange(num_tracks, dtype=float)
315
+ return tracks_xyz, colorscale, track_ids
316
+
317
+
318
+ def track_segments_for_frame(tracks_xyz: np.ndarray | None, track_ids: np.ndarray | None, f: int):
319
+ if tracks_xyz is None or track_ids is None or f <= 0:
320
+ return np.array([]), np.array([]), np.array([]), np.array([])
321
+
322
+ start_t = max(0, f - TRAIL_LENGTH)
323
+ num_tracks = tracks_xyz.shape[1]
324
+
325
+ xs, ys, zs, cs = [], [], [], []
326
+ for j in range(num_tracks):
327
+ seg = tracks_xyz[start_t : f + 1, j, :]
328
+ if seg.shape[0] < 2:
329
+ continue
330
+
331
+ xs.extend([seg[:, 0], np.array([np.nan])])
332
+ ys.extend([seg[:, 1], np.array([np.nan])])
333
+ zs.extend([seg[:, 2], np.array([np.nan])])
334
+ cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float))
335
+
336
+ x = np.concatenate(xs) if xs else np.array([])
337
+ y = np.concatenate(ys) if ys else np.array([])
338
+ z = np.concatenate(zs) if zs else np.array([])
339
+ c = np.concatenate(cs) if cs else np.array([])
340
+
341
+ return x, y, z, c
342
+
343
+
344
+ def build_pointcloud_figure_update(
345
+ data,
346
+ conf_thres: float,
347
+ mask_black_bg: bool,
348
+ mask_white_bg: bool,
349
+ ):
350
+ if data is None:
351
+ return go.Figure()
352
+
353
+ world_points = data["world_points"]
354
+ conf = data.get("world_points_conf")
355
+ images = ensure_nhwc_images(data["images"])
356
+ S = world_points.shape[0]
357
+
358
+ global_min, global_max, aspectratio = compute_scene_bounds(world_points)
359
+
360
+ # UPDATED: pass same masks into prepare_tracks
361
+ tracks_xyz, colorscale, track_ids = prepare_tracks(
362
+ world_points=world_points,
363
+ images_nhwc=images,
364
+ conf=conf,
365
+ conf_thres=conf_thres,
366
+ mask_black_bg=mask_black_bg,
367
+ mask_white_bg=mask_white_bg,
368
+ )
369
+ track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1
370
+
371
+ pts_xyz = [None] * S
372
+ pts_cols = [None] * S
373
+ trk_xyz = [None] * S
374
+ trk_c = [None] * S
375
+
376
+ for i in range(S):
377
+ pts_i, cols_i = sample_frame_points(
378
+ world_points=world_points,
379
+ images_nhwc=images,
380
+ conf=conf,
381
+ idx=i,
382
+ conf_thres=conf_thres,
383
+ mask_black_bg=mask_black_bg,
384
+ mask_white_bg=mask_white_bg,
385
+ max_points=MAX_POINTS_PER_FRAME,
386
+ )
387
+ pts_xyz[i] = pts_i
388
+ pts_cols[i] = cols_i
389
+
390
+ x, y, z, c = track_segments_for_frame(tracks_xyz, track_ids, f=i)
391
+ trk_xyz[i] = (x, y, z)
392
+ trk_c[i] = c
393
+
394
+ p0 = pts_xyz[0]
395
+ c0 = pts_cols[0]
396
+ x0, y0, z0 = trk_xyz[0]
397
+ tc0 = trk_c[0]
398
+
399
+ scene_cfg = dict(
400
+ xaxis=dict(
401
+ visible=False,
402
+ showbackground=False,
403
+ showgrid=False,
404
+ zeroline=False,
405
+ showticklabels=False,
406
+ range=[float(global_min[0]), float(global_max[0])],
407
+ ),
408
+ yaxis=dict(
409
+ visible=False,
410
+ showbackground=False,
411
+ showgrid=False,
412
+ zeroline=False,
413
+ showticklabels=False,
414
+ range=[float(global_min[1]), float(global_max[1])],
415
+ ),
416
+ zaxis=dict(
417
+ visible=False,
418
+ showbackground=False,
419
+ showgrid=False,
420
+ zeroline=False,
421
+ showticklabels=False,
422
+ range=[float(global_min[2]), float(global_max[2])],
423
+ ),
424
+ aspectmode="manual",
425
+ aspectratio=aspectratio,
426
+ dragmode="orbit",
427
+ camera=dict(
428
+ eye=dict(x=0.0, y=0.0, z=-1.0),
429
+ center=dict(x=0.0, y=0.0, z=0.0),
430
+ up=dict(x=0.0, y=-1.0, z=0.0),
431
+ ),
432
+ )
433
+
434
+ fig = go.Figure(
435
+ data=[
436
+ go.Scatter3d(
437
+ x=p0[:, 0],
438
+ y=p0[:, 1],
439
+ z=p0[:, 2],
440
+ mode="markers",
441
+ marker=dict(size=2, color=c0),
442
+ showlegend=False,
443
+ name="points",
444
+ ),
445
+ go.Scatter3d(
446
+ x=x0,
447
+ y=y0,
448
+ z=z0,
449
+ mode="lines",
450
+ line=dict(
451
+ width=2,
452
+ color=tc0 if (tc0 is not None and tc0.size) else None,
453
+ colorscale=colorscale if colorscale is not None else None,
454
+ cmin=0,
455
+ cmax=track_cmax,
456
+ ),
457
+ hoverinfo="skip",
458
+ showlegend=False,
459
+ name="tracks",
460
+ ),
461
+ ]
462
+ )
463
+
464
+ steps = []
465
+ for i in range(S):
466
+ pi = pts_xyz[i]
467
+ ci = pts_cols[i]
468
+ xi, yi, zi = trk_xyz[i]
469
+ ti = trk_c[i]
470
+
471
+ steps.append(
472
+ dict(
473
+ method="update",
474
+ label=str(i),
475
+ args=[
476
+ {
477
+ "x": [pi[:, 0], xi],
478
+ "y": [pi[:, 1], yi],
479
+ "z": [pi[:, 2], zi],
480
+ "marker.color": [ci, None],
481
+ "line.color": [None, ti if (ti is not None and len(ti)) else None],
482
+ },
483
+ {},
484
+ ],
485
+ )
486
+ )
487
+
488
+ sliders = [
489
+ dict(
490
+ active=0,
491
+ currentvalue={"prefix": "Frame: ", "visible": True, "font": {"size": 14}},
492
+ pad={"t": 10},
493
+ len=0.6,
494
+ x=0.2,
495
+ font={"size": 8},
496
+ steps=steps,
497
+ )
498
+ ]
499
+
500
+ fig.update_layout(
501
+ margin=dict(l=0, r=0, t=30, b=0),
502
+ scene=scene_cfg,
503
+ sliders=sliders,
504
+ showlegend=False,
505
+ title="Scrub frames with the slider below",
506
+ uirevision="keep-camera",
507
+ height=700,
508
+ )
509
+ return fig
510
+
511
+
512
+ def run_model(target_dir: str, model: VDPM, frame_id_arg=0) -> dict:
513
+ require_cuda()
514
+
515
+ image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
516
+ if not image_names:
517
+ raise ValueError("No images found. Check your upload.")
518
+
519
+ # Load metadata for Multi-View sync (Load BEFORE slicing to respect view count)
520
+ meta_path = os.path.join(target_dir, "meta.json")
521
+ num_views = 1
522
+ if os.path.exists(meta_path):
523
+ try:
524
+ import json
525
+ with open(meta_path, 'r') as f:
526
+ num_views = json.load(f).get("num_views", 1)
527
+ except:
528
+ pass
529
+
530
+ # Limit frames to prevent OOM on 8GB GPUs
531
+ if len(image_names) > MAX_FRAMES:
532
+ # Round down to nearest multiple of num_views to preserve full scenes
533
+ limit = (MAX_FRAMES // num_views) * num_views
534
+ if limit == 0:
535
+ limit = num_views # At least one full timestep
536
+ print(f"⚠️ Warning: MAX_FRAMES={MAX_FRAMES} is smaller than num_views={num_views}. Processing 1 full timestep anyway (may OOM).")
537
+
538
+ print(f"⚠️ Limiting to {limit} frames ({limit // num_views} timesteps * {num_views} views) to fit in GPU memory")
539
+ image_names = image_names[:limit]
540
+
541
+ images = load_and_preprocess_images(image_names).to(device)
542
+
543
+ if device == "cuda":
544
+ print(f"GPU memory before inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
545
+
546
+ print(f"Running inference on {len(image_names)} images ({num_views} synchronized views)...")
547
+
548
+ # Construct 'views' dictionaries with correct time indices
549
+ views = []
550
+ for i in range(len(image_names)):
551
+ t_idx = i // num_views
552
+ views.append({
553
+ "img": images[i].unsqueeze(0), # (1, C, H, W)
554
+ "view_idxs": torch.tensor([[0, t_idx]], device=device, dtype=torch.long)
555
+ })
556
+
557
+ inference_start = time.time()
558
+
559
+ with torch.no_grad():
560
+ with torch.amp.autocast('cuda'):
561
+ # Pass constructed views so model uses correct time query
562
+ predictions = model.inference(views=views)
563
+
564
+ inference_time = time.time() - inference_start
565
+ print(f"✓ Inference completed in {inference_time:.2f}s ({inference_time/len(image_names):.2f}s per frame)")
566
+
567
+ # Move results to CPU immediately to free GPU memory
568
+ pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
569
+ conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
570
+
571
+ # Clear predictions from GPU to save memory
572
+ del predictions
573
+ if device == "cuda":
574
+ torch.cuda.empty_cache()
575
+ print(f"GPU memory after inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
576
+
577
+ world_points = np.concatenate(pts_list, axis=0)
578
+ world_points_conf = np.concatenate(conf_list, axis=0)
579
+
580
+ try:
581
+ frame_id = int(frame_id_arg)
582
+ except Exception:
583
+ frame_id = 0
584
+
585
+ if frame_id >= world_points.shape[0]:
586
+ frame_id = 0
587
+
588
+ world_points_s = world_points[:, frame_id, ::2, ::2, :]
589
+ single_mask = world_points_conf[frame_id, frame_id, ::2, ::2]
590
+ world_points_conf_s = np.tile(single_mask[np.newaxis, ...], (world_points.shape[0], 1, 1))
591
+
592
+ img_np = images.detach().cpu().numpy()
593
+ img_np = img_np[frame_id : frame_id + 1, :, ::2, ::2]
594
+ img_np = np.repeat(img_np, world_points.shape[0], axis=0)
595
+
596
+ torch.cuda.empty_cache()
597
+ return {
598
+ "world_points": world_points_s,
599
+ "world_points_conf": world_points_conf_s,
600
+ "images": img_np,
601
+ }
602
+
603
+
604
+ def handle_uploads(input_video, input_images):
605
+ start_time = time.time()
606
+ gc.collect()
607
+ torch.cuda.empty_cache()
608
+
609
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
610
+ target_dir = f"input_images_{timestamp}"
611
+ target_dir_images = os.path.join(target_dir, "images")
612
+
613
+ if os.path.exists(target_dir):
614
+ shutil.rmtree(target_dir)
615
+ os.makedirs(target_dir_images, exist_ok=True)
616
+
617
+ image_paths = []
618
+
619
+ if input_images:
620
+ for file_obj in input_images:
621
+ src = gradio_file_path(file_obj)
622
+ if not src:
623
+ continue
624
+ dst = os.path.join(target_dir_images, os.path.basename(src))
625
+ shutil.copy(src, dst)
626
+ image_paths.append(dst)
627
+
628
+ if input_video:
629
+ # Check if input is a list (Gradio 4.x/5.x or file_count="multiple")
630
+ input_video_list = input_video if isinstance(input_video, list) else [input_video]
631
+
632
+ # Determine starting frame number based on existing images
633
+ existing_files = os.listdir(target_dir_images)
634
+ frame_num = len(existing_files)
635
+
636
+ # Modified for Interleaved/Synchronized processing
637
+ # 1. Open all videos
638
+ captures = []
639
+ capture_meta = []
640
+ for idx, vid_obj in enumerate(input_video_list):
641
+ video_path = gradio_file_path(vid_obj)
642
+ print(f"Preparing video {idx+1}/{len(input_video_list)}: {video_path}")
643
+
644
+ vs = cv2.VideoCapture(video_path)
645
+ fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0)
646
+ if fps <= 0: fps = 30.0 # Fallback
647
+
648
+ frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
649
+ captures.append(vs)
650
+ capture_meta.append({"interval": frame_interval, "name": video_path})
651
+
652
+ # 2. Step through them together
653
+ print("Processing videos in interleaved mode...")
654
+ step_count = 0
655
+ active_videos = True
656
+
657
+ while active_videos:
658
+ active_videos = False
659
+ for i, vs in enumerate(captures):
660
+ if not vs.isOpened():
661
+ continue
662
+
663
+ gotit, frame = vs.read()
664
+ if gotit:
665
+ active_videos = True # Keep going as long as at least one video has frames
666
+
667
+ if step_count % capture_meta[i]["interval"] == 0:
668
+ out_path = os.path.join(target_dir_images, f"{frame_num:06}.png")
669
+ cv2.imwrite(out_path, frame)
670
+ image_paths.append(out_path)
671
+ frame_num += 1
672
+ else:
673
+ vs.release()
674
+
675
+ step_count += 1
676
+
677
+
678
+ image_paths.sort()
679
+
680
+ # Save metadata about capture structure (num_views)
681
+ num_views = len(input_video_list) if input_video else 1
682
+ meta_path = os.path.join(target_dir, "meta.json")
683
+ try:
684
+ import json
685
+ with open(meta_path, 'w') as f:
686
+ json.dump({"num_views": num_views}, f)
687
+ except Exception as e:
688
+ print(f"Warning: could not save metadata: {e}")
689
+
690
+ print(f"Files copied to {target_dir_images}; took {time.time() - start_time:.3f} seconds")
691
+ return target_dir, image_paths
692
+
693
+
694
+ def update_gallery_on_upload(input_video, input_images):
695
+ if not input_video and not input_images:
696
+ return None, None, None, None
697
+ target_dir, image_paths = handle_uploads(input_video, input_images)
698
+ return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
699
+
700
+
701
+ def gradio_reconstruct(
702
+ target_dir,
703
+ conf_thres=50.0,
704
+ mask_black_bg=False,
705
+ mask_white_bg=False,
706
+ frame_id_val=0,
707
+ ):
708
+ if not os.path.isdir(target_dir) or target_dir == "None":
709
+ return None, "No valid target directory found. Please upload first.", None
710
+
711
+ gc.collect()
712
+ torch.cuda.empty_cache()
713
+
714
+ target_dir_images = os.path.join(target_dir, "images")
715
+ num_frames = len(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else 0
716
+
717
+ with torch.no_grad():
718
+ predictions = run_model(target_dir, model, frame_id_val)
719
+
720
+ fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg)
721
+
722
+ torch.cuda.empty_cache()
723
+ msg = f"Reconstruction Success ({num_frames} frames processed, showing frame {frame_id_val})."
724
+ return fig, msg, predictions
725
+
726
+
727
+ def update_plot(
728
+ target_dir,
729
+ predictions,
730
+ conf_thres,
731
+ mask_black_bg,
732
+ mask_white_bg,
733
+ is_example,
734
+ ):
735
+ if is_example == "True" or predictions is None:
736
+ return None, "No reconstruction available. Please click the Reconstruct button first."
737
+
738
+ fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg)
739
+ return fig, "Updated visualization with new settings. Use the slider below the plot to scrub frames."
740
+
741
+
742
+ def clear_fields():
743
+ return None
744
+
745
+
746
+ def update_log():
747
+ return "Loading and Reconstructing..."
748
+
749
+
750
+ def example_pipeline(
751
+ input_video_ex,
752
+ num_images_str,
753
+ input_images_ex,
754
+ conf_thres_val,
755
+ mask_black_bg_val,
756
+ mask_white_bg_val,
757
+ is_example_str,
758
+ frame_id_val,
759
+ ):
760
+ target_dir, image_paths = handle_uploads(input_video_ex, input_images_ex)
761
+ fig, log_msg, predictions = gradio_reconstruct(
762
+ target_dir,
763
+ conf_thres_val,
764
+ mask_black_bg_val,
765
+ mask_white_bg_val,
766
+ frame_id_val,
767
+ )
768
+ return fig, log_msg, target_dir, predictions, image_paths
769
+
770
+
771
+ colosseum_video = "examples/videos/Colosseum.mp4"
772
+ camel_video = "examples/videos/camel.mp4"
773
+ tennis_video = "examples/videos/tennis.mp4"
774
+ paragliding_video = "examples/videos/paragliding.mp4"
775
+ stroller_video = "examples/videos/stroller.mp4"
776
+ goldfish_video = "examples/videos/goldfish.mp4"
777
+ horse_video = "examples/videos/horse.mp4"
778
+ swing_video = "examples/videos/swing.mp4"
779
+ car_video = "examples/videos/car.mp4"
780
+ figure1_video = "examples/videos/figure1.mp4"
781
+ figure2_video = "examples/videos/figure2.mp4"
782
+ figure3_video = "examples/videos/figure3.mp4"
783
+ tesla_video = "examples/videos/tesla.mp4"
784
+ pstudio_video = "examples/videos/pstudio.mp4"
785
+
786
+ theme = gr.themes.Default(
787
+ primary_hue=gr.themes.colors.slate,
788
+ secondary_hue=gr.themes.colors.zinc,
789
+ neutral_hue=gr.themes.colors.slate,
790
+ ).set(
791
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
792
+ checkbox_label_text_color_selected="*button_primary_text_color",
793
+ body_background_fill="#FFFFFF",
794
+ )
795
+
796
+ css = """
797
+ .custom-log * {
798
+ font-style: italic;
799
+ font-size: 22px !important;
800
+ background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
801
+ -webkit-background-clip: text;
802
+ background-clip: text;
803
+ font-weight: bold !important;
804
+ color: transparent !important;
805
+ text-align: center !important;
806
+ }
807
+
808
+ .example-log * {
809
+ font-style: italic;
810
+ font-size: 16px !important;
811
+ background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
812
+ -webkit-background-clip: text;
813
+ background-clip: text;
814
+ color: transparent !important;
815
+ }
816
+
817
+ #my_radio .wrap {
818
+ display: flex;
819
+ flex-wrap: nowrap;
820
+ justify-content: center;
821
+ align-items: center;
822
+ }
823
+
824
+ #my_radio .wrap label {
825
+ display: flex;
826
+ width: 50%;
827
+ justify-content: center;
828
+ align-items: center;
829
+ margin: 0;
830
+ padding: 10px 0;
831
+ box-sizing: border-box;
832
+ }
833
+ """
834
+
835
+ cfg = load_cfg_from_cli()
836
+ model = load_model(cfg)
837
+
838
+ with gr.Blocks(theme=theme, css=css) as demo:
839
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
840
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
841
+ frame_id_state = gr.Textbox(label="frame_id", visible=False, value="0")
842
+
843
+ gr.HTML(
844
+ """
845
+ <h1>V-DPM: Video Reconstruction with Dynamic Point Maps</h1>
846
+ <p>
847
+ <a href="https://github.com/eldar/vdpm">🐙 GitHub Repository</a> |
848
+ <a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/">Project Page</a>
849
+ </p>
850
+ <div style="font-size: 16px; line-height: 1.5;">
851
+ <p>Upload a video (or multiple videos for multi-view setup) or a set of images to create a dynamic point map reconstruction of a scene or object.</p>
852
+ </div>
853
+ """
854
+ )
855
+
856
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
857
+ predictions_state = gr.State(value=None)
858
+
859
+ with gr.Row():
860
+ with gr.Column(scale=2):
861
+ # Change Video input to File input to allow multiple videos for multi-view support
862
+ gr.Markdown("### Input")
863
+ input_video = gr.File(
864
+ label="Upload Video(s)",
865
+ file_count="multiple",
866
+ file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"],
867
+ interactive=True
868
+ )
869
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
870
+ image_gallery = gr.Gallery(
871
+ label="Preview",
872
+ columns=4,
873
+ height="300px",
874
+ show_download_button=True,
875
+ object_fit="contain",
876
+ preview=True,
877
+ )
878
+
879
+ with gr.Column(scale=5):
880
+ gr.Markdown("**3D Reconstruction (Point Cloud)**")
881
+ log_output = gr.Markdown(
882
+ "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
883
+ )
884
+ reconstruction_output = gr.Plot(label="3D Point Cloud")
885
+
886
+ with gr.Row():
887
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
888
+ gr.ClearButton(
889
+ [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
890
+ scale=1,
891
+ )
892
+
893
+ with gr.Row():
894
+ conf_thres = gr.Slider(0, 100, value=50, step=1, label="Confidence Threshold (%)")
895
+ with gr.Column():
896
+ mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
897
+ mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
898
+
899
+ examples = [
900
+ [camel_video, "17", None, 15.0, False, False, "True", "8"],
901
+ [horse_video, "18", None, 50.0, False, False, "True", "2"],
902
+ [tennis_video, "11", None, 5.0, False, False, "True", "0"],
903
+ [paragliding_video, "11", None, 5.0, False, False, "True", "0"],
904
+ [stroller_video, "17", None, 10.0, False, False, "True", "8"],
905
+ [goldfish_video, "11", None, 12.0, False, False, "True", "5"],
906
+ [swing_video, "10", None, 40.0, False, False, "True", "4"],
907
+ [car_video, "13", None, 15.0, False, False, "True", "7"],
908
+ [figure1_video, "10", None, 25.0, False, False, "True", "0"],
909
+ [figure2_video, "12", None, 25.0, False, False, "True", "6"],
910
+ [figure3_video, "13", None, 30.0, False, False, "True", "0"],
911
+ [tesla_video, "18", None, 20.0, False, True, "True", "0"],
912
+ [pstudio_video, "12", None, 0.0, False, False, "True", "6"],
913
+ ]
914
+
915
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
916
+ gr.Examples(
917
+ examples=examples,
918
+ inputs=[
919
+ input_video,
920
+ num_images,
921
+ input_images,
922
+ conf_thres,
923
+ mask_black_bg,
924
+ mask_white_bg,
925
+ is_example,
926
+ frame_id_state,
927
+ ],
928
+ outputs=[
929
+ reconstruction_output,
930
+ log_output,
931
+ target_dir_output,
932
+ predictions_state,
933
+ image_gallery,
934
+ ],
935
+ fn=example_pipeline,
936
+ cache_examples=False,
937
+ examples_per_page=50,
938
+ )
939
+
940
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
941
+ fn=update_log, inputs=[], outputs=[log_output]
942
+ ).then(
943
+ fn=gradio_reconstruct,
944
+ inputs=[
945
+ target_dir_output,
946
+ conf_thres,
947
+ mask_black_bg,
948
+ mask_white_bg,
949
+ frame_id_state,
950
+ ],
951
+ outputs=[reconstruction_output, log_output, predictions_state],
952
+ ).then(
953
+ fn=lambda: "False", inputs=[], outputs=[is_example]
954
+ )
955
+
956
+ for ctrl in (conf_thres, mask_black_bg, mask_white_bg):
957
+ ctrl.change(
958
+ fn=update_plot,
959
+ inputs=[
960
+ target_dir_output,
961
+ predictions_state,
962
+ conf_thres,
963
+ mask_black_bg,
964
+ mask_white_bg,
965
+ is_example,
966
+ ],
967
+ outputs=[reconstruction_output, log_output],
968
+ )
969
+
970
+ input_video.change(
971
+ fn=update_gallery_on_upload,
972
+ inputs=[input_video, input_images],
973
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
974
+ )
975
+ input_images.change(
976
+ fn=update_gallery_on_upload,
977
+ inputs=[input_video, input_images],
978
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
979
+ )
980
+
981
+ demo.queue(max_size=20).launch(show_error=True, share=True)
input_images_20260127_052216_587020/images/000000.png ADDED

Git LFS Details

  • SHA256: c4adf23ac78af3407f4eee3009cce3c7ee539692e5712202c60be5d3e593f54b
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
input_images_20260127_052216_587020/images/000001.png ADDED

Git LFS Details

  • SHA256: 53d4376634c154d147e70e8df4c08e80f1bd63a0cc0fd6123f4fb6e67455ba92
  • Pointer size: 131 Bytes
  • Size of remote file: 514 kB
input_images_20260127_052216_587020/images/000002.png ADDED

Git LFS Details

  • SHA256: a40105cdc90f6bbae632107b9e8c409676046a57842f8785d7215afd69839c52
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
input_images_20260127_052216_587020/images/000003.png ADDED

Git LFS Details

  • SHA256: 226c74f09bdcb858f3557751837bf984bbc22b3be33734145ada4bba8e341335
  • Pointer size: 131 Bytes
  • Size of remote file: 504 kB
input_images_20260127_052216_587020/images/000004.png ADDED

Git LFS Details

  • SHA256: 196b625f9c9a56d8cf94be3caee89b1b77226b1b865c605751b045a0d6ba6a6b
  • Pointer size: 131 Bytes
  • Size of remote file: 481 kB
input_images_20260127_052216_587020/images/000005.png ADDED

Git LFS Details

  • SHA256: 375d91fc42919a5fe59f1cb9596f6c12b38b5d39609da5b6353cc080fbecafa2
  • Pointer size: 131 Bytes
  • Size of remote file: 488 kB
input_images_20260127_052216_587020/images/000006.png ADDED

Git LFS Details

  • SHA256: 9a1da18b6d826d0ccca9d8d279af4e27b6f705daf8f4f0723b9b14c26a1161d4
  • Pointer size: 131 Bytes
  • Size of remote file: 461 kB
input_images_20260127_052216_587020/images/000007.png ADDED

Git LFS Details

  • SHA256: c7c02464b8d79769bc323a136cc2a4d0b1ea302131d266e56e219e8815e23d46
  • Pointer size: 131 Bytes
  • Size of remote file: 466 kB
input_images_20260127_052216_587020/images/000008.png ADDED

Git LFS Details

  • SHA256: 601e3dea45fe59a43bef6f5d03ea698eda549551e58482fec56cfad14651768e
  • Pointer size: 131 Bytes
  • Size of remote file: 456 kB
input_images_20260127_052216_587020/images/000009.png ADDED

Git LFS Details

  • SHA256: d711883b782482526df3a7bcbfebb0878e79f2cfedf80e09ea32116c7afd4da4
  • Pointer size: 131 Bytes
  • Size of remote file: 457 kB
input_images_20260127_052216_587020/images/000010.png ADDED

Git LFS Details

  • SHA256: aaf9b1938f7d1ec88ae235e58580e0ad655581a9945bb8002bfcfd072dccc824
  • Pointer size: 131 Bytes
  • Size of remote file: 442 kB
input_images_20260127_052216_587020/images/000011.png ADDED

Git LFS Details

  • SHA256: 3e3a8f7c3916b41ea53704cc3596530315ec8bd205fe71fa19243c1900698a09
  • Pointer size: 131 Bytes
  • Size of remote file: 464 kB
input_images_20260127_052216_587020/images/000012.png ADDED

Git LFS Details

  • SHA256: 9fbdda59560e050f53e7022483aec5baade097e52d7014ee9bd23d5752320c38
  • Pointer size: 131 Bytes
  • Size of remote file: 464 kB
input_images_20260127_052216_587020/images/000013.png ADDED

Git LFS Details

  • SHA256: 4ce4adf3259a97107b04be6c2d9116f6573f043c400b8af84351f7b79abe6817
  • Pointer size: 131 Bytes
  • Size of remote file: 454 kB
input_images_20260127_052216_587020/images/000014.png ADDED

Git LFS Details

  • SHA256: 3cfd8295700f348323105258445bfbf689b10f89b93cef0bedc35eae97aa293f
  • Pointer size: 131 Bytes
  • Size of remote file: 452 kB
input_images_20260127_052216_587020/images/000015.png ADDED

Git LFS Details

  • SHA256: e3e7acc0f280b73e2854cf3b19b7d3da2ad4cd2102d4f5444c4f7e1b0ecc16f5
  • Pointer size: 131 Bytes
  • Size of remote file: 453 kB
input_images_20260127_052216_587020/images/000016.png ADDED

Git LFS Details

  • SHA256: d5e8002fb935f4b08007036872b4af42ef34ce08c1ffbdb7814607e224a6d6e6
  • Pointer size: 131 Bytes
  • Size of remote file: 456 kB
input_images_20260127_052216_587020/images/000017.png ADDED

Git LFS Details

  • SHA256: 7f38f2888354644b184cbed84137fba11426210899b3d356426eedcf006a45e5
  • Pointer size: 131 Bytes
  • Size of remote file: 404 kB
input_images_20260127_052439_748027/images/000000.png ADDED

Git LFS Details

  • SHA256: 32f33d84408f085ac1431180baa4284375652a581d6fd884668e344a887e504d
  • Pointer size: 131 Bytes
  • Size of remote file: 681 kB
input_images_20260127_052439_748027/images/000001.png ADDED

Git LFS Details

  • SHA256: bde7ae7b942849b420c30bea4456d7f25edb1911645a6893391f77478113dc2c
  • Pointer size: 131 Bytes
  • Size of remote file: 730 kB
input_images_20260127_052439_748027/images/000002.png ADDED

Git LFS Details

  • SHA256: fd5849f66dde1716f8cd753326407d4c31307f648005df532a63683f0fed95c5
  • Pointer size: 131 Bytes
  • Size of remote file: 779 kB
input_images_20260127_052439_748027/images/000003.png ADDED

Git LFS Details

  • SHA256: 6214a1df5809781ad2876ea4677e9c1a05703ceb6014165afc43005642591bb0
  • Pointer size: 131 Bytes
  • Size of remote file: 770 kB