Spaces:
Paused
Paused
Initial ZeroGPU Gradio Space for LingBot-Map
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- LICENSE.txt +201 -0
- README.md +45 -6
- app.py +630 -0
- assets/teaser.png +3 -0
- lingbot_map/__init__.py +0 -0
- lingbot_map/aggregator/__init__.py +2 -0
- lingbot_map/aggregator/base.py +608 -0
- lingbot_map/aggregator/stream.py +531 -0
- lingbot_map/heads/__init__.py +0 -0
- lingbot_map/heads/camera_head.py +458 -0
- lingbot_map/heads/dpt_head.py +679 -0
- lingbot_map/heads/head_act.py +125 -0
- lingbot_map/heads/utils.py +109 -0
- lingbot_map/layers/__init__.py +5 -0
- lingbot_map/layers/attention.py +766 -0
- lingbot_map/layers/block.py +514 -0
- lingbot_map/layers/drop_path.py +34 -0
- lingbot_map/layers/flashinfer_cache.py +640 -0
- lingbot_map/layers/layer_scale.py +22 -0
- lingbot_map/layers/mlp.py +40 -0
- lingbot_map/layers/patch_embed.py +85 -0
- lingbot_map/layers/rope.py +474 -0
- lingbot_map/layers/swiglu_ffn.py +67 -0
- lingbot_map/layers/vision_transformer.py +411 -0
- lingbot_map/models/__init__.py +0 -0
- lingbot_map/models/gct_base.py +359 -0
- lingbot_map/models/gct_stream.py +448 -0
- lingbot_map/models/gct_stream_window.py +1206 -0
- lingbot_map/utils/__init__.py +0 -0
- lingbot_map/utils/geometry.py +774 -0
- lingbot_map/utils/load_fn.py +243 -0
- lingbot_map/utils/pose_enc.py +331 -0
- lingbot_map/utils/rotation.py +132 -0
- lingbot_map/vis/__init__.py +59 -0
- lingbot_map/vis/glb_export.py +509 -0
- lingbot_map/vis/point_cloud_viewer.py +1437 -0
- lingbot_map/vis/sky_segmentation.py +457 -0
- lingbot_map/vis/utils.py +206 -0
- lingbot_map/vis/viser_wrapper.py +248 -0
- requirements.txt +14 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.DS_Store
|
| 3 |
+
.gradio/
|
| 4 |
+
app_output/
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,12 +1,51 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.12.0
|
| 8 |
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: LingBot-Map ZeroGPU Demo
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: green
|
|
|
|
| 5 |
sdk: gradio
|
|
|
|
| 6 |
app_file: app.py
|
| 7 |
+
python_version: 3.10.13
|
| 8 |
pinned: false
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
startup_duration_timeout: 1h
|
| 11 |
+
models:
|
| 12 |
+
- robbyant/lingbot-map
|
| 13 |
+
preload_from_hub:
|
| 14 |
+
- robbyant/lingbot-map lingbot-map.pt,lingbot-map-long.pt
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# LingBot-Map ZeroGPU Demo
|
| 18 |
+
|
| 19 |
+
Gradio Space wrapper around `Robbyant/lingbot-map` tuned for Hugging Face ZeroGPU:
|
| 20 |
+
|
| 21 |
+
- uses the upstream `lingbot_map` package directly
|
| 22 |
+
- downloads checkpoints from `robbyant/lingbot-map`
|
| 23 |
+
- runs the SDPA fallback path instead of FlashInfer
|
| 24 |
+
- caps inputs to short clips so the app fits a shared ZeroGPU workflow
|
| 25 |
+
- exports a browser-friendly `.glb` scene plus a zipped artifact bundle
|
| 26 |
+
|
| 27 |
+
## Recommended Space Settings
|
| 28 |
+
|
| 29 |
+
1. Create a new **Gradio** Space.
|
| 30 |
+
2. In **Settings -> Hardware**, switch the Space to **ZeroGPU**.
|
| 31 |
+
3. Keep the repo public or protected as needed.
|
| 32 |
+
|
| 33 |
+
## Current Limits
|
| 34 |
+
|
| 35 |
+
- short demos only
|
| 36 |
+
- default frame cap: 24 frames
|
| 37 |
+
- model preview is exported as GLB, not the local `viser` server
|
| 38 |
+
- the app is optimized for `lingbot-map.pt` and `lingbot-map-long.pt`
|
| 39 |
+
|
| 40 |
+
## Local Sanity Check
|
| 41 |
+
|
| 42 |
+
If you want to import the app locally without downloading the checkpoint at startup:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
LINGBOT_SPACE_SKIP_MODEL_LOAD=1 python app.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Upstream
|
| 49 |
+
|
| 50 |
+
- GitHub: https://github.com/Robbyant/lingbot-map
|
| 51 |
+
- Model: https://huggingface.co/robbyant/lingbot-map
|
app.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import gc
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
import zipfile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Iterable
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
from PIL import Image, ImageDraw
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import spaces
|
| 22 |
+
except ImportError:
|
| 23 |
+
class _SpacesShim:
|
| 24 |
+
@staticmethod
|
| 25 |
+
def GPU(*decorator_args, **decorator_kwargs):
|
| 26 |
+
if decorator_args and callable(decorator_args[0]) and len(decorator_args) == 1 and not decorator_kwargs:
|
| 27 |
+
return decorator_args[0]
|
| 28 |
+
|
| 29 |
+
def _wrap(func):
|
| 30 |
+
return func
|
| 31 |
+
|
| 32 |
+
return _wrap
|
| 33 |
+
|
| 34 |
+
spaces = _SpacesShim()
|
| 35 |
+
|
| 36 |
+
from lingbot_map.models.gct_stream import GCTStream
|
| 37 |
+
from lingbot_map.utils.geometry import closed_form_inverse_se3_general
|
| 38 |
+
from lingbot_map.utils.load_fn import load_and_preprocess_images
|
| 39 |
+
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
| 40 |
+
from lingbot_map.vis.glb_export import predictions_to_glb
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
ROOT = Path(__file__).resolve().parent
|
| 44 |
+
OUTPUT_ROOT = ROOT / "app_output"
|
| 45 |
+
OUTPUT_ROOT.mkdir(exist_ok=True)
|
| 46 |
+
|
| 47 |
+
HF_MODEL_REPO = "robbyant/lingbot-map"
|
| 48 |
+
MODEL_FILENAMES = {
|
| 49 |
+
"balanced": "lingbot-map.pt",
|
| 50 |
+
"long": "lingbot-map-long.pt",
|
| 51 |
+
"stage1": "lingbot-map-stage1.pt",
|
| 52 |
+
}
|
| 53 |
+
MODEL_LABELS = {
|
| 54 |
+
"balanced": "Balanced",
|
| 55 |
+
"long": "Long",
|
| 56 |
+
"stage1": "Stage-1",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
IMAGE_SIZE = 518
|
| 60 |
+
PATCH_SIZE = 14
|
| 61 |
+
DEFAULT_FPS = 8
|
| 62 |
+
DEFAULT_MAX_FRAMES = 24
|
| 63 |
+
MAX_FRAMES_HARD_LIMIT = 24
|
| 64 |
+
DEFAULT_SCALE_FRAMES = 4
|
| 65 |
+
DEFAULT_KEYFRAME_INTERVAL = 2
|
| 66 |
+
DEFAULT_CONF_PERCENTILE = 50.0
|
| 67 |
+
DEFAULT_CAMERA_ITERATIONS = 1
|
| 68 |
+
IS_SPACE_RUNTIME = bool(os.getenv("SPACE_ID"))
|
| 69 |
+
SKIP_EAGER_MODEL_LOAD = os.getenv("LINGBOT_SPACE_SKIP_MODEL_LOAD") == "1"
|
| 70 |
+
|
| 71 |
+
MODEL_CACHE: dict[str, dict[str, Any]] = {}
|
| 72 |
+
MODEL_CACHE_LOCK = threading.Lock()
|
| 73 |
+
STARTUP_NOTES: list[str] = []
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _resolve_path(file_obj: Any) -> str:
|
| 77 |
+
if file_obj is None:
|
| 78 |
+
return ""
|
| 79 |
+
if isinstance(file_obj, str):
|
| 80 |
+
return file_obj
|
| 81 |
+
return getattr(file_obj, "name", "")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _cleanup_old_runs(keep_last: int = 8) -> None:
|
| 85 |
+
run_dirs = sorted([p for p in OUTPUT_ROOT.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime)
|
| 86 |
+
for stale_dir in run_dirs[:-keep_last]:
|
| 87 |
+
shutil.rmtree(stale_dir, ignore_errors=True)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _pick_runtime_device() -> torch.device:
|
| 91 |
+
try:
|
| 92 |
+
torch.empty(1, device="cuda")
|
| 93 |
+
return torch.device("cuda")
|
| 94 |
+
except Exception:
|
| 95 |
+
return torch.device("cpu")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _load_model_bundle(model_variant: str) -> dict[str, Any]:
|
| 99 |
+
with MODEL_CACHE_LOCK:
|
| 100 |
+
cached = MODEL_CACHE.get(model_variant)
|
| 101 |
+
if cached is not None:
|
| 102 |
+
return cached
|
| 103 |
+
|
| 104 |
+
if MODEL_CACHE:
|
| 105 |
+
MODEL_CACHE.clear()
|
| 106 |
+
gc.collect()
|
| 107 |
+
if torch.cuda.is_available():
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
|
| 110 |
+
device = _pick_runtime_device()
|
| 111 |
+
weight_name = MODEL_FILENAMES[model_variant]
|
| 112 |
+
weight_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=weight_name)
|
| 113 |
+
|
| 114 |
+
model = GCTStream(
|
| 115 |
+
img_size=IMAGE_SIZE,
|
| 116 |
+
patch_size=PATCH_SIZE,
|
| 117 |
+
enable_3d_rope=True,
|
| 118 |
+
max_frame_num=1024,
|
| 119 |
+
kv_cache_sliding_window=64,
|
| 120 |
+
kv_cache_scale_frames=8,
|
| 121 |
+
kv_cache_cross_frame_special=True,
|
| 122 |
+
kv_cache_include_scale_frames=True,
|
| 123 |
+
use_sdpa=True,
|
| 124 |
+
camera_num_iterations=DEFAULT_CAMERA_ITERATIONS,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
checkpoint = torch.load(weight_path, map_location="cpu", weights_only=False)
|
| 128 |
+
state_dict = checkpoint.get("model", checkpoint)
|
| 129 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 130 |
+
|
| 131 |
+
model = model.to(device).eval()
|
| 132 |
+
inference_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 133 |
+
if device.type == "cuda" and getattr(model, "aggregator", None) is not None:
|
| 134 |
+
model.aggregator = model.aggregator.to(dtype=inference_dtype)
|
| 135 |
+
|
| 136 |
+
bundle = {
|
| 137 |
+
"model": model,
|
| 138 |
+
"device": device,
|
| 139 |
+
"dtype": inference_dtype,
|
| 140 |
+
"weight_name": weight_name,
|
| 141 |
+
"weight_path": str(weight_path),
|
| 142 |
+
"missing_keys": len(missing),
|
| 143 |
+
"unexpected_keys": len(unexpected),
|
| 144 |
+
}
|
| 145 |
+
MODEL_CACHE[model_variant] = bundle
|
| 146 |
+
return bundle
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _eager_load_default_model() -> None:
|
| 150 |
+
if not IS_SPACE_RUNTIME or SKIP_EAGER_MODEL_LOAD:
|
| 151 |
+
return
|
| 152 |
+
try:
|
| 153 |
+
bundle = _load_model_bundle("balanced")
|
| 154 |
+
STARTUP_NOTES.append(
|
| 155 |
+
f"Startup preload complete on `{bundle['device']}` with `{bundle['weight_name']}`."
|
| 156 |
+
)
|
| 157 |
+
except Exception as exc:
|
| 158 |
+
STARTUP_NOTES.append(f"Startup preload failed: {exc}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _copy_image_inputs(image_files: Iterable[Any], input_dir: Path, max_frames: int) -> list[str]:
|
| 162 |
+
paths = sorted(filter(None, (_resolve_path(item) for item in image_files)), key=lambda value: Path(value).name)
|
| 163 |
+
if not paths:
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
copied = []
|
| 167 |
+
for idx, src_path in enumerate(paths[:max_frames]):
|
| 168 |
+
src = Path(src_path)
|
| 169 |
+
suffix = src.suffix.lower() or ".png"
|
| 170 |
+
dest = input_dir / f"{idx:06d}{suffix}"
|
| 171 |
+
shutil.copy2(src, dest)
|
| 172 |
+
copied.append(str(dest))
|
| 173 |
+
return copied
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _extract_video_frames(video_file: str, frames_dir: Path, fps: int, max_frames: int) -> tuple[list[str], dict[str, Any]]:
|
| 177 |
+
cap = cv2.VideoCapture(video_file)
|
| 178 |
+
if not cap.isOpened():
|
| 179 |
+
raise ValueError("Could not open the uploaded video.")
|
| 180 |
+
|
| 181 |
+
source_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 182 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
| 183 |
+
interval = max(1, round(source_fps / max(fps, 1)))
|
| 184 |
+
|
| 185 |
+
saved_paths = []
|
| 186 |
+
frame_idx = 0
|
| 187 |
+
while len(saved_paths) < max_frames:
|
| 188 |
+
ok, frame = cap.read()
|
| 189 |
+
if not ok:
|
| 190 |
+
break
|
| 191 |
+
if frame_idx % interval == 0:
|
| 192 |
+
output_path = frames_dir / f"{len(saved_paths):06d}.jpg"
|
| 193 |
+
cv2.imwrite(str(output_path), frame)
|
| 194 |
+
saved_paths.append(str(output_path))
|
| 195 |
+
frame_idx += 1
|
| 196 |
+
|
| 197 |
+
cap.release()
|
| 198 |
+
|
| 199 |
+
return saved_paths, {
|
| 200 |
+
"source_fps": round(source_fps, 2),
|
| 201 |
+
"sample_interval": interval,
|
| 202 |
+
"original_frame_count": total_frames,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _prepare_inputs(image_files: list[Any], video_file: Any, fps: int, max_frames: int) -> tuple[torch.Tensor, list[str], Path, dict[str, Any]]:
|
| 207 |
+
_cleanup_old_runs()
|
| 208 |
+
work_dir = Path(tempfile.mkdtemp(prefix="lingbot-map-", dir=OUTPUT_ROOT))
|
| 209 |
+
input_dir = work_dir / "inputs"
|
| 210 |
+
input_dir.mkdir(parents=True, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
image_paths = _copy_image_inputs(image_files or [], input_dir, max_frames=max_frames)
|
| 213 |
+
input_summary = {"input_mode": None}
|
| 214 |
+
|
| 215 |
+
if image_paths:
|
| 216 |
+
input_summary["input_mode"] = "images"
|
| 217 |
+
input_summary["source_fps"] = None
|
| 218 |
+
input_summary["sample_interval"] = None
|
| 219 |
+
input_summary["original_frame_count"] = len(image_paths)
|
| 220 |
+
else:
|
| 221 |
+
video_path = _resolve_path(video_file)
|
| 222 |
+
if not video_path:
|
| 223 |
+
raise ValueError("Upload either ordered images or a video.")
|
| 224 |
+
image_paths, video_summary = _extract_video_frames(video_path, input_dir, fps=fps, max_frames=max_frames)
|
| 225 |
+
input_summary["input_mode"] = "video"
|
| 226 |
+
input_summary.update(video_summary)
|
| 227 |
+
|
| 228 |
+
if len(image_paths) < 2:
|
| 229 |
+
raise ValueError("Provide at least 2 frames. The Space is tuned for short multi-frame reconstructions.")
|
| 230 |
+
|
| 231 |
+
images = load_and_preprocess_images(
|
| 232 |
+
image_paths,
|
| 233 |
+
mode="crop",
|
| 234 |
+
image_size=IMAGE_SIZE,
|
| 235 |
+
patch_size=PATCH_SIZE,
|
| 236 |
+
)
|
| 237 |
+
return images, image_paths, work_dir, input_summary
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _squeeze_single_batch(key: str, value: torch.Tensor) -> torch.Tensor:
|
| 241 |
+
batched_dims = {
|
| 242 |
+
"pose_enc": 3,
|
| 243 |
+
"depth": 5,
|
| 244 |
+
"depth_conf": 4,
|
| 245 |
+
"world_points": 5,
|
| 246 |
+
"world_points_conf": 4,
|
| 247 |
+
"extrinsic": 4,
|
| 248 |
+
"intrinsic": 4,
|
| 249 |
+
"images": 5,
|
| 250 |
+
}
|
| 251 |
+
expected_ndim = batched_dims.get(key)
|
| 252 |
+
if expected_ndim is None or value.ndim != expected_ndim or value.shape[0] != 1:
|
| 253 |
+
return value
|
| 254 |
+
return value[0]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _postprocess_predictions(predictions: dict[str, Any], images: torch.Tensor) -> tuple[dict[str, Any], torch.Tensor]:
|
| 258 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
|
| 259 |
+
extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype)
|
| 260 |
+
extrinsic_4x4[..., :3, :4] = extrinsic
|
| 261 |
+
extrinsic_4x4[..., 3, 3] = 1.0
|
| 262 |
+
extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4)
|
| 263 |
+
|
| 264 |
+
predictions["extrinsic"] = extrinsic_4x4[..., :3, :4]
|
| 265 |
+
predictions["intrinsic"] = intrinsic
|
| 266 |
+
predictions.pop("pose_enc_list", None)
|
| 267 |
+
predictions.pop("images", None)
|
| 268 |
+
|
| 269 |
+
for key, value in list(predictions.items()):
|
| 270 |
+
if isinstance(value, torch.Tensor):
|
| 271 |
+
predictions[key] = _squeeze_single_batch(key, value.detach().to("cpu"))
|
| 272 |
+
|
| 273 |
+
images_cpu = images.detach().to("cpu")
|
| 274 |
+
if torch.cuda.is_available():
|
| 275 |
+
torch.cuda.synchronize()
|
| 276 |
+
return predictions, images_cpu
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _prepare_for_visualization(predictions: dict[str, Any], images: torch.Tensor) -> dict[str, Any]:
|
| 280 |
+
vis_predictions = {}
|
| 281 |
+
for key, value in predictions.items():
|
| 282 |
+
if isinstance(value, torch.Tensor):
|
| 283 |
+
vis_predictions[key] = _squeeze_single_batch(key, value).detach().cpu().numpy()
|
| 284 |
+
else:
|
| 285 |
+
vis_predictions[key] = value
|
| 286 |
+
vis_predictions["images"] = _squeeze_single_batch("images", images).detach().cpu().numpy()
|
| 287 |
+
return vis_predictions
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def _estimate_gpu_duration(images: torch.Tensor, model_variant: str, num_scale_frames: int, keyframe_interval: int) -> int:
|
| 291 |
+
frame_count = int(getattr(images, "shape", [DEFAULT_MAX_FRAMES])[0])
|
| 292 |
+
del model_variant, num_scale_frames, keyframe_interval
|
| 293 |
+
return min(180, max(60, 24 + frame_count * 4))
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@spaces.GPU(duration=_estimate_gpu_duration)
|
| 297 |
+
def _run_inference(images: torch.Tensor, model_variant: str, num_scale_frames: int, keyframe_interval: int) -> tuple[dict[str, Any], torch.Tensor, dict[str, Any]]:
|
| 298 |
+
bundle = _load_model_bundle(model_variant)
|
| 299 |
+
model = bundle["model"]
|
| 300 |
+
device = bundle["device"]
|
| 301 |
+
dtype = bundle["dtype"]
|
| 302 |
+
|
| 303 |
+
if device.type == "cuda":
|
| 304 |
+
torch.cuda.empty_cache()
|
| 305 |
+
torch.cuda.reset_peak_memory_stats()
|
| 306 |
+
|
| 307 |
+
images = images.to(device)
|
| 308 |
+
output_device = torch.device("cpu")
|
| 309 |
+
autocast_context = (
|
| 310 |
+
torch.amp.autocast("cuda", dtype=dtype)
|
| 311 |
+
if device.type == "cuda"
|
| 312 |
+
else contextlib.nullcontext()
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
started_at = time.time()
|
| 316 |
+
with torch.no_grad():
|
| 317 |
+
with autocast_context:
|
| 318 |
+
predictions = model.inference_streaming(
|
| 319 |
+
images,
|
| 320 |
+
num_scale_frames=num_scale_frames,
|
| 321 |
+
keyframe_interval=keyframe_interval,
|
| 322 |
+
output_device=output_device,
|
| 323 |
+
)
|
| 324 |
+
inference_seconds = time.time() - started_at
|
| 325 |
+
|
| 326 |
+
images_for_post = predictions["images"]
|
| 327 |
+
del images
|
| 328 |
+
if device.type == "cuda":
|
| 329 |
+
torch.cuda.empty_cache()
|
| 330 |
+
|
| 331 |
+
predictions, images_cpu = _postprocess_predictions(predictions, images_for_post)
|
| 332 |
+
return predictions, images_cpu, {
|
| 333 |
+
"runtime_seconds": round(inference_seconds, 2),
|
| 334 |
+
"device": str(device),
|
| 335 |
+
"dtype": str(dtype),
|
| 336 |
+
"weight_name": bundle["weight_name"],
|
| 337 |
+
"weight_path": bundle["weight_path"],
|
| 338 |
+
"missing_keys": bundle["missing_keys"],
|
| 339 |
+
"unexpected_keys": bundle["unexpected_keys"],
|
| 340 |
+
"peak_memory_gb": round(torch.cuda.max_memory_allocated() / 1e9, 2) if device.type == "cuda" else None,
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _make_preview_strip(images: torch.Tensor, output_path: Path) -> str:
|
| 345 |
+
frames = images.detach().cpu()
|
| 346 |
+
count = frames.shape[0]
|
| 347 |
+
indices = sorted({int(round(i)) for i in np.linspace(0, count - 1, num=min(4, count))})
|
| 348 |
+
|
| 349 |
+
tiles = []
|
| 350 |
+
for idx in indices:
|
| 351 |
+
rgb = (frames[idx].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
|
| 352 |
+
tile = Image.fromarray(rgb).resize((320, 220))
|
| 353 |
+
tiles.append(tile)
|
| 354 |
+
|
| 355 |
+
banner = Image.new("RGB", (320 * len(tiles), 260), color=(245, 240, 228))
|
| 356 |
+
draw = ImageDraw.Draw(banner)
|
| 357 |
+
draw.text((18, 14), f"LingBot-Map preview | {count} frames", fill=(31, 41, 55))
|
| 358 |
+
draw.text((18, 38), "ZeroGPU demo export", fill=(87, 96, 110))
|
| 359 |
+
|
| 360 |
+
x_offset = 0
|
| 361 |
+
for tile in tiles:
|
| 362 |
+
banner.paste(tile, (x_offset, 72))
|
| 363 |
+
x_offset += tile.width
|
| 364 |
+
|
| 365 |
+
banner.save(output_path)
|
| 366 |
+
return str(output_path)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _save_predictions_npz(predictions: dict[str, Any], output_path: Path) -> str:
|
| 370 |
+
arrays = {}
|
| 371 |
+
for key, value in predictions.items():
|
| 372 |
+
if isinstance(value, torch.Tensor):
|
| 373 |
+
arrays[key] = value.detach().cpu().numpy()
|
| 374 |
+
np.savez_compressed(output_path, **arrays)
|
| 375 |
+
return str(output_path)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _count_confident_points(vis_predictions: dict[str, Any], conf_percentile: float) -> tuple[int, float]:
|
| 379 |
+
conf = vis_predictions.get("world_points_conf")
|
| 380 |
+
if conf is None:
|
| 381 |
+
return 0, 0.0
|
| 382 |
+
conf_flat = conf.reshape(-1)
|
| 383 |
+
threshold = np.percentile(conf_flat, conf_percentile) if conf_percentile > 0 else 0.0
|
| 384 |
+
kept = int(((conf_flat >= threshold) & (conf_flat > 1e-5)).sum())
|
| 385 |
+
return kept, float(threshold)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _zip_outputs(work_dir: Path, paths: list[Path], output_name: str) -> str:
|
| 389 |
+
zip_path = work_dir / output_name
|
| 390 |
+
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
|
| 391 |
+
for path in paths:
|
| 392 |
+
if path.exists():
|
| 393 |
+
zip_file.write(path, arcname=path.name)
|
| 394 |
+
return str(zip_path)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _export_outputs(
|
| 398 |
+
work_dir: Path,
|
| 399 |
+
image_paths: list[str],
|
| 400 |
+
predictions: dict[str, Any],
|
| 401 |
+
images_cpu: torch.Tensor,
|
| 402 |
+
input_summary: dict[str, Any],
|
| 403 |
+
runtime_summary: dict[str, Any],
|
| 404 |
+
model_variant: str,
|
| 405 |
+
num_scale_frames: int,
|
| 406 |
+
keyframe_interval: int,
|
| 407 |
+
conf_percentile: float,
|
| 408 |
+
) -> tuple[str, str, dict[str, Any]]:
|
| 409 |
+
vis_predictions = _prepare_for_visualization(predictions, images_cpu)
|
| 410 |
+
|
| 411 |
+
glb_path = work_dir / "lingbot-map-reconstruction.glb"
|
| 412 |
+
scene = predictions_to_glb(
|
| 413 |
+
vis_predictions,
|
| 414 |
+
conf_thres=conf_percentile,
|
| 415 |
+
show_cam=True,
|
| 416 |
+
target_dir=str(work_dir),
|
| 417 |
+
mask_sky=False,
|
| 418 |
+
)
|
| 419 |
+
scene.export(glb_path)
|
| 420 |
+
|
| 421 |
+
preview_path = Path(_make_preview_strip(images_cpu, work_dir / "preview.png"))
|
| 422 |
+
npz_path = Path(_save_predictions_npz(predictions, work_dir / "predictions.npz"))
|
| 423 |
+
|
| 424 |
+
points_kept, conf_threshold = _count_confident_points(vis_predictions, conf_percentile)
|
| 425 |
+
summary = {
|
| 426 |
+
"model_variant": MODEL_LABELS[model_variant],
|
| 427 |
+
"model_filename": MODEL_FILENAMES[model_variant],
|
| 428 |
+
"frames_used": len(image_paths),
|
| 429 |
+
"num_scale_frames": num_scale_frames,
|
| 430 |
+
"keyframe_interval": keyframe_interval,
|
| 431 |
+
"confidence_percentile": conf_percentile,
|
| 432 |
+
"confidence_threshold": round(conf_threshold, 4),
|
| 433 |
+
"points_kept_for_glb": points_kept,
|
| 434 |
+
"input_summary": input_summary,
|
| 435 |
+
"runtime_summary": runtime_summary,
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
summary_path = work_dir / "summary.json"
|
| 439 |
+
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 440 |
+
|
| 441 |
+
artifact_path = _zip_outputs(
|
| 442 |
+
work_dir,
|
| 443 |
+
[glb_path, preview_path, npz_path, summary_path],
|
| 444 |
+
output_name="lingbot-map-results.zip",
|
| 445 |
+
)
|
| 446 |
+
return str(glb_path), artifact_path, summary
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _format_status(summary: dict[str, Any]) -> str:
|
| 450 |
+
runtime = summary["runtime_summary"]
|
| 451 |
+
input_summary = summary["input_summary"]
|
| 452 |
+
lines = [
|
| 453 |
+
"## Run Complete",
|
| 454 |
+
f"- Model: `{summary['model_filename']}`",
|
| 455 |
+
f"- Frames used: `{summary['frames_used']}`",
|
| 456 |
+
f"- Input mode: `{input_summary['input_mode']}`",
|
| 457 |
+
f"- Runtime: `{runtime['runtime_seconds']}s` on `{runtime['device']}`",
|
| 458 |
+
f"- GLB confidence percentile: `{summary['confidence_percentile']}`",
|
| 459 |
+
f"- Points kept for GLB: `{summary['points_kept_for_glb']}`",
|
| 460 |
+
]
|
| 461 |
+
if runtime.get("peak_memory_gb") is not None:
|
| 462 |
+
lines.append(f"- Peak GPU memory: `{runtime['peak_memory_gb']} GB`")
|
| 463 |
+
if input_summary.get("sample_interval"):
|
| 464 |
+
lines.append(f"- Video sample interval: `every {input_summary['sample_interval']} frame(s)`")
|
| 465 |
+
return "\n".join(lines)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def reconstruct_scene(
|
| 469 |
+
image_files: list[Any],
|
| 470 |
+
video_file: Any,
|
| 471 |
+
model_variant: str,
|
| 472 |
+
fps: int,
|
| 473 |
+
max_frames: int,
|
| 474 |
+
num_scale_frames: int,
|
| 475 |
+
keyframe_interval: int,
|
| 476 |
+
conf_percentile: float,
|
| 477 |
+
):
|
| 478 |
+
max_frames = max(2, min(int(max_frames), MAX_FRAMES_HARD_LIMIT))
|
| 479 |
+
num_scale_frames = max(1, int(num_scale_frames))
|
| 480 |
+
keyframe_interval = max(1, int(keyframe_interval))
|
| 481 |
+
conf_percentile = float(conf_percentile)
|
| 482 |
+
|
| 483 |
+
images, image_paths, work_dir, input_summary = _prepare_inputs(
|
| 484 |
+
image_files=image_files or [],
|
| 485 |
+
video_file=video_file,
|
| 486 |
+
fps=int(fps),
|
| 487 |
+
max_frames=max_frames,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
num_scale_frames = min(num_scale_frames, int(images.shape[0]))
|
| 491 |
+
predictions, images_cpu, runtime_summary = _run_inference(
|
| 492 |
+
images,
|
| 493 |
+
model_variant=model_variant,
|
| 494 |
+
num_scale_frames=num_scale_frames,
|
| 495 |
+
keyframe_interval=keyframe_interval,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
glb_path, artifact_path, summary = _export_outputs(
|
| 499 |
+
work_dir=work_dir,
|
| 500 |
+
image_paths=image_paths,
|
| 501 |
+
predictions=predictions,
|
| 502 |
+
images_cpu=images_cpu,
|
| 503 |
+
input_summary=input_summary,
|
| 504 |
+
runtime_summary=runtime_summary,
|
| 505 |
+
model_variant=model_variant,
|
| 506 |
+
num_scale_frames=num_scale_frames,
|
| 507 |
+
keyframe_interval=keyframe_interval,
|
| 508 |
+
conf_percentile=conf_percentile,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
preview_path = str(work_dir / "preview.png")
|
| 512 |
+
status = _format_status(summary)
|
| 513 |
+
return glb_path, preview_path, artifact_path, summary, status
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _build_startup_markdown() -> str:
|
| 517 |
+
if not STARTUP_NOTES:
|
| 518 |
+
return (
|
| 519 |
+
"Short-form LingBot-Map Space for Hugging Face ZeroGPU. "
|
| 520 |
+
"It uses the upstream checkpoint files, SDPA inference, and exports a GLB scene plus a zipped results bundle."
|
| 521 |
+
)
|
| 522 |
+
return "\n".join([f"- {note}" for note in STARTUP_NOTES])
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
CSS = """
|
| 526 |
+
.shell {
|
| 527 |
+
max-width: 1180px;
|
| 528 |
+
margin: 0 auto;
|
| 529 |
+
}
|
| 530 |
+
.headline {
|
| 531 |
+
background: linear-gradient(135deg, #f3ead7 0%, #d6e6d4 100%);
|
| 532 |
+
border: 1px solid #d9ccb3;
|
| 533 |
+
border-radius: 20px;
|
| 534 |
+
padding: 20px 24px;
|
| 535 |
+
}
|
| 536 |
+
.headline h1 {
|
| 537 |
+
margin: 0 0 8px 0;
|
| 538 |
+
color: #14231a;
|
| 539 |
+
}
|
| 540 |
+
.headline p {
|
| 541 |
+
margin: 0;
|
| 542 |
+
color: #304437;
|
| 543 |
+
}
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
_eager_load_default_model()
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
with gr.Blocks(css=CSS, theme=gr.themes.Soft(primary_hue="green", secondary_hue="amber"), title="LingBot-Map ZeroGPU Demo") as demo:
|
| 551 |
+
gr.Markdown("<div class='shell'>")
|
| 552 |
+
with gr.Row():
|
| 553 |
+
gr.Image(value=str(ROOT / "assets" / "teaser.png"), show_label=False, interactive=False, min_width=320)
|
| 554 |
+
gr.Markdown(
|
| 555 |
+
"""
|
| 556 |
+
<div class="headline">
|
| 557 |
+
<h1>LingBot-Map ZeroGPU Demo</h1>
|
| 558 |
+
<p>Upload ordered images or a short video. The Space samples up to 24 frames, runs the SDPA fallback path, and exports a GLB scene plus a zipped artifact bundle.</p>
|
| 559 |
+
</div>
|
| 560 |
+
"""
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
gr.Markdown(_build_startup_markdown())
|
| 564 |
+
|
| 565 |
+
with gr.Row():
|
| 566 |
+
with gr.Column(scale=1):
|
| 567 |
+
image_files = gr.File(
|
| 568 |
+
label="Ordered images",
|
| 569 |
+
file_count="multiple",
|
| 570 |
+
file_types=["image"],
|
| 571 |
+
type="filepath",
|
| 572 |
+
)
|
| 573 |
+
video_file = gr.File(
|
| 574 |
+
label="Or upload one video",
|
| 575 |
+
file_types=["video"],
|
| 576 |
+
type="filepath",
|
| 577 |
+
)
|
| 578 |
+
model_variant = gr.Dropdown(
|
| 579 |
+
choices=[("Balanced", "balanced"), ("Long", "long"), ("Stage-1", "stage1")],
|
| 580 |
+
value="balanced",
|
| 581 |
+
label="Checkpoint",
|
| 582 |
+
)
|
| 583 |
+
fps = gr.Slider(minimum=1, maximum=12, step=1, value=DEFAULT_FPS, label="Video sampling FPS")
|
| 584 |
+
max_frames = gr.Slider(minimum=2, maximum=MAX_FRAMES_HARD_LIMIT, step=1, value=DEFAULT_MAX_FRAMES, label="Max frames")
|
| 585 |
+
num_scale_frames = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_SCALE_FRAMES, label="Scale frames")
|
| 586 |
+
keyframe_interval = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_KEYFRAME_INTERVAL, label="Keyframe interval")
|
| 587 |
+
conf_percentile = gr.Slider(
|
| 588 |
+
minimum=0,
|
| 589 |
+
maximum=90,
|
| 590 |
+
step=5,
|
| 591 |
+
value=DEFAULT_CONF_PERCENTILE,
|
| 592 |
+
label="GLB confidence percentile",
|
| 593 |
+
)
|
| 594 |
+
run_button = gr.Button("Reconstruct Scene", variant="primary")
|
| 595 |
+
|
| 596 |
+
with gr.Column(scale=1):
|
| 597 |
+
model_preview = gr.Model3D(label="3D preview", clear_color=[0.97, 0.94, 0.88, 1.0])
|
| 598 |
+
preview_image = gr.Image(label="Preview strip", interactive=False)
|
| 599 |
+
artifact_file = gr.File(label="Download bundle")
|
| 600 |
+
summary_json = gr.JSON(label="Run summary")
|
| 601 |
+
status_markdown = gr.Markdown()
|
| 602 |
+
|
| 603 |
+
run_button.click(
|
| 604 |
+
fn=reconstruct_scene,
|
| 605 |
+
inputs=[
|
| 606 |
+
image_files,
|
| 607 |
+
video_file,
|
| 608 |
+
model_variant,
|
| 609 |
+
fps,
|
| 610 |
+
max_frames,
|
| 611 |
+
num_scale_frames,
|
| 612 |
+
keyframe_interval,
|
| 613 |
+
conf_percentile,
|
| 614 |
+
],
|
| 615 |
+
outputs=[
|
| 616 |
+
model_preview,
|
| 617 |
+
preview_image,
|
| 618 |
+
artifact_file,
|
| 619 |
+
summary_json,
|
| 620 |
+
status_markdown,
|
| 621 |
+
],
|
| 622 |
+
show_progress="full",
|
| 623 |
+
)
|
| 624 |
+
gr.Markdown("</div>")
|
| 625 |
+
|
| 626 |
+
demo.queue(default_concurrency_limit=1)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
if __name__ == "__main__":
|
| 630 |
+
demo.launch()
|
assets/teaser.png
ADDED
|
Git LFS Details
|
lingbot_map/__init__.py
ADDED
|
File without changes
|
lingbot_map/aggregator/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .stream import AggregatorStream
|
| 2 |
+
from .base import AggregatorBase
|
lingbot_map/aggregator/base.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AggregatorBase - Base class for all Aggregator implementations.
|
| 3 |
+
|
| 4 |
+
Provides shared functionality:
|
| 5 |
+
- Patch embedding (DINOv2)
|
| 6 |
+
- Special tokens (camera, register, scale)
|
| 7 |
+
- Block building
|
| 8 |
+
- Common forward pass structure
|
| 9 |
+
|
| 10 |
+
Subclasses implement mode-specific attention logic.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from typing import Optional, Tuple, List
|
| 18 |
+
|
| 19 |
+
from lingbot_map.layers import PatchEmbed
|
| 20 |
+
from lingbot_map.layers.block import Block
|
| 21 |
+
from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 22 |
+
from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 27 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def slice_expand_and_flatten(token, B, S, first_num_frame=1):
|
| 31 |
+
"""
|
| 32 |
+
Helper function to slice, expand and flatten tokens.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
token: Token tensor [1, 2, N, C] where first index is for first frames
|
| 36 |
+
B: Batch size
|
| 37 |
+
S: Sequence length
|
| 38 |
+
first_num_frame: Number of frames to use first token for
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Flattened tokens [B*S, N, C]
|
| 42 |
+
"""
|
| 43 |
+
# token shape: [1, 2, N, C]
|
| 44 |
+
# Expand to [B, S, N, C]
|
| 45 |
+
if first_num_frame > 1:
|
| 46 |
+
# Use first token for first first_num_frame frames, second for rest
|
| 47 |
+
token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C]
|
| 48 |
+
token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C]
|
| 49 |
+
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
|
| 50 |
+
else:
|
| 51 |
+
# Use first token for first frame, second for rest
|
| 52 |
+
token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C]
|
| 53 |
+
token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C]
|
| 54 |
+
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
|
| 55 |
+
|
| 56 |
+
# Flatten to [B*S, N, C]
|
| 57 |
+
return token_expanded.reshape(B * S, -1, token.shape[-1])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AggregatorBase(nn.Module, ABC):
|
| 61 |
+
"""
|
| 62 |
+
Base class for all Aggregator implementations.
|
| 63 |
+
|
| 64 |
+
Handles shared components:
|
| 65 |
+
- Patch embedding (DINOv2 or conv)
|
| 66 |
+
- Special tokens (camera, register, optionally scale)
|
| 67 |
+
- Block creation (frame + global)
|
| 68 |
+
- RoPE (2D rotary position embeddings)
|
| 69 |
+
- Common forward pass scaffolding
|
| 70 |
+
|
| 71 |
+
Subclasses must implement:
|
| 72 |
+
- _process_global_attention(): Mode-specific cross-frame attention logic
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
# Architecture parameters
|
| 78 |
+
img_size=518,
|
| 79 |
+
patch_size=14,
|
| 80 |
+
embed_dim=1024,
|
| 81 |
+
depth=24,
|
| 82 |
+
num_heads=16,
|
| 83 |
+
mlp_ratio=4.0,
|
| 84 |
+
num_register_tokens=4,
|
| 85 |
+
# Block configuration
|
| 86 |
+
block_fn=Block,
|
| 87 |
+
qkv_bias=True,
|
| 88 |
+
proj_bias=True,
|
| 89 |
+
ffn_bias=True,
|
| 90 |
+
qk_norm=True,
|
| 91 |
+
init_values=0.01,
|
| 92 |
+
# Patch embedding
|
| 93 |
+
patch_embed="dinov2_vitl14_reg",
|
| 94 |
+
pretrained_path=None,
|
| 95 |
+
# Attention pattern
|
| 96 |
+
aa_order=["frame", "global"],
|
| 97 |
+
aa_block_size=1,
|
| 98 |
+
# RoPE
|
| 99 |
+
rope_freq=100,
|
| 100 |
+
disable_global_rope=False,
|
| 101 |
+
# Gradient checkpointing
|
| 102 |
+
use_reentrant: bool = False,
|
| 103 |
+
use_gradient_checkpoint: bool = True,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
# Store configuration
|
| 108 |
+
self.img_size = img_size
|
| 109 |
+
self.patch_size = patch_size
|
| 110 |
+
self.embed_dim = embed_dim
|
| 111 |
+
self.depth = depth
|
| 112 |
+
self.num_heads = num_heads
|
| 113 |
+
self.mlp_ratio = mlp_ratio
|
| 114 |
+
self.num_register_tokens = num_register_tokens
|
| 115 |
+
self.aa_order = aa_order
|
| 116 |
+
self.aa_block_size = aa_block_size
|
| 117 |
+
self.disable_global_rope = disable_global_rope
|
| 118 |
+
self.use_reentrant = use_reentrant
|
| 119 |
+
self.use_gradient_checkpoint = use_gradient_checkpoint
|
| 120 |
+
self.pretrained_path = pretrained_path
|
| 121 |
+
self.enable_ulysses_cp = False # CP disabled
|
| 122 |
+
|
| 123 |
+
print("pretrained_path:", self.pretrained_path)
|
| 124 |
+
|
| 125 |
+
# Validate depth
|
| 126 |
+
if self.depth % self.aa_block_size != 0:
|
| 127 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 128 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 129 |
+
|
| 130 |
+
# Build patch embedding
|
| 131 |
+
self._build_patch_embed(
|
| 132 |
+
patch_embed=patch_embed,
|
| 133 |
+
img_size=img_size,
|
| 134 |
+
patch_size=patch_size,
|
| 135 |
+
num_register_tokens=num_register_tokens,
|
| 136 |
+
embed_dim=embed_dim,
|
| 137 |
+
pretrained_path=pretrained_path
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Initialize RoPE
|
| 141 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 142 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 143 |
+
|
| 144 |
+
# Build blocks (frame + global)
|
| 145 |
+
self._build_blocks(
|
| 146 |
+
block_fn=block_fn,
|
| 147 |
+
depth=depth,
|
| 148 |
+
embed_dim=embed_dim,
|
| 149 |
+
num_heads=num_heads,
|
| 150 |
+
mlp_ratio=mlp_ratio,
|
| 151 |
+
qkv_bias=qkv_bias,
|
| 152 |
+
proj_bias=proj_bias,
|
| 153 |
+
ffn_bias=ffn_bias,
|
| 154 |
+
init_values=init_values,
|
| 155 |
+
qk_norm=qk_norm,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Setup special tokens (camera, register, optionally scale)
|
| 159 |
+
self._setup_special_tokens()
|
| 160 |
+
|
| 161 |
+
# Register normalization constants
|
| 162 |
+
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
| 163 |
+
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
|
| 164 |
+
|
| 165 |
+
# Initialize from DINO checkpoint if available
|
| 166 |
+
if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
|
| 167 |
+
self._init_blocks_from_dino(self._dino_checkpoint)
|
| 168 |
+
del self._dino_checkpoint # Free memory
|
| 169 |
+
|
| 170 |
+
def _build_patch_embed(
|
| 171 |
+
self,
|
| 172 |
+
patch_embed: str,
|
| 173 |
+
img_size: int,
|
| 174 |
+
patch_size: int,
|
| 175 |
+
num_register_tokens: int,
|
| 176 |
+
embed_dim: int,
|
| 177 |
+
pretrained_path: str,
|
| 178 |
+
interpolate_antialias=True,
|
| 179 |
+
interpolate_offset=0.0,
|
| 180 |
+
block_chunks=0,
|
| 181 |
+
init_values=1.0,
|
| 182 |
+
):
|
| 183 |
+
"""
|
| 184 |
+
Build patch embedding layer.
|
| 185 |
+
|
| 186 |
+
Supports:
|
| 187 |
+
- "conv": Simple convolutional patch embedding
|
| 188 |
+
- "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
|
| 189 |
+
"""
|
| 190 |
+
if "conv" in patch_embed:
|
| 191 |
+
self.patch_embed = PatchEmbed(
|
| 192 |
+
img_size=img_size,
|
| 193 |
+
patch_size=patch_size,
|
| 194 |
+
in_chans=3,
|
| 195 |
+
embed_dim=embed_dim
|
| 196 |
+
)
|
| 197 |
+
self._dino_checkpoint = None
|
| 198 |
+
|
| 199 |
+
else:
|
| 200 |
+
vit_models = {
|
| 201 |
+
"dinov2_vitl14_reg": vit_large,
|
| 202 |
+
"dinov2_vitb14_reg": vit_base,
|
| 203 |
+
"dinov2_vits14_reg": vit_small,
|
| 204 |
+
"dinov2_vitg2_reg": vit_giant2,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
if patch_embed not in vit_models:
|
| 208 |
+
raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")
|
| 209 |
+
|
| 210 |
+
self.patch_embed = vit_models[patch_embed](
|
| 211 |
+
img_size=img_size,
|
| 212 |
+
patch_size=patch_size,
|
| 213 |
+
num_register_tokens=num_register_tokens,
|
| 214 |
+
interpolate_antialias=interpolate_antialias,
|
| 215 |
+
interpolate_offset=interpolate_offset,
|
| 216 |
+
block_chunks=block_chunks,
|
| 217 |
+
init_values=init_values,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Load pretrained weights
|
| 221 |
+
try:
|
| 222 |
+
ckpt = torch.load(pretrained_path)
|
| 223 |
+
del ckpt['pos_embed']
|
| 224 |
+
logger.info("Loading pretrained weights for DINOv2")
|
| 225 |
+
missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
|
| 226 |
+
logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
|
| 227 |
+
|
| 228 |
+
# Store checkpoint for block initialization
|
| 229 |
+
self._dino_checkpoint = ckpt
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.warning(f"Failed to load pretrained weights: {e}")
|
| 232 |
+
self._dino_checkpoint = None
|
| 233 |
+
|
| 234 |
+
# Disable gradients for mask token
|
| 235 |
+
if hasattr(self.patch_embed, "mask_token"):
|
| 236 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
| 237 |
+
|
| 238 |
+
@abstractmethod
|
| 239 |
+
def _build_blocks(
|
| 240 |
+
self,
|
| 241 |
+
block_fn,
|
| 242 |
+
depth: int,
|
| 243 |
+
embed_dim: int,
|
| 244 |
+
num_heads: int,
|
| 245 |
+
mlp_ratio: float,
|
| 246 |
+
qkv_bias: bool,
|
| 247 |
+
proj_bias: bool,
|
| 248 |
+
ffn_bias: bool,
|
| 249 |
+
init_values: float,
|
| 250 |
+
qk_norm: bool,
|
| 251 |
+
):
|
| 252 |
+
"""
|
| 253 |
+
Build frame_blocks and global_blocks.
|
| 254 |
+
|
| 255 |
+
Subclasses implement mode-specific block creation.
|
| 256 |
+
|
| 257 |
+
Must create:
|
| 258 |
+
- self.frame_blocks: nn.ModuleList of frame attention blocks
|
| 259 |
+
- self.global_blocks: nn.ModuleList of global attention blocks
|
| 260 |
+
"""
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
@abstractmethod
|
| 264 |
+
def _setup_special_tokens(self):
|
| 265 |
+
"""
|
| 266 |
+
Setup camera token, register tokens, and optionally scale token.
|
| 267 |
+
|
| 268 |
+
Subclasses implement mode-specific token initialization.
|
| 269 |
+
|
| 270 |
+
Must create:
|
| 271 |
+
- self.camera_token
|
| 272 |
+
- self.register_token
|
| 273 |
+
- self.scale_token (optional, for causal mode)
|
| 274 |
+
- self.patch_start_idx
|
| 275 |
+
- self.num_special_tokens
|
| 276 |
+
"""
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
def _init_blocks_from_dino(self, dino_ckpt: dict):
|
| 280 |
+
"""
|
| 281 |
+
Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
dino_ckpt: Checkpoint dictionary from DINOv2 model
|
| 285 |
+
"""
|
| 286 |
+
logger.info("Initializing blocks from DINOv2 pretrained weights")
|
| 287 |
+
|
| 288 |
+
# Extract block keys
|
| 289 |
+
dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
|
| 290 |
+
if not dino_block_keys:
|
| 291 |
+
logger.warning("No 'blocks' found in DINO checkpoint")
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
# Get block indices
|
| 295 |
+
block_indices = set()
|
| 296 |
+
for key in dino_block_keys:
|
| 297 |
+
parts = key.split('.')
|
| 298 |
+
if len(parts) > 1 and parts[1].isdigit():
|
| 299 |
+
block_indices.add(int(parts[1]))
|
| 300 |
+
|
| 301 |
+
num_dino_blocks = len(block_indices)
|
| 302 |
+
print(f"Found {num_dino_blocks} blocks in DINO checkpoint")
|
| 303 |
+
|
| 304 |
+
# Initialize frame_blocks
|
| 305 |
+
for i, block in enumerate(self.frame_blocks):
|
| 306 |
+
dino_block_idx = i % num_dino_blocks
|
| 307 |
+
block_state_dict = {}
|
| 308 |
+
prefix = f'blocks.{dino_block_idx}.'
|
| 309 |
+
for key, value in dino_ckpt.items():
|
| 310 |
+
if key.startswith(prefix):
|
| 311 |
+
new_key = key[len(prefix):]
|
| 312 |
+
block_state_dict[new_key] = value
|
| 313 |
+
|
| 314 |
+
if block_state_dict:
|
| 315 |
+
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
|
| 316 |
+
if i == 0: # Only log for first block to avoid spam
|
| 317 |
+
print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
|
| 318 |
+
|
| 319 |
+
# Initialize global_blocks
|
| 320 |
+
for i, block in enumerate(self.global_blocks):
|
| 321 |
+
dino_block_idx = i % num_dino_blocks
|
| 322 |
+
block_state_dict = {}
|
| 323 |
+
prefix = f'blocks.{dino_block_idx}.'
|
| 324 |
+
for key, value in dino_ckpt.items():
|
| 325 |
+
if key.startswith(prefix):
|
| 326 |
+
new_key = key[len(prefix):]
|
| 327 |
+
block_state_dict[new_key] = value
|
| 328 |
+
|
| 329 |
+
if block_state_dict:
|
| 330 |
+
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
|
| 331 |
+
if i == 0: # Only log for first block to avoid spam
|
| 332 |
+
print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
|
| 333 |
+
|
| 334 |
+
logger.info("Successfully initialized blocks from DINOv2 weights")
|
| 335 |
+
|
| 336 |
+
def _embed_images(
|
| 337 |
+
self,
|
| 338 |
+
images: torch.Tensor,
|
| 339 |
+
num_frame_for_scale: Optional[int] = None,
|
| 340 |
+
) -> Tuple[torch.Tensor, int, int, int, int, int]:
|
| 341 |
+
"""
|
| 342 |
+
Embed images and prepare for attention processing.
|
| 343 |
+
|
| 344 |
+
Handles:
|
| 345 |
+
- Image normalization
|
| 346 |
+
- Patch embedding
|
| 347 |
+
- Special token concatenation
|
| 348 |
+
- Position embedding
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
images: Input images [B, S, 3, H, W] in range [0, 1]
|
| 352 |
+
num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
(tokens, B, S, S, P, C):
|
| 356 |
+
tokens: Embedded tokens [B*S, P, C]
|
| 357 |
+
B: Batch size
|
| 358 |
+
S: Sequence length
|
| 359 |
+
S: Same as above (no CP slicing)
|
| 360 |
+
P: Number of tokens per frame (patches + special tokens)
|
| 361 |
+
C: Embedding dimension
|
| 362 |
+
"""
|
| 363 |
+
B, S, C_in, H, W = images.shape
|
| 364 |
+
|
| 365 |
+
if C_in != 3:
|
| 366 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
| 367 |
+
|
| 368 |
+
# Normalize images
|
| 369 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
| 370 |
+
|
| 371 |
+
# No CP slicing: S_local == S_global
|
| 372 |
+
S_local = S
|
| 373 |
+
S_global = S
|
| 374 |
+
|
| 375 |
+
# Reshape for patch embedding [B*S, C, H, W]
|
| 376 |
+
images = images.view(B * S, C_in, H, W)
|
| 377 |
+
|
| 378 |
+
# Patch embedding
|
| 379 |
+
patch_tokens = self.patch_embed(images)
|
| 380 |
+
if isinstance(patch_tokens, dict):
|
| 381 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
| 382 |
+
|
| 383 |
+
_, P_patch, C = patch_tokens.shape
|
| 384 |
+
|
| 385 |
+
# Prepare special tokens
|
| 386 |
+
special_tokens = self._prepare_special_tokens(
|
| 387 |
+
B, S_local, S_global, C,
|
| 388 |
+
num_frame_for_scale=num_frame_for_scale
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Concatenate special tokens + patch tokens
|
| 392 |
+
tokens = torch.cat([special_tokens, patch_tokens], dim=1)
|
| 393 |
+
|
| 394 |
+
_, P, C = tokens.shape
|
| 395 |
+
|
| 396 |
+
return tokens, B, S_local, S_global, P, C
|
| 397 |
+
|
| 398 |
+
@abstractmethod
|
| 399 |
+
def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
|
| 400 |
+
"""
|
| 401 |
+
Prepare special tokens (camera, register, optionally scale).
|
| 402 |
+
|
| 403 |
+
Subclasses implement mode-specific token preparation.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
B: Batch size
|
| 407 |
+
S_local: Local sequence length
|
| 408 |
+
S_global: Global sequence length
|
| 409 |
+
C: Embedding dimension
|
| 410 |
+
**kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
Special tokens [B*S, N_special, C]
|
| 414 |
+
"""
|
| 415 |
+
pass
|
| 416 |
+
|
| 417 |
+
def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
|
| 418 |
+
"""
|
| 419 |
+
Get 2D position embeddings for RoPE.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
B: Batch size
|
| 423 |
+
S: Sequence length
|
| 424 |
+
H: Image height
|
| 425 |
+
W: Image width
|
| 426 |
+
device: Device to create positions on
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Position tensor [B*S, P, 2] or None if rope is disabled
|
| 430 |
+
"""
|
| 431 |
+
if self.rope is None:
|
| 432 |
+
return None
|
| 433 |
+
|
| 434 |
+
# Get patch positions
|
| 435 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
|
| 436 |
+
|
| 437 |
+
# Add offset for patch tokens (skip special tokens at pos=0)
|
| 438 |
+
if self.patch_start_idx > 0:
|
| 439 |
+
pos = pos + 1
|
| 440 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
|
| 441 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 442 |
+
|
| 443 |
+
return pos
|
| 444 |
+
|
| 445 |
+
def _process_frame_attention(
|
| 446 |
+
self,
|
| 447 |
+
tokens: torch.Tensor,
|
| 448 |
+
B: int,
|
| 449 |
+
S: int,
|
| 450 |
+
P: int,
|
| 451 |
+
C: int,
|
| 452 |
+
frame_idx: int,
|
| 453 |
+
pos: Optional[torch.Tensor] = None,
|
| 454 |
+
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
| 455 |
+
"""
|
| 456 |
+
Process frame attention blocks.
|
| 457 |
+
|
| 458 |
+
Frame attention operates independently per frame (no cross-frame communication).
|
| 459 |
+
Tokens stay in shape [B*S, P, C].
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
tokens: Input tokens [B*S, P, C]
|
| 463 |
+
B: Batch size
|
| 464 |
+
S: Sequence length
|
| 465 |
+
P: Tokens per frame
|
| 466 |
+
C: Embedding dimension
|
| 467 |
+
frame_idx: Current frame block index
|
| 468 |
+
pos: Position embeddings [B*S, P, 2]
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
(tokens, frame_idx, intermediates):
|
| 472 |
+
tokens: Output tokens [B*S, P, C]
|
| 473 |
+
frame_idx: Updated frame block index
|
| 474 |
+
intermediates: List of intermediate outputs [B, S, P, C]
|
| 475 |
+
"""
|
| 476 |
+
# Ensure correct shape
|
| 477 |
+
if tokens.shape != (B * S, P, C):
|
| 478 |
+
tokens = tokens.view(B * S, P, C)
|
| 479 |
+
|
| 480 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 481 |
+
pos = pos.view(B * S, P, 2)
|
| 482 |
+
|
| 483 |
+
intermediates = []
|
| 484 |
+
|
| 485 |
+
# Process blocks
|
| 486 |
+
for i in range(self.aa_block_size):
|
| 487 |
+
if self.training and self.use_gradient_checkpoint:
|
| 488 |
+
from torch.utils.checkpoint import checkpoint
|
| 489 |
+
tokens = checkpoint(
|
| 490 |
+
self.frame_blocks[frame_idx],
|
| 491 |
+
tokens,
|
| 492 |
+
pos,
|
| 493 |
+
False, # enable_ulysses_cp (always False)
|
| 494 |
+
use_reentrant=self.use_reentrant
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)
|
| 498 |
+
|
| 499 |
+
frame_idx += 1
|
| 500 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 501 |
+
|
| 502 |
+
return tokens, frame_idx, intermediates
|
| 503 |
+
|
| 504 |
+
@abstractmethod
|
| 505 |
+
def _process_global_attention(
|
| 506 |
+
self,
|
| 507 |
+
tokens: torch.Tensor,
|
| 508 |
+
B: int,
|
| 509 |
+
S_local: int,
|
| 510 |
+
S_global: int,
|
| 511 |
+
P: int,
|
| 512 |
+
C: int,
|
| 513 |
+
global_idx: int,
|
| 514 |
+
pos: Optional[torch.Tensor] = None,
|
| 515 |
+
**kwargs
|
| 516 |
+
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
| 517 |
+
"""
|
| 518 |
+
Process global (cross-frame) attention blocks.
|
| 519 |
+
|
| 520 |
+
Subclasses implement mode-specific attention logic.
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
tokens: Input tokens
|
| 524 |
+
B: Batch size
|
| 525 |
+
S_local: Local sequence length
|
| 526 |
+
S_global: Global sequence length
|
| 527 |
+
P: Tokens per frame
|
| 528 |
+
C: Embedding dimension
|
| 529 |
+
global_idx: Current global block index
|
| 530 |
+
pos: Position embeddings
|
| 531 |
+
**kwargs: Mode-specific parameters
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
(tokens, global_idx, intermediates):
|
| 535 |
+
tokens: Output tokens
|
| 536 |
+
global_idx: Updated global block index
|
| 537 |
+
intermediates: List of intermediate outputs
|
| 538 |
+
"""
|
| 539 |
+
pass
|
| 540 |
+
|
| 541 |
+
def forward(
|
| 542 |
+
self,
|
| 543 |
+
images: torch.Tensor,
|
| 544 |
+
selected_idx: Optional[List[int]] = None,
|
| 545 |
+
# Mode-specific parameters
|
| 546 |
+
num_frame_for_scale: Optional[int] = None,
|
| 547 |
+
sliding_window_size: Optional[int] = None,
|
| 548 |
+
num_frame_per_block: int = 1,
|
| 549 |
+
) -> Tuple[List[torch.Tensor], int]:
|
| 550 |
+
"""
|
| 551 |
+
Forward pass.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
images: Input images [B, S, 3, H, W] in range [0, 1]
|
| 555 |
+
selected_idx: Which block indices to output (None = all)
|
| 556 |
+
num_frame_for_scale: Number of frames for scale estimation (causal mode)
|
| 557 |
+
sliding_window_size: Sliding window size in blocks (causal mode)
|
| 558 |
+
num_frame_per_block: Number of frames per processing block (causal mode)
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
(output_list, patch_start_idx):
|
| 562 |
+
output_list: List of block outputs [B, S, P, 2C]
|
| 563 |
+
patch_start_idx: Index where patch tokens start
|
| 564 |
+
"""
|
| 565 |
+
B, S_input, _, H, W = images.shape
|
| 566 |
+
|
| 567 |
+
# Embed images
|
| 568 |
+
tokens, B, S_local, S_global, P, C = self._embed_images(
|
| 569 |
+
images,
|
| 570 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Get position embeddings
|
| 574 |
+
pos_local = self._get_positions(B, S_local, H, W, device=images.device)
|
| 575 |
+
pos_global = self._get_positions(B, S_global, H, W, device=images.device)
|
| 576 |
+
|
| 577 |
+
# Alternating attention
|
| 578 |
+
frame_idx = 0
|
| 579 |
+
global_idx = 0
|
| 580 |
+
output_list = []
|
| 581 |
+
|
| 582 |
+
for block_group_idx in range(self.aa_block_num):
|
| 583 |
+
for attn_type in self.aa_order:
|
| 584 |
+
if attn_type == "frame":
|
| 585 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
| 586 |
+
tokens, B, S_local, P, C, frame_idx, pos=pos_local
|
| 587 |
+
)
|
| 588 |
+
elif attn_type == "global":
|
| 589 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
| 590 |
+
tokens, B, S_local, S_global, P, C, global_idx,
|
| 591 |
+
pos=pos_global,
|
| 592 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 593 |
+
sliding_window_size=sliding_window_size,
|
| 594 |
+
num_frame_per_block=num_frame_per_block,
|
| 595 |
+
image_height=H,
|
| 596 |
+
image_width=W,
|
| 597 |
+
)
|
| 598 |
+
else:
|
| 599 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 600 |
+
|
| 601 |
+
# Collect outputs
|
| 602 |
+
if selected_idx is None or block_group_idx in selected_idx:
|
| 603 |
+
for i in range(len(frame_intermediates)):
|
| 604 |
+
# Concatenate frame and global intermediates [B, S, P, 2C]
|
| 605 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
| 606 |
+
output_list.append(concat_inter)
|
| 607 |
+
|
| 608 |
+
return output_list, self.patch_start_idx
|
lingbot_map/aggregator/stream.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AggregatorStream - Streaming causal aggregator with FlashInfer KV cache.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- Temporal causal attention
|
| 6 |
+
- Sliding window support
|
| 7 |
+
- Scale token for scale estimation frames
|
| 8 |
+
- Streaming inference with FlashInfer paged KV cache
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from typing import Optional, Tuple, List
|
| 15 |
+
|
| 16 |
+
from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock
|
| 17 |
+
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
| 18 |
+
from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AggregatorStream(AggregatorBase):
|
| 24 |
+
"""
|
| 25 |
+
Streaming causal aggregator with FlashInfer paged KV cache.
|
| 26 |
+
|
| 27 |
+
Features:
|
| 28 |
+
- Temporal causal attention (each frame only attends to past frames)
|
| 29 |
+
- Sliding window support to limit attention scope
|
| 30 |
+
- Scale token for scale estimation frames
|
| 31 |
+
- Streaming inference with FlashInfer KV cache
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
# Causal-specific parameters
|
| 37 |
+
sliding_window_size: int = -1,
|
| 38 |
+
num_frame_for_scale: int = 1,
|
| 39 |
+
num_random_frames: int = 0,
|
| 40 |
+
attend_to_special_tokens: bool = False,
|
| 41 |
+
attend_to_scale_frames: bool = False,
|
| 42 |
+
enable_3d_rope: bool = False,
|
| 43 |
+
max_frame_num: int = 1024,
|
| 44 |
+
# KV cache parameters
|
| 45 |
+
kv_cache_sliding_window: int = 64,
|
| 46 |
+
kv_cache_scale_frames: int = 8,
|
| 47 |
+
kv_cache_cross_frame_special: bool = True,
|
| 48 |
+
kv_cache_include_scale_frames: bool = True,
|
| 49 |
+
kv_cache_camera_only: bool = False,
|
| 50 |
+
# Base class parameters via **kwargs
|
| 51 |
+
**kwargs
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Initialize AggregatorStream.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
| 58 |
+
num_frame_for_scale: Number of scale estimation frames
|
| 59 |
+
num_random_frames: Number of random frames for long-range dependencies
|
| 60 |
+
attend_to_special_tokens: Enable cross-frame special token attention
|
| 61 |
+
attend_to_scale_frames: Include scale frames in attention
|
| 62 |
+
enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache
|
| 63 |
+
max_frame_num: Maximum number of frames for 3D RoPE
|
| 64 |
+
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
| 65 |
+
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
| 66 |
+
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
| 67 |
+
kv_cache_include_scale_frames: Include scale frames in KV cache
|
| 68 |
+
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
| 69 |
+
**kwargs: Base class parameters
|
| 70 |
+
"""
|
| 71 |
+
self.sliding_window_size = sliding_window_size
|
| 72 |
+
self.num_frame_for_scale = num_frame_for_scale
|
| 73 |
+
self.num_random_frames = num_random_frames
|
| 74 |
+
self.attend_to_special_tokens = attend_to_special_tokens
|
| 75 |
+
self.attend_to_scale_frames = attend_to_scale_frames
|
| 76 |
+
self.enable_3d_rope = enable_3d_rope
|
| 77 |
+
self.max_frame_num = max_frame_num
|
| 78 |
+
# KV cache parameters
|
| 79 |
+
self.kv_cache_sliding_window = kv_cache_sliding_window
|
| 80 |
+
self.kv_cache_scale_frames = kv_cache_scale_frames
|
| 81 |
+
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
| 82 |
+
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
| 83 |
+
self.kv_cache_camera_only = kv_cache_camera_only
|
| 84 |
+
|
| 85 |
+
# Pop kwargs that are passed but not needed by base class
|
| 86 |
+
kwargs.pop('enable_stream_inference', None)
|
| 87 |
+
use_flashinfer = kwargs.pop('use_flashinfer', True)
|
| 88 |
+
kwargs.pop('use_flexflash', None)
|
| 89 |
+
use_sdpa = kwargs.pop('use_sdpa', False)
|
| 90 |
+
|
| 91 |
+
# Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache)
|
| 92 |
+
self.use_sdpa = use_sdpa
|
| 93 |
+
self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested
|
| 94 |
+
|
| 95 |
+
# Call parent __init__
|
| 96 |
+
super().__init__(**kwargs)
|
| 97 |
+
|
| 98 |
+
# Initialize KV cache
|
| 99 |
+
self._init_kv_cache()
|
| 100 |
+
|
| 101 |
+
# Initialize 3D RoPE if enabled
|
| 102 |
+
if self.enable_3d_rope:
|
| 103 |
+
self._init_3d_rope()
|
| 104 |
+
|
| 105 |
+
def _build_blocks(
|
| 106 |
+
self,
|
| 107 |
+
block_fn,
|
| 108 |
+
depth: int,
|
| 109 |
+
embed_dim: int,
|
| 110 |
+
num_heads: int,
|
| 111 |
+
mlp_ratio: float,
|
| 112 |
+
qkv_bias: bool,
|
| 113 |
+
proj_bias: bool,
|
| 114 |
+
ffn_bias: bool,
|
| 115 |
+
init_values: float,
|
| 116 |
+
qk_norm: bool,
|
| 117 |
+
):
|
| 118 |
+
"""Build frame and global blocks for streaming causal mode."""
|
| 119 |
+
block_params = dict(
|
| 120 |
+
dim=embed_dim,
|
| 121 |
+
num_heads=num_heads,
|
| 122 |
+
mlp_ratio=mlp_ratio,
|
| 123 |
+
qkv_bias=qkv_bias,
|
| 124 |
+
proj_bias=proj_bias,
|
| 125 |
+
ffn_bias=ffn_bias,
|
| 126 |
+
init_values=init_values,
|
| 127 |
+
qk_norm=qk_norm,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Frame blocks: Standard Block + RoPE
|
| 131 |
+
self.frame_blocks = nn.ModuleList([
|
| 132 |
+
block_fn(**block_params, rope=self.rope)
|
| 133 |
+
for _ in range(depth)
|
| 134 |
+
])
|
| 135 |
+
|
| 136 |
+
# Global blocks: FlashInferBlock (default) or SDPABlock (fallback)
|
| 137 |
+
GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock
|
| 138 |
+
self.global_blocks = nn.ModuleList([
|
| 139 |
+
GlobalBlockCls(
|
| 140 |
+
**block_params,
|
| 141 |
+
rope=self.rope if not self.disable_global_rope else None,
|
| 142 |
+
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
| 143 |
+
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
| 144 |
+
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
| 145 |
+
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
| 146 |
+
kv_cache_camera_only=self.kv_cache_camera_only,
|
| 147 |
+
)
|
| 148 |
+
for _ in range(depth)
|
| 149 |
+
])
|
| 150 |
+
|
| 151 |
+
def _setup_special_tokens(self):
|
| 152 |
+
"""Setup camera, register, and scale tokens for causal mode."""
|
| 153 |
+
# Camera token
|
| 154 |
+
self.camera_token = nn.Parameter(
|
| 155 |
+
torch.randn(1, 2, 1, self.embed_dim)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Register tokens
|
| 159 |
+
if self.num_register_tokens > 0:
|
| 160 |
+
self.register_token = nn.Parameter(
|
| 161 |
+
torch.randn(1, 2, self.num_register_tokens, self.embed_dim)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Scale token (causal mode specific)
|
| 165 |
+
self.scale_token = nn.Parameter(
|
| 166 |
+
torch.ones(1, 2, 1, self.embed_dim)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Initialize
|
| 170 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
| 171 |
+
if self.num_register_tokens > 0:
|
| 172 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
| 173 |
+
nn.init.normal_(self.scale_token, std=1e-6)
|
| 174 |
+
|
| 175 |
+
# Token indexing (includes scale token)
|
| 176 |
+
self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale
|
| 177 |
+
self.num_special_tokens = 1 + self.num_register_tokens + 1
|
| 178 |
+
|
| 179 |
+
def _init_kv_cache(self):
|
| 180 |
+
"""Initialize KV cache for streaming inference."""
|
| 181 |
+
self.kv_cache_manager = None # FlashInfer (lazy-initialized)
|
| 182 |
+
self.kv_cache = {} # Dict-based cache for SDPA
|
| 183 |
+
self.total_frames_processed = 0
|
| 184 |
+
self._cached_pos3d = None
|
| 185 |
+
|
| 186 |
+
if self.use_sdpa:
|
| 187 |
+
# Dict-based KV cache for SDPA
|
| 188 |
+
if hasattr(self, 'depth'):
|
| 189 |
+
for i in range(self.depth):
|
| 190 |
+
self.kv_cache[f"k_{i}"] = None
|
| 191 |
+
self.kv_cache[f"v_{i}"] = None
|
| 192 |
+
self.kv_cache[f"k_{i}_special"] = None
|
| 193 |
+
self.kv_cache[f"v_{i}_special"] = None
|
| 194 |
+
logger.info(f"SDPA KV cache initialized with {self.depth} blocks")
|
| 195 |
+
else:
|
| 196 |
+
logger.info("FlashInfer KV cache will be lazily initialized on first forward")
|
| 197 |
+
|
| 198 |
+
def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
|
| 199 |
+
"""Lazily initialize FlashInferKVCacheManager on first use.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
device: Device for cache tensors.
|
| 203 |
+
dtype: Data type for cache tensors.
|
| 204 |
+
tokens_per_frame: Actual number of tokens per frame (patches + specials).
|
| 205 |
+
If None, falls back to assuming square images of self.img_size.
|
| 206 |
+
"""
|
| 207 |
+
if self.kv_cache_manager is None:
|
| 208 |
+
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
|
| 209 |
+
num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L
|
| 210 |
+
head_dim = 64
|
| 211 |
+
if tokens_per_frame is None:
|
| 212 |
+
tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
|
| 213 |
+
# max_num_frames: scale + window + headroom
|
| 214 |
+
max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16
|
| 215 |
+
self.kv_cache_manager = FlashInferKVCacheManager(
|
| 216 |
+
num_blocks=self.depth,
|
| 217 |
+
max_num_frames=max_num_frames,
|
| 218 |
+
tokens_per_frame=tokens_per_frame,
|
| 219 |
+
num_heads=num_heads,
|
| 220 |
+
head_dim=head_dim,
|
| 221 |
+
dtype=dtype,
|
| 222 |
+
device=device,
|
| 223 |
+
num_special_tokens=self.num_special_tokens,
|
| 224 |
+
scale_frames=self.kv_cache_scale_frames,
|
| 225 |
+
sliding_window=self.kv_cache_sliding_window,
|
| 226 |
+
max_total_frames=self.max_frame_num + 100,
|
| 227 |
+
force_fp32=getattr(self, 'kv_cache_force_fp32', False),
|
| 228 |
+
fa3=getattr(self, 'kv_cache_fa3', False),
|
| 229 |
+
)
|
| 230 |
+
logger.info(
|
| 231 |
+
f"FlashInfer KV cache manager initialized: {self.depth} blocks, "
|
| 232 |
+
f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}"
|
| 233 |
+
)
|
| 234 |
+
return self.kv_cache_manager
|
| 235 |
+
|
| 236 |
+
def clean_kv_cache(self):
|
| 237 |
+
"""Clean KV cache (call this when starting a new sequence)."""
|
| 238 |
+
if self.kv_cache_manager is not None:
|
| 239 |
+
self.kv_cache_manager.reset()
|
| 240 |
+
if self.kv_cache:
|
| 241 |
+
for key in list(self.kv_cache.keys()):
|
| 242 |
+
if key == "_skip_append":
|
| 243 |
+
self.kv_cache[key] = False
|
| 244 |
+
else:
|
| 245 |
+
self.kv_cache[key] = None
|
| 246 |
+
self.total_frames_processed = 0
|
| 247 |
+
self._cached_pos3d = None
|
| 248 |
+
logger.info("KV cache cleaned")
|
| 249 |
+
|
| 250 |
+
def _init_3d_rope(self):
|
| 251 |
+
"""Initialize 3D RoPE for streaming inference."""
|
| 252 |
+
if not self.enable_3d_rope:
|
| 253 |
+
self.rope3d = None
|
| 254 |
+
return
|
| 255 |
+
|
| 256 |
+
num_heads = 16
|
| 257 |
+
head_dim = self.embed_dim // num_heads
|
| 258 |
+
|
| 259 |
+
self.rope3d = WanRotaryPosEmbed(
|
| 260 |
+
attention_head_dim=head_dim,
|
| 261 |
+
patch_size=(1, self.patch_size, self.patch_size),
|
| 262 |
+
max_seq_len=self.max_frame_num,
|
| 263 |
+
)
|
| 264 |
+
logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}")
|
| 265 |
+
|
| 266 |
+
def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end):
|
| 267 |
+
"""
|
| 268 |
+
Generate 3D RoPE positions for streaming mode with correct global frame indices.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
num_frames: Number of frames in current batch
|
| 272 |
+
H, W: Image height and width
|
| 273 |
+
device: Device to create positions on
|
| 274 |
+
f_start: Global start frame index
|
| 275 |
+
f_end: Global end frame index
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor
|
| 279 |
+
"""
|
| 280 |
+
if self.rope3d is None:
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
pph = H // self.patch_size
|
| 284 |
+
ppw = W // self.patch_size
|
| 285 |
+
|
| 286 |
+
pos3d = self.rope3d(
|
| 287 |
+
ppf=num_frames,
|
| 288 |
+
pph=pph,
|
| 289 |
+
ppw=ppw,
|
| 290 |
+
patch_start_idx=self.num_special_tokens,
|
| 291 |
+
device=device,
|
| 292 |
+
f_start=f_start,
|
| 293 |
+
f_end=f_end
|
| 294 |
+
)
|
| 295 |
+
return pos3d
|
| 296 |
+
|
| 297 |
+
def _prepare_special_tokens(
|
| 298 |
+
self,
|
| 299 |
+
B: int,
|
| 300 |
+
S_local: int,
|
| 301 |
+
S_global: int,
|
| 302 |
+
C: int,
|
| 303 |
+
num_frame_for_scale: Optional[int] = None,
|
| 304 |
+
) -> torch.Tensor:
|
| 305 |
+
"""
|
| 306 |
+
Prepare camera, register, and scale tokens.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
B: Batch size
|
| 310 |
+
S_local: Local sequence length
|
| 311 |
+
S_global: Global sequence length
|
| 312 |
+
C: Embedding dimension
|
| 313 |
+
num_frame_for_scale: Number of frames for scale estimation
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Special tokens [B*S_global, N_special, C]
|
| 317 |
+
"""
|
| 318 |
+
# Get effective num_frame_for_scale
|
| 319 |
+
scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale
|
| 320 |
+
|
| 321 |
+
# Check cache state for both backends
|
| 322 |
+
has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0
|
| 323 |
+
has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None
|
| 324 |
+
|
| 325 |
+
# Determine if we're in causal inference mode based on KV cache state
|
| 326 |
+
causal_inference = True
|
| 327 |
+
|
| 328 |
+
if causal_inference and has_flashinfer_cache:
|
| 329 |
+
S_cached = self.kv_cache_manager.num_frames
|
| 330 |
+
S_true = S_cached + S_global
|
| 331 |
+
elif causal_inference and has_sdpa_cache:
|
| 332 |
+
_, _, S_cached, _, _ = self.kv_cache["k_0"].shape
|
| 333 |
+
S_true = S_cached + S_global
|
| 334 |
+
else:
|
| 335 |
+
S_true = S_global
|
| 336 |
+
|
| 337 |
+
# Expand tokens based on mode
|
| 338 |
+
if causal_inference and S_true > S_global:
|
| 339 |
+
# Streaming mode: expand with S_true, then slice to get current frames
|
| 340 |
+
effective_scale_frames = min(scale_frames, S_true)
|
| 341 |
+
|
| 342 |
+
camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
|
| 343 |
+
camera_token = camera_token_full[-S_global:, :, :]
|
| 344 |
+
|
| 345 |
+
register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
|
| 346 |
+
register_token = register_token_full[-S_global:, :, :]
|
| 347 |
+
scale_token_full = slice_expand_and_flatten(
|
| 348 |
+
self.scale_token, B, S_true, first_num_frame=effective_scale_frames
|
| 349 |
+
)
|
| 350 |
+
scale_token = scale_token_full[-S_global:, :, :]
|
| 351 |
+
else:
|
| 352 |
+
# Batch mode or first inference: expand directly
|
| 353 |
+
effective_scale_frames = min(scale_frames, S_global)
|
| 354 |
+
|
| 355 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S_global)
|
| 356 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S_global)
|
| 357 |
+
scale_token = slice_expand_and_flatten(
|
| 358 |
+
self.scale_token, B, S_global, first_num_frame=effective_scale_frames
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1)
|
| 362 |
+
|
| 363 |
+
# Verify shape
|
| 364 |
+
expected_shape = (B * S_global, self.num_special_tokens, C)
|
| 365 |
+
assert special_tokens.shape == expected_shape, \
|
| 366 |
+
f"Expected {expected_shape}, got {special_tokens.shape}"
|
| 367 |
+
|
| 368 |
+
return special_tokens
|
| 369 |
+
|
| 370 |
+
def _process_global_attention(
|
| 371 |
+
self,
|
| 372 |
+
tokens: torch.Tensor,
|
| 373 |
+
B: int,
|
| 374 |
+
S_local: int,
|
| 375 |
+
S_global: int,
|
| 376 |
+
P: int,
|
| 377 |
+
C: int,
|
| 378 |
+
global_idx: int,
|
| 379 |
+
pos: Optional[torch.Tensor] = None,
|
| 380 |
+
# Mode-specific parameters
|
| 381 |
+
num_frame_for_scale: Optional[int] = None,
|
| 382 |
+
sliding_window_size: Optional[int] = None,
|
| 383 |
+
num_frame_per_block: int = 1,
|
| 384 |
+
**kwargs,
|
| 385 |
+
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
| 386 |
+
"""
|
| 387 |
+
Process causal global attention via FlashInfer streaming path.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
tokens: Input tokens
|
| 391 |
+
B: Batch size
|
| 392 |
+
S_local: Local sequence length
|
| 393 |
+
S_global: Global sequence length
|
| 394 |
+
P: Tokens per frame
|
| 395 |
+
C: Embedding dimension
|
| 396 |
+
global_idx: Current global block index
|
| 397 |
+
pos: Position embeddings
|
| 398 |
+
num_frame_for_scale: Number of frames for scale estimation
|
| 399 |
+
sliding_window_size: Sliding window size in blocks
|
| 400 |
+
num_frame_per_block: Number of frames per processing block
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
(tokens, global_idx, intermediates)
|
| 404 |
+
"""
|
| 405 |
+
# Extract image dimensions from kwargs for 3D RoPE
|
| 406 |
+
image_height = kwargs.get('image_height', self.img_size)
|
| 407 |
+
image_width = kwargs.get('image_width', self.img_size)
|
| 408 |
+
|
| 409 |
+
return self._process_causal_stream(
|
| 410 |
+
tokens, B, S_local, S_global, P, C, global_idx, pos,
|
| 411 |
+
num_frame_per_block, sliding_window_size, num_frame_for_scale,
|
| 412 |
+
image_height=image_height, image_width=image_width
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def _process_causal_stream(
|
| 416 |
+
self,
|
| 417 |
+
tokens: torch.Tensor,
|
| 418 |
+
B: int,
|
| 419 |
+
S_local: int,
|
| 420 |
+
S_global: int,
|
| 421 |
+
P: int,
|
| 422 |
+
C: int,
|
| 423 |
+
global_idx: int,
|
| 424 |
+
pos: Optional[torch.Tensor] = None,
|
| 425 |
+
num_frame_per_block: int = 1,
|
| 426 |
+
sliding_window_size: Optional[int] = None,
|
| 427 |
+
num_frame_for_scale: Optional[int] = None,
|
| 428 |
+
image_height: Optional[int] = None,
|
| 429 |
+
image_width: Optional[int] = None,
|
| 430 |
+
):
|
| 431 |
+
"""
|
| 432 |
+
Causal attention for streaming inference using FlashInfer KV cache.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
tokens: Input tokens [B*S_local, P, C]
|
| 436 |
+
B: Batch size
|
| 437 |
+
S_local: Local sequence length
|
| 438 |
+
S_global: Global sequence length
|
| 439 |
+
P: Number of patches per frame (includes special tokens)
|
| 440 |
+
C: Channel dimension
|
| 441 |
+
global_idx: Starting block index
|
| 442 |
+
pos: Position embeddings [B*S_global, P, 2]
|
| 443 |
+
num_frame_per_block: Number of frames per block
|
| 444 |
+
sliding_window_size: Sliding window size in blocks
|
| 445 |
+
num_frame_for_scale: Number of scale frames
|
| 446 |
+
image_height: Image height for 3D RoPE calculation
|
| 447 |
+
image_width: Image width for 3D RoPE calculation
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
(tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs
|
| 451 |
+
"""
|
| 452 |
+
# Get effective parameters
|
| 453 |
+
scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale
|
| 454 |
+
|
| 455 |
+
# Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C]
|
| 456 |
+
if tokens.shape != (B, S_local * P, C):
|
| 457 |
+
tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C)
|
| 458 |
+
|
| 459 |
+
# Calculate number of frames for block mask
|
| 460 |
+
num_frames = S_global
|
| 461 |
+
num_patches = P - self.num_special_tokens
|
| 462 |
+
|
| 463 |
+
# Check if this is the first block group
|
| 464 |
+
is_first_block_group = (global_idx < self.aa_block_size)
|
| 465 |
+
|
| 466 |
+
if self.enable_3d_rope and self.rope3d is not None:
|
| 467 |
+
if is_first_block_group:
|
| 468 |
+
f_start = self.total_frames_processed
|
| 469 |
+
f_end = self.total_frames_processed + S_global
|
| 470 |
+
|
| 471 |
+
H = image_height if image_height is not None else self.img_size
|
| 472 |
+
W = image_width if image_width is not None else self.img_size
|
| 473 |
+
pos3d = self._get_3d_positions_streaming(
|
| 474 |
+
S_global, H, W, tokens.device, f_start, f_end
|
| 475 |
+
)
|
| 476 |
+
self._cached_pos3d = pos3d
|
| 477 |
+
else:
|
| 478 |
+
pos3d = self._cached_pos3d
|
| 479 |
+
pos = pos3d
|
| 480 |
+
else:
|
| 481 |
+
# Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2]
|
| 482 |
+
if pos is not None and pos.shape != (B, S_global * P, 2):
|
| 483 |
+
pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2)
|
| 484 |
+
|
| 485 |
+
intermediates = []
|
| 486 |
+
|
| 487 |
+
# Process blocks with KV cache
|
| 488 |
+
for _ in range(self.aa_block_size):
|
| 489 |
+
num_patches = P - self.num_special_tokens
|
| 490 |
+
if self.use_sdpa:
|
| 491 |
+
# SDPA: dict-based KV cache
|
| 492 |
+
tokens = self.global_blocks[global_idx](
|
| 493 |
+
tokens,
|
| 494 |
+
pos=pos,
|
| 495 |
+
enable_ulysses_cp=False,
|
| 496 |
+
num_patches=num_patches,
|
| 497 |
+
num_special=self.num_special_tokens,
|
| 498 |
+
num_frames=num_frames,
|
| 499 |
+
enable_3d_rope=self.enable_3d_rope,
|
| 500 |
+
kv_cache=self.kv_cache,
|
| 501 |
+
global_idx=global_idx,
|
| 502 |
+
num_frame_per_block=num_frame_per_block,
|
| 503 |
+
num_frame_for_scale=scale_frames,
|
| 504 |
+
num_register_tokens=self.num_register_tokens,
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
# FlashInfer: paged KV cache manager
|
| 508 |
+
manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P)
|
| 509 |
+
tokens = self.global_blocks[global_idx](
|
| 510 |
+
tokens,
|
| 511 |
+
pos=pos,
|
| 512 |
+
enable_ulysses_cp=False,
|
| 513 |
+
num_patches=num_patches,
|
| 514 |
+
num_special=self.num_special_tokens,
|
| 515 |
+
num_frames=num_frames,
|
| 516 |
+
enable_3d_rope=self.enable_3d_rope,
|
| 517 |
+
kv_cache=manager,
|
| 518 |
+
global_idx=global_idx,
|
| 519 |
+
num_frame_per_block=num_frame_per_block,
|
| 520 |
+
num_frame_for_scale=scale_frames,
|
| 521 |
+
num_register_tokens=self.num_register_tokens,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
global_idx += 1
|
| 525 |
+
intermediates.append(tokens.view(B, S_local, P, C))
|
| 526 |
+
|
| 527 |
+
# Update total frames processed counter only on the first block group
|
| 528 |
+
if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)):
|
| 529 |
+
self.total_frames_processed += S_global
|
| 530 |
+
|
| 531 |
+
return tokens, global_idx, intermediates
|
lingbot_map/heads/__init__.py
ADDED
|
File without changes
|
lingbot_map/heads/camera_head.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from lingbot_map.layers import Mlp
|
| 15 |
+
from lingbot_map.layers.block import Block
|
| 16 |
+
from lingbot_map.layers.block import CameraBlock
|
| 17 |
+
from lingbot_map.heads.head_act import activate_pose
|
| 18 |
+
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
| 19 |
+
from functools import partial
|
| 20 |
+
from torch.utils.checkpoint import checkpoint
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CameraHead(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 26 |
+
|
| 27 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim_in: int = 2048,
|
| 33 |
+
trunk_depth: int = 4,
|
| 34 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
| 35 |
+
num_heads: int = 16,
|
| 36 |
+
mlp_ratio: int = 4,
|
| 37 |
+
init_values: float = 0.01,
|
| 38 |
+
trans_act: str = "linear",
|
| 39 |
+
quat_act: str = "linear",
|
| 40 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 41 |
+
enable_ulysses_cp=False,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 46 |
+
self.target_dim = 9
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 49 |
+
|
| 50 |
+
self.trans_act = trans_act
|
| 51 |
+
self.quat_act = quat_act
|
| 52 |
+
self.fl_act = fl_act
|
| 53 |
+
self.trunk_depth = trunk_depth
|
| 54 |
+
|
| 55 |
+
self.enable_ulysses_cp = enable_ulysses_cp
|
| 56 |
+
|
| 57 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 58 |
+
self.trunk = nn.Sequential(
|
| 59 |
+
*[
|
| 60 |
+
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
| 61 |
+
for _ in range(trunk_depth)
|
| 62 |
+
]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Normalizations for camera token and trunk output.
|
| 66 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 67 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 68 |
+
|
| 69 |
+
# Learnable empty camera pose token.
|
| 70 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 71 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 72 |
+
|
| 73 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 74 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 75 |
+
|
| 76 |
+
# Adaptive layer normalization without affine parameters.
|
| 77 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 78 |
+
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
| 79 |
+
|
| 80 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
|
| 81 |
+
"""
|
| 82 |
+
Forward pass to predict camera parameters.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 86 |
+
the last tensor is used for prediction.
|
| 87 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 91 |
+
"""
|
| 92 |
+
# Use tokens from the last block for camera prediction.
|
| 93 |
+
tokens = aggregated_tokens_list[-1]
|
| 94 |
+
|
| 95 |
+
# Extract the camera tokens
|
| 96 |
+
pose_tokens = tokens[:, :, 0]
|
| 97 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 98 |
+
|
| 99 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 100 |
+
return pred_pose_enc_list
|
| 101 |
+
|
| 102 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 103 |
+
"""
|
| 104 |
+
Iteratively refine camera pose predictions.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 108 |
+
num_iterations (int): Number of refinement iterations.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
list: List of activated camera encodings from each iteration.
|
| 112 |
+
"""
|
| 113 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 114 |
+
pred_pose_enc = None
|
| 115 |
+
pred_pose_enc_list = []
|
| 116 |
+
|
| 117 |
+
for _ in range(num_iterations):
|
| 118 |
+
# Use a learned empty pose for the first iteration.
|
| 119 |
+
if pred_pose_enc is None:
|
| 120 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 121 |
+
else:
|
| 122 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 123 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 124 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 125 |
+
|
| 126 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 127 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 128 |
+
|
| 129 |
+
# Adaptive layer normalization and modulation.
|
| 130 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 131 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 132 |
+
|
| 133 |
+
# Apply trunk blocks with enable_ulysses_cp
|
| 134 |
+
for block in self.trunk:
|
| 135 |
+
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
|
| 136 |
+
# Compute the delta update for the pose encoding.
|
| 137 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 138 |
+
|
| 139 |
+
if pred_pose_enc is None:
|
| 140 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 141 |
+
else:
|
| 142 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 143 |
+
|
| 144 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 145 |
+
activated_pose = activate_pose(
|
| 146 |
+
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
| 147 |
+
)
|
| 148 |
+
pred_pose_enc_list.append(activated_pose)
|
| 149 |
+
|
| 150 |
+
return pred_pose_enc_list
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 156 |
+
"""
|
| 157 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 158 |
+
return x * (1 + scale) + shift
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class CameraCausalHead(nn.Module):
|
| 162 |
+
"""
|
| 163 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 164 |
+
|
| 165 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
dim_in: int = 2048,
|
| 171 |
+
trunk_depth: int = 4,
|
| 172 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
| 173 |
+
num_heads: int = 16,
|
| 174 |
+
mlp_ratio: int = 4,
|
| 175 |
+
init_values: float = 0.01,
|
| 176 |
+
trans_act: str = "linear",
|
| 177 |
+
quat_act: str = "linear",
|
| 178 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 179 |
+
num_iterations = 4,
|
| 180 |
+
elementwise_attn_output_gate: bool = False,
|
| 181 |
+
sliding_window_size: int = -1,
|
| 182 |
+
attend_to_scale_frames: bool = False,
|
| 183 |
+
num_random_frames: int = 0,
|
| 184 |
+
enable_ulysses_cp: bool = False,
|
| 185 |
+
attn_class: str = "flexflashattn_varlen",
|
| 186 |
+
# KV cache parameters
|
| 187 |
+
kv_cache_sliding_window: int = 64,
|
| 188 |
+
kv_cache_scale_frames: int = 8,
|
| 189 |
+
kv_cache_cross_frame_special: bool = True,
|
| 190 |
+
kv_cache_include_scale_frames: bool = True,
|
| 191 |
+
kv_cache_camera_only: bool = False,
|
| 192 |
+
# 3D RoPE parameters
|
| 193 |
+
enable_3d_rope: bool = False,
|
| 194 |
+
max_frame_num: int = 1024,
|
| 195 |
+
rope_theta: float = 10000.0,
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
|
| 199 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 200 |
+
self.target_dim = 9
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 203 |
+
|
| 204 |
+
self.trans_act = trans_act
|
| 205 |
+
self.quat_act = quat_act
|
| 206 |
+
self.fl_act = fl_act
|
| 207 |
+
self.trunk_depth = trunk_depth
|
| 208 |
+
self.sliding_window_size = sliding_window_size
|
| 209 |
+
self.enable_ulysses_cp = enable_ulysses_cp
|
| 210 |
+
self.num_heads = num_heads
|
| 211 |
+
|
| 212 |
+
# 3D RoPE for temporal position encoding
|
| 213 |
+
self.enable_3d_rope = enable_3d_rope
|
| 214 |
+
if enable_3d_rope:
|
| 215 |
+
head_dim = dim_in // num_heads
|
| 216 |
+
# For camera head: each frame has 1 token (frame_seqlen=1)
|
| 217 |
+
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
|
| 218 |
+
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
|
| 219 |
+
self.rope3d = WanRotaryPosEmbed(
|
| 220 |
+
attention_head_dim=head_dim,
|
| 221 |
+
patch_size=(max_frame_num, 1, 1),
|
| 222 |
+
theta=rope_theta,
|
| 223 |
+
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
self.rope3d = None
|
| 227 |
+
|
| 228 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 229 |
+
self.trunk = nn.Sequential(
|
| 230 |
+
*[
|
| 231 |
+
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
|
| 232 |
+
for _ in range(trunk_depth)
|
| 233 |
+
]
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Normalizations for camera token and trunk output.
|
| 237 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 238 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 239 |
+
|
| 240 |
+
# Learnable empty camera pose token.
|
| 241 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 242 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 243 |
+
|
| 244 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 245 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 246 |
+
|
| 247 |
+
# Adaptive layer normalization without affine parameters.
|
| 248 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 249 |
+
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
| 250 |
+
|
| 251 |
+
self.num_iterations = num_iterations
|
| 252 |
+
|
| 253 |
+
self.kv_cache = None
|
| 254 |
+
self.pos_cache = None
|
| 255 |
+
self.frame_idx = 0
|
| 256 |
+
self.cp_size = 1
|
| 257 |
+
|
| 258 |
+
## Get cp size if enable ulysses cp
|
| 259 |
+
if self.enable_ulysses_cp:
|
| 260 |
+
from torchtitan.distributed.sequence_parallel import (
|
| 261 |
+
init_sequence_parallel,
|
| 262 |
+
get_ulysses_sequence_parallel_rank,
|
| 263 |
+
get_ulysses_sequence_parallel_world_size,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
self.cp_size = get_ulysses_sequence_parallel_world_size()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def clean_kv_cache(self):
|
| 271 |
+
del self.kv_cache
|
| 272 |
+
self.kv_cache = None
|
| 273 |
+
self.frame_idx = 0
|
| 274 |
+
|
| 275 |
+
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = None, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
| 276 |
+
"""
|
| 277 |
+
Forward pass to predict camera parameters.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 281 |
+
the last tensor is used for prediction.
|
| 282 |
+
num_iterations (int, optional): Number of iterative refinement steps.
|
| 283 |
+
If None, falls back to self.num_iterations (set at construction).
|
| 284 |
+
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
|
| 285 |
+
If None, use the default self.sliding_window_size.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 289 |
+
"""
|
| 290 |
+
if num_iterations is None:
|
| 291 |
+
num_iterations = self.num_iterations
|
| 292 |
+
|
| 293 |
+
# Use passed sliding_window_size if provided, otherwise use default
|
| 294 |
+
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
| 295 |
+
|
| 296 |
+
# Use tokens from the last block for camera prediction.
|
| 297 |
+
tokens = aggregated_tokens_list[-1]
|
| 298 |
+
|
| 299 |
+
# Extract the camera tokens
|
| 300 |
+
pose_tokens = tokens[:, :, 0]
|
| 301 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 302 |
+
|
| 303 |
+
if causal_inference:
|
| 304 |
+
if self.kv_cache is None:
|
| 305 |
+
self.kv_cache = []
|
| 306 |
+
for i in range(num_iterations):
|
| 307 |
+
self.kv_cache.append({"_skip_append": False})
|
| 308 |
+
for j in range(self.trunk_depth):
|
| 309 |
+
self.kv_cache[i][f"k_{j}"] = None
|
| 310 |
+
self.kv_cache[i][f"v_{j}"] = None
|
| 311 |
+
|
| 312 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
|
| 313 |
+
return pred_pose_enc_list
|
| 314 |
+
|
| 315 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
|
| 316 |
+
"""
|
| 317 |
+
Iteratively refine camera pose predictions.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
|
| 321 |
+
num_iterations (int): Number of refinement iterations.
|
| 322 |
+
sliding_window_size (int, optional): Sliding window size to use.
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
list: List of activated camera encodings from each iteration.
|
| 326 |
+
"""
|
| 327 |
+
B, S, C = pose_tokens.shape
|
| 328 |
+
pred_pose_enc = None
|
| 329 |
+
pred_pose_enc_list = []
|
| 330 |
+
|
| 331 |
+
# Check if this is the first call (processing scale frames)
|
| 332 |
+
# Scale frames should use batch mode attention for numerical consistency
|
| 333 |
+
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
|
| 334 |
+
|
| 335 |
+
# Generate 3D RoPE positions if enabled
|
| 336 |
+
pos3d = None
|
| 337 |
+
if self.rope3d is not None:
|
| 338 |
+
# For camera tokens: shape [B, S, C] where each frame has 1 token
|
| 339 |
+
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
|
| 340 |
+
|
| 341 |
+
# In streaming mode with KV cache, use frame_idx to track global position
|
| 342 |
+
# Otherwise, generate positions from 0
|
| 343 |
+
if self.kv_cache is not None:
|
| 344 |
+
f_start = self.frame_idx
|
| 345 |
+
f_end = self.frame_idx + S
|
| 346 |
+
else:
|
| 347 |
+
f_start = 0
|
| 348 |
+
f_end = None # Will use ppf as frame count
|
| 349 |
+
|
| 350 |
+
pos3d = self.rope3d(
|
| 351 |
+
ppf=S * self.cp_size, # Total frames (with CP)
|
| 352 |
+
pph=1, # height = 1 (camera token)
|
| 353 |
+
ppw=1, # width = 1 (camera token)
|
| 354 |
+
patch_start_idx=0, # No special tokens before
|
| 355 |
+
device=pose_tokens.device,
|
| 356 |
+
f_start=f_start,
|
| 357 |
+
f_end=f_end,
|
| 358 |
+
) # Returns [1, 1, S*cp_size, head_dim//2] complex
|
| 359 |
+
|
| 360 |
+
for i in range(num_iterations):
|
| 361 |
+
# Use a learned empty pose for the first iteration.
|
| 362 |
+
if pred_pose_enc is None:
|
| 363 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 364 |
+
else:
|
| 365 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 366 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 367 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 368 |
+
|
| 369 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 370 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 371 |
+
|
| 372 |
+
# Adaptive layer normalization and modulation.
|
| 373 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 374 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 375 |
+
|
| 376 |
+
for idx in range(self.trunk_depth):
|
| 377 |
+
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
|
| 378 |
+
# Compute the delta update for the pose encoding.
|
| 379 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 380 |
+
|
| 381 |
+
if pred_pose_enc is None:
|
| 382 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 383 |
+
else:
|
| 384 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 385 |
+
|
| 386 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 387 |
+
activated_pose = activate_pose(
|
| 388 |
+
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
| 389 |
+
)
|
| 390 |
+
pred_pose_enc_list.append(activated_pose)
|
| 391 |
+
|
| 392 |
+
# Update frame_idx for streaming mode (KV cache)
|
| 393 |
+
if self.kv_cache is not None:
|
| 394 |
+
self.frame_idx += S
|
| 395 |
+
|
| 396 |
+
return pred_pose_enc_list
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 400 |
+
"""
|
| 401 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 402 |
+
"""
|
| 403 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 404 |
+
return x * (1 + scale) + shift
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class CameraDecoder(nn.Module):
|
| 410 |
+
def __init__(
|
| 411 |
+
self,
|
| 412 |
+
in_dim,
|
| 413 |
+
out_dim,
|
| 414 |
+
dec_embed_dim=512,
|
| 415 |
+
depth=5,
|
| 416 |
+
dec_num_heads=8,
|
| 417 |
+
mlp_ratio=4,
|
| 418 |
+
rope=None,
|
| 419 |
+
need_project=True,
|
| 420 |
+
use_checkpoint=False,
|
| 421 |
+
):
|
| 422 |
+
super().__init__()
|
| 423 |
+
|
| 424 |
+
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
|
| 425 |
+
self.use_checkpoint = use_checkpoint
|
| 426 |
+
|
| 427 |
+
self.blocks = nn.ModuleList([
|
| 428 |
+
Block(
|
| 429 |
+
dim=dec_embed_dim,
|
| 430 |
+
num_heads=dec_num_heads,
|
| 431 |
+
mlp_ratio=mlp_ratio,
|
| 432 |
+
qkv_bias=True,
|
| 433 |
+
proj_bias=True,
|
| 434 |
+
ffn_bias=True,
|
| 435 |
+
drop_path=0.0,
|
| 436 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 437 |
+
act_layer=nn.GELU,
|
| 438 |
+
ffn_layer=Mlp,
|
| 439 |
+
init_values=None,
|
| 440 |
+
qk_norm=False,
|
| 441 |
+
# attn_class=MemEffAttentionRope,
|
| 442 |
+
rope=rope
|
| 443 |
+
) for _ in range(depth)])
|
| 444 |
+
|
| 445 |
+
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
| 446 |
+
|
| 447 |
+
def forward(self, hidden, xpos=None):
|
| 448 |
+
hidden = self.projects(hidden)
|
| 449 |
+
B, V, P, C = hidden.shape
|
| 450 |
+
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
|
| 451 |
+
for i, blk in enumerate(self.blocks):
|
| 452 |
+
if self.use_checkpoint and self.training:
|
| 453 |
+
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
|
| 454 |
+
else:
|
| 455 |
+
hidden = blk(hidden, pos=xpos)
|
| 456 |
+
out = self.linear_out(hidden).view(B, V, P, -1)
|
| 457 |
+
|
| 458 |
+
return out
|
lingbot_map/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from .head_act import activate_head
|
| 18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DPTHead(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
DPT Head for dense prediction tasks.
|
| 24 |
+
|
| 25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim_in (int): Input dimension (channels).
|
| 31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim_in: int,
|
| 46 |
+
patch_size: int = 14,
|
| 47 |
+
output_dim: int = 4,
|
| 48 |
+
activation: str = "inv_log",
|
| 49 |
+
conf_activation: str = "expp1",
|
| 50 |
+
features: int = 256,
|
| 51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 52 |
+
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
|
| 53 |
+
pos_embed: bool = True,
|
| 54 |
+
feature_only: bool = False,
|
| 55 |
+
down_ratio: int = 1,
|
| 56 |
+
) -> None:
|
| 57 |
+
super(DPTHead, self).__init__()
|
| 58 |
+
self.patch_size = patch_size
|
| 59 |
+
self.activation = activation
|
| 60 |
+
self.conf_activation = conf_activation
|
| 61 |
+
self.pos_embed = pos_embed
|
| 62 |
+
self.feature_only = feature_only
|
| 63 |
+
self.down_ratio = down_ratio
|
| 64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 65 |
+
|
| 66 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 67 |
+
|
| 68 |
+
# Projection layers for each output channel from tokens.
|
| 69 |
+
self.projects = nn.ModuleList(
|
| 70 |
+
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Resize layers for upsampling feature maps.
|
| 74 |
+
self.resize_layers = nn.ModuleList(
|
| 75 |
+
[
|
| 76 |
+
nn.ConvTranspose2d(
|
| 77 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 78 |
+
),
|
| 79 |
+
nn.ConvTranspose2d(
|
| 80 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 81 |
+
),
|
| 82 |
+
nn.Identity(),
|
| 83 |
+
nn.Conv2d(
|
| 84 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 85 |
+
),
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 90 |
+
|
| 91 |
+
# Attach additional modules to scratch.
|
| 92 |
+
self.scratch.stem_transpose = None
|
| 93 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 94 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 95 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 96 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 97 |
+
|
| 98 |
+
head_features_1 = features
|
| 99 |
+
head_features_2 = 32
|
| 100 |
+
|
| 101 |
+
if feature_only:
|
| 102 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
| 103 |
+
else:
|
| 104 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 105 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 106 |
+
)
|
| 107 |
+
conv2_in_channels = head_features_1 // 2
|
| 108 |
+
|
| 109 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 110 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 111 |
+
nn.ReLU(inplace=True),
|
| 112 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(
|
| 116 |
+
self,
|
| 117 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 118 |
+
images: torch.Tensor,
|
| 119 |
+
patch_start_idx: int,
|
| 120 |
+
frames_chunk_size: int = 8,
|
| 121 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
+
"""
|
| 123 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 124 |
+
Args:
|
| 125 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 126 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 127 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 128 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 129 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 130 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 134 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 135 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 136 |
+
"""
|
| 137 |
+
B, _, _, H, W = images.shape
|
| 138 |
+
|
| 139 |
+
S = aggregated_tokens_list[0].shape[1]
|
| 140 |
+
|
| 141 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 142 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 143 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 144 |
+
|
| 145 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 146 |
+
assert frames_chunk_size > 0
|
| 147 |
+
|
| 148 |
+
# Process frames in batches
|
| 149 |
+
all_preds = []
|
| 150 |
+
all_conf = []
|
| 151 |
+
|
| 152 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 153 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 154 |
+
|
| 155 |
+
# Process batch of frames
|
| 156 |
+
if self.feature_only:
|
| 157 |
+
chunk_output = self._forward_impl(
|
| 158 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 159 |
+
)
|
| 160 |
+
all_preds.append(chunk_output)
|
| 161 |
+
else:
|
| 162 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 163 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 164 |
+
)
|
| 165 |
+
all_preds.append(chunk_preds)
|
| 166 |
+
all_conf.append(chunk_conf)
|
| 167 |
+
|
| 168 |
+
# Concatenate results along the sequence dimension
|
| 169 |
+
if self.feature_only:
|
| 170 |
+
return torch.cat(all_preds, dim=1)
|
| 171 |
+
else:
|
| 172 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 173 |
+
|
| 174 |
+
def _forward_impl(
|
| 175 |
+
self,
|
| 176 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 177 |
+
images: torch.Tensor,
|
| 178 |
+
patch_start_idx: int,
|
| 179 |
+
frames_start_idx: int = None,
|
| 180 |
+
frames_end_idx: int = None,
|
| 181 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 182 |
+
"""
|
| 183 |
+
Implementation of the forward pass through the DPT head.
|
| 184 |
+
|
| 185 |
+
This method processes a specific chunk of frames from the sequence.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 189 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 190 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 191 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 192 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
B, _, _, H, W = images.shape
|
| 199 |
+
|
| 200 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 201 |
+
|
| 202 |
+
out = []
|
| 203 |
+
dpt_idx = 0
|
| 204 |
+
|
| 205 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 206 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 211 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 212 |
+
|
| 213 |
+
B, S = x.shape[0], x.shape[1]
|
| 214 |
+
|
| 215 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
| 216 |
+
|
| 217 |
+
x = self.norm(x)
|
| 218 |
+
|
| 219 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 220 |
+
|
| 221 |
+
x = self.projects[dpt_idx](x)
|
| 222 |
+
if self.pos_embed:
|
| 223 |
+
x = self._apply_pos_embed(x, W, H)
|
| 224 |
+
x = self.resize_layers[dpt_idx](x)
|
| 225 |
+
|
| 226 |
+
out.append(x)
|
| 227 |
+
dpt_idx += 1
|
| 228 |
+
|
| 229 |
+
# Fuse features from multiple layers.
|
| 230 |
+
out = self.scratch_forward(out)
|
| 231 |
+
# Interpolate fused output to match target image resolution.
|
| 232 |
+
out = custom_interpolate(
|
| 233 |
+
out,
|
| 234 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
| 235 |
+
mode="bilinear",
|
| 236 |
+
align_corners=True,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
if self.pos_embed:
|
| 240 |
+
out = self._apply_pos_embed(out, W, H)
|
| 241 |
+
|
| 242 |
+
if self.feature_only:
|
| 243 |
+
return out.view(B, S, *out.shape[1:])
|
| 244 |
+
|
| 245 |
+
out = self.scratch.output_conv2(out)
|
| 246 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
| 247 |
+
|
| 248 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 249 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 250 |
+
return preds, conf
|
| 251 |
+
|
| 252 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 253 |
+
"""
|
| 254 |
+
Apply positional embedding to tensor x.
|
| 255 |
+
"""
|
| 256 |
+
patch_w = x.shape[-1]
|
| 257 |
+
patch_h = x.shape[-2]
|
| 258 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 259 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 260 |
+
pos_embed = pos_embed * ratio
|
| 261 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 262 |
+
return x + pos_embed
|
| 263 |
+
|
| 264 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 265 |
+
"""
|
| 266 |
+
Forward pass through the fusion blocks.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Tensor: Fused feature map.
|
| 273 |
+
"""
|
| 274 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 275 |
+
|
| 276 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 277 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 278 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 279 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 280 |
+
|
| 281 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 282 |
+
del layer_4_rn, layer_4
|
| 283 |
+
|
| 284 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 285 |
+
del layer_3_rn, layer_3
|
| 286 |
+
|
| 287 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 288 |
+
del layer_2_rn, layer_2
|
| 289 |
+
|
| 290 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 291 |
+
del layer_1_rn, layer_1
|
| 292 |
+
|
| 293 |
+
out = self.scratch.output_conv1(out)
|
| 294 |
+
return out
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
################################################################################
|
| 298 |
+
# Modules
|
| 299 |
+
################################################################################
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 303 |
+
return FeatureFusionBlock(
|
| 304 |
+
features,
|
| 305 |
+
nn.ReLU(inplace=True),
|
| 306 |
+
deconv=False,
|
| 307 |
+
bn=False,
|
| 308 |
+
expand=False,
|
| 309 |
+
align_corners=True,
|
| 310 |
+
size=size,
|
| 311 |
+
has_residual=has_residual,
|
| 312 |
+
groups=groups,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 317 |
+
scratch = nn.Module()
|
| 318 |
+
out_shape1 = out_shape
|
| 319 |
+
out_shape2 = out_shape
|
| 320 |
+
out_shape3 = out_shape
|
| 321 |
+
if len(in_shape) >= 4:
|
| 322 |
+
out_shape4 = out_shape
|
| 323 |
+
|
| 324 |
+
if expand:
|
| 325 |
+
out_shape1 = out_shape
|
| 326 |
+
out_shape2 = out_shape * 2
|
| 327 |
+
out_shape3 = out_shape * 4
|
| 328 |
+
if len(in_shape) >= 4:
|
| 329 |
+
out_shape4 = out_shape * 8
|
| 330 |
+
|
| 331 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 332 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 333 |
+
)
|
| 334 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 335 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 336 |
+
)
|
| 337 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 338 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 339 |
+
)
|
| 340 |
+
if len(in_shape) >= 4:
|
| 341 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 342 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 343 |
+
)
|
| 344 |
+
return scratch
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class ResidualConvUnit(nn.Module):
|
| 348 |
+
"""Residual convolution module."""
|
| 349 |
+
|
| 350 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 351 |
+
"""Init.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
features (int): number of features
|
| 355 |
+
"""
|
| 356 |
+
super().__init__()
|
| 357 |
+
|
| 358 |
+
self.bn = bn
|
| 359 |
+
self.groups = groups
|
| 360 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 361 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 362 |
+
|
| 363 |
+
self.norm1 = None
|
| 364 |
+
self.norm2 = None
|
| 365 |
+
|
| 366 |
+
self.activation = activation
|
| 367 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 368 |
+
|
| 369 |
+
def forward(self, x):
|
| 370 |
+
"""Forward pass.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
x (tensor): input
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
tensor: output
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
out = self.activation(x)
|
| 380 |
+
out = self.conv1(out)
|
| 381 |
+
if self.norm1 is not None:
|
| 382 |
+
out = self.norm1(out)
|
| 383 |
+
|
| 384 |
+
out = self.activation(out)
|
| 385 |
+
out = self.conv2(out)
|
| 386 |
+
if self.norm2 is not None:
|
| 387 |
+
out = self.norm2(out)
|
| 388 |
+
|
| 389 |
+
return self.skip_add.add(out, x)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class FeatureFusionBlock(nn.Module):
|
| 393 |
+
"""Feature fusion block."""
|
| 394 |
+
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
features,
|
| 398 |
+
activation,
|
| 399 |
+
deconv=False,
|
| 400 |
+
bn=False,
|
| 401 |
+
expand=False,
|
| 402 |
+
align_corners=True,
|
| 403 |
+
size=None,
|
| 404 |
+
has_residual=True,
|
| 405 |
+
groups=1,
|
| 406 |
+
):
|
| 407 |
+
"""Init.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
features (int): number of features
|
| 411 |
+
"""
|
| 412 |
+
super(FeatureFusionBlock, self).__init__()
|
| 413 |
+
|
| 414 |
+
self.deconv = deconv
|
| 415 |
+
self.align_corners = align_corners
|
| 416 |
+
self.groups = groups
|
| 417 |
+
self.expand = expand
|
| 418 |
+
out_features = features
|
| 419 |
+
if self.expand == True:
|
| 420 |
+
out_features = features // 2
|
| 421 |
+
|
| 422 |
+
self.out_conv = nn.Conv2d(
|
| 423 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
if has_residual:
|
| 427 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 428 |
+
|
| 429 |
+
self.has_residual = has_residual
|
| 430 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 431 |
+
|
| 432 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 433 |
+
self.size = size
|
| 434 |
+
|
| 435 |
+
def forward(self, *xs, size=None):
|
| 436 |
+
"""Forward pass.
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
tensor: output
|
| 440 |
+
"""
|
| 441 |
+
output = xs[0]
|
| 442 |
+
|
| 443 |
+
if self.has_residual:
|
| 444 |
+
res = self.resConfUnit1(xs[1])
|
| 445 |
+
output = self.skip_add.add(output, res)
|
| 446 |
+
|
| 447 |
+
output = self.resConfUnit2(output)
|
| 448 |
+
|
| 449 |
+
if (size is None) and (self.size is None):
|
| 450 |
+
modifier = {"scale_factor": 2}
|
| 451 |
+
elif size is None:
|
| 452 |
+
modifier = {"size": self.size}
|
| 453 |
+
else:
|
| 454 |
+
modifier = {"size": size}
|
| 455 |
+
|
| 456 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 457 |
+
output = self.out_conv(output)
|
| 458 |
+
|
| 459 |
+
return output
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def custom_interpolate(
|
| 463 |
+
x: torch.Tensor,
|
| 464 |
+
size: Tuple[int, int] = None,
|
| 465 |
+
scale_factor: float = None,
|
| 466 |
+
mode: str = "bilinear",
|
| 467 |
+
align_corners: bool = True,
|
| 468 |
+
) -> torch.Tensor:
|
| 469 |
+
"""
|
| 470 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 471 |
+
"""
|
| 472 |
+
if size is None:
|
| 473 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 474 |
+
|
| 475 |
+
INT_MAX = 1610612736
|
| 476 |
+
|
| 477 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 478 |
+
|
| 479 |
+
if input_elements > INT_MAX:
|
| 480 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 481 |
+
interpolated_chunks = [
|
| 482 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 483 |
+
]
|
| 484 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 485 |
+
return x.contiguous()
|
| 486 |
+
else:
|
| 487 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
| 488 |
+
|
| 489 |
+
class DPTHead_Update(nn.Module):
|
| 490 |
+
def __init__(
|
| 491 |
+
self,
|
| 492 |
+
in_channels,
|
| 493 |
+
features=256,
|
| 494 |
+
use_bn=False,
|
| 495 |
+
out_channels=[256, 512, 1024, 1024],
|
| 496 |
+
use_clstoken=False
|
| 497 |
+
):
|
| 498 |
+
super(DPTHead_Update, self).__init__()
|
| 499 |
+
|
| 500 |
+
self.use_clstoken = use_clstoken
|
| 501 |
+
|
| 502 |
+
self.projects = nn.ModuleList([
|
| 503 |
+
nn.Conv2d(
|
| 504 |
+
in_channels=in_channels,
|
| 505 |
+
out_channels=out_channel,
|
| 506 |
+
kernel_size=1,
|
| 507 |
+
stride=1,
|
| 508 |
+
padding=0,
|
| 509 |
+
) for out_channel in out_channels
|
| 510 |
+
])
|
| 511 |
+
|
| 512 |
+
self.resize_layers = nn.ModuleList([
|
| 513 |
+
nn.ConvTranspose2d(
|
| 514 |
+
in_channels=out_channels[0],
|
| 515 |
+
out_channels=out_channels[0],
|
| 516 |
+
kernel_size=4,
|
| 517 |
+
stride=4,
|
| 518 |
+
padding=0),
|
| 519 |
+
nn.ConvTranspose2d(
|
| 520 |
+
in_channels=out_channels[1],
|
| 521 |
+
out_channels=out_channels[1],
|
| 522 |
+
kernel_size=2,
|
| 523 |
+
stride=2,
|
| 524 |
+
padding=0),
|
| 525 |
+
nn.Identity(),
|
| 526 |
+
nn.Conv2d(
|
| 527 |
+
in_channels=out_channels[3],
|
| 528 |
+
out_channels=out_channels[3],
|
| 529 |
+
kernel_size=3,
|
| 530 |
+
stride=2,
|
| 531 |
+
padding=1)
|
| 532 |
+
])
|
| 533 |
+
|
| 534 |
+
if use_clstoken:
|
| 535 |
+
self.readout_projects = nn.ModuleList()
|
| 536 |
+
for _ in range(len(self.projects)):
|
| 537 |
+
self.readout_projects.append(
|
| 538 |
+
nn.Sequential(
|
| 539 |
+
nn.Linear(2 * in_channels, in_channels),
|
| 540 |
+
nn.GELU()))
|
| 541 |
+
|
| 542 |
+
self.scratch = _make_scratch(
|
| 543 |
+
out_channels,
|
| 544 |
+
features,
|
| 545 |
+
groups=1,
|
| 546 |
+
expand=False,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
self.scratch.stem_transpose = None
|
| 550 |
+
|
| 551 |
+
self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn)
|
| 552 |
+
self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn)
|
| 553 |
+
self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn)
|
| 554 |
+
self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn)
|
| 555 |
+
|
| 556 |
+
head_features_1 = features
|
| 557 |
+
head_features_2 = 32
|
| 558 |
+
|
| 559 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
| 560 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 561 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 562 |
+
nn.ReLU(True),
|
| 563 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
| 564 |
+
nn.ReLU(True),
|
| 565 |
+
nn.Identity(),
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
|
| 569 |
+
out = []
|
| 570 |
+
for i, x in enumerate(out_features):
|
| 571 |
+
if self.use_clstoken:
|
| 572 |
+
x, cls_token = x[0], x[1]
|
| 573 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 574 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 575 |
+
|
| 576 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 577 |
+
|
| 578 |
+
x = self.projects[i](x)
|
| 579 |
+
x = self.resize_layers[i](x)
|
| 580 |
+
|
| 581 |
+
out.append(x)
|
| 582 |
+
|
| 583 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
| 584 |
+
|
| 585 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 586 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 587 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 588 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 589 |
+
|
| 590 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 591 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 592 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 593 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 594 |
+
out = self.scratch.output_conv1(path_1)
|
| 595 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
| 596 |
+
if return_intermediate:
|
| 597 |
+
return out, path_1, path_2, path_3, path_4
|
| 598 |
+
else:
|
| 599 |
+
out = self.scratch.output_conv2(out)
|
| 600 |
+
return out
|
| 601 |
+
|
| 602 |
+
def _make_fusion_block_slam(features, use_bn, size=None):
|
| 603 |
+
return FeatureFusionBlock_slam(
|
| 604 |
+
features,
|
| 605 |
+
nn.ReLU(False),
|
| 606 |
+
deconv=False,
|
| 607 |
+
bn=use_bn,
|
| 608 |
+
expand=False,
|
| 609 |
+
align_corners=True,
|
| 610 |
+
size=size,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class FeatureFusionBlock_slam(nn.Module):
|
| 615 |
+
"""Feature fusion block.
|
| 616 |
+
"""
|
| 617 |
+
|
| 618 |
+
def __init__(
|
| 619 |
+
self,
|
| 620 |
+
features,
|
| 621 |
+
activation,
|
| 622 |
+
deconv=False,
|
| 623 |
+
bn=False,
|
| 624 |
+
expand=False,
|
| 625 |
+
align_corners=True,
|
| 626 |
+
size=None
|
| 627 |
+
):
|
| 628 |
+
"""Init.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
features (int): number of features
|
| 632 |
+
"""
|
| 633 |
+
super(FeatureFusionBlock_slam, self).__init__()
|
| 634 |
+
|
| 635 |
+
self.deconv = deconv
|
| 636 |
+
self.align_corners = align_corners
|
| 637 |
+
|
| 638 |
+
self.groups=1
|
| 639 |
+
|
| 640 |
+
self.expand = expand
|
| 641 |
+
out_features = features
|
| 642 |
+
if self.expand == True:
|
| 643 |
+
out_features = features // 2
|
| 644 |
+
|
| 645 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
| 646 |
+
|
| 647 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
| 648 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
| 649 |
+
|
| 650 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 651 |
+
|
| 652 |
+
self.size=size
|
| 653 |
+
|
| 654 |
+
def forward(self, *xs, size=None):
|
| 655 |
+
"""Forward pass.
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
tensor: output
|
| 659 |
+
"""
|
| 660 |
+
output = xs[0]
|
| 661 |
+
|
| 662 |
+
if len(xs) == 2:
|
| 663 |
+
res = self.resConfUnit1(xs[1])
|
| 664 |
+
output = self.skip_add.add(output, res)
|
| 665 |
+
|
| 666 |
+
output = self.resConfUnit2(output)
|
| 667 |
+
|
| 668 |
+
if (size is None) and (self.size is None):
|
| 669 |
+
modifier = {"scale_factor": 2}
|
| 670 |
+
elif size is None:
|
| 671 |
+
modifier = {"size": self.size}
|
| 672 |
+
else:
|
| 673 |
+
modifier = {"size": size}
|
| 674 |
+
|
| 675 |
+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 676 |
+
|
| 677 |
+
output = self.out_conv(output)
|
| 678 |
+
|
| 679 |
+
return output
|
lingbot_map/heads/head_act.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
| 13 |
+
"""
|
| 14 |
+
Activate pose parameters with specified activation functions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 18 |
+
trans_act: Activation type for translation component
|
| 19 |
+
quat_act: Activation type for quaternion component
|
| 20 |
+
fl_act: Activation type for focal length component
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Activated pose parameters tensor
|
| 24 |
+
"""
|
| 25 |
+
T = pred_pose_enc[..., :3]
|
| 26 |
+
quat = pred_pose_enc[..., 3:7]
|
| 27 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
| 28 |
+
|
| 29 |
+
T = base_pose_act(T, trans_act)
|
| 30 |
+
quat = base_pose_act(quat, quat_act)
|
| 31 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
| 32 |
+
|
| 33 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 34 |
+
|
| 35 |
+
return pred_pose_enc
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
| 39 |
+
"""
|
| 40 |
+
Apply basic activation function to pose parameters.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
pose_enc: Tensor containing encoded pose parameters
|
| 44 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Activated pose parameters
|
| 48 |
+
"""
|
| 49 |
+
if act_type == "linear":
|
| 50 |
+
return pose_enc
|
| 51 |
+
elif act_type == "inv_log":
|
| 52 |
+
return inverse_log_transform(pose_enc)
|
| 53 |
+
elif act_type == "exp":
|
| 54 |
+
return torch.exp(pose_enc)
|
| 55 |
+
elif act_type == "relu":
|
| 56 |
+
return F.relu(pose_enc)
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 62 |
+
"""
|
| 63 |
+
Process network output to extract 3D points and confidence values.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
out: Network output tensor (B, C, H, W)
|
| 67 |
+
activation: Activation type for 3D points
|
| 68 |
+
conf_activation: Activation type for confidence values
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 72 |
+
"""
|
| 73 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 74 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 75 |
+
|
| 76 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 77 |
+
xyz = fmap[:, :, :, :-1]
|
| 78 |
+
conf = fmap[:, :, :, -1]
|
| 79 |
+
|
| 80 |
+
if activation == "norm_exp":
|
| 81 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 82 |
+
xyz_normed = xyz / d
|
| 83 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 84 |
+
elif activation == "norm":
|
| 85 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 86 |
+
elif activation == "exp":
|
| 87 |
+
pts3d = torch.exp(xyz)
|
| 88 |
+
elif activation == "relu":
|
| 89 |
+
pts3d = F.relu(xyz)
|
| 90 |
+
elif activation == "inv_log":
|
| 91 |
+
pts3d = inverse_log_transform(xyz)
|
| 92 |
+
elif activation == "xy_inv_log":
|
| 93 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
| 94 |
+
z = inverse_log_transform(z)
|
| 95 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 96 |
+
elif activation == "sigmoid":
|
| 97 |
+
pts3d = torch.sigmoid(xyz)
|
| 98 |
+
elif activation == "linear":
|
| 99 |
+
pts3d = xyz
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 102 |
+
|
| 103 |
+
if conf_activation == "expp1":
|
| 104 |
+
conf_out = 1 + conf.exp()
|
| 105 |
+
elif conf_activation == "expp0":
|
| 106 |
+
conf_out = conf.exp()
|
| 107 |
+
elif conf_activation == "sigmoid":
|
| 108 |
+
conf_out = torch.sigmoid(conf)
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 111 |
+
|
| 112 |
+
return pts3d, conf_out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def inverse_log_transform(y):
|
| 116 |
+
"""
|
| 117 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
y: Input tensor
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Transformed tensor
|
| 124 |
+
"""
|
| 125 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
lingbot_map/heads/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 17 |
+
embed_dim: Output channel dimension for embeddings
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 21 |
+
"""
|
| 22 |
+
H, W, grid_dim = pos_grid.shape
|
| 23 |
+
assert grid_dim == 2
|
| 24 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 25 |
+
|
| 26 |
+
# Process x and y coordinates separately
|
| 27 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 28 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 29 |
+
|
| 30 |
+
# Combine and reshape
|
| 31 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 32 |
+
|
| 33 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
- embed_dim: The embedding dimension.
|
| 42 |
+
- pos: The position to generate the embedding from.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
- emb: The generated 1D positional embedding.
|
| 46 |
+
"""
|
| 47 |
+
assert embed_dim % 2 == 0
|
| 48 |
+
device = pos.device
|
| 49 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
| 50 |
+
omega /= embed_dim / 2.0
|
| 51 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 52 |
+
|
| 53 |
+
pos = pos.reshape(-1) # (M,)
|
| 54 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 55 |
+
|
| 56 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 57 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 58 |
+
|
| 59 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 60 |
+
return emb.float()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Inspired by https://github.com/microsoft/moge
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def create_uv_grid(
|
| 67 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 71 |
+
|
| 72 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 73 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 74 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
width (int): Number of points horizontally.
|
| 78 |
+
height (int): Number of points vertically.
|
| 79 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 80 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 81 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 85 |
+
"""
|
| 86 |
+
# Derive aspect ratio if not explicitly provided
|
| 87 |
+
if aspect_ratio is None:
|
| 88 |
+
aspect_ratio = float(width) / float(height)
|
| 89 |
+
|
| 90 |
+
# Compute normalized spans for X and Y
|
| 91 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 92 |
+
span_x = aspect_ratio / diag_factor
|
| 93 |
+
span_y = 1.0 / diag_factor
|
| 94 |
+
|
| 95 |
+
# Establish the linspace boundaries
|
| 96 |
+
left_x = -span_x * (width - 1) / width
|
| 97 |
+
right_x = span_x * (width - 1) / width
|
| 98 |
+
top_y = -span_y * (height - 1) / height
|
| 99 |
+
bottom_y = span_y * (height - 1) / height
|
| 100 |
+
|
| 101 |
+
# Generate 1D coordinates
|
| 102 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 103 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 104 |
+
|
| 105 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 106 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 107 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 108 |
+
|
| 109 |
+
return uv_grid
|
lingbot_map/layers/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lingbot_map.layers.mlp import Mlp
|
| 2 |
+
from lingbot_map.layers.patch_embed import PatchEmbed
|
| 3 |
+
from lingbot_map.layers.block import Block
|
| 4 |
+
from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused
|
| 5 |
+
from lingbot_map.layers.attention import Attention as MemEffAttention
|
lingbot_map/layers/attention.py
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import math
|
| 13 |
+
import warnings
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
from torch import nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
from lingbot_map.layers.rope import apply_rotary_emb
|
| 21 |
+
|
| 22 |
+
from einops import rearrange
|
| 23 |
+
|
| 24 |
+
# FlashInfer imports (optional - for paged attention)
|
| 25 |
+
try:
|
| 26 |
+
import flashinfer
|
| 27 |
+
FLASHINFER_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
FLASHINFER_AVAILABLE = False
|
| 30 |
+
print("flashinfer not available")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from torchtitan.distributed.sequence_parallel import (
|
| 34 |
+
gather_seq_scatter_heads,
|
| 35 |
+
gather_heads_scatter_seq,
|
| 36 |
+
pad_tensor,
|
| 37 |
+
slice_input_tensor_scale_grad,
|
| 38 |
+
gather_outputs,
|
| 39 |
+
)
|
| 40 |
+
except ImportError:
|
| 41 |
+
print("torchtitan not available for ulysses cp")
|
| 42 |
+
|
| 43 |
+
def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int):
|
| 44 |
+
"""Gather sequence dimension and scatter head dimension for Q, K, V tensors."""
|
| 45 |
+
q = gather_seq_scatter_heads(q, seq_dim, head_dim)
|
| 46 |
+
k = gather_seq_scatter_heads(k, seq_dim, head_dim)
|
| 47 |
+
v = gather_seq_scatter_heads(v, seq_dim, head_dim)
|
| 48 |
+
return q, k, v
|
| 49 |
+
|
| 50 |
+
from typing_extensions import List
|
| 51 |
+
from typing import Optional, Tuple
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Attention(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
dim: int,
|
| 58 |
+
num_heads: int = 8,
|
| 59 |
+
qkv_bias: bool = True,
|
| 60 |
+
proj_bias: bool = True,
|
| 61 |
+
attn_drop: float = 0.0,
|
| 62 |
+
proj_drop: float = 0.0,
|
| 63 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 64 |
+
qk_norm: bool = False,
|
| 65 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 66 |
+
rope=None,
|
| 67 |
+
) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 70 |
+
self.num_heads = num_heads
|
| 71 |
+
self.head_dim = dim // num_heads
|
| 72 |
+
self.scale = self.head_dim**-0.5
|
| 73 |
+
self.fused_attn = fused_attn
|
| 74 |
+
|
| 75 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 76 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 77 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 78 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 79 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 80 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 81 |
+
self.rope = rope
|
| 82 |
+
|
| 83 |
+
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
| 84 |
+
B, N, C = x.shape
|
| 85 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 86 |
+
q, k, v = qkv.unbind(0)
|
| 87 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 88 |
+
|
| 89 |
+
if enable_ulysses_cp:
|
| 90 |
+
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
| 91 |
+
|
| 92 |
+
if self.rope is not None:
|
| 93 |
+
q = self.rope(q, pos)
|
| 94 |
+
k = self.rope(k, pos)
|
| 95 |
+
|
| 96 |
+
if self.fused_attn:
|
| 97 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
| 98 |
+
else:
|
| 99 |
+
q = q * self.scale
|
| 100 |
+
attn = q @ k.transpose(-2, -1)
|
| 101 |
+
attn = attn.softmax(dim=-1)
|
| 102 |
+
attn = self.attn_drop(attn)
|
| 103 |
+
x = attn @ v
|
| 104 |
+
|
| 105 |
+
if enable_ulysses_cp:
|
| 106 |
+
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
| 107 |
+
|
| 108 |
+
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
|
| 109 |
+
x = self.proj(x)
|
| 110 |
+
x = self.proj_drop(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class CausalAttention(nn.Module):
|
| 115 |
+
"""
|
| 116 |
+
Causal self-attention module with KV cache support for streaming inference.
|
| 117 |
+
Used by CasualBlockCamera in camera_head.py.
|
| 118 |
+
"""
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
dim: int,
|
| 122 |
+
num_heads: int = 8,
|
| 123 |
+
qkv_bias: bool = True,
|
| 124 |
+
proj_bias: bool = True,
|
| 125 |
+
attn_drop: float = 0.0,
|
| 126 |
+
proj_drop: float = 0.0,
|
| 127 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 128 |
+
qk_norm: bool = False,
|
| 129 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 130 |
+
rope=None,
|
| 131 |
+
elementwise_attn_output_gate=False,
|
| 132 |
+
# KV cache eviction parameters (matching build_attn_mask)
|
| 133 |
+
kv_cache_sliding_window: int =64,
|
| 134 |
+
kv_cache_scale_frames: int = 8,
|
| 135 |
+
kv_cache_cross_frame_special: bool = True,
|
| 136 |
+
kv_cache_include_scale_frames: bool = True,
|
| 137 |
+
kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token)
|
| 138 |
+
) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 141 |
+
self.num_heads = num_heads
|
| 142 |
+
self.head_dim = dim // num_heads
|
| 143 |
+
self.scale = self.head_dim**-0.5
|
| 144 |
+
self.fused_attn = fused_attn
|
| 145 |
+
|
| 146 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 147 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 148 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 149 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 150 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 151 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 152 |
+
self.rope = rope
|
| 153 |
+
|
| 154 |
+
self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None
|
| 155 |
+
|
| 156 |
+
# Store KV cache eviction parameters
|
| 157 |
+
self.kv_cache_sliding_window = kv_cache_sliding_window
|
| 158 |
+
self.kv_cache_scale_frames = kv_cache_scale_frames
|
| 159 |
+
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
| 160 |
+
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
| 161 |
+
self.kv_cache_camera_only = kv_cache_camera_only
|
| 162 |
+
|
| 163 |
+
def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor:
|
| 164 |
+
B, N, C = x.shape
|
| 165 |
+
|
| 166 |
+
# Calculate special token indices
|
| 167 |
+
camera_token_idx = 0
|
| 168 |
+
scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale
|
| 169 |
+
|
| 170 |
+
# [3, B, num_heads, N, head_dim]
|
| 171 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 172 |
+
q, k, v = qkv.unbind(0)
|
| 173 |
+
|
| 174 |
+
if self.gate_proj is not None:
|
| 175 |
+
gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 176 |
+
if kv_cache is None:
|
| 177 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 178 |
+
if enable_ulysses_cp:
|
| 179 |
+
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
| 180 |
+
N = q.shape[2] # Update N after gather
|
| 181 |
+
if self.rope is not None and not enable_3d_rope:
|
| 182 |
+
q = self.rope(q, pos)
|
| 183 |
+
k = self.rope(k, pos)
|
| 184 |
+
elif enable_3d_rope and pos is not None:
|
| 185 |
+
q = apply_rotary_emb(q, pos)
|
| 186 |
+
k = apply_rotary_emb(k, pos)
|
| 187 |
+
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]]
|
| 190 |
+
if block_mask.dim() == 2:
|
| 191 |
+
block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
|
| 192 |
+
block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
|
| 193 |
+
|
| 194 |
+
video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N]
|
| 195 |
+
video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
|
| 196 |
+
|
| 197 |
+
mask = block_mask | ~video_mask
|
| 198 |
+
|
| 199 |
+
# Apply sliding window mask if sliding_window_size > 0
|
| 200 |
+
# sliding_window_size is in units of num_frame_per_block
|
| 201 |
+
if sliding_window_size > 0 and frame_seqlen is not None:
|
| 202 |
+
# Create sliding window mask: each frame can only attend to frames within the window
|
| 203 |
+
num_frames = N // frame_seqlen
|
| 204 |
+
sliding_mask = torch.zeros_like(mask, dtype=torch.bool)
|
| 205 |
+
|
| 206 |
+
for i in range(num_frames):
|
| 207 |
+
q_start = i * frame_seqlen
|
| 208 |
+
q_end = (i + 1) * frame_seqlen
|
| 209 |
+
# Calculate the window start: sliding_window_size is in units of num_frame_per_block
|
| 210 |
+
# So the actual window size in frames is sliding_window_size * num_frame_per_block
|
| 211 |
+
window_size_in_frames = sliding_window_size * num_frame_per_block
|
| 212 |
+
window_start_frame = max(0, i - window_size_in_frames + 1)
|
| 213 |
+
k_start = window_start_frame * frame_seqlen
|
| 214 |
+
k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal)
|
| 215 |
+
sliding_mask[:, :, q_start:q_end, k_start:k_end] = True
|
| 216 |
+
|
| 217 |
+
# Combine with existing mask: both masks need to allow attention
|
| 218 |
+
mask = mask & sliding_mask
|
| 219 |
+
|
| 220 |
+
# If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames
|
| 221 |
+
if num_frame_for_scale > 0:
|
| 222 |
+
for i in range(num_frames):
|
| 223 |
+
q_start = i * frame_seqlen
|
| 224 |
+
q_end = (i + 1) * frame_seqlen
|
| 225 |
+
# Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask)
|
| 226 |
+
mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True
|
| 227 |
+
|
| 228 |
+
## global attention for the first num_frame_for_scale frames
|
| 229 |
+
if num_frame_for_scale > 0:
|
| 230 |
+
mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True
|
| 231 |
+
|
| 232 |
+
if self.fused_attn:
|
| 233 |
+
x = F.scaled_dot_product_attention(
|
| 234 |
+
q,
|
| 235 |
+
k,
|
| 236 |
+
v,
|
| 237 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 238 |
+
attn_mask=mask
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
# Apply RoPE to current k before caching
|
| 242 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 243 |
+
|
| 244 |
+
if self.rope is not None and not enable_3d_rope:
|
| 245 |
+
q = self.rope(q, pos)
|
| 246 |
+
k = self.rope(k, pos)
|
| 247 |
+
elif enable_3d_rope and pos is not None:
|
| 248 |
+
q = apply_rotary_emb(q, pos)
|
| 249 |
+
k = apply_rotary_emb(k, pos)
|
| 250 |
+
|
| 251 |
+
# Check if we should skip appending to cache (non-keyframe in keyframe mode)
|
| 252 |
+
skip_append = kv_cache.get("_skip_append", False)
|
| 253 |
+
|
| 254 |
+
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
| 255 |
+
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
| 256 |
+
|
| 257 |
+
if not skip_append:
|
| 258 |
+
# KEYFRAME: store in cache (original behavior)
|
| 259 |
+
if kv_cache[f"k_{global_idx}"] is None:
|
| 260 |
+
kv_cache[f"k_{global_idx}"] = k_reshaped
|
| 261 |
+
kv_cache[f"v_{global_idx}"] = v_reshaped
|
| 262 |
+
else:
|
| 263 |
+
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
|
| 264 |
+
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
| 265 |
+
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
| 266 |
+
kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
|
| 267 |
+
kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
|
| 268 |
+
|
| 269 |
+
# Apply sliding window eviction BEFORE attention to match causal_3drope behavior
|
| 270 |
+
# This ensures current frame only attends to frames within the sliding window
|
| 271 |
+
self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx)
|
| 272 |
+
|
| 273 |
+
# Retrieve full k, v from cache (already RoPE-applied, already evicted)
|
| 274 |
+
k = kv_cache[f"k_{global_idx}"].clone()
|
| 275 |
+
v = kv_cache[f"v_{global_idx}"].clone()
|
| 276 |
+
else:
|
| 277 |
+
# NON-KEYFRAME: attend to [cached + current] without storing in cache
|
| 278 |
+
if kv_cache[f"k_{global_idx}"] is not None:
|
| 279 |
+
k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
|
| 280 |
+
v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
|
| 281 |
+
else:
|
| 282 |
+
k = k_reshaped
|
| 283 |
+
v = v_reshaped
|
| 284 |
+
a, b, c, d, e = k.shape
|
| 285 |
+
|
| 286 |
+
k = k.reshape(a, b, c*d, e)
|
| 287 |
+
v = v.reshape(a, b, c*d, e)
|
| 288 |
+
|
| 289 |
+
# Prepend special tokens (camera + scale) from evicted frames if they exist
|
| 290 |
+
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
|
| 291 |
+
special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D]
|
| 292 |
+
special_v = kv_cache[f"v_{global_idx}_special"]
|
| 293 |
+
sa, sb, sc, sd, se = special_k.shape
|
| 294 |
+
special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D]
|
| 295 |
+
special_v = special_v.reshape(sa, sb, sc * sd, se)
|
| 296 |
+
|
| 297 |
+
# Prepend special tokens (older frames first)
|
| 298 |
+
k = torch.cat([special_k, k], dim=2)
|
| 299 |
+
v = torch.cat([special_v, v], dim=2)
|
| 300 |
+
|
| 301 |
+
# Note: k from cache is already RoPE-applied, no need to apply again
|
| 302 |
+
|
| 303 |
+
if self.fused_attn:
|
| 304 |
+
# Use mask-based SDPA to ensure same kernel as batch mode
|
| 305 |
+
# The causal constraint is enforced by KV cache contents, not by mask
|
| 306 |
+
mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device)
|
| 307 |
+
x = F.scaled_dot_product_attention(
|
| 308 |
+
q,
|
| 309 |
+
k,
|
| 310 |
+
v,
|
| 311 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 312 |
+
attn_mask=mask,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if self.gate_proj is not None:
|
| 316 |
+
x = x * torch.sigmoid(gate_score)
|
| 317 |
+
if enable_ulysses_cp:
|
| 318 |
+
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
| 319 |
+
# Use actual dimensions from attention output, not original input C
|
| 320 |
+
# x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim]
|
| 321 |
+
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
|
| 322 |
+
x = self.proj(x)
|
| 323 |
+
x = self.proj_drop(x)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx):
|
| 327 |
+
"""
|
| 328 |
+
Apply sliding window eviction to KV cache BEFORE attention.
|
| 329 |
+
|
| 330 |
+
This ensures current frame only attends to frames within the sliding window,
|
| 331 |
+
matching the behavior of causal_3drope's attention mask.
|
| 332 |
+
"""
|
| 333 |
+
sliding_window_frames = self.kv_cache_sliding_window
|
| 334 |
+
scale_frames = self.kv_cache_scale_frames
|
| 335 |
+
|
| 336 |
+
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
|
| 337 |
+
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
|
| 338 |
+
|
| 339 |
+
if num_cached_frames > sliding_window_frames + scale_frames:
|
| 340 |
+
evict_start = scale_frames
|
| 341 |
+
evict_end = num_cached_frames - sliding_window_frames
|
| 342 |
+
|
| 343 |
+
if evict_end > evict_start:
|
| 344 |
+
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
| 345 |
+
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
| 346 |
+
|
| 347 |
+
if self.kv_cache_cross_frame_special:
|
| 348 |
+
if self.kv_cache_camera_only:
|
| 349 |
+
# Only keep camera token
|
| 350 |
+
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
| 351 |
+
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
| 352 |
+
else:
|
| 353 |
+
# Keep ALL special tokens (camera + register + scale) to match attention_mask behavior
|
| 354 |
+
# Special tokens are in range [camera_token_idx, scale_token_idx+1)
|
| 355 |
+
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
| 356 |
+
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
| 357 |
+
|
| 358 |
+
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
|
| 359 |
+
kv_cache[f"k_{global_idx}_special"] = new_special_k
|
| 360 |
+
kv_cache[f"v_{global_idx}_special"] = new_special_v
|
| 361 |
+
else:
|
| 362 |
+
kv_cache[f"k_{global_idx}_special"] = torch.cat(
|
| 363 |
+
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
|
| 364 |
+
kv_cache[f"v_{global_idx}_special"] = torch.cat(
|
| 365 |
+
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
|
| 366 |
+
|
| 367 |
+
if self.kv_cache_include_scale_frames:
|
| 368 |
+
kv_cache[f"k_{global_idx}"] = torch.cat([
|
| 369 |
+
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
|
| 370 |
+
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 371 |
+
], dim=2)
|
| 372 |
+
kv_cache[f"v_{global_idx}"] = torch.cat([
|
| 373 |
+
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
|
| 374 |
+
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 375 |
+
], dim=2)
|
| 376 |
+
else:
|
| 377 |
+
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 378 |
+
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class FlashInferAttention(Attention):
|
| 382 |
+
"""
|
| 383 |
+
FlashInfer variant of the GCT attention layer.
|
| 384 |
+
Uses FlashInferKVCacheManager for paged KV cache storage and
|
| 385 |
+
FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper).
|
| 386 |
+
Supports the same optimized token layout and KV cache streaming inference.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(
|
| 390 |
+
self,
|
| 391 |
+
dim: int,
|
| 392 |
+
num_heads: int = 8,
|
| 393 |
+
qkv_bias: bool = True,
|
| 394 |
+
proj_bias: bool = True,
|
| 395 |
+
attn_drop: float = 0.0,
|
| 396 |
+
proj_drop: float = 0.0,
|
| 397 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 398 |
+
qk_norm: bool = False,
|
| 399 |
+
fused_attn: bool = True,
|
| 400 |
+
rope=None,
|
| 401 |
+
# KV cache eviction parameters
|
| 402 |
+
kv_cache_sliding_window: int = 64,
|
| 403 |
+
kv_cache_scale_frames: int = 8,
|
| 404 |
+
kv_cache_cross_frame_special: bool = True,
|
| 405 |
+
kv_cache_include_scale_frames: bool = True,
|
| 406 |
+
kv_cache_camera_only: bool = False,
|
| 407 |
+
) -> None:
|
| 408 |
+
if not FLASHINFER_AVAILABLE:
|
| 409 |
+
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
|
| 410 |
+
|
| 411 |
+
super().__init__(
|
| 412 |
+
dim=dim,
|
| 413 |
+
num_heads=num_heads,
|
| 414 |
+
qkv_bias=qkv_bias,
|
| 415 |
+
proj_bias=proj_bias,
|
| 416 |
+
attn_drop=attn_drop,
|
| 417 |
+
proj_drop=proj_drop,
|
| 418 |
+
norm_layer=norm_layer,
|
| 419 |
+
qk_norm=qk_norm,
|
| 420 |
+
fused_attn=fused_attn,
|
| 421 |
+
rope=rope,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Store KV cache eviction parameters
|
| 425 |
+
self.kv_cache_sliding_window = kv_cache_sliding_window
|
| 426 |
+
self.kv_cache_scale_frames = kv_cache_scale_frames
|
| 427 |
+
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
| 428 |
+
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
| 429 |
+
self.kv_cache_camera_only = kv_cache_camera_only
|
| 430 |
+
|
| 431 |
+
def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
| 432 |
+
"""Fused pre-attention ops for single-frame streaming (Phase 2).
|
| 433 |
+
|
| 434 |
+
Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to
|
| 435 |
+
[tpf, H, D] format ready for append_frame + compute_attention.
|
| 436 |
+
|
| 437 |
+
Extracted as a method so torch.compile can capture all pre-attn ops as one
|
| 438 |
+
CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 ->
|
| 439 |
+
squeeze/permute/contiguous x3).
|
| 440 |
+
"""
|
| 441 |
+
B, N, C = x.shape
|
| 442 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 443 |
+
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
|
| 444 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 445 |
+
if self.rope is not None and not enable_3d_rope:
|
| 446 |
+
q = self.rope(q, pos)
|
| 447 |
+
k = self.rope(k, pos)
|
| 448 |
+
elif self.rope is not None: # enable_3d_rope=True
|
| 449 |
+
q = apply_rotary_emb(q, pos)
|
| 450 |
+
k = apply_rotary_emb(k, pos)
|
| 451 |
+
# Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode)
|
| 452 |
+
q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous()
|
| 453 |
+
k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous()
|
| 454 |
+
v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous()
|
| 455 |
+
return q_nhd, k_nhd, v_nhd
|
| 456 |
+
|
| 457 |
+
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
| 458 |
+
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
| 459 |
+
# KV cache parameters (kv_cache is a FlashInferKVCacheManager or None)
|
| 460 |
+
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
| 461 |
+
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
| 462 |
+
"""
|
| 463 |
+
Forward pass with FlashInfer paged KV cache and attention.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
x: Input tensor [B, N, C]
|
| 467 |
+
kv_cache: FlashInferKVCacheManager instance or None (batch mode)
|
| 468 |
+
global_idx: Block index for per-block cache access
|
| 469 |
+
"""
|
| 470 |
+
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
|
| 471 |
+
|
| 472 |
+
B, N, C = x.shape
|
| 473 |
+
|
| 474 |
+
# Detect if using optimized layout
|
| 475 |
+
using_optimized_layout = (num_patches is not None and num_special is not None
|
| 476 |
+
and num_frames is not None)
|
| 477 |
+
|
| 478 |
+
# ========== Batch Mode (no KV cache manager) ==========
|
| 479 |
+
if not isinstance(kv_cache, FlashInferKVCacheManager):
|
| 480 |
+
# [3, B, num_heads, N, head_dim]
|
| 481 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 482 |
+
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
|
| 483 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 484 |
+
|
| 485 |
+
if enable_ulysses_cp:
|
| 486 |
+
if using_optimized_layout:
|
| 487 |
+
boundary = num_frames * num_patches
|
| 488 |
+
q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :]
|
| 489 |
+
q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :]
|
| 490 |
+
q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv(
|
| 491 |
+
q_patch, k_patch, v_patch, seq_dim=2, head_dim=1
|
| 492 |
+
)
|
| 493 |
+
q_special, k_special, v_special = gather_seq_scatter_heads_qkv(
|
| 494 |
+
q_special, k_special, v_special, seq_dim=2, head_dim=1
|
| 495 |
+
)
|
| 496 |
+
q = torch.cat([q_patch, q_special], dim=2)
|
| 497 |
+
k = torch.cat([k_patch, k_special], dim=2)
|
| 498 |
+
v = torch.cat([v_patch, v_special], dim=2)
|
| 499 |
+
else:
|
| 500 |
+
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
| 501 |
+
|
| 502 |
+
if self.rope is not None and not enable_3d_rope:
|
| 503 |
+
q = self.rope(q, pos)
|
| 504 |
+
k = self.rope(k, pos)
|
| 505 |
+
elif self.rope is not None and enable_3d_rope:
|
| 506 |
+
q = apply_rotary_emb(q, pos)
|
| 507 |
+
k = apply_rotary_emb(k, pos)
|
| 508 |
+
|
| 509 |
+
# Batch mode: use SDPA for numerical consistency with SDPA variant
|
| 510 |
+
x = F.scaled_dot_product_attention(
|
| 511 |
+
q, k, v,
|
| 512 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if enable_ulysses_cp:
|
| 516 |
+
if using_optimized_layout:
|
| 517 |
+
seq_global = x.shape[2]
|
| 518 |
+
seq_local = num_frames * (num_patches + num_special)
|
| 519 |
+
cp_size = seq_global // seq_local
|
| 520 |
+
boundary_global = num_frames * cp_size * num_patches
|
| 521 |
+
x_patch = x[:, :, :boundary_global, :]
|
| 522 |
+
x_special = x[:, :, boundary_global:, :]
|
| 523 |
+
x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1)
|
| 524 |
+
x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1)
|
| 525 |
+
x = torch.cat([x_patch, x_special], dim=2)
|
| 526 |
+
else:
|
| 527 |
+
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
| 528 |
+
|
| 529 |
+
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
| 530 |
+
|
| 531 |
+
# ========== Streaming Mode (with FlashInferKVCacheManager) ==========
|
| 532 |
+
else:
|
| 533 |
+
manager = kv_cache # FlashInferKVCacheManager
|
| 534 |
+
|
| 535 |
+
# Phase 1 (scale frames): num_frames > 1 — multi-frame batch
|
| 536 |
+
# Phase 2 (streaming): num_frames == 1 — single frame
|
| 537 |
+
is_multi_frame = (num_frames is not None and num_frames > 1)
|
| 538 |
+
|
| 539 |
+
if is_multi_frame:
|
| 540 |
+
# Phase 1: compute full self-attention via SDPA (all frames attend to each other),
|
| 541 |
+
# then append each frame's K/V to the paged cache one at a time.
|
| 542 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 543 |
+
q, k, v = qkv.unbind(0)
|
| 544 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 545 |
+
|
| 546 |
+
# Apply RoPE before caching (RoPE baked into K before append)
|
| 547 |
+
if self.rope is not None and not enable_3d_rope:
|
| 548 |
+
q = self.rope(q, pos)
|
| 549 |
+
k = self.rope(k, pos)
|
| 550 |
+
elif self.rope is not None and enable_3d_rope:
|
| 551 |
+
q = apply_rotary_emb(q, pos)
|
| 552 |
+
k = apply_rotary_emb(k, pos)
|
| 553 |
+
|
| 554 |
+
x = F.scaled_dot_product_attention(
|
| 555 |
+
q, k, v,
|
| 556 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 557 |
+
)
|
| 558 |
+
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
| 559 |
+
|
| 560 |
+
# Append each frame's K/V to the paged cache individually.
|
| 561 |
+
tpf = manager.tokens_per_frame
|
| 562 |
+
k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D]
|
| 563 |
+
v_all = v.squeeze(0).permute(1, 0, 2)
|
| 564 |
+
for f_idx in range(num_frames):
|
| 565 |
+
s = f_idx * tpf
|
| 566 |
+
manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous())
|
| 567 |
+
manager.evict_frames(
|
| 568 |
+
block_idx=global_idx,
|
| 569 |
+
scale_frames=self.kv_cache_scale_frames,
|
| 570 |
+
sliding_window=self.kv_cache_sliding_window,
|
| 571 |
+
cross_frame_special=self.kv_cache_cross_frame_special,
|
| 572 |
+
include_scale_frames=self.kv_cache_include_scale_frames,
|
| 573 |
+
camera_only=self.kv_cache_camera_only,
|
| 574 |
+
num_register_tokens=num_register_tokens,
|
| 575 |
+
)
|
| 576 |
+
else:
|
| 577 |
+
# Phase 2: single-frame streaming via FlashInfer paged attention.
|
| 578 |
+
q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
| 579 |
+
|
| 580 |
+
# 1. Append to paged cache
|
| 581 |
+
manager.append_frame(global_idx, k_nhd, v_nhd)
|
| 582 |
+
|
| 583 |
+
# 2. Apply sliding window eviction
|
| 584 |
+
manager.evict_frames(
|
| 585 |
+
block_idx=global_idx,
|
| 586 |
+
scale_frames=self.kv_cache_scale_frames,
|
| 587 |
+
sliding_window=self.kv_cache_sliding_window,
|
| 588 |
+
cross_frame_special=self.kv_cache_cross_frame_special,
|
| 589 |
+
include_scale_frames=self.kv_cache_include_scale_frames,
|
| 590 |
+
camera_only=self.kv_cache_camera_only,
|
| 591 |
+
num_register_tokens=num_register_tokens,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
| 595 |
+
x = manager.compute_attention(global_idx, q_nhd)
|
| 596 |
+
|
| 597 |
+
# Convert back: [tpf, H, D] -> [B, tpf, C].
|
| 598 |
+
x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim)
|
| 599 |
+
|
| 600 |
+
x = self.proj(x)
|
| 601 |
+
x = self.proj_drop(x)
|
| 602 |
+
return x
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
class SDPAAttention(Attention):
|
| 606 |
+
"""
|
| 607 |
+
SDPA variant for streaming inference.
|
| 608 |
+
Uses F.scaled_dot_product_attention with dict-based KV cache.
|
| 609 |
+
No FlashInfer dependency required — works on any CUDA GPU.
|
| 610 |
+
"""
|
| 611 |
+
|
| 612 |
+
def __init__(
|
| 613 |
+
self,
|
| 614 |
+
dim: int,
|
| 615 |
+
num_heads: int = 8,
|
| 616 |
+
qkv_bias: bool = True,
|
| 617 |
+
proj_bias: bool = True,
|
| 618 |
+
attn_drop: float = 0.0,
|
| 619 |
+
proj_drop: float = 0.0,
|
| 620 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 621 |
+
qk_norm: bool = False,
|
| 622 |
+
fused_attn: bool = True,
|
| 623 |
+
rope=None,
|
| 624 |
+
kv_cache_sliding_window: int = 64,
|
| 625 |
+
kv_cache_scale_frames: int = 8,
|
| 626 |
+
kv_cache_cross_frame_special: bool = True,
|
| 627 |
+
kv_cache_include_scale_frames: bool = True,
|
| 628 |
+
kv_cache_camera_only: bool = False,
|
| 629 |
+
) -> None:
|
| 630 |
+
super().__init__(
|
| 631 |
+
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
|
| 632 |
+
attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer,
|
| 633 |
+
qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
|
| 634 |
+
)
|
| 635 |
+
self.kv_cache_sliding_window = kv_cache_sliding_window
|
| 636 |
+
self.kv_cache_scale_frames = kv_cache_scale_frames
|
| 637 |
+
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
| 638 |
+
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
| 639 |
+
self.kv_cache_camera_only = kv_cache_camera_only
|
| 640 |
+
|
| 641 |
+
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
| 642 |
+
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
| 643 |
+
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
| 644 |
+
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
| 645 |
+
B, N, C = x.shape
|
| 646 |
+
using_optimized_layout = (num_patches is not None and num_special is not None
|
| 647 |
+
and num_frames is not None)
|
| 648 |
+
|
| 649 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 650 |
+
q, k, v = qkv.unbind(0)
|
| 651 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 652 |
+
|
| 653 |
+
# ========== Batch Mode (no KV cache) ==========
|
| 654 |
+
if kv_cache is None:
|
| 655 |
+
if self.rope is not None and not enable_3d_rope:
|
| 656 |
+
q = self.rope(q, pos)
|
| 657 |
+
k = self.rope(k, pos)
|
| 658 |
+
elif self.rope is not None and enable_3d_rope:
|
| 659 |
+
q = apply_rotary_emb(q, pos)
|
| 660 |
+
k = apply_rotary_emb(k, pos)
|
| 661 |
+
|
| 662 |
+
x = F.scaled_dot_product_attention(
|
| 663 |
+
q, k, v,
|
| 664 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 665 |
+
)
|
| 666 |
+
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
| 667 |
+
|
| 668 |
+
# ========== Streaming Mode (with KV cache dict) ==========
|
| 669 |
+
else:
|
| 670 |
+
if self.rope is not None and not enable_3d_rope:
|
| 671 |
+
q = self.rope(q, pos)
|
| 672 |
+
k = self.rope(k, pos)
|
| 673 |
+
elif self.rope is not None and enable_3d_rope:
|
| 674 |
+
q = apply_rotary_emb(q, pos)
|
| 675 |
+
k = apply_rotary_emb(k, pos)
|
| 676 |
+
|
| 677 |
+
camera_token_idx = 0
|
| 678 |
+
scale_token_idx = camera_token_idx + num_register_tokens + 1
|
| 679 |
+
|
| 680 |
+
if kv_cache[f"k_{global_idx}"] is None:
|
| 681 |
+
kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block,
|
| 682 |
+
N // num_frame_per_block, self.head_dim)
|
| 683 |
+
kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block,
|
| 684 |
+
N // num_frame_per_block, self.head_dim)
|
| 685 |
+
else:
|
| 686 |
+
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
|
| 687 |
+
kv_cache[f"k_{global_idx}"] = torch.cat((
|
| 688 |
+
kv_cache[f"k_{global_idx}"],
|
| 689 |
+
k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
| 690 |
+
), dim=2)
|
| 691 |
+
kv_cache[f"v_{global_idx}"] = torch.cat((
|
| 692 |
+
kv_cache[f"v_{global_idx}"],
|
| 693 |
+
v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
| 694 |
+
), dim=2)
|
| 695 |
+
|
| 696 |
+
self._apply_kv_cache_eviction(
|
| 697 |
+
kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
k_cached = kv_cache[f"k_{global_idx}"].clone()
|
| 701 |
+
v_cached = kv_cache[f"v_{global_idx}"].clone()
|
| 702 |
+
a, b, c, d, e = k_cached.shape
|
| 703 |
+
k_full = k_cached.reshape(a, b, c * d, e)
|
| 704 |
+
v_full = v_cached.reshape(a, b, c * d, e)
|
| 705 |
+
|
| 706 |
+
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
|
| 707 |
+
special_k = kv_cache[f"k_{global_idx}_special"]
|
| 708 |
+
special_v = kv_cache[f"v_{global_idx}_special"]
|
| 709 |
+
sa, sb, sc, sd, se = special_k.shape
|
| 710 |
+
k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2)
|
| 711 |
+
v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2)
|
| 712 |
+
|
| 713 |
+
q_seq_len = q.shape[2]
|
| 714 |
+
x = F.scaled_dot_product_attention(
|
| 715 |
+
q, k_full, v_full,
|
| 716 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 717 |
+
)
|
| 718 |
+
x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim)
|
| 719 |
+
|
| 720 |
+
x = self.proj(x)
|
| 721 |
+
x = self.proj_drop(x)
|
| 722 |
+
return x
|
| 723 |
+
|
| 724 |
+
def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens):
|
| 725 |
+
"""Apply sliding window eviction to KV cache."""
|
| 726 |
+
sliding_window_frames = self.kv_cache_sliding_window
|
| 727 |
+
scale_frames = self.kv_cache_scale_frames
|
| 728 |
+
|
| 729 |
+
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
|
| 730 |
+
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
|
| 731 |
+
if num_cached_frames > sliding_window_frames + scale_frames:
|
| 732 |
+
evict_start = scale_frames
|
| 733 |
+
evict_end = num_cached_frames - sliding_window_frames
|
| 734 |
+
if evict_end > evict_start:
|
| 735 |
+
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
| 736 |
+
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
| 737 |
+
|
| 738 |
+
if self.kv_cache_cross_frame_special:
|
| 739 |
+
if self.kv_cache_camera_only:
|
| 740 |
+
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
| 741 |
+
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
| 742 |
+
else:
|
| 743 |
+
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
| 744 |
+
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
| 745 |
+
|
| 746 |
+
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
|
| 747 |
+
kv_cache[f"k_{global_idx}_special"] = new_special_k
|
| 748 |
+
kv_cache[f"v_{global_idx}_special"] = new_special_v
|
| 749 |
+
else:
|
| 750 |
+
kv_cache[f"k_{global_idx}_special"] = torch.cat(
|
| 751 |
+
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
|
| 752 |
+
kv_cache[f"v_{global_idx}_special"] = torch.cat(
|
| 753 |
+
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
|
| 754 |
+
|
| 755 |
+
if self.kv_cache_include_scale_frames:
|
| 756 |
+
kv_cache[f"k_{global_idx}"] = torch.cat([
|
| 757 |
+
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
|
| 758 |
+
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 759 |
+
], dim=2)
|
| 760 |
+
kv_cache[f"v_{global_idx}"] = torch.cat([
|
| 761 |
+
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
|
| 762 |
+
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 763 |
+
], dim=2)
|
| 764 |
+
else:
|
| 765 |
+
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
| 766 |
+
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
lingbot_map/layers/block.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, Tensor
|
| 18 |
+
|
| 19 |
+
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
|
| 20 |
+
from functools import lru_cache, partial
|
| 21 |
+
from torch.nn.attention.flex_attention import BlockMask, create_mask
|
| 22 |
+
from .drop_path import DropPath
|
| 23 |
+
from .layer_scale import LayerScale
|
| 24 |
+
from .mlp import Mlp
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Block(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dim: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
mlp_ratio: float = 4.0,
|
| 33 |
+
qkv_bias: bool = True,
|
| 34 |
+
proj_bias: bool = True,
|
| 35 |
+
ffn_bias: bool = True,
|
| 36 |
+
drop: float = 0.0,
|
| 37 |
+
attn_drop: float = 0.0,
|
| 38 |
+
init_values=None,
|
| 39 |
+
drop_path: float = 0.0,
|
| 40 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 41 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 42 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 43 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 44 |
+
qk_norm: bool = False,
|
| 45 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 46 |
+
rope=None,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.norm1 = norm_layer(dim)
|
| 51 |
+
|
| 52 |
+
self.attn = attn_class(
|
| 53 |
+
dim,
|
| 54 |
+
num_heads=num_heads,
|
| 55 |
+
qkv_bias=qkv_bias,
|
| 56 |
+
proj_bias=proj_bias,
|
| 57 |
+
attn_drop=attn_drop,
|
| 58 |
+
proj_drop=drop,
|
| 59 |
+
qk_norm=qk_norm,
|
| 60 |
+
fused_attn=fused_attn,
|
| 61 |
+
rope=rope,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 65 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 66 |
+
|
| 67 |
+
self.norm2 = norm_layer(dim)
|
| 68 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 69 |
+
self.mlp = ffn_layer(
|
| 70 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
| 71 |
+
)
|
| 72 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 73 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 74 |
+
|
| 75 |
+
self.sample_drop_ratio = drop_path
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
| 78 |
+
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
| 79 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 80 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
| 81 |
+
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
| 82 |
+
enable_3d_rope=enable_3d_rope))
|
| 83 |
+
|
| 84 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 85 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 86 |
+
|
| 87 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 88 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 89 |
+
x = drop_add_residual_stochastic_depth(
|
| 90 |
+
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 91 |
+
)
|
| 92 |
+
x = drop_add_residual_stochastic_depth(
|
| 93 |
+
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 94 |
+
)
|
| 95 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 96 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 97 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 98 |
+
else:
|
| 99 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 100 |
+
x = x + ffn_residual_func(x)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def drop_add_residual_stochastic_depth(
|
| 105 |
+
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
| 106 |
+
) -> Tensor:
|
| 107 |
+
# 1) extract subset using permutation
|
| 108 |
+
b, n, d = x.shape
|
| 109 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 110 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 111 |
+
x_subset = x[brange]
|
| 112 |
+
|
| 113 |
+
# 2) apply residual_func to get residual
|
| 114 |
+
if pos is not None:
|
| 115 |
+
# if necessary, apply rope to the subset
|
| 116 |
+
pos = pos[brange]
|
| 117 |
+
residual = residual_func(x_subset, pos=pos)
|
| 118 |
+
else:
|
| 119 |
+
residual = residual_func(x_subset)
|
| 120 |
+
|
| 121 |
+
x_flat = x.flatten(1)
|
| 122 |
+
residual = residual.flatten(1)
|
| 123 |
+
|
| 124 |
+
residual_scale_factor = b / sample_subset_size
|
| 125 |
+
|
| 126 |
+
# 3) add the residual
|
| 127 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 128 |
+
return x_plus_residual.view_as(x)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 132 |
+
b, n, d = x.shape
|
| 133 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 134 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 135 |
+
residual_scale_factor = b / sample_subset_size
|
| 136 |
+
return brange, residual_scale_factor
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 140 |
+
if scaling_vector is None:
|
| 141 |
+
x_flat = x.flatten(1)
|
| 142 |
+
residual = residual.flatten(1)
|
| 143 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 144 |
+
else:
|
| 145 |
+
x_plus_residual = scaled_index_add(
|
| 146 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 147 |
+
)
|
| 148 |
+
return x_plus_residual
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class FlashInferBlock(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
FlashInfer variant of causal block for GCT.
|
| 154 |
+
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
|
| 155 |
+
Supports optimized token layout and KV cache streaming inference.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
dim: int,
|
| 161 |
+
num_heads: int,
|
| 162 |
+
mlp_ratio: float = 4.0,
|
| 163 |
+
qkv_bias: bool = True,
|
| 164 |
+
proj_bias: bool = True,
|
| 165 |
+
ffn_bias: bool = True,
|
| 166 |
+
drop: float = 0.0,
|
| 167 |
+
attn_drop: float = 0.0,
|
| 168 |
+
init_values=None,
|
| 169 |
+
drop_path: float = 0.0,
|
| 170 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 171 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 172 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 173 |
+
qk_norm: bool = False,
|
| 174 |
+
rope=None,
|
| 175 |
+
kv_cache_sliding_window: int = 64,
|
| 176 |
+
kv_cache_scale_frames: int = 8,
|
| 177 |
+
kv_cache_cross_frame_special: bool = True,
|
| 178 |
+
kv_cache_include_scale_frames: bool = True,
|
| 179 |
+
kv_cache_camera_only: bool = False,
|
| 180 |
+
) -> None:
|
| 181 |
+
super().__init__()
|
| 182 |
+
|
| 183 |
+
self.norm1 = norm_layer(dim)
|
| 184 |
+
self.attn = FlashInferAttention(
|
| 185 |
+
dim=dim,
|
| 186 |
+
num_heads=num_heads,
|
| 187 |
+
qk_norm=qk_norm,
|
| 188 |
+
qkv_bias=qkv_bias,
|
| 189 |
+
proj_bias=proj_bias,
|
| 190 |
+
attn_drop=attn_drop,
|
| 191 |
+
proj_drop=drop,
|
| 192 |
+
rope=rope,
|
| 193 |
+
kv_cache_sliding_window=kv_cache_sliding_window,
|
| 194 |
+
kv_cache_scale_frames=kv_cache_scale_frames,
|
| 195 |
+
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
| 196 |
+
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
| 197 |
+
kv_cache_camera_only=kv_cache_camera_only,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 201 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 202 |
+
|
| 203 |
+
self.norm2 = norm_layer(dim)
|
| 204 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 205 |
+
self.mlp = ffn_layer(
|
| 206 |
+
in_features=dim,
|
| 207 |
+
hidden_features=mlp_hidden_dim,
|
| 208 |
+
act_layer=act_layer,
|
| 209 |
+
drop=drop,
|
| 210 |
+
bias=ffn_bias
|
| 211 |
+
)
|
| 212 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 213 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 214 |
+
|
| 215 |
+
self.sample_drop_ratio = drop_path
|
| 216 |
+
|
| 217 |
+
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
| 218 |
+
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
|
| 219 |
+
|
| 220 |
+
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
|
| 221 |
+
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
|
| 225 |
+
ready for manager.append_frame + manager.compute_attention.
|
| 226 |
+
"""
|
| 227 |
+
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
|
| 228 |
+
|
| 229 |
+
def forward(
|
| 230 |
+
self,
|
| 231 |
+
x: Tensor,
|
| 232 |
+
pos=None,
|
| 233 |
+
enable_ulysses_cp=False,
|
| 234 |
+
num_patches=None,
|
| 235 |
+
num_special=None,
|
| 236 |
+
num_frames=None,
|
| 237 |
+
enable_3d_rope=False,
|
| 238 |
+
kv_cache=None,
|
| 239 |
+
global_idx=0,
|
| 240 |
+
num_frame_per_block=1,
|
| 241 |
+
num_frame_for_scale=-1,
|
| 242 |
+
num_register_tokens=4,
|
| 243 |
+
) -> Tensor:
|
| 244 |
+
# Phase 2 (streaming): single-frame FlashInfer paged attention.
|
| 245 |
+
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
|
| 246 |
+
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
|
| 247 |
+
if is_streaming:
|
| 248 |
+
manager = kv_cache
|
| 249 |
+
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
|
| 250 |
+
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
| 251 |
+
# Eager: write frame K/V to paged cache
|
| 252 |
+
manager.append_frame(global_idx, k_nhd, v_nhd)
|
| 253 |
+
# CPU-only: update eviction state (deque ops, no GPU kernel)
|
| 254 |
+
manager.evict_frames(
|
| 255 |
+
block_idx=global_idx,
|
| 256 |
+
scale_frames=self.attn.kv_cache_scale_frames,
|
| 257 |
+
sliding_window=self.attn.kv_cache_sliding_window,
|
| 258 |
+
cross_frame_special=self.attn.kv_cache_cross_frame_special,
|
| 259 |
+
include_scale_frames=self.attn.kv_cache_include_scale_frames,
|
| 260 |
+
camera_only=self.attn.kv_cache_camera_only,
|
| 261 |
+
num_register_tokens=num_register_tokens,
|
| 262 |
+
)
|
| 263 |
+
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
| 264 |
+
attn_x = manager.compute_attention(global_idx, q_nhd)
|
| 265 |
+
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
|
| 266 |
+
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
|
| 267 |
+
self.attn.num_heads * self.attn.head_dim)
|
| 268 |
+
# Compiled: output projection
|
| 269 |
+
attn_x = self.attn.proj(attn_x)
|
| 270 |
+
x = x + self.ls1(attn_x)
|
| 271 |
+
else:
|
| 272 |
+
# Phase 1 (multi-frame scale pass) or non-streaming training path
|
| 273 |
+
x = x + self.ls1(self.attn(
|
| 274 |
+
self.norm1(x),
|
| 275 |
+
pos=pos,
|
| 276 |
+
enable_ulysses_cp=enable_ulysses_cp,
|
| 277 |
+
num_patches=num_patches,
|
| 278 |
+
num_special=num_special,
|
| 279 |
+
num_frames=num_frames,
|
| 280 |
+
enable_3d_rope=enable_3d_rope,
|
| 281 |
+
kv_cache=kv_cache,
|
| 282 |
+
global_idx=global_idx,
|
| 283 |
+
num_frame_per_block=num_frame_per_block,
|
| 284 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 285 |
+
num_register_tokens=num_register_tokens,
|
| 286 |
+
))
|
| 287 |
+
x = self.ffn_residual(x)
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
def ffn_residual(self, x: Tensor) -> Tensor:
|
| 291 |
+
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
|
| 292 |
+
|
| 293 |
+
Includes the residual add (x + ...) so torch.compile captures the entire
|
| 294 |
+
ffn branch as one CUDA graph.
|
| 295 |
+
"""
|
| 296 |
+
return x + self.ls2(self.mlp(self.norm2(x)))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class CameraBlock(nn.Module):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
dim: int,
|
| 303 |
+
num_heads: int,
|
| 304 |
+
mlp_ratio: float = 4.0,
|
| 305 |
+
qkv_bias: bool = True,
|
| 306 |
+
proj_bias: bool = True,
|
| 307 |
+
ffn_bias: bool = True,
|
| 308 |
+
drop: float = 0.0,
|
| 309 |
+
attn_drop: float = 0.0,
|
| 310 |
+
init_values=None,
|
| 311 |
+
drop_path: float = 0.0,
|
| 312 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 313 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 314 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 315 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 316 |
+
qk_norm: bool = False,
|
| 317 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 318 |
+
rope=None,
|
| 319 |
+
elementwise_attn_output_gate: bool = False,
|
| 320 |
+
sliding_window_size: int = -1,
|
| 321 |
+
attend_to_scale_frames: bool = False,
|
| 322 |
+
num_random_frames: int = 0,
|
| 323 |
+
# KV cache parameters
|
| 324 |
+
kv_cache_sliding_window: int = 64,
|
| 325 |
+
kv_cache_scale_frames: int = 8,
|
| 326 |
+
kv_cache_cross_frame_special: bool = True,
|
| 327 |
+
kv_cache_include_scale_frames: bool = True,
|
| 328 |
+
kv_cache_camera_only: bool = False,
|
| 329 |
+
) -> None:
|
| 330 |
+
super().__init__()
|
| 331 |
+
|
| 332 |
+
self.norm1 = norm_layer(dim)
|
| 333 |
+
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
|
| 334 |
+
qk_norm=qk_norm, qkv_bias=qkv_bias,
|
| 335 |
+
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
|
| 336 |
+
kv_cache_sliding_window=kv_cache_sliding_window,
|
| 337 |
+
kv_cache_scale_frames=kv_cache_scale_frames,
|
| 338 |
+
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
| 339 |
+
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
| 340 |
+
kv_cache_camera_only=kv_cache_camera_only)
|
| 341 |
+
|
| 342 |
+
self.sliding_window_size = sliding_window_size
|
| 343 |
+
self.attend_to_scale_frames = attend_to_scale_frames
|
| 344 |
+
self.num_random_frames = num_random_frames
|
| 345 |
+
|
| 346 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 347 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 348 |
+
|
| 349 |
+
self.norm2 = norm_layer(dim)
|
| 350 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 351 |
+
self.mlp = ffn_layer(
|
| 352 |
+
in_features=dim,
|
| 353 |
+
hidden_features=mlp_hidden_dim,
|
| 354 |
+
act_layer=act_layer,
|
| 355 |
+
drop=drop,
|
| 356 |
+
bias=ffn_bias
|
| 357 |
+
)
|
| 358 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 359 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 360 |
+
|
| 361 |
+
self.sample_drop_ratio = drop_path
|
| 362 |
+
self.masks = {}
|
| 363 |
+
|
| 364 |
+
@torch.no_grad()
|
| 365 |
+
def _prepare_blockwise_causal_attn_mask(self,
|
| 366 |
+
device: torch.device | str, num_frames: int = 21,
|
| 367 |
+
frame_seqlen: int = 1560, num_frame_per_block=1
|
| 368 |
+
) -> BlockMask:
|
| 369 |
+
"""
|
| 370 |
+
we will divide the token sequence into the following format
|
| 371 |
+
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
| 372 |
+
We use flexattention to construct the attention mask
|
| 373 |
+
"""
|
| 374 |
+
total_length = num_frames * frame_seqlen
|
| 375 |
+
|
| 376 |
+
# we do right padding to get to a multiple of 128
|
| 377 |
+
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
| 378 |
+
|
| 379 |
+
ends = torch.zeros(total_length + padded_length,
|
| 380 |
+
device=device, dtype=torch.long)
|
| 381 |
+
|
| 382 |
+
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
| 383 |
+
frame_indices = torch.arange(
|
| 384 |
+
start=0,
|
| 385 |
+
end=total_length,
|
| 386 |
+
step=frame_seqlen * num_frame_per_block,
|
| 387 |
+
device=device
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
for tmp in frame_indices:
|
| 391 |
+
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
| 392 |
+
frame_seqlen * num_frame_per_block
|
| 393 |
+
|
| 394 |
+
def attention_mask(b, h, q_idx, kv_idx):
|
| 395 |
+
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
| 396 |
+
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
| 397 |
+
|
| 398 |
+
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
| 399 |
+
KV_LEN=total_length + padded_length, device=device)
|
| 400 |
+
|
| 401 |
+
return block_mask
|
| 402 |
+
|
| 403 |
+
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
|
| 404 |
+
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
|
| 405 |
+
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
| 406 |
+
|
| 407 |
+
# Fast path for full attention (camera head) - skip mask computation
|
| 408 |
+
if full_attention:
|
| 409 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 410 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
|
| 411 |
+
|
| 412 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 413 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 414 |
+
|
| 415 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 416 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 417 |
+
x = x + self.drop_path1(ffn_residual_func(x))
|
| 418 |
+
else:
|
| 419 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 420 |
+
x = x + ffn_residual_func(x)
|
| 421 |
+
return x
|
| 422 |
+
|
| 423 |
+
mask_block = self._prepare_blockwise_causal_attn_mask(
|
| 424 |
+
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 428 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
|
| 429 |
+
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
|
| 430 |
+
|
| 431 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 432 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 433 |
+
|
| 434 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 435 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 436 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 437 |
+
else:
|
| 438 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 439 |
+
x = x + ffn_residual_func(x)
|
| 440 |
+
return x
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class SDPABlock(nn.Module):
|
| 444 |
+
"""
|
| 445 |
+
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
|
| 446 |
+
with dict-based KV cache. No FlashInfer dependency required.
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
def __init__(
|
| 450 |
+
self,
|
| 451 |
+
dim: int,
|
| 452 |
+
num_heads: int,
|
| 453 |
+
mlp_ratio: float = 4.0,
|
| 454 |
+
qkv_bias: bool = True,
|
| 455 |
+
proj_bias: bool = True,
|
| 456 |
+
ffn_bias: bool = True,
|
| 457 |
+
drop: float = 0.0,
|
| 458 |
+
attn_drop: float = 0.0,
|
| 459 |
+
init_values=None,
|
| 460 |
+
drop_path: float = 0.0,
|
| 461 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 462 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 463 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 464 |
+
qk_norm: bool = False,
|
| 465 |
+
rope=None,
|
| 466 |
+
kv_cache_sliding_window: int = 64,
|
| 467 |
+
kv_cache_scale_frames: int = 8,
|
| 468 |
+
kv_cache_cross_frame_special: bool = True,
|
| 469 |
+
kv_cache_include_scale_frames: bool = True,
|
| 470 |
+
kv_cache_camera_only: bool = False,
|
| 471 |
+
) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.norm1 = norm_layer(dim)
|
| 474 |
+
self.attn = SDPAAttention(
|
| 475 |
+
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
|
| 476 |
+
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
|
| 477 |
+
kv_cache_sliding_window=kv_cache_sliding_window,
|
| 478 |
+
kv_cache_scale_frames=kv_cache_scale_frames,
|
| 479 |
+
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
| 480 |
+
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
| 481 |
+
kv_cache_camera_only=kv_cache_camera_only,
|
| 482 |
+
)
|
| 483 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 484 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 485 |
+
self.norm2 = norm_layer(dim)
|
| 486 |
+
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
| 487 |
+
act_layer=act_layer, drop=drop, bias=ffn_bias)
|
| 488 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 489 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 490 |
+
self.sample_drop_ratio = drop_path
|
| 491 |
+
|
| 492 |
+
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
| 493 |
+
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
| 494 |
+
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
| 495 |
+
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
| 496 |
+
def attn_residual_func(x, pos=None):
|
| 497 |
+
return self.ls1(self.attn(
|
| 498 |
+
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
| 499 |
+
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
| 500 |
+
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
|
| 501 |
+
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
|
| 502 |
+
num_register_tokens=num_register_tokens,
|
| 503 |
+
))
|
| 504 |
+
|
| 505 |
+
def ffn_residual_func(x):
|
| 506 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 507 |
+
|
| 508 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 509 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 510 |
+
x = x + self.drop_path1(ffn_residual_func(x))
|
| 511 |
+
else:
|
| 512 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 513 |
+
x = x + ffn_residual_func(x)
|
| 514 |
+
return x
|
lingbot_map/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
lingbot_map/layers/flashinfer_cache.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FlashInfer KV Cache Manager — Two-Stream Paged Design.
|
| 3 |
+
|
| 4 |
+
Two logical streams sharing one physical page pool per layer:
|
| 5 |
+
|
| 6 |
+
Patch stream (recyclable):
|
| 7 |
+
- page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
|
| 8 |
+
- Exactly 1 patch page per frame
|
| 9 |
+
- Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
|
| 10 |
+
- Recent frames → live_window_patch_pages (evicted when > sliding_window)
|
| 11 |
+
|
| 12 |
+
Special stream (append-only, never recycled):
|
| 13 |
+
- num_special_tokens (6) special tokens per frame
|
| 14 |
+
- Packed continuously: one special page holds floor(page_size/6) frames
|
| 15 |
+
e.g. page_size=256 → 42 frames per special page, 4 slots wasted
|
| 16 |
+
- Specials written for EVERY frame (including scale + window), not just evicted ones.
|
| 17 |
+
|
| 18 |
+
Physical layout per block:
|
| 19 |
+
kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
|
| 20 |
+
Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
|
| 21 |
+
Pages max_patch_pages .. max_pages-1: special page pool (append-only)
|
| 22 |
+
dim 1: 0=K 1=V
|
| 23 |
+
|
| 24 |
+
Attention computation:
|
| 25 |
+
visible = scale_patch_pages + live_window_patch_pages + all_special_pages
|
| 26 |
+
Special pages placed LAST → paged_kv_last_page_len naturally describes
|
| 27 |
+
the partial special-tail without a custom mask.
|
| 28 |
+
|
| 29 |
+
plan() is called ONCE per frame step (when block_idx == 0).
|
| 30 |
+
run() is called per layer, reusing the same plan. All layers at the
|
| 31 |
+
same frame step have identical page structures (same page IDs in same
|
| 32 |
+
positions), so reusing the plan across layers is correct.
|
| 33 |
+
|
| 34 |
+
Public API is drop-in compatible with the previous FlashInferKVCacheManager:
|
| 35 |
+
append_frame(block_idx, k, v)
|
| 36 |
+
evict_frames(block_idx, scale_frames, sliding_window, ...)
|
| 37 |
+
compute_attention(block_idx, q) -> out
|
| 38 |
+
reset()
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import collections
|
| 42 |
+
import math
|
| 43 |
+
from typing import List
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
from torch import Tensor
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
import flashinfer
|
| 50 |
+
FLASHINFER_AVAILABLE = True
|
| 51 |
+
except ImportError:
|
| 52 |
+
FLASHINFER_AVAILABLE = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class FlashInferKVCacheManager:
|
| 56 |
+
"""
|
| 57 |
+
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
num_blocks: Number of Transformer blocks (one cache per block).
|
| 61 |
+
max_num_frames: Maximum frames held in the KV window at once
|
| 62 |
+
(scale_frames + sliding_window + headroom).
|
| 63 |
+
tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
|
| 64 |
+
num_heads: Number of KV heads (= QO heads; MHA assumed).
|
| 65 |
+
head_dim: Head dimension (64 for ViT-L).
|
| 66 |
+
dtype: Storage dtype (bfloat16 / float16).
|
| 67 |
+
device: CUDA device.
|
| 68 |
+
num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
|
| 69 |
+
scale_frames: Number of always-resident scale frames (8).
|
| 70 |
+
sliding_window: Sliding window size (64).
|
| 71 |
+
max_total_frames: Upper bound on total frames ever processed; used to
|
| 72 |
+
pre-allocate the special page pool (default 2048).
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
num_blocks: int,
|
| 78 |
+
max_num_frames: int,
|
| 79 |
+
tokens_per_frame: int,
|
| 80 |
+
num_heads: int,
|
| 81 |
+
head_dim: int,
|
| 82 |
+
dtype: torch.dtype,
|
| 83 |
+
device: torch.device,
|
| 84 |
+
num_special_tokens: int = 6,
|
| 85 |
+
scale_frames: int = 8,
|
| 86 |
+
sliding_window: int = 64,
|
| 87 |
+
max_total_frames: int = 2048,
|
| 88 |
+
force_fp32: bool = False,
|
| 89 |
+
fa3: bool = False,
|
| 90 |
+
):
|
| 91 |
+
if not FLASHINFER_AVAILABLE:
|
| 92 |
+
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
|
| 93 |
+
|
| 94 |
+
self.num_blocks = num_blocks
|
| 95 |
+
self.num_special_tokens = num_special_tokens # 6
|
| 96 |
+
self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
|
| 97 |
+
# Use exact page_size = patches_per_frame to eliminate zero-padded slots.
|
| 98 |
+
# FA2 (backend="fa2") supports non-power-of-2 page sizes.
|
| 99 |
+
# FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
|
| 100 |
+
p = self.patches_per_frame
|
| 101 |
+
if fa3:
|
| 102 |
+
# Round up to next power-of-2 for FA3 SM90 kernel requirement.
|
| 103 |
+
# e.g. 999 → 1024 (25 zero-padded slots per patch page)
|
| 104 |
+
self.page_size = 1 << (p - 1).bit_length()
|
| 105 |
+
else:
|
| 106 |
+
self.page_size = p # exact: no zero padding in patch pages
|
| 107 |
+
self.scale_frames = scale_frames # 8
|
| 108 |
+
self.sliding_window = sliding_window # 64
|
| 109 |
+
self.num_heads = num_heads
|
| 110 |
+
self.head_dim = head_dim
|
| 111 |
+
self.tokens_per_frame = tokens_per_frame
|
| 112 |
+
|
| 113 |
+
assert self.patches_per_frame > 0, (
|
| 114 |
+
f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
|
| 115 |
+
)
|
| 116 |
+
assert self.page_size > 0
|
| 117 |
+
|
| 118 |
+
# force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
|
| 119 |
+
# instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
|
| 120 |
+
# in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
|
| 121 |
+
self.force_fp32 = force_fp32
|
| 122 |
+
if force_fp32:
|
| 123 |
+
self.dtype = torch.float32
|
| 124 |
+
else:
|
| 125 |
+
if dtype == torch.float32:
|
| 126 |
+
dtype = torch.bfloat16
|
| 127 |
+
self.dtype = dtype
|
| 128 |
+
self.device = device
|
| 129 |
+
|
| 130 |
+
# ── Page pool sizing ─────────────────────────────────────────────────
|
| 131 |
+
# Patch: scale + window + 16 headroom (pages recycled → fixed count)
|
| 132 |
+
max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
|
| 133 |
+
# Special: enough for max_total_frames × 6 tokens, plus 16 headroom
|
| 134 |
+
max_special_pages = (
|
| 135 |
+
math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
|
| 136 |
+
)
|
| 137 |
+
self.max_patch_pages = max_patch_pages
|
| 138 |
+
self.max_num_pages = max_patch_pages + max_special_pages
|
| 139 |
+
|
| 140 |
+
# ── Physical paged KV caches ─────────────────────────────────────────
|
| 141 |
+
# Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
|
| 142 |
+
self.kv_caches: List[Tensor] = [
|
| 143 |
+
torch.zeros(
|
| 144 |
+
self.max_num_pages, 2, self.page_size, num_heads, head_dim,
|
| 145 |
+
dtype=dtype, device=device,
|
| 146 |
+
)
|
| 147 |
+
for _ in range(num_blocks)
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
# ── Per-block state ──────────────────────────────────────────────────
|
| 151 |
+
# Patch pages (IDs 0 .. max_patch_pages-1)
|
| 152 |
+
self.scale_patch_pages: List[collections.deque] = [
|
| 153 |
+
collections.deque() for _ in range(num_blocks)
|
| 154 |
+
]
|
| 155 |
+
self.live_window_patch_pages: List[collections.deque] = [
|
| 156 |
+
collections.deque() for _ in range(num_blocks)
|
| 157 |
+
]
|
| 158 |
+
self.free_patch_pages: List[List[int]] = [
|
| 159 |
+
list(range(max_patch_pages)) for _ in range(num_blocks)
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# Special pages (IDs max_patch_pages .. max_num_pages-1)
|
| 163 |
+
self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
|
| 164 |
+
self.free_special_pages: List[List[int]] = [
|
| 165 |
+
list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
|
| 166 |
+
]
|
| 167 |
+
self.special_token_count: List[int] = [0] * num_blocks
|
| 168 |
+
|
| 169 |
+
# Frame counter per block (determines scale vs window routing)
|
| 170 |
+
self.frame_count: List[int] = [0] * num_blocks
|
| 171 |
+
|
| 172 |
+
# Deferred eviction support for flow-based keyframe selection.
|
| 173 |
+
# When True, evict_frames() becomes a no-op; caller must later call
|
| 174 |
+
# execute_deferred_eviction() or rollback_last_frame().
|
| 175 |
+
self._defer_eviction: bool = False
|
| 176 |
+
|
| 177 |
+
# ── FlashInfer wrapper ───────────────────────────────────────────────
|
| 178 |
+
# plan() is called once per frame step (block_idx == 0).
|
| 179 |
+
# run() is called per layer, reusing the same aux structures.
|
| 180 |
+
# backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
|
| 181 |
+
# FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
|
| 182 |
+
# FlashInfer 0.2.5 at 518×378 resolution.
|
| 183 |
+
_fi_backend = "fa3" if fa3 else "fa2"
|
| 184 |
+
self.workspace_buffer = torch.zeros(
|
| 185 |
+
128 * 1024 * 1024, dtype=torch.uint8, device=device
|
| 186 |
+
)
|
| 187 |
+
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
| 188 |
+
self.workspace_buffer,
|
| 189 |
+
kv_layout="NHD",
|
| 190 |
+
backend=_fi_backend,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
|
| 194 |
+
self._qo_indptr = torch.tensor(
|
| 195 |
+
[0, tokens_per_frame], dtype=torch.int32, device=device
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# =========================================================================
|
| 199 |
+
# Public API (drop-in compatible with previous FlashInferKVCacheManager)
|
| 200 |
+
# =========================================================================
|
| 201 |
+
|
| 202 |
+
def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Append one frame's K/V tensors to the two-stream cache.
|
| 205 |
+
|
| 206 |
+
Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
|
| 207 |
+
i.e. specials come first (matching stream.py's patch_start_idx convention).
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
block_idx: Block/layer index (0 … num_blocks-1).
|
| 211 |
+
k: [tokens_per_frame, H, D] NHD layout.
|
| 212 |
+
v: [tokens_per_frame, H, D] NHD layout.
|
| 213 |
+
"""
|
| 214 |
+
n = self.num_special_tokens # 6
|
| 215 |
+
sp_k = k[:n].to(self.dtype) # [6, H, D]
|
| 216 |
+
patch_k = k[n:].to(self.dtype) # [256, H, D]
|
| 217 |
+
sp_v = v[:n].to(self.dtype)
|
| 218 |
+
patch_v = v[n:].to(self.dtype)
|
| 219 |
+
|
| 220 |
+
assert patch_k.shape[0] == self.patches_per_frame, (
|
| 221 |
+
f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
|
| 222 |
+
f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self._write_patch_page(block_idx, patch_k, patch_v)
|
| 226 |
+
self._write_special_tokens(block_idx, sp_k, sp_v)
|
| 227 |
+
self.frame_count[block_idx] += 1
|
| 228 |
+
|
| 229 |
+
def evict_frames(
|
| 230 |
+
self,
|
| 231 |
+
block_idx: int,
|
| 232 |
+
scale_frames: int,
|
| 233 |
+
sliding_window: int,
|
| 234 |
+
cross_frame_special: bool = True,
|
| 235 |
+
include_scale_frames: bool = True,
|
| 236 |
+
camera_only: bool = False,
|
| 237 |
+
num_register_tokens: int = 4,
|
| 238 |
+
) -> None:
|
| 239 |
+
"""
|
| 240 |
+
Evict old window patch pages (recycle to free list).
|
| 241 |
+
|
| 242 |
+
Special pages are NEVER evicted.
|
| 243 |
+
Scale pages are NEVER evicted.
|
| 244 |
+
Only live_window_patch_pages beyond `sliding_window` are recycled.
|
| 245 |
+
|
| 246 |
+
When ``_defer_eviction`` is True, this method is a no-op. The caller
|
| 247 |
+
is expected to later call ``execute_deferred_eviction()`` (keep frame)
|
| 248 |
+
or ``rollback_last_frame()`` (discard frame).
|
| 249 |
+
"""
|
| 250 |
+
if self._defer_eviction:
|
| 251 |
+
return
|
| 252 |
+
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
|
| 253 |
+
old_page = self.live_window_patch_pages[block_idx].popleft()
|
| 254 |
+
self.free_patch_pages[block_idx].append(old_page)
|
| 255 |
+
|
| 256 |
+
def execute_deferred_eviction(
|
| 257 |
+
self,
|
| 258 |
+
block_idx: int,
|
| 259 |
+
scale_frames: int,
|
| 260 |
+
sliding_window: int,
|
| 261 |
+
**kwargs,
|
| 262 |
+
) -> None:
|
| 263 |
+
"""Run the eviction that was skipped while ``_defer_eviction`` was True."""
|
| 264 |
+
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
|
| 265 |
+
old_page = self.live_window_patch_pages[block_idx].popleft()
|
| 266 |
+
self.free_patch_pages[block_idx].append(old_page)
|
| 267 |
+
|
| 268 |
+
def rollback_last_frame(self, block_idx: int) -> None:
|
| 269 |
+
"""Undo the most recent ``append_frame()`` for *block_idx*.
|
| 270 |
+
|
| 271 |
+
This reverses all three sub-operations of ``append_frame``:
|
| 272 |
+
patch page allocation, special-token write, and frame_count increment.
|
| 273 |
+
It must be called **before** any eviction for that frame (i.e. while
|
| 274 |
+
``_defer_eviction`` is True or before ``evict_frames`` is called).
|
| 275 |
+
"""
|
| 276 |
+
assert self.frame_count[block_idx] > 0, (
|
| 277 |
+
f"block {block_idx}: cannot rollback, frame_count is 0"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# 1) Undo patch page ── pop from whichever deque it was routed to.
|
| 281 |
+
if self.frame_count[block_idx] > self.scale_frames:
|
| 282 |
+
page_id = self.live_window_patch_pages[block_idx].pop()
|
| 283 |
+
else:
|
| 284 |
+
page_id = self.scale_patch_pages[block_idx].pop()
|
| 285 |
+
self.free_patch_pages[block_idx].append(page_id)
|
| 286 |
+
|
| 287 |
+
# 2) Undo special tokens
|
| 288 |
+
n = self.num_special_tokens
|
| 289 |
+
new_count = self.special_token_count[block_idx] - n
|
| 290 |
+
assert new_count >= 0, (
|
| 291 |
+
f"block {block_idx}: special_token_count underflow "
|
| 292 |
+
f"({self.special_token_count[block_idx]} - {n})"
|
| 293 |
+
)
|
| 294 |
+
new_num_pages = math.ceil(new_count / self.page_size) if new_count > 0 else 0
|
| 295 |
+
while len(self.all_special_pages[block_idx]) > new_num_pages:
|
| 296 |
+
freed = self.all_special_pages[block_idx].pop()
|
| 297 |
+
self.free_special_pages[block_idx].append(freed)
|
| 298 |
+
self.special_token_count[block_idx] = new_count
|
| 299 |
+
|
| 300 |
+
# 3) Decrement frame count
|
| 301 |
+
self.frame_count[block_idx] -= 1
|
| 302 |
+
|
| 303 |
+
def _gather_kv(self, block_idx: int):
|
| 304 |
+
"""
|
| 305 |
+
Gather all visible K and V tokens from the paged cache into dense tensors.
|
| 306 |
+
|
| 307 |
+
Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
|
| 308 |
+
supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
k_flat: [kv_len, H, D] — all visible K tokens concatenated
|
| 312 |
+
v_flat: [kv_len, H, D] — all visible V tokens concatenated
|
| 313 |
+
"""
|
| 314 |
+
visible = self.build_visible_page_table(block_idx)
|
| 315 |
+
last_len = self.compute_last_page_len(block_idx)
|
| 316 |
+
P = self.page_size
|
| 317 |
+
|
| 318 |
+
parts_k, parts_v = [], []
|
| 319 |
+
for i, pid in enumerate(visible):
|
| 320 |
+
n = last_len if (i == len(visible) - 1) else P
|
| 321 |
+
parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
|
| 322 |
+
parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
|
| 323 |
+
|
| 324 |
+
k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
|
| 325 |
+
v_flat = torch.cat(parts_v, dim=0)
|
| 326 |
+
return k_flat, v_flat
|
| 327 |
+
|
| 328 |
+
def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
|
| 329 |
+
"""
|
| 330 |
+
Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
|
| 331 |
+
|
| 332 |
+
When self.force_fp32 is True, gathers all visible K/V into dense tensors
|
| 333 |
+
and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
|
| 334 |
+
This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
|
| 335 |
+
|
| 336 |
+
plan() is called once per frame step (when block_idx == 0).
|
| 337 |
+
All layers at the same step share the same visible page structure,
|
| 338 |
+
so the plan is reused by calling run() with each layer's kv_cache.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
block_idx: Block/layer index.
|
| 342 |
+
q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
out: [q_len, H, D]
|
| 346 |
+
"""
|
| 347 |
+
if self.frame_count[block_idx] == 0:
|
| 348 |
+
# No KV present yet (should not occur in normal usage after append_frame)
|
| 349 |
+
return torch.zeros_like(q)
|
| 350 |
+
|
| 351 |
+
if self.force_fp32:
|
| 352 |
+
# ── fp32 gather+SDPA path ─────────────────────────────────────────
|
| 353 |
+
# Gather visible K/V from paged cache and run SDPA in fp32.
|
| 354 |
+
# This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
|
| 355 |
+
# q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
|
| 356 |
+
import torch.nn.functional as F_nn
|
| 357 |
+
k_flat, v_flat = self._gather_kv(block_idx)
|
| 358 |
+
q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
|
| 359 |
+
k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
|
| 360 |
+
v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
|
| 361 |
+
out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
|
| 362 |
+
return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
|
| 363 |
+
|
| 364 |
+
if block_idx == 0:
|
| 365 |
+
# ── Plan once per frame step ──────────────────────────────────────
|
| 366 |
+
# Build visible page table from block 0's state.
|
| 367 |
+
# All blocks have identical page structures, so this plan is valid
|
| 368 |
+
# for all subsequent run() calls (block_idx = 1, 2, ...).
|
| 369 |
+
visible = self.build_visible_page_table(0)
|
| 370 |
+
last_len = self.compute_last_page_len(0)
|
| 371 |
+
|
| 372 |
+
assert visible, "visible page table is empty after append_frame"
|
| 373 |
+
assert 1 <= last_len <= self.page_size, (
|
| 374 |
+
f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
|
| 378 |
+
paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
|
| 379 |
+
paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
|
| 380 |
+
|
| 381 |
+
self.prefill_wrapper.plan(
|
| 382 |
+
self._qo_indptr,
|
| 383 |
+
paged_kv_indptr,
|
| 384 |
+
paged_kv_indices,
|
| 385 |
+
paged_kv_last_page_len,
|
| 386 |
+
num_qo_heads = self.num_heads,
|
| 387 |
+
num_kv_heads = self.num_heads,
|
| 388 |
+
head_dim_qk = self.head_dim,
|
| 389 |
+
page_size = self.page_size,
|
| 390 |
+
causal = False, # custom page ordering; no causal mask
|
| 391 |
+
pos_encoding_mode = "NONE", # RoPE applied externally before append
|
| 392 |
+
q_data_type = self.dtype,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# ── Run attention for this layer ──────────────────────────────────────
|
| 396 |
+
# Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
|
| 397 |
+
return self.prefill_wrapper.run(
|
| 398 |
+
q = q.to(self.dtype).contiguous(),
|
| 399 |
+
paged_kv_cache = self.kv_caches[block_idx],
|
| 400 |
+
) # → [q_len, H, D]
|
| 401 |
+
|
| 402 |
+
def reset(self) -> None:
|
| 403 |
+
"""Reset all per-block state for a new sequence."""
|
| 404 |
+
for i in range(self.num_blocks):
|
| 405 |
+
self.scale_patch_pages[i].clear()
|
| 406 |
+
self.live_window_patch_pages[i].clear()
|
| 407 |
+
self.all_special_pages[i].clear()
|
| 408 |
+
self.free_patch_pages[i] = list(range(self.max_patch_pages))
|
| 409 |
+
self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
|
| 410 |
+
self.special_token_count[i] = 0
|
| 411 |
+
self.frame_count[i] = 0
|
| 412 |
+
|
| 413 |
+
# =========================================================================
|
| 414 |
+
# Helper methods
|
| 415 |
+
# =========================================================================
|
| 416 |
+
|
| 417 |
+
def build_visible_page_table(self, block_idx: int) -> List[int]:
|
| 418 |
+
"""
|
| 419 |
+
Return page IDs in strict order: scale → window → special.
|
| 420 |
+
|
| 421 |
+
Placing special pages last means only the final page may be partially
|
| 422 |
+
full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
|
| 423 |
+
without a custom attention mask.
|
| 424 |
+
"""
|
| 425 |
+
return (
|
| 426 |
+
list(self.scale_patch_pages[block_idx]) +
|
| 427 |
+
list(self.live_window_patch_pages[block_idx]) +
|
| 428 |
+
list(self.all_special_pages[block_idx])
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
def compute_last_page_len(self, block_idx: int) -> int:
|
| 432 |
+
"""
|
| 433 |
+
Valid token count in the last page of the visible sequence.
|
| 434 |
+
|
| 435 |
+
- No special pages → last page is a patch page.
|
| 436 |
+
Returns patches_per_frame (real tokens written),
|
| 437 |
+
which may be < page_size when page_size was rounded
|
| 438 |
+
up to a power of 2.
|
| 439 |
+
- Special tail partial → special_token_count % page_size.
|
| 440 |
+
- Special tail exactly full → page_size.
|
| 441 |
+
"""
|
| 442 |
+
if not self.all_special_pages[block_idx]:
|
| 443 |
+
# Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
|
| 444 |
+
# positions P..page_size-1 are zero padding. Tell FlashInfer the true
|
| 445 |
+
# valid count so it doesn't read beyond the real tokens.
|
| 446 |
+
return self.patches_per_frame
|
| 447 |
+
|
| 448 |
+
tail = self.special_token_count[block_idx] % self.page_size
|
| 449 |
+
return self.page_size if tail == 0 else tail
|
| 450 |
+
|
| 451 |
+
# ── Internal write helpers ────────────────────────────────────────────────
|
| 452 |
+
|
| 453 |
+
def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
|
| 454 |
+
"""
|
| 455 |
+
Allocate one free patch page and write patches_per_frame patch tokens.
|
| 456 |
+
|
| 457 |
+
Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
|
| 458 |
+
the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
|
| 459 |
+
kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
|
| 460 |
+
patch_k/v fill exactly one full page (patches_per_frame == page_size).
|
| 461 |
+
|
| 462 |
+
Routes to scale_patch_pages if still filling scale quota,
|
| 463 |
+
otherwise to live_window_patch_pages.
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
page_id: Physical page index used.
|
| 467 |
+
"""
|
| 468 |
+
assert self.free_patch_pages[block_idx], (
|
| 469 |
+
f"block {block_idx}: patch page pool exhausted — "
|
| 470 |
+
f"scale={len(self.scale_patch_pages[block_idx])}, "
|
| 471 |
+
f"window={len(self.live_window_patch_pages[block_idx])}, "
|
| 472 |
+
f"free={len(self.free_patch_pages[block_idx])}"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
page_id = self.free_patch_pages[block_idx].pop()
|
| 476 |
+
|
| 477 |
+
# Direct slice write: positions 0..patches_per_frame-1.
|
| 478 |
+
# When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
|
| 479 |
+
# this is equivalent to a full-page write. When page_size > patches_per_frame
|
| 480 |
+
# (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
|
| 481 |
+
# positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
|
| 482 |
+
P = self.patches_per_frame
|
| 483 |
+
self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
|
| 484 |
+
self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
|
| 485 |
+
|
| 486 |
+
if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
|
| 487 |
+
self.scale_patch_pages[block_idx].append(page_id)
|
| 488 |
+
else:
|
| 489 |
+
self.live_window_patch_pages[block_idx].append(page_id)
|
| 490 |
+
|
| 491 |
+
return page_id
|
| 492 |
+
|
| 493 |
+
def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
|
| 494 |
+
"""
|
| 495 |
+
Append num_special_tokens (6) special tokens to the special stream.
|
| 496 |
+
|
| 497 |
+
Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
|
| 498 |
+
tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
|
| 499 |
+
overhead of flashinfer.page.append_paged_kv_cache.
|
| 500 |
+
|
| 501 |
+
Handles page-boundary crossing: if 6 tokens straddle two pages, performs
|
| 502 |
+
two slice writes (rare — page_size=256 >> 6).
|
| 503 |
+
"""
|
| 504 |
+
remaining = self.num_special_tokens # 6
|
| 505 |
+
written = 0
|
| 506 |
+
|
| 507 |
+
while remaining > 0:
|
| 508 |
+
tail_offset = self.special_token_count[block_idx] % self.page_size
|
| 509 |
+
|
| 510 |
+
if tail_offset == 0:
|
| 511 |
+
# Current tail page is full (or no page exists) — allocate a new one
|
| 512 |
+
assert self.free_special_pages[block_idx], (
|
| 513 |
+
f"block {block_idx}: special page pool exhausted at "
|
| 514 |
+
f"special_token_count={self.special_token_count[block_idx]}. "
|
| 515 |
+
f"Increase max_total_frames."
|
| 516 |
+
)
|
| 517 |
+
new_page = self.free_special_pages[block_idx].pop()
|
| 518 |
+
self.all_special_pages[block_idx].append(new_page)
|
| 519 |
+
|
| 520 |
+
tail_page = self.all_special_pages[block_idx][-1]
|
| 521 |
+
space = self.page_size - tail_offset # free slots in tail page
|
| 522 |
+
write_n = min(remaining, space)
|
| 523 |
+
|
| 524 |
+
# Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
|
| 525 |
+
# shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
|
| 526 |
+
end = tail_offset + write_n
|
| 527 |
+
self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
|
| 528 |
+
self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
|
| 529 |
+
|
| 530 |
+
self.special_token_count[block_idx] += write_n
|
| 531 |
+
written += write_n
|
| 532 |
+
remaining -= write_n
|
| 533 |
+
|
| 534 |
+
# ── Legacy property (used by stream.py) ──────────────────────────────────
|
| 535 |
+
|
| 536 |
+
@property
|
| 537 |
+
def num_frames(self) -> int:
|
| 538 |
+
"""Number of frames appended to block 0 (representative)."""
|
| 539 |
+
return self.frame_count[0] if self.frame_count else 0
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# =============================================================================
|
| 543 |
+
# Sanity check
|
| 544 |
+
# =============================================================================
|
| 545 |
+
|
| 546 |
+
def _sanity_check():
|
| 547 |
+
"""
|
| 548 |
+
Minimal smoke test.
|
| 549 |
+
Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
|
| 550 |
+
"""
|
| 551 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 552 |
+
if not torch.cuda.is_available():
|
| 553 |
+
print("[sanity_check] CUDA not available — skipping.")
|
| 554 |
+
return
|
| 555 |
+
|
| 556 |
+
tokens_per_frame = 262 # 256 patch + 6 special (224×224)
|
| 557 |
+
num_special = 6
|
| 558 |
+
patches_per_frame = tokens_per_frame - num_special # 256
|
| 559 |
+
page_size = patches_per_frame # 256
|
| 560 |
+
|
| 561 |
+
mgr = FlashInferKVCacheManager(
|
| 562 |
+
num_blocks = 2,
|
| 563 |
+
max_num_frames = 88,
|
| 564 |
+
tokens_per_frame = tokens_per_frame,
|
| 565 |
+
num_heads = 16,
|
| 566 |
+
head_dim = 64,
|
| 567 |
+
dtype = torch.bfloat16,
|
| 568 |
+
device = device,
|
| 569 |
+
num_special_tokens = num_special,
|
| 570 |
+
scale_frames = 8,
|
| 571 |
+
sliding_window = 64,
|
| 572 |
+
max_total_frames = 200,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
def make_kv():
|
| 576 |
+
k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
| 577 |
+
v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
| 578 |
+
return k, v
|
| 579 |
+
|
| 580 |
+
def make_q():
|
| 581 |
+
return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
| 582 |
+
|
| 583 |
+
for block in range(2):
|
| 584 |
+
for t in range(100):
|
| 585 |
+
k, v = make_kv()
|
| 586 |
+
mgr.append_frame(block, k, v)
|
| 587 |
+
mgr.evict_frames(block, scale_frames=8, sliding_window=64)
|
| 588 |
+
|
| 589 |
+
# ── Page count checks ───────────────────────────────────────────────
|
| 590 |
+
n_scale = len(mgr.scale_patch_pages[block])
|
| 591 |
+
n_window = len(mgr.live_window_patch_pages[block])
|
| 592 |
+
n_spec = len(mgr.all_special_pages[block])
|
| 593 |
+
sp_count = mgr.special_token_count[block]
|
| 594 |
+
|
| 595 |
+
assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
|
| 596 |
+
assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
|
| 597 |
+
# 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
|
| 598 |
+
expected_spec_pages = math.ceil(100 * num_special / page_size)
|
| 599 |
+
assert n_spec == expected_spec_pages, (
|
| 600 |
+
f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
|
| 601 |
+
)
|
| 602 |
+
assert sp_count == 100 * num_special, (
|
| 603 |
+
f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# ── last_page_len ────────────────────────────────────────────────────
|
| 607 |
+
last_len = mgr.compute_last_page_len(block)
|
| 608 |
+
tail = sp_count % page_size
|
| 609 |
+
expected_len = page_size if tail == 0 else tail
|
| 610 |
+
assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
|
| 611 |
+
|
| 612 |
+
# ── visible page table order ─────────────────────────────────────────
|
| 613 |
+
visible = mgr.build_visible_page_table(block)
|
| 614 |
+
assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
|
| 615 |
+
for pid in visible[:n_scale + n_window]:
|
| 616 |
+
assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
|
| 617 |
+
for pid in visible[n_scale + n_window:]:
|
| 618 |
+
assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
|
| 619 |
+
|
| 620 |
+
# ── forward pass: plan() once for block 0, run() for both blocks ─────
|
| 621 |
+
if block == 1:
|
| 622 |
+
# Simulate the actual calling pattern: plan on block 0, run on both
|
| 623 |
+
q0 = make_q()
|
| 624 |
+
out0 = mgr.compute_attention(0, q0) # triggers plan()
|
| 625 |
+
q1 = make_q()
|
| 626 |
+
out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
|
| 627 |
+
assert out0.shape == (tokens_per_frame, 16, 64)
|
| 628 |
+
assert out1.shape == (tokens_per_frame, 16, 64)
|
| 629 |
+
|
| 630 |
+
print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
|
| 631 |
+
f"special_pages={n_spec}, special_tokens={sp_count}, "
|
| 632 |
+
f"last_page_len={last_len}")
|
| 633 |
+
|
| 634 |
+
mgr.reset()
|
| 635 |
+
assert mgr.frame_count[0] == 0
|
| 636 |
+
print("\n[sanity_check] All assertions passed.")
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
if __name__ == "__main__":
|
| 640 |
+
_sanity_check()
|
lingbot_map/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.inplace = inplace
|
| 19 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 20 |
+
|
| 21 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 22 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
lingbot_map/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
lingbot_map/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
|
| 51 |
+
|
| 52 |
+
self.img_size = image_HW
|
| 53 |
+
self.patch_size = patch_HW
|
| 54 |
+
self.patches_resolution = patch_grid_size
|
| 55 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 56 |
+
|
| 57 |
+
self.in_chans = in_chans
|
| 58 |
+
self.embed_dim = embed_dim
|
| 59 |
+
|
| 60 |
+
self.flatten_embedding = flatten_embedding
|
| 61 |
+
|
| 62 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 63 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 64 |
+
|
| 65 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 66 |
+
_, _, H, W = x.shape
|
| 67 |
+
patch_H, patch_W = self.patch_size
|
| 68 |
+
|
| 69 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 70 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 71 |
+
|
| 72 |
+
x = self.proj(x) # B C H W
|
| 73 |
+
H, W = x.size(2), x.size(3)
|
| 74 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 75 |
+
x = self.norm(x)
|
| 76 |
+
if not self.flatten_embedding:
|
| 77 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
def flops(self) -> float:
|
| 81 |
+
Ho, Wo = self.patches_resolution
|
| 82 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 83 |
+
if self.norm is not None:
|
| 84 |
+
flops += Ho * Wo * self.embed_dim
|
| 85 |
+
return flops
|
lingbot_map/layers/rope.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
| 8 |
+
|
| 9 |
+
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
| 10 |
+
# which extends the original RoPE concept to handle 2D spatial positions.
|
| 11 |
+
|
| 12 |
+
# Inspired by:
|
| 13 |
+
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 14 |
+
# https://github.com/naver-ai/rope-vit
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from typing import Dict, Tuple
|
| 22 |
+
|
| 23 |
+
from typing import List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PositionGetter:
|
| 27 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
| 28 |
+
|
| 29 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
| 30 |
+
in a 2D grid, caching results to avoid redundant computations.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
| 34 |
+
grid dimensions.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
"""Initializes the position generator with an empty cache."""
|
| 39 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 40 |
+
|
| 41 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 42 |
+
"""Generates spatial positions for a batch of patches.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
batch_size: Number of samples in the batch.
|
| 46 |
+
height: Height of the grid in patches.
|
| 47 |
+
width: Width of the grid in patches.
|
| 48 |
+
device: Target device for the position tensor.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
| 52 |
+
for each position in the grid, repeated for each batch item.
|
| 53 |
+
"""
|
| 54 |
+
if (height, width) not in self.position_cache:
|
| 55 |
+
y_coords = torch.arange(height, device=device)
|
| 56 |
+
x_coords = torch.arange(width, device=device)
|
| 57 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
| 58 |
+
self.position_cache[height, width] = positions
|
| 59 |
+
|
| 60 |
+
cached_positions = self.position_cache[height, width]
|
| 61 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
| 65 |
+
"""2D Rotary Position Embedding implementation.
|
| 66 |
+
|
| 67 |
+
This module applies rotary position embeddings to input tokens based on their
|
| 68 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
| 69 |
+
separately for vertical and horizontal dimensions.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
| 73 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
| 74 |
+
|
| 75 |
+
Attributes:
|
| 76 |
+
base_frequency: Base frequency for computing position embeddings.
|
| 77 |
+
scaling_factor: Factor to scale the computed frequencies.
|
| 78 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
| 82 |
+
"""Initializes the 2D RoPE module."""
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.base_frequency = frequency
|
| 85 |
+
self.scaling_factor = scaling_factor
|
| 86 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 87 |
+
|
| 88 |
+
def _compute_frequency_components(
|
| 89 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 90 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 91 |
+
"""Computes frequency components for rotary embeddings.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
dim: Feature dimension (must be even).
|
| 95 |
+
seq_len: Maximum sequence length.
|
| 96 |
+
device: Target device for computations.
|
| 97 |
+
dtype: Data type for the computed tensors.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
| 101 |
+
"""
|
| 102 |
+
cache_key = (dim, seq_len, device, dtype)
|
| 103 |
+
if cache_key not in self.frequency_cache:
|
| 104 |
+
# Compute frequency bands
|
| 105 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 106 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
| 107 |
+
|
| 108 |
+
# Generate position-dependent frequencies
|
| 109 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 110 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 111 |
+
|
| 112 |
+
# Compute and cache frequency components
|
| 113 |
+
angles = angles.to(dtype)
|
| 114 |
+
angles = torch.cat((angles, angles), dim=-1)
|
| 115 |
+
cos_components = angles.cos().to(dtype)
|
| 116 |
+
sin_components = angles.sin().to(dtype)
|
| 117 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
| 118 |
+
|
| 119 |
+
return self.frequency_cache[cache_key]
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
x: Input tensor to rotate.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Rotated feature tensor.
|
| 130 |
+
"""
|
| 131 |
+
feature_dim = x.shape[-1]
|
| 132 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
| 133 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 134 |
+
|
| 135 |
+
def _apply_1d_rope(
|
| 136 |
+
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
tokens: Input token features.
|
| 142 |
+
positions: Position indices.
|
| 143 |
+
cos_comp: Cosine components for rotation.
|
| 144 |
+
sin_comp: Sine components for rotation.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Tokens with applied rotary position embeddings.
|
| 148 |
+
"""
|
| 149 |
+
# Embed positions with frequency components
|
| 150 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
| 151 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
| 152 |
+
|
| 153 |
+
# Apply rotation
|
| 154 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
| 155 |
+
|
| 156 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
| 161 |
+
The feature dimension (dim) must be divisible by 4.
|
| 162 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
| 163 |
+
the y and x coordinates for each token.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
| 167 |
+
|
| 168 |
+
Raises:
|
| 169 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
| 170 |
+
"""
|
| 171 |
+
# Validate inputs
|
| 172 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 173 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
| 174 |
+
|
| 175 |
+
# Compute feature dimension for each spatial direction
|
| 176 |
+
feature_dim = tokens.size(-1) // 2
|
| 177 |
+
|
| 178 |
+
# Get frequency components
|
| 179 |
+
max_position = int(positions.max()) + 1
|
| 180 |
+
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
| 181 |
+
|
| 182 |
+
# Split features for vertical and horizontal processing
|
| 183 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
| 184 |
+
|
| 185 |
+
# Apply RoPE separately for each dimension
|
| 186 |
+
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
| 187 |
+
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
| 188 |
+
|
| 189 |
+
# Combine processed features
|
| 190 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_1d_rotary_pos_embed(
|
| 195 |
+
dim: int,
|
| 196 |
+
pos: Union[np.ndarray, int],
|
| 197 |
+
theta: float = 10000.0,
|
| 198 |
+
use_real=False,
|
| 199 |
+
linear_factor=1.0,
|
| 200 |
+
ntk_factor=1.0,
|
| 201 |
+
repeat_interleave_real=True,
|
| 202 |
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
| 203 |
+
):
|
| 204 |
+
"""
|
| 205 |
+
计算1D旋转位置编码(RoPE)的频率张量。
|
| 206 |
+
|
| 207 |
+
RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
|
| 208 |
+
公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
dim: 特征维度,必须是偶数(因为要成对处理)
|
| 212 |
+
pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
|
| 213 |
+
theta: 基础频率,控制位置编码的周期性(默认10000)
|
| 214 |
+
use_real: 是否返回实数形式(cos和sin分开)还是复数形式
|
| 215 |
+
linear_factor: 线性缩放因子,用于上下文扩展
|
| 216 |
+
ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
|
| 217 |
+
repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
|
| 218 |
+
freqs_dtype: 频率张量的数据类型
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
|
| 222 |
+
实数形式:两个 [S, D] 的张量(cos和sin)
|
| 223 |
+
"""
|
| 224 |
+
# 确保维度是偶数(RoPE需要成对处理维度)
|
| 225 |
+
assert dim % 2 == 0
|
| 226 |
+
|
| 227 |
+
# 将位置转换为torch张量
|
| 228 |
+
if isinstance(pos, int):
|
| 229 |
+
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
|
| 230 |
+
if isinstance(pos, np.ndarray):
|
| 231 |
+
pos = torch.from_numpy(pos) # [S]
|
| 232 |
+
|
| 233 |
+
# 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
|
| 234 |
+
theta = theta * ntk_factor
|
| 235 |
+
|
| 236 |
+
# 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
|
| 237 |
+
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
|
| 238 |
+
# 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
|
| 239 |
+
freqs = (
|
| 240 |
+
1.0
|
| 241 |
+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
| 242 |
+
/ linear_factor
|
| 243 |
+
) # [D/2],每个频率对应一个维度对
|
| 244 |
+
|
| 245 |
+
# 步骤2:计算位置-频率矩阵
|
| 246 |
+
# ���用外积:pos[m] * freqs[i] = m * θ_i
|
| 247 |
+
# 结果:每个位置m和每个频率i的组合
|
| 248 |
+
freqs = torch.outer(pos, freqs) # [S, D/2]
|
| 249 |
+
|
| 250 |
+
# 步骤3:根据返回格式转换
|
| 251 |
+
if use_real and repeat_interleave_real:
|
| 252 |
+
# 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
|
| 253 |
+
# 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
|
| 254 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
| 255 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
| 256 |
+
return freqs_cos, freqs_sin
|
| 257 |
+
elif use_real:
|
| 258 |
+
# 方式2:拼接重复(用于stable audio, allegro等模型)
|
| 259 |
+
# 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
|
| 260 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
| 261 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
| 262 |
+
return freqs_cos, freqs_sin
|
| 263 |
+
else:
|
| 264 |
+
# 方式3:复数形式(用于lumina等模型)
|
| 265 |
+
# 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
|
| 266 |
+
# torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
|
| 267 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
|
| 268 |
+
return freqs_cis
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class WanRotaryPosEmbed(nn.Module):
|
| 272 |
+
"""
|
| 273 |
+
3D旋转位置编码(3D RoPE)模块
|
| 274 |
+
|
| 275 |
+
核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
|
| 276 |
+
每个维度(t, h, w)独立使用RoPE,然后拼接起来。
|
| 277 |
+
|
| 278 |
+
公式:
|
| 279 |
+
对于3D位置 (f, h, w)(帧、高度、宽度):
|
| 280 |
+
- 帧维度使用 dim_f 个特征维度
|
| 281 |
+
- 高度维度使用 dim_h 个特征维度
|
| 282 |
+
- 宽度维度使用 dim_w 个特征维度
|
| 283 |
+
其中 dim_f + dim_h + dim_w = attention_head_dim
|
| 284 |
+
"""
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
attention_head_dim: int,
|
| 288 |
+
patch_size: Tuple[int, int, int],
|
| 289 |
+
max_seq_len: int = 1024,
|
| 290 |
+
theta: float = 10000.0,
|
| 291 |
+
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
|
| 292 |
+
):
|
| 293 |
+
super().__init__()
|
| 294 |
+
|
| 295 |
+
self.attention_head_dim = attention_head_dim # 注意力头的总维度
|
| 296 |
+
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
|
| 297 |
+
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
|
| 298 |
+
|
| 299 |
+
# 步骤1:分配维度给三个空间维度
|
| 300 |
+
if fhw_dim is not None:
|
| 301 |
+
# 如果指定了维度分配,使用指定的
|
| 302 |
+
assert attention_head_dim == sum(
|
| 303 |
+
fhw_dim
|
| 304 |
+
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
|
| 305 |
+
t_dim, h_dim, w_dim = fhw_dim
|
| 306 |
+
else:
|
| 307 |
+
# 否则自动分配:h和w各占1/3,t占剩余
|
| 308 |
+
# 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
|
| 309 |
+
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
| 310 |
+
t_dim = attention_head_dim - h_dim - w_dim
|
| 311 |
+
|
| 312 |
+
# 保存维度分配以便在forward中使用
|
| 313 |
+
self.fhw_dim = (t_dim, h_dim, w_dim)
|
| 314 |
+
|
| 315 |
+
# 步骤2:为每个维度预计算频率
|
| 316 |
+
# 分别计算时间、高度、宽度三个维度的RoPE频率
|
| 317 |
+
freqs = []
|
| 318 |
+
for dim in [t_dim, h_dim, w_dim]:
|
| 319 |
+
# 每个维度独立调用1D RoPE
|
| 320 |
+
# 返回复数形式的频率: [max_seq_len, dim//2]
|
| 321 |
+
freq = get_1d_rotary_pos_embed(
|
| 322 |
+
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
| 323 |
+
)
|
| 324 |
+
freqs.append(freq)
|
| 325 |
+
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
|
| 326 |
+
self.freqs = torch.cat(freqs, dim=1)
|
| 327 |
+
|
| 328 |
+
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
|
| 329 |
+
"""
|
| 330 |
+
前向传播:为3D输入(视频帧+patch)生成旋转位置编码
|
| 331 |
+
|
| 332 |
+
参数:
|
| 333 |
+
- ppf (int): 帧数(patches per frame),当f_end为None时使用
|
| 334 |
+
- pph (int): 每帧的patch高度数量
|
| 335 |
+
- ppw (int): 每帧的patch宽度数量
|
| 336 |
+
- patch_start_idx (int): 每帧的特殊token数量(在patches之前)
|
| 337 |
+
- device: 计算设备(CPU/GPU)
|
| 338 |
+
- f_start (int): 起始帧索引(用于causal模式),默认为0
|
| 339 |
+
- f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
|
| 340 |
+
|
| 341 |
+
返回:
|
| 342 |
+
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
|
| 343 |
+
|
| 344 |
+
Token排列顺序:
|
| 345 |
+
[frame0_special_token_0, ..., frame0_special_token_N,
|
| 346 |
+
frame0_patch_0, ..., frame0_patch_M,
|
| 347 |
+
frame1_special_token_0, ..., frame1_special_token_N,
|
| 348 |
+
frame1_patch_0, ..., frame1_patch_M,
|
| 349 |
+
...]
|
| 350 |
+
|
| 351 |
+
模���:
|
| 352 |
+
- 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
|
| 353 |
+
- Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
# 步骤1:将预计算的频率移到目标设备,并分割成三个维度
|
| 357 |
+
self.freqs = self.freqs.to(device)
|
| 358 |
+
# 获取实际的维度分配
|
| 359 |
+
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
|
| 360 |
+
t_dim, h_dim, w_dim = self.fhw_dim
|
| 361 |
+
else:
|
| 362 |
+
# 自动分配的情况
|
| 363 |
+
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
|
| 364 |
+
t_dim = self.attention_head_dim - h_dim - w_dim
|
| 365 |
+
|
| 366 |
+
# 使用正确的split sizes(每个维度的一半)
|
| 367 |
+
freqs = self.freqs.split_with_sizes(
|
| 368 |
+
[
|
| 369 |
+
t_dim // 2, # 时间维度
|
| 370 |
+
h_dim // 2, # 高度维度
|
| 371 |
+
w_dim // 2, # 宽度维度
|
| 372 |
+
],
|
| 373 |
+
dim=1,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
|
| 377 |
+
if f_end is not None:
|
| 378 |
+
ppf = f_end - f_start
|
| 379 |
+
frame_slice = slice(f_start, f_end)
|
| 380 |
+
else:
|
| 381 |
+
# 非causal模式:使用从0开始的ppf个帧
|
| 382 |
+
frame_slice = slice(0, ppf)
|
| 383 |
+
|
| 384 |
+
# 步骤2:处理特殊token(如果存在)
|
| 385 |
+
## For other tokens
|
| 386 |
+
if patch_start_idx > 0:
|
| 387 |
+
# 2.1 为特殊token生成位置编码
|
| 388 |
+
# 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
|
| 389 |
+
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
|
| 390 |
+
# Shape: (ppf, patch_start_idx, dim)
|
| 391 |
+
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
|
| 392 |
+
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
|
| 393 |
+
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
|
| 394 |
+
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
|
| 395 |
+
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
|
| 396 |
+
|
| 397 |
+
# 2.2 为图像patch生成位置编码
|
| 398 |
+
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
|
| 399 |
+
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
|
| 400 |
+
# Shape: (ppf, pph, ppw, dim)
|
| 401 |
+
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
| 402 |
+
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
|
| 403 |
+
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
|
| 404 |
+
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
|
| 405 |
+
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
|
| 406 |
+
|
| 407 |
+
# 步骤3:按照正确的顺序组合特殊token和patches
|
| 408 |
+
# 每帧内部顺序:[特殊tokens, patches]
|
| 409 |
+
# Concatenate special tokens and patches for each frame along the second dimension
|
| 410 |
+
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
|
| 411 |
+
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
|
| 412 |
+
|
| 413 |
+
# 步骤4:展平为最终形状并添加batch和head维度
|
| 414 |
+
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
|
| 415 |
+
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
|
| 416 |
+
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
|
| 417 |
+
return freqs
|
| 418 |
+
|
| 419 |
+
# 如果没有特殊token(patch_start_idx == 0),只处理图像patches
|
| 420 |
+
# 所有patches位于 (f, 0:pph, 0:ppw)
|
| 421 |
+
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
| 422 |
+
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
|
| 423 |
+
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
|
| 424 |
+
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
|
| 425 |
+
return freqs
|
| 426 |
+
|
| 427 |
+
def apply_rotary_emb(x, freqs):
|
| 428 |
+
"""
|
| 429 |
+
应用旋转位置编码到输入特征
|
| 430 |
+
|
| 431 |
+
核心思想:使用复数乘法实现���征旋转,保持相对位置信息
|
| 432 |
+
|
| 433 |
+
数学原理:
|
| 434 |
+
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
|
| 435 |
+
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
|
| 436 |
+
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
|
| 437 |
+
|
| 438 |
+
这等价于旋转矩阵:
|
| 439 |
+
[cos(θ) -sin(θ)] [x1]
|
| 440 |
+
[sin(θ) cos(θ)] [x2]
|
| 441 |
+
|
| 442 |
+
参数:
|
| 443 |
+
- x: 输入特征 [batch, heads, seq_len, head_dim]
|
| 444 |
+
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
|
| 445 |
+
|
| 446 |
+
返回:
|
| 447 |
+
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
|
| 448 |
+
|
| 449 |
+
实现步骤:
|
| 450 |
+
1. 将x的每两个连续特征看作一个复数 (real, imag)
|
| 451 |
+
2. 与预计算的复数频率 e^(iθ) 相乘
|
| 452 |
+
3. 转回实数表示
|
| 453 |
+
"""
|
| 454 |
+
# 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
|
| 455 |
+
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
|
| 456 |
+
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
| 457 |
+
|
| 458 |
+
# 步骤2:转换为复数表示 [b, h, seq, 32]
|
| 459 |
+
# 每个元素是 real + imag*i
|
| 460 |
+
x_complex = torch.view_as_complex(x_reshaped)
|
| 461 |
+
|
| 462 |
+
# 步骤3:复数乘法实现旋转
|
| 463 |
+
# x_complex * freqs 相当于将每对特征旋转θ角度
|
| 464 |
+
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
|
| 465 |
+
x_rotated = x_complex * freqs
|
| 466 |
+
|
| 467 |
+
# 步骤4:转回实数表示 [b, h, seq, 32, 2]
|
| 468 |
+
x_real = torch.view_as_real(x_rotated)
|
| 469 |
+
|
| 470 |
+
# 步骤5:展平最后两维 [b, h, seq, 64]
|
| 471 |
+
x_out = x_real.flatten(3)
|
| 472 |
+
|
| 473 |
+
# 步骤6:转回原始数据类型
|
| 474 |
+
return x_out.to(x.dtype)
|
lingbot_map/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
# try:
|
| 39 |
+
# if XFORMERS_ENABLED:
|
| 40 |
+
# from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
# XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
# else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
# raise ImportError
|
| 47 |
+
# except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
|
lingbot_map/layers/vision_transformer.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block
|
| 20 |
+
|
| 21 |
+
# TODO: Check this
|
| 22 |
+
# We replace NestedTensorBlock with Block
|
| 23 |
+
from .block import Block
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("dinov2")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 29 |
+
if not depth_first and include_root:
|
| 30 |
+
fn(module=module, name=name)
|
| 31 |
+
for child_name, child_module in module.named_children():
|
| 32 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 33 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 34 |
+
if depth_first and include_root:
|
| 35 |
+
fn(module=module, name=name)
|
| 36 |
+
return module
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BlockChunk(nn.ModuleList):
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
for b in self:
|
| 42 |
+
x = b(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DinoVisionTransformer(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
img_size=224,
|
| 50 |
+
patch_size=16,
|
| 51 |
+
in_chans=3,
|
| 52 |
+
embed_dim=768,
|
| 53 |
+
depth=12,
|
| 54 |
+
num_heads=12,
|
| 55 |
+
mlp_ratio=4.0,
|
| 56 |
+
qkv_bias=True,
|
| 57 |
+
ffn_bias=True,
|
| 58 |
+
proj_bias=True,
|
| 59 |
+
drop_path_rate=0.0,
|
| 60 |
+
drop_path_uniform=False,
|
| 61 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 62 |
+
embed_layer=PatchEmbed,
|
| 63 |
+
act_layer=nn.GELU,
|
| 64 |
+
block_fn=Block,
|
| 65 |
+
ffn_layer="mlp",
|
| 66 |
+
block_chunks=1,
|
| 67 |
+
num_register_tokens=0,
|
| 68 |
+
interpolate_antialias=False,
|
| 69 |
+
interpolate_offset=0.1,
|
| 70 |
+
drop_cls_token=False,
|
| 71 |
+
qk_norm=False,
|
| 72 |
+
):
|
| 73 |
+
"""
|
| 74 |
+
Args:
|
| 75 |
+
img_size (int, tuple): input image size
|
| 76 |
+
patch_size (int, tuple): patch size
|
| 77 |
+
in_chans (int): number of input channels
|
| 78 |
+
embed_dim (int): embedding dimension
|
| 79 |
+
depth (int): depth of transformer
|
| 80 |
+
num_heads (int): number of attention heads
|
| 81 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 82 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 83 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 84 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 85 |
+
drop_path_rate (float): stochastic depth rate
|
| 86 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 87 |
+
weight_init (str): weight init scheme
|
| 88 |
+
init_values (float): layer-scale init values
|
| 89 |
+
embed_layer (nn.Module): patch embedding layer
|
| 90 |
+
act_layer (nn.Module): MLP activation layer
|
| 91 |
+
block_fn (nn.Module): transformer block class
|
| 92 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 93 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 94 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 95 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 96 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 97 |
+
"""
|
| 98 |
+
super().__init__()
|
| 99 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 100 |
+
|
| 101 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 102 |
+
self.num_tokens = 1 if not drop_cls_token else 0
|
| 103 |
+
self.n_blocks = depth
|
| 104 |
+
self.num_heads = num_heads
|
| 105 |
+
self.patch_size = patch_size
|
| 106 |
+
self.num_register_tokens = num_register_tokens
|
| 107 |
+
self.interpolate_antialias = interpolate_antialias
|
| 108 |
+
self.interpolate_offset = interpolate_offset
|
| 109 |
+
self.use_reentrant = False # hardcoded to False
|
| 110 |
+
|
| 111 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 112 |
+
num_patches = self.patch_embed.num_patches
|
| 113 |
+
|
| 114 |
+
self.drop_cls_token = drop_cls_token
|
| 115 |
+
|
| 116 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None
|
| 117 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 118 |
+
assert num_register_tokens >= 0
|
| 119 |
+
self.register_tokens = (
|
| 120 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if drop_path_uniform is True:
|
| 124 |
+
dpr = [drop_path_rate] * depth
|
| 125 |
+
else:
|
| 126 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 127 |
+
|
| 128 |
+
if ffn_layer == "mlp":
|
| 129 |
+
logger.info("using MLP layer as FFN")
|
| 130 |
+
ffn_layer = Mlp
|
| 131 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 132 |
+
logger.info("using SwiGLU layer as FFN")
|
| 133 |
+
ffn_layer = SwiGLUFFNFused
|
| 134 |
+
elif ffn_layer == "identity":
|
| 135 |
+
logger.info("using Identity layer as FFN")
|
| 136 |
+
|
| 137 |
+
def f(*args, **kwargs):
|
| 138 |
+
return nn.Identity()
|
| 139 |
+
|
| 140 |
+
ffn_layer = f
|
| 141 |
+
else:
|
| 142 |
+
raise NotImplementedError
|
| 143 |
+
|
| 144 |
+
blocks_list = [
|
| 145 |
+
block_fn(
|
| 146 |
+
dim=embed_dim,
|
| 147 |
+
num_heads=num_heads,
|
| 148 |
+
mlp_ratio=mlp_ratio,
|
| 149 |
+
qkv_bias=qkv_bias,
|
| 150 |
+
proj_bias=proj_bias,
|
| 151 |
+
ffn_bias=ffn_bias,
|
| 152 |
+
drop_path=dpr[i],
|
| 153 |
+
norm_layer=norm_layer,
|
| 154 |
+
act_layer=act_layer,
|
| 155 |
+
ffn_layer=ffn_layer,
|
| 156 |
+
init_values=init_values,
|
| 157 |
+
qk_norm=qk_norm,
|
| 158 |
+
)
|
| 159 |
+
for i in range(depth)
|
| 160 |
+
]
|
| 161 |
+
if block_chunks > 0:
|
| 162 |
+
self.chunked_blocks = True
|
| 163 |
+
chunked_blocks = []
|
| 164 |
+
chunksize = depth // block_chunks
|
| 165 |
+
for i in range(0, depth, chunksize):
|
| 166 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 167 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 168 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 169 |
+
else:
|
| 170 |
+
self.chunked_blocks = False
|
| 171 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 172 |
+
|
| 173 |
+
self.norm = norm_layer(embed_dim)
|
| 174 |
+
self.head = nn.Identity()
|
| 175 |
+
|
| 176 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 177 |
+
|
| 178 |
+
self.init_weights()
|
| 179 |
+
|
| 180 |
+
def init_weights(self):
|
| 181 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 182 |
+
nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None
|
| 183 |
+
if self.register_tokens is not None:
|
| 184 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 185 |
+
named_apply(init_weights_vit_timm, self)
|
| 186 |
+
|
| 187 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 188 |
+
previous_dtype = x.dtype
|
| 189 |
+
npatch = x.shape[1] - 1
|
| 190 |
+
N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1]
|
| 191 |
+
if npatch == N and w == h:
|
| 192 |
+
return self.pos_embed
|
| 193 |
+
pos_embed = self.pos_embed.float()
|
| 194 |
+
if not self.drop_cls_token:
|
| 195 |
+
class_pos_embed = pos_embed[:, 0]
|
| 196 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 197 |
+
else:
|
| 198 |
+
patch_pos_embed = pos_embed
|
| 199 |
+
dim = x.shape[-1]
|
| 200 |
+
w0 = w // self.patch_size
|
| 201 |
+
h0 = h // self.patch_size
|
| 202 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 203 |
+
assert N == M * M
|
| 204 |
+
kwargs = {}
|
| 205 |
+
if self.interpolate_offset:
|
| 206 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 207 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 208 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 209 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 210 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 211 |
+
else:
|
| 212 |
+
# Simply specify an output size instead of a scale factor
|
| 213 |
+
kwargs["size"] = (w0, h0)
|
| 214 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 215 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 216 |
+
mode="bicubic",
|
| 217 |
+
antialias=self.interpolate_antialias,
|
| 218 |
+
**kwargs,
|
| 219 |
+
)
|
| 220 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 221 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 222 |
+
if not self.drop_cls_token:
|
| 223 |
+
x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 224 |
+
else:
|
| 225 |
+
x = patch_pos_embed
|
| 226 |
+
return x.to(previous_dtype)
|
| 227 |
+
|
| 228 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 229 |
+
B, nc, w, h = x.shape
|
| 230 |
+
x = self.patch_embed(x)
|
| 231 |
+
if masks is not None:
|
| 232 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 233 |
+
|
| 234 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x
|
| 235 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 236 |
+
|
| 237 |
+
if self.register_tokens is not None:
|
| 238 |
+
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
|
| 239 |
+
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
def forward_features_list(self, x_list, masks_list):
|
| 243 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 244 |
+
|
| 245 |
+
for blk in self.blocks:
|
| 246 |
+
if self.training:
|
| 247 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 248 |
+
else:
|
| 249 |
+
x = blk(x)
|
| 250 |
+
|
| 251 |
+
all_x = x
|
| 252 |
+
output = []
|
| 253 |
+
for x, masks in zip(all_x, masks_list):
|
| 254 |
+
x_norm = self.norm(x)
|
| 255 |
+
output.append(
|
| 256 |
+
{
|
| 257 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 258 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 259 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 260 |
+
"x_prenorm": x,
|
| 261 |
+
"masks": masks,
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
return output
|
| 265 |
+
|
| 266 |
+
def forward_features(self, x, masks=None):
|
| 267 |
+
if isinstance(x, list):
|
| 268 |
+
return self.forward_features_list(x, masks)
|
| 269 |
+
|
| 270 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 271 |
+
|
| 272 |
+
for blk in self.blocks:
|
| 273 |
+
if self.training:
|
| 274 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 275 |
+
else:
|
| 276 |
+
x = blk(x)
|
| 277 |
+
|
| 278 |
+
x_norm = self.norm(x)
|
| 279 |
+
return {
|
| 280 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 281 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 282 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 283 |
+
"x_prenorm": x,
|
| 284 |
+
"masks": masks,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 288 |
+
x = self.prepare_tokens_with_masks(x)
|
| 289 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 290 |
+
output, total_block_len = [], len(self.blocks)
|
| 291 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 292 |
+
for i, blk in enumerate(self.blocks):
|
| 293 |
+
x = blk(x)
|
| 294 |
+
if i in blocks_to_take:
|
| 295 |
+
output.append(x)
|
| 296 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 297 |
+
return output
|
| 298 |
+
|
| 299 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 300 |
+
x = self.prepare_tokens_with_masks(x)
|
| 301 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 302 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 303 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 304 |
+
for block_chunk in self.blocks:
|
| 305 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 306 |
+
x = blk(x)
|
| 307 |
+
if i in blocks_to_take:
|
| 308 |
+
output.append(x)
|
| 309 |
+
i += 1
|
| 310 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 311 |
+
return output
|
| 312 |
+
|
| 313 |
+
def get_intermediate_layers(
|
| 314 |
+
self,
|
| 315 |
+
x: torch.Tensor,
|
| 316 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 317 |
+
reshape: bool = False,
|
| 318 |
+
return_class_token: bool = False,
|
| 319 |
+
norm=True,
|
| 320 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 321 |
+
if self.chunked_blocks:
|
| 322 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 323 |
+
else:
|
| 324 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 325 |
+
if norm:
|
| 326 |
+
outputs = [self.norm(out) for out in outputs]
|
| 327 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 328 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 329 |
+
if reshape:
|
| 330 |
+
B, _, w, h = x.shape
|
| 331 |
+
outputs = [
|
| 332 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 333 |
+
for out in outputs
|
| 334 |
+
]
|
| 335 |
+
if return_class_token:
|
| 336 |
+
return tuple(zip(outputs, class_tokens))
|
| 337 |
+
return tuple(outputs)
|
| 338 |
+
|
| 339 |
+
def forward(self, *args, is_training=True, **kwargs):
|
| 340 |
+
ret = self.forward_features(*args, **kwargs)
|
| 341 |
+
if is_training:
|
| 342 |
+
return ret
|
| 343 |
+
else:
|
| 344 |
+
return self.head(ret["x_norm_clstoken"])
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 348 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 349 |
+
if isinstance(module, nn.Linear):
|
| 350 |
+
trunc_normal_(module.weight, std=0.02)
|
| 351 |
+
if module.bias is not None:
|
| 352 |
+
nn.init.zeros_(module.bias)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 356 |
+
model = DinoVisionTransformer(
|
| 357 |
+
patch_size=patch_size,
|
| 358 |
+
embed_dim=384,
|
| 359 |
+
depth=12,
|
| 360 |
+
num_heads=6,
|
| 361 |
+
mlp_ratio=4,
|
| 362 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 363 |
+
num_register_tokens=num_register_tokens,
|
| 364 |
+
**kwargs,
|
| 365 |
+
)
|
| 366 |
+
return model
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 370 |
+
model = DinoVisionTransformer(
|
| 371 |
+
patch_size=patch_size,
|
| 372 |
+
embed_dim=768,
|
| 373 |
+
depth=12,
|
| 374 |
+
num_heads=12,
|
| 375 |
+
mlp_ratio=4,
|
| 376 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 377 |
+
num_register_tokens=num_register_tokens,
|
| 378 |
+
**kwargs,
|
| 379 |
+
)
|
| 380 |
+
return model
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 384 |
+
model = DinoVisionTransformer(
|
| 385 |
+
patch_size=patch_size,
|
| 386 |
+
embed_dim=1024,
|
| 387 |
+
depth=24,
|
| 388 |
+
num_heads=16,
|
| 389 |
+
mlp_ratio=4,
|
| 390 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 391 |
+
num_register_tokens=num_register_tokens,
|
| 392 |
+
**kwargs,
|
| 393 |
+
)
|
| 394 |
+
return model
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 398 |
+
"""
|
| 399 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 400 |
+
"""
|
| 401 |
+
model = DinoVisionTransformer(
|
| 402 |
+
patch_size=patch_size,
|
| 403 |
+
embed_dim=1536,
|
| 404 |
+
depth=40,
|
| 405 |
+
num_heads=24,
|
| 406 |
+
mlp_ratio=4,
|
| 407 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 408 |
+
num_register_tokens=num_register_tokens,
|
| 409 |
+
**kwargs,
|
| 410 |
+
)
|
| 411 |
+
return model
|
lingbot_map/models/__init__.py
ADDED
|
File without changes
|
lingbot_map/models/gct_base.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GCTBase - Base class for GCT model implementations.
|
| 3 |
+
|
| 4 |
+
Provides shared functionality:
|
| 5 |
+
- Prediction heads (camera, depth, point)
|
| 6 |
+
- Forward pass structure
|
| 7 |
+
- Model hub mixin (PyTorchModelHubMixin)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from typing import Optional, Dict, Any, List, Union
|
| 16 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 17 |
+
|
| 18 |
+
from lingbot_map.heads.dpt_head import DPTHead
|
| 19 |
+
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
| 20 |
+
from lingbot_map.utils.geometry import closed_form_inverse_se3
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GCTBase(nn.Module, PyTorchModelHubMixin, ABC):
|
| 26 |
+
"""
|
| 27 |
+
Base class for GCT model implementations.
|
| 28 |
+
|
| 29 |
+
Handles shared components:
|
| 30 |
+
- Prediction heads (camera, depth, point)
|
| 31 |
+
- Forward pass structure
|
| 32 |
+
- Input normalization
|
| 33 |
+
|
| 34 |
+
Subclasses must implement:
|
| 35 |
+
- _build_aggregator(): Create mode-specific aggregator
|
| 36 |
+
- _build_camera_head(): Create mode-specific camera head
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
# Architecture parameters
|
| 42 |
+
img_size: int = 518,
|
| 43 |
+
patch_size: int = 14,
|
| 44 |
+
embed_dim: int = 1024,
|
| 45 |
+
patch_embed: str = 'dinov2_vitl14_reg',
|
| 46 |
+
disable_global_rope: bool = False,
|
| 47 |
+
# Head configuration
|
| 48 |
+
enable_camera: bool = True,
|
| 49 |
+
enable_point: bool = True,
|
| 50 |
+
enable_local_point: bool = False,
|
| 51 |
+
enable_depth: bool = True,
|
| 52 |
+
enable_track: bool = False,
|
| 53 |
+
# Camera head sliding window
|
| 54 |
+
enable_camera_sliding_window: bool = False,
|
| 55 |
+
# 3D RoPE
|
| 56 |
+
enable_3d_rope: bool = False,
|
| 57 |
+
# Context Parallelism (kept for checkpoint compatibility but not used)
|
| 58 |
+
enable_ulysses_cp: bool = False,
|
| 59 |
+
# Normalization
|
| 60 |
+
enable_normalize: bool = False,
|
| 61 |
+
# Prediction normalization
|
| 62 |
+
pred_normalization: bool = False,
|
| 63 |
+
pred_normalization_detach_scale: bool = False,
|
| 64 |
+
# Gradient checkpointing
|
| 65 |
+
use_gradient_checkpoint: bool = True,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
# Store configuration
|
| 70 |
+
self.img_size = img_size
|
| 71 |
+
self.patch_size = patch_size
|
| 72 |
+
self.embed_dim = embed_dim
|
| 73 |
+
self.patch_embed = patch_embed
|
| 74 |
+
self.disable_global_rope = disable_global_rope
|
| 75 |
+
|
| 76 |
+
self.enable_ulysses_cp = False # CP disabled in standalone package
|
| 77 |
+
self.enable_normalize = enable_normalize
|
| 78 |
+
self.pred_normalization = pred_normalization
|
| 79 |
+
self.pred_normalization_detach_scale = pred_normalization_detach_scale
|
| 80 |
+
self.use_gradient_checkpoint = use_gradient_checkpoint
|
| 81 |
+
|
| 82 |
+
# Head flags
|
| 83 |
+
self.enable_camera = enable_camera
|
| 84 |
+
self.enable_point = enable_point
|
| 85 |
+
self.enable_local_point = enable_local_point
|
| 86 |
+
self.enable_depth = enable_depth
|
| 87 |
+
self.enable_track = enable_track
|
| 88 |
+
self.enable_camera_sliding_window = enable_camera_sliding_window
|
| 89 |
+
self.enable_3d_rope = enable_3d_rope
|
| 90 |
+
|
| 91 |
+
# Build aggregator (subclass-specific)
|
| 92 |
+
self.aggregator = self._build_aggregator()
|
| 93 |
+
|
| 94 |
+
# Build prediction heads (subclass-specific)
|
| 95 |
+
self.camera_head = self._build_camera_head() if enable_camera else None
|
| 96 |
+
self.point_head = self._build_point_head() if enable_point else None
|
| 97 |
+
self.local_point_head = self._build_local_point_head() if enable_local_point else None
|
| 98 |
+
self.depth_head = self._build_depth_head() if enable_depth else None
|
| 99 |
+
|
| 100 |
+
@abstractmethod
|
| 101 |
+
def _build_aggregator(self) -> nn.Module:
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
@abstractmethod
|
| 105 |
+
def _build_camera_head(self) -> nn.Module:
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
def _build_depth_head(self) -> nn.Module:
|
| 109 |
+
return DPTHead(
|
| 110 |
+
dim_in=2 * self.embed_dim,
|
| 111 |
+
patch_size=self.patch_size,
|
| 112 |
+
output_dim=2,
|
| 113 |
+
activation="exp",
|
| 114 |
+
conf_activation="expp1"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _build_point_head(self) -> nn.Module:
|
| 118 |
+
return DPTHead(
|
| 119 |
+
dim_in=2 * self.embed_dim,
|
| 120 |
+
patch_size=self.patch_size,
|
| 121 |
+
output_dim=4,
|
| 122 |
+
activation="inv_log",
|
| 123 |
+
conf_activation="expp1"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def _build_local_point_head(self) -> nn.Module:
|
| 127 |
+
return DPTHead(
|
| 128 |
+
dim_in=2 * self.embed_dim,
|
| 129 |
+
patch_size=self.patch_size,
|
| 130 |
+
output_dim=4,
|
| 131 |
+
activation="inv_log",
|
| 132 |
+
conf_activation="expp1"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def _normalize_input(self, images: torch.Tensor, query_points=None):
|
| 136 |
+
if len(images.shape) == 4:
|
| 137 |
+
images = images.unsqueeze(0)
|
| 138 |
+
if query_points is not None and len(query_points.shape) == 2:
|
| 139 |
+
query_points = query_points.unsqueeze(0)
|
| 140 |
+
return images, query_points
|
| 141 |
+
|
| 142 |
+
@abstractmethod
|
| 143 |
+
def _aggregate_features(
|
| 144 |
+
self,
|
| 145 |
+
images: torch.Tensor,
|
| 146 |
+
num_frame_for_scale: Optional[int] = None,
|
| 147 |
+
sliding_window_size: Optional[int] = None,
|
| 148 |
+
num_frame_per_block: int = 1,
|
| 149 |
+
view_graphs: Optional[torch.Tensor] = None,
|
| 150 |
+
causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None,
|
| 151 |
+
ordered_video: Optional[torch.Tensor] = None,
|
| 152 |
+
is_cp_sliced: bool = False,
|
| 153 |
+
) -> tuple:
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
def _predict_camera(
|
| 157 |
+
self,
|
| 158 |
+
aggregated_tokens_list: list,
|
| 159 |
+
mask: Optional[torch.Tensor] = None,
|
| 160 |
+
causal_inference: bool = False,
|
| 161 |
+
num_frame_for_scale: Optional[int] = None,
|
| 162 |
+
sliding_window_size: Optional[int] = None,
|
| 163 |
+
num_frame_per_block: int = 1,
|
| 164 |
+
gather_outputs: bool = True,
|
| 165 |
+
) -> Dict[str, torch.Tensor]:
|
| 166 |
+
if self.camera_head is None:
|
| 167 |
+
return {}
|
| 168 |
+
|
| 169 |
+
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
| 170 |
+
|
| 171 |
+
camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1
|
| 172 |
+
|
| 173 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 174 |
+
pose_enc_list = self.camera_head(
|
| 175 |
+
aggregated_tokens_list_fp32,
|
| 176 |
+
mask=mask,
|
| 177 |
+
causal_inference=causal_inference,
|
| 178 |
+
num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1,
|
| 179 |
+
sliding_window_size=camera_sliding_window,
|
| 180 |
+
num_frame_per_block=num_frame_per_block,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"pose_enc": pose_enc_list[-1],
|
| 185 |
+
"pose_enc_list": pose_enc_list,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
def _predict_depth(
|
| 189 |
+
self,
|
| 190 |
+
aggregated_tokens_list: list,
|
| 191 |
+
images: torch.Tensor,
|
| 192 |
+
patch_start_idx: int,
|
| 193 |
+
gather_outputs: bool = True,
|
| 194 |
+
) -> Dict[str, torch.Tensor]:
|
| 195 |
+
if self.depth_head is None:
|
| 196 |
+
return {}
|
| 197 |
+
|
| 198 |
+
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
| 199 |
+
images_fp32 = images.float()
|
| 200 |
+
|
| 201 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 202 |
+
depth, depth_conf = self.depth_head(
|
| 203 |
+
aggregated_tokens_list_fp32,
|
| 204 |
+
images=images_fp32,
|
| 205 |
+
patch_start_idx=patch_start_idx
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return {"depth": depth, "depth_conf": depth_conf}
|
| 209 |
+
|
| 210 |
+
def _predict_points(
|
| 211 |
+
self,
|
| 212 |
+
aggregated_tokens_list: list,
|
| 213 |
+
images: torch.Tensor,
|
| 214 |
+
patch_start_idx: int,
|
| 215 |
+
gather_outputs: bool = True,
|
| 216 |
+
) -> Dict[str, torch.Tensor]:
|
| 217 |
+
if self.point_head is None:
|
| 218 |
+
return {}
|
| 219 |
+
|
| 220 |
+
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
| 221 |
+
images_fp32 = images.float()
|
| 222 |
+
|
| 223 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 224 |
+
pts3d, pts3d_conf = self.point_head(
|
| 225 |
+
aggregated_tokens_list_fp32,
|
| 226 |
+
images=images_fp32,
|
| 227 |
+
patch_start_idx=patch_start_idx
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return {"world_points": pts3d, "world_points_conf": pts3d_conf}
|
| 231 |
+
|
| 232 |
+
def _predict_local_points(
|
| 233 |
+
self,
|
| 234 |
+
aggregated_tokens_list: list,
|
| 235 |
+
images: torch.Tensor,
|
| 236 |
+
patch_start_idx: int,
|
| 237 |
+
gather_outputs: bool = True,
|
| 238 |
+
) -> Dict[str, torch.Tensor]:
|
| 239 |
+
if self.local_point_head is None:
|
| 240 |
+
return {}
|
| 241 |
+
|
| 242 |
+
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
| 243 |
+
images_fp32 = images.float()
|
| 244 |
+
|
| 245 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 246 |
+
pts3d, pts3d_conf = self.local_point_head(
|
| 247 |
+
aggregated_tokens_list_fp32,
|
| 248 |
+
images=images_fp32,
|
| 249 |
+
patch_start_idx=patch_start_idx
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
return {"cam_points": pts3d, "cam_points_conf": pts3d_conf}
|
| 253 |
+
|
| 254 |
+
def _unproject_depth_to_world(
|
| 255 |
+
self,
|
| 256 |
+
depth: torch.Tensor,
|
| 257 |
+
pose_enc: torch.Tensor,
|
| 258 |
+
) -> torch.Tensor:
|
| 259 |
+
B, S, H, W, _ = depth.shape
|
| 260 |
+
device = depth.device
|
| 261 |
+
dtype = depth.dtype
|
| 262 |
+
|
| 263 |
+
image_size_hw = (H, W)
|
| 264 |
+
extrinsics, intrinsics = pose_encoding_to_extri_intri(
|
| 265 |
+
pose_enc, image_size_hw=image_size_hw, build_intrinsics=True
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
extrinsics_flat = extrinsics.view(B * S, 3, 4)
|
| 269 |
+
extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype)
|
| 270 |
+
extrinsics_4x4[:, :3, :] = extrinsics_flat
|
| 271 |
+
extrinsics_4x4[:, 3, 3] = 1.0
|
| 272 |
+
c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4)
|
| 273 |
+
|
| 274 |
+
y_grid, x_grid = torch.meshgrid(
|
| 275 |
+
torch.arange(H, device=device, dtype=dtype),
|
| 276 |
+
torch.arange(W, device=device, dtype=dtype),
|
| 277 |
+
indexing='ij'
|
| 278 |
+
)
|
| 279 |
+
pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1)
|
| 280 |
+
|
| 281 |
+
intrinsics_inv = torch.inverse(intrinsics)
|
| 282 |
+
camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords)
|
| 283 |
+
camera_points = camera_coords * depth
|
| 284 |
+
|
| 285 |
+
ones = torch.ones_like(camera_points[..., :1])
|
| 286 |
+
camera_points_h = torch.cat([camera_points, ones], dim=-1)
|
| 287 |
+
world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h)
|
| 288 |
+
|
| 289 |
+
return world_points_h[..., :3]
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
images: torch.Tensor,
|
| 294 |
+
query_points: Optional[torch.Tensor] = None,
|
| 295 |
+
num_frame_for_scale: Optional[int] = None,
|
| 296 |
+
sliding_window_size: Optional[int] = None,
|
| 297 |
+
num_frame_per_block: int = 1,
|
| 298 |
+
mask: Optional[torch.Tensor] = None,
|
| 299 |
+
causal_inference: bool = False,
|
| 300 |
+
ordered_video: Optional[torch.Tensor] = None,
|
| 301 |
+
gather_outputs: bool = True,
|
| 302 |
+
point_masks: Optional[torch.Tensor] = None,
|
| 303 |
+
**kwargs,
|
| 304 |
+
) -> Dict[str, torch.Tensor]:
|
| 305 |
+
"""
|
| 306 |
+
Forward pass of the GCT model.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
| 310 |
+
query_points: Optional query points [N, 2] or [B, N, 2]
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
Dictionary containing predictions:
|
| 314 |
+
- pose_enc: Camera pose encoding [B, S, 9]
|
| 315 |
+
- depth: Depth maps [B, S, H, W, 1]
|
| 316 |
+
- depth_conf: Depth confidence [B, S, H, W]
|
| 317 |
+
- world_points: 3D world coordinates [B, S, H, W, 3]
|
| 318 |
+
- world_points_conf: Point confidence [B, S, H, W]
|
| 319 |
+
"""
|
| 320 |
+
images, query_points = self._normalize_input(images, query_points)
|
| 321 |
+
|
| 322 |
+
aggregated_tokens_list, patch_start_idx = self._aggregate_features(
|
| 323 |
+
images,
|
| 324 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 325 |
+
sliding_window_size=sliding_window_size,
|
| 326 |
+
num_frame_per_block=num_frame_per_block,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
predictions = {}
|
| 330 |
+
|
| 331 |
+
predictions.update(self._predict_camera(
|
| 332 |
+
aggregated_tokens_list,
|
| 333 |
+
mask=ordered_video,
|
| 334 |
+
causal_inference=causal_inference,
|
| 335 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 336 |
+
sliding_window_size=sliding_window_size,
|
| 337 |
+
num_frame_per_block=num_frame_per_block,
|
| 338 |
+
gather_outputs=gather_outputs,
|
| 339 |
+
))
|
| 340 |
+
|
| 341 |
+
predictions.update(self._predict_depth(
|
| 342 |
+
aggregated_tokens_list, images, patch_start_idx,
|
| 343 |
+
gather_outputs=gather_outputs,
|
| 344 |
+
))
|
| 345 |
+
|
| 346 |
+
predictions.update(self._predict_points(
|
| 347 |
+
aggregated_tokens_list, images, patch_start_idx,
|
| 348 |
+
gather_outputs=gather_outputs,
|
| 349 |
+
))
|
| 350 |
+
|
| 351 |
+
predictions.update(self._predict_local_points(
|
| 352 |
+
aggregated_tokens_list, images, patch_start_idx,
|
| 353 |
+
gather_outputs=gather_outputs,
|
| 354 |
+
))
|
| 355 |
+
|
| 356 |
+
if not self.training:
|
| 357 |
+
predictions["images"] = images
|
| 358 |
+
|
| 359 |
+
return predictions
|
lingbot_map/models/gct_stream.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GCTStream - Streaming GCT with KV cache for online inference.
|
| 3 |
+
|
| 4 |
+
Provides streaming inference functionality:
|
| 5 |
+
- Temporal causal attention with KV cache
|
| 6 |
+
- Sliding window support
|
| 7 |
+
- Efficient frame-by-frame processing
|
| 8 |
+
- 3D RoPE support for temporal consistency
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from typing import Optional, Dict, Any, List
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
|
| 17 |
+
from lingbot_map.heads.camera_head import CameraCausalHead
|
| 18 |
+
from lingbot_map.models.gct_base import GCTBase
|
| 19 |
+
from lingbot_map.aggregator.stream import AggregatorStream
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GCTStream(GCTBase):
|
| 25 |
+
"""
|
| 26 |
+
Streaming GCT model with KV cache for efficient online inference.
|
| 27 |
+
|
| 28 |
+
Features:
|
| 29 |
+
- AggregatorStream with KV cache support (FlashInfer backend)
|
| 30 |
+
- CameraCausalHead for pose refinement
|
| 31 |
+
- Sliding window attention for memory efficiency
|
| 32 |
+
- Frame-by-frame streaming inference
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
# Architecture parameters
|
| 38 |
+
img_size: int = 518,
|
| 39 |
+
patch_size: int = 14,
|
| 40 |
+
embed_dim: int = 1024,
|
| 41 |
+
patch_embed: str = 'dinov2_vitl14_reg',
|
| 42 |
+
pretrained_path: str = '',
|
| 43 |
+
disable_global_rope: bool = False,
|
| 44 |
+
# Head configuration
|
| 45 |
+
enable_camera: bool = True,
|
| 46 |
+
enable_point: bool = True,
|
| 47 |
+
enable_local_point: bool = False,
|
| 48 |
+
enable_depth: bool = True,
|
| 49 |
+
enable_track: bool = False,
|
| 50 |
+
# Normalization
|
| 51 |
+
enable_normalize: bool = False,
|
| 52 |
+
# Prediction normalization
|
| 53 |
+
pred_normalization: bool = False,
|
| 54 |
+
# Stream-specific parameters
|
| 55 |
+
sliding_window_size: int = -1,
|
| 56 |
+
num_frame_for_scale: int = 1,
|
| 57 |
+
num_random_frames: int = 0,
|
| 58 |
+
attend_to_special_tokens: bool = False,
|
| 59 |
+
attend_to_scale_frames: bool = False,
|
| 60 |
+
enable_stream_inference: bool = True, # Default to True for streaming
|
| 61 |
+
enable_3d_rope: bool = False,
|
| 62 |
+
max_frame_num: int = 1024,
|
| 63 |
+
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
|
| 64 |
+
enable_camera_3d_rope: bool = False,
|
| 65 |
+
camera_rope_theta: float = 10000.0,
|
| 66 |
+
# Scale token configuration (kept for checkpoint compat, ignored)
|
| 67 |
+
use_scale_token: bool = True,
|
| 68 |
+
# KV cache parameters
|
| 69 |
+
kv_cache_sliding_window: int = 64,
|
| 70 |
+
kv_cache_scale_frames: int = 8,
|
| 71 |
+
kv_cache_cross_frame_special: bool = True,
|
| 72 |
+
kv_cache_include_scale_frames: bool = True,
|
| 73 |
+
kv_cache_camera_only: bool = False,
|
| 74 |
+
# Backend selection
|
| 75 |
+
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
| 76 |
+
# Gradient checkpointing
|
| 77 |
+
use_gradient_checkpoint: bool = True,
|
| 78 |
+
# Camera head iterative refinement (lower = faster inference; default 4)
|
| 79 |
+
camera_num_iterations: int = 4,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Initialize GCTStream.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
img_size: Input image size
|
| 86 |
+
patch_size: Patch size for embedding
|
| 87 |
+
embed_dim: Embedding dimension
|
| 88 |
+
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
|
| 89 |
+
pretrained_path: Path to pretrained DINOv2 weights
|
| 90 |
+
disable_global_rope: Disable RoPE in global attention
|
| 91 |
+
enable_camera/point/depth/track: Enable prediction heads
|
| 92 |
+
enable_normalize: Enable normalization
|
| 93 |
+
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
| 94 |
+
num_frame_for_scale: Number of scale estimation frames
|
| 95 |
+
num_random_frames: Number of random frames for long-range dependencies
|
| 96 |
+
attend_to_special_tokens: Enable cross-frame special token attention
|
| 97 |
+
attend_to_scale_frames: Whether to attend to scale frames
|
| 98 |
+
enable_stream_inference: Enable streaming inference with KV cache
|
| 99 |
+
enable_3d_rope: Enable 3D RoPE for temporal consistency
|
| 100 |
+
max_frame_num: Maximum number of frames for 3D RoPE
|
| 101 |
+
use_scale_token: Kept for checkpoint compatibility, ignored
|
| 102 |
+
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
| 103 |
+
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
| 104 |
+
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
| 105 |
+
kv_cache_include_scale_frames: Include scale frames in KV cache
|
| 106 |
+
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
| 107 |
+
"""
|
| 108 |
+
# Store stream-specific parameters before calling super().__init__()
|
| 109 |
+
self.pretrained_path = pretrained_path
|
| 110 |
+
self.sliding_window_size = sliding_window_size
|
| 111 |
+
self.num_frame_for_scale = num_frame_for_scale
|
| 112 |
+
self.num_random_frames = num_random_frames
|
| 113 |
+
self.attend_to_special_tokens = attend_to_special_tokens
|
| 114 |
+
self.attend_to_scale_frames = attend_to_scale_frames
|
| 115 |
+
self.enable_stream_inference = enable_stream_inference
|
| 116 |
+
self.enable_3d_rope = enable_3d_rope
|
| 117 |
+
self.max_frame_num = max_frame_num
|
| 118 |
+
# Camera head 3D RoPE settings
|
| 119 |
+
self.enable_camera_3d_rope = enable_camera_3d_rope
|
| 120 |
+
self.camera_rope_theta = camera_rope_theta
|
| 121 |
+
# KV cache parameters
|
| 122 |
+
self.kv_cache_sliding_window = kv_cache_sliding_window
|
| 123 |
+
self.kv_cache_scale_frames = kv_cache_scale_frames
|
| 124 |
+
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
| 125 |
+
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
| 126 |
+
self.kv_cache_camera_only = kv_cache_camera_only
|
| 127 |
+
self.use_sdpa = use_sdpa
|
| 128 |
+
self.camera_num_iterations = camera_num_iterations
|
| 129 |
+
|
| 130 |
+
# Call base class __init__ (will call _build_aggregator)
|
| 131 |
+
super().__init__(
|
| 132 |
+
img_size=img_size,
|
| 133 |
+
patch_size=patch_size,
|
| 134 |
+
embed_dim=embed_dim,
|
| 135 |
+
patch_embed=patch_embed,
|
| 136 |
+
disable_global_rope=disable_global_rope,
|
| 137 |
+
enable_camera=enable_camera,
|
| 138 |
+
enable_point=enable_point,
|
| 139 |
+
enable_local_point=enable_local_point,
|
| 140 |
+
enable_depth=enable_depth,
|
| 141 |
+
enable_track=enable_track,
|
| 142 |
+
enable_normalize=enable_normalize,
|
| 143 |
+
pred_normalization=pred_normalization,
|
| 144 |
+
enable_3d_rope=enable_3d_rope,
|
| 145 |
+
use_gradient_checkpoint=use_gradient_checkpoint,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def _build_aggregator(self) -> nn.Module:
|
| 149 |
+
"""
|
| 150 |
+
Build streaming aggregator with KV cache support (FlashInfer backend).
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
AggregatorStream module
|
| 154 |
+
"""
|
| 155 |
+
return AggregatorStream(
|
| 156 |
+
img_size=self.img_size,
|
| 157 |
+
patch_size=self.patch_size,
|
| 158 |
+
embed_dim=self.embed_dim,
|
| 159 |
+
patch_embed=self.patch_embed,
|
| 160 |
+
pretrained_path=self.pretrained_path,
|
| 161 |
+
disable_global_rope=self.disable_global_rope,
|
| 162 |
+
sliding_window_size=self.sliding_window_size,
|
| 163 |
+
num_frame_for_scale=self.num_frame_for_scale,
|
| 164 |
+
num_random_frames=self.num_random_frames,
|
| 165 |
+
attend_to_special_tokens=self.attend_to_special_tokens,
|
| 166 |
+
attend_to_scale_frames=self.attend_to_scale_frames,
|
| 167 |
+
enable_stream_inference=self.enable_stream_inference,
|
| 168 |
+
enable_3d_rope=self.enable_3d_rope,
|
| 169 |
+
max_frame_num=self.max_frame_num,
|
| 170 |
+
# Backend: FlashInfer (default) or SDPA (fallback)
|
| 171 |
+
use_flashinfer=not self.use_sdpa,
|
| 172 |
+
use_sdpa=self.use_sdpa,
|
| 173 |
+
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
| 174 |
+
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
| 175 |
+
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
| 176 |
+
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
| 177 |
+
kv_cache_camera_only=self.kv_cache_camera_only,
|
| 178 |
+
use_gradient_checkpoint=self.use_gradient_checkpoint,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def _build_camera_head(self) -> nn.Module:
|
| 182 |
+
"""
|
| 183 |
+
Build causal camera head for streaming inference.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
CameraCausalHead module or None
|
| 187 |
+
"""
|
| 188 |
+
return CameraCausalHead(
|
| 189 |
+
dim_in=2 * self.embed_dim,
|
| 190 |
+
sliding_window_size=self.sliding_window_size,
|
| 191 |
+
attend_to_scale_frames=self.attend_to_scale_frames,
|
| 192 |
+
num_iterations=self.camera_num_iterations,
|
| 193 |
+
# KV cache parameters
|
| 194 |
+
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
| 195 |
+
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
| 196 |
+
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
| 197 |
+
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
| 198 |
+
kv_cache_camera_only=self.kv_cache_camera_only,
|
| 199 |
+
# Camera head 3D RoPE parameters
|
| 200 |
+
enable_3d_rope=self.enable_camera_3d_rope,
|
| 201 |
+
max_frame_num=self.max_frame_num,
|
| 202 |
+
rope_theta=self.camera_rope_theta,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def _aggregate_features(
|
| 206 |
+
self,
|
| 207 |
+
images: torch.Tensor,
|
| 208 |
+
num_frame_for_scale: Optional[int] = None,
|
| 209 |
+
sliding_window_size: Optional[int] = None,
|
| 210 |
+
num_frame_per_block: int = 1,
|
| 211 |
+
**kwargs,
|
| 212 |
+
) -> tuple:
|
| 213 |
+
"""
|
| 214 |
+
Run aggregator to get multi-scale features.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
images: Input images [B, S, 3, H, W]
|
| 218 |
+
num_frame_for_scale: Number of frames for scale estimation
|
| 219 |
+
sliding_window_size: Override sliding window size
|
| 220 |
+
num_frame_per_block: Number of frames per block
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
(aggregated_tokens_list, patch_start_idx)
|
| 224 |
+
"""
|
| 225 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(
|
| 226 |
+
images,
|
| 227 |
+
selected_idx=[4, 11, 17, 23],
|
| 228 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 229 |
+
sliding_window_size=sliding_window_size,
|
| 230 |
+
num_frame_per_block=num_frame_per_block,
|
| 231 |
+
)
|
| 232 |
+
return aggregated_tokens_list, patch_start_idx
|
| 233 |
+
|
| 234 |
+
def clean_kv_cache(self):
|
| 235 |
+
"""
|
| 236 |
+
Clean KV cache in aggregator.
|
| 237 |
+
|
| 238 |
+
Call this method when starting a new video sequence to clear
|
| 239 |
+
cached key-value pairs from previous sequences.
|
| 240 |
+
"""
|
| 241 |
+
if hasattr(self.aggregator, 'clean_kv_cache'):
|
| 242 |
+
self.aggregator.clean_kv_cache()
|
| 243 |
+
else:
|
| 244 |
+
logger.warning("Aggregator does not support KV cache cleaning")
|
| 245 |
+
if hasattr(self.camera_head, 'kv_cache'):
|
| 246 |
+
self.camera_head.clean_kv_cache()
|
| 247 |
+
else:
|
| 248 |
+
logger.warning("Camera head does not support KV cache cleaning")
|
| 249 |
+
|
| 250 |
+
def _set_skip_append(self, skip: bool):
|
| 251 |
+
"""Set _skip_append flag on all KV caches (aggregator + camera head).
|
| 252 |
+
|
| 253 |
+
When skip=True, attention layers will attend to [cached_kv + current_kv]
|
| 254 |
+
but will NOT store the current frame's KV in cache. This is used for
|
| 255 |
+
non-keyframe processing in keyframe-based streaming inference.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
skip: If True, subsequent forward passes will not append KV to cache.
|
| 259 |
+
"""
|
| 260 |
+
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
|
| 261 |
+
self.aggregator.kv_cache["_skip_append"] = skip
|
| 262 |
+
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
|
| 263 |
+
for cache_dict in self.camera_head.kv_cache:
|
| 264 |
+
cache_dict["_skip_append"] = skip
|
| 265 |
+
|
| 266 |
+
def get_kv_cache_info(self) -> Dict[str, Any]:
|
| 267 |
+
"""
|
| 268 |
+
Get information about current KV cache state.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
Dictionary with cache statistics:
|
| 272 |
+
- num_cached_blocks: Number of blocks with cached KV
|
| 273 |
+
- cache_memory_mb: Approximate memory usage in MB
|
| 274 |
+
"""
|
| 275 |
+
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
|
| 276 |
+
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
|
| 277 |
+
|
| 278 |
+
kv_cache = self.aggregator.kv_cache
|
| 279 |
+
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
|
| 280 |
+
|
| 281 |
+
# Estimate memory usage
|
| 282 |
+
total_elements = 0
|
| 283 |
+
for _, v in kv_cache.items():
|
| 284 |
+
if v is not None and torch.is_tensor(v):
|
| 285 |
+
total_elements += v.numel()
|
| 286 |
+
|
| 287 |
+
# Assume bfloat16 (2 bytes per element)
|
| 288 |
+
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
|
| 289 |
+
|
| 290 |
+
return {
|
| 291 |
+
"num_cached_blocks": num_cached,
|
| 292 |
+
"cache_memory_mb": round(cache_memory_mb, 2)
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
@torch.no_grad()
|
| 296 |
+
def inference_streaming(
|
| 297 |
+
self,
|
| 298 |
+
images: torch.Tensor,
|
| 299 |
+
num_scale_frames: Optional[int] = None,
|
| 300 |
+
keyframe_interval: int = 1,
|
| 301 |
+
output_device: Optional[torch.device] = None,
|
| 302 |
+
) -> Dict[str, torch.Tensor]:
|
| 303 |
+
"""
|
| 304 |
+
Streaming inference: process scale frames first, then frame-by-frame.
|
| 305 |
+
|
| 306 |
+
This method enables efficient online inference by:
|
| 307 |
+
1. Processing initial scale frames together (bidirectional attention via scale token)
|
| 308 |
+
2. Processing remaining frames one-by-one with KV cache (causal streaming)
|
| 309 |
+
|
| 310 |
+
Keyframe mode (keyframe_interval > 1):
|
| 311 |
+
- Every keyframe_interval-th frame (after scale frames) is a keyframe
|
| 312 |
+
- Keyframes: KV is stored in cache (normal behavior)
|
| 313 |
+
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
|
| 314 |
+
- All frames produce full predictions regardless of keyframe status
|
| 315 |
+
- Reduces KV cache memory growth by ~1/keyframe_interval
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
| 319 |
+
num_scale_frames: Number of initial frames for scale estimation.
|
| 320 |
+
If None, uses self.num_frame_for_scale.
|
| 321 |
+
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
|
| 322 |
+
whose KV persists in cache. 1 = every frame is a
|
| 323 |
+
keyframe (default, same as original behavior).
|
| 324 |
+
output_device: Device to store output predictions on. If None, keeps on
|
| 325 |
+
the same device as the model. Set to torch.device('cpu')
|
| 326 |
+
to offload predictions per-frame and avoid GPU OOM on
|
| 327 |
+
long sequences.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
Dictionary containing predictions for all frames:
|
| 331 |
+
- pose_enc: [B, S, 9]
|
| 332 |
+
- depth: [B, S, H, W, 1]
|
| 333 |
+
- depth_conf: [B, S, H, W]
|
| 334 |
+
- world_points: [B, S, H, W, 3]
|
| 335 |
+
- world_points_conf: [B, S, H, W]
|
| 336 |
+
"""
|
| 337 |
+
# Normalize input shape
|
| 338 |
+
if len(images.shape) == 4:
|
| 339 |
+
images = images.unsqueeze(0)
|
| 340 |
+
B, S, C, H, W = images.shape
|
| 341 |
+
|
| 342 |
+
# Determine number of scale frames
|
| 343 |
+
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
|
| 344 |
+
scale_frames = min(scale_frames, S) # Cap to available frames
|
| 345 |
+
|
| 346 |
+
# Helper to move tensor to output device
|
| 347 |
+
def _to_out(t: torch.Tensor) -> torch.Tensor:
|
| 348 |
+
if output_device is not None:
|
| 349 |
+
return t.to(output_device)
|
| 350 |
+
return t
|
| 351 |
+
|
| 352 |
+
# Clean KV caches before starting new sequence
|
| 353 |
+
self.clean_kv_cache()
|
| 354 |
+
|
| 355 |
+
# Phase 1: Process scale frames together
|
| 356 |
+
# These frames get bidirectional attention among themselves via scale token
|
| 357 |
+
logger.info(f'Processing {scale_frames} scale frames...')
|
| 358 |
+
scale_images = images[:, :scale_frames]
|
| 359 |
+
scale_output = self.forward(
|
| 360 |
+
scale_images,
|
| 361 |
+
num_frame_for_scale=scale_frames,
|
| 362 |
+
num_frame_per_block=scale_frames, # Process all scale frames as one block
|
| 363 |
+
causal_inference=True,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Initialize output lists with scale frame predictions (offload if needed)
|
| 367 |
+
all_pose_enc = [_to_out(scale_output["pose_enc"])]
|
| 368 |
+
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
|
| 369 |
+
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
|
| 370 |
+
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
|
| 371 |
+
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
|
| 372 |
+
del scale_output
|
| 373 |
+
|
| 374 |
+
# Phase 2: Process remaining frames one-by-one
|
| 375 |
+
pbar = tqdm(
|
| 376 |
+
range(scale_frames, S),
|
| 377 |
+
desc='Streaming inference',
|
| 378 |
+
initial=scale_frames,
|
| 379 |
+
total=S,
|
| 380 |
+
)
|
| 381 |
+
for i in pbar:
|
| 382 |
+
frame_image = images[:, i:i+1]
|
| 383 |
+
|
| 384 |
+
# Determine if this frame is a keyframe
|
| 385 |
+
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
|
| 386 |
+
|
| 387 |
+
if not is_keyframe:
|
| 388 |
+
self._set_skip_append(True)
|
| 389 |
+
|
| 390 |
+
frame_output = self.forward(
|
| 391 |
+
frame_image,
|
| 392 |
+
num_frame_for_scale=scale_frames, # Keep same for scale token logic
|
| 393 |
+
num_frame_per_block=1, # Single frame per block
|
| 394 |
+
causal_inference=True,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if not is_keyframe:
|
| 398 |
+
self._set_skip_append(False)
|
| 399 |
+
|
| 400 |
+
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
|
| 401 |
+
if "depth" in frame_output:
|
| 402 |
+
all_depth.append(_to_out(frame_output["depth"]))
|
| 403 |
+
if "depth_conf" in frame_output:
|
| 404 |
+
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
|
| 405 |
+
if "world_points" in frame_output:
|
| 406 |
+
all_world_points.append(_to_out(frame_output["world_points"]))
|
| 407 |
+
if "world_points_conf" in frame_output:
|
| 408 |
+
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
|
| 409 |
+
del frame_output
|
| 410 |
+
|
| 411 |
+
# Free GPU memory before concatenation
|
| 412 |
+
if output_device is not None:
|
| 413 |
+
# Move images to output device, then free GPU copy
|
| 414 |
+
images_out = _to_out(images)
|
| 415 |
+
del images
|
| 416 |
+
# Clean KV cache (no longer needed after inference)
|
| 417 |
+
self.clean_kv_cache()
|
| 418 |
+
if torch.cuda.is_available():
|
| 419 |
+
torch.cuda.empty_cache()
|
| 420 |
+
else:
|
| 421 |
+
images_out = images
|
| 422 |
+
|
| 423 |
+
# Concatenate all predictions along sequence dimension
|
| 424 |
+
predictions = {
|
| 425 |
+
"pose_enc": torch.cat(all_pose_enc, dim=1),
|
| 426 |
+
}
|
| 427 |
+
del all_pose_enc
|
| 428 |
+
if all_depth:
|
| 429 |
+
predictions["depth"] = torch.cat(all_depth, dim=1)
|
| 430 |
+
del all_depth
|
| 431 |
+
if all_depth_conf:
|
| 432 |
+
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
|
| 433 |
+
del all_depth_conf
|
| 434 |
+
if all_world_points:
|
| 435 |
+
predictions["world_points"] = torch.cat(all_world_points, dim=1)
|
| 436 |
+
del all_world_points
|
| 437 |
+
if all_world_points_conf:
|
| 438 |
+
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
|
| 439 |
+
del all_world_points_conf
|
| 440 |
+
|
| 441 |
+
# Store images for visualization
|
| 442 |
+
predictions["images"] = images_out
|
| 443 |
+
|
| 444 |
+
# Apply prediction normalization if enabled
|
| 445 |
+
if self.pred_normalization:
|
| 446 |
+
predictions = self._normalize_predictions(predictions)
|
| 447 |
+
|
| 448 |
+
return predictions
|
lingbot_map/models/gct_stream_window.py
ADDED
|
@@ -0,0 +1,1206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GCTStream - Streaming GCT with KV cache for online inference.
|
| 3 |
+
|
| 4 |
+
Provides streaming inference functionality:
|
| 5 |
+
- Temporal causal attention with KV cache
|
| 6 |
+
- Sliding window support
|
| 7 |
+
- Efficient frame-by-frame processing
|
| 8 |
+
- 3D RoPE support for temporal consistency
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from typing import Optional, Dict, Any, List
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
|
| 17 |
+
from lingbot_map.utils.rotation import quat_to_mat, mat_to_quat
|
| 18 |
+
|
| 19 |
+
from lingbot_map.heads.camera_head import CameraCausalHead
|
| 20 |
+
from lingbot_map.models.gct_base import GCTBase
|
| 21 |
+
from lingbot_map.aggregator.stream import AggregatorStream
|
| 22 |
+
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
| 23 |
+
from lingbot_map.utils.geometry import closed_form_inverse_se3
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def _compute_flow_magnitude(
|
| 30 |
+
cur_pose_enc: torch.Tensor,
|
| 31 |
+
kf_pose_enc: torch.Tensor,
|
| 32 |
+
cur_depth: torch.Tensor,
|
| 33 |
+
image_size_hw: tuple,
|
| 34 |
+
stride: int = 8,
|
| 35 |
+
) -> float:
|
| 36 |
+
"""Compute mean optical flow magnitude induced by camera motion.
|
| 37 |
+
|
| 38 |
+
Projects current frame pixels into the last keyframe camera using the
|
| 39 |
+
current depth map and both frames' poses, then returns the average
|
| 40 |
+
pixel displacement (L2 norm of flow) over valid pixels.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
cur_pose_enc: Current frame pose encoding [B, 1, 9].
|
| 44 |
+
kf_pose_enc: Last keyframe pose encoding [B, 1, 9].
|
| 45 |
+
cur_depth: Current frame depth map [B, 1, H, W, 1].
|
| 46 |
+
image_size_hw: (H, W) of the depth map.
|
| 47 |
+
stride: Subsampling stride for efficiency.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Mean flow magnitude in pixels (scalar float).
|
| 51 |
+
"""
|
| 52 |
+
H, W = image_size_hw
|
| 53 |
+
device = cur_pose_enc.device
|
| 54 |
+
dtype = cur_depth.dtype
|
| 55 |
+
|
| 56 |
+
cur_ext, cur_intr = pose_encoding_to_extri_intri(
|
| 57 |
+
cur_pose_enc, image_size_hw=image_size_hw
|
| 58 |
+
)
|
| 59 |
+
kf_ext, kf_intr = pose_encoding_to_extri_intri(
|
| 60 |
+
kf_pose_enc, image_size_hw=image_size_hw
|
| 61 |
+
)
|
| 62 |
+
B = cur_ext.shape[0]
|
| 63 |
+
|
| 64 |
+
cur_ext = cur_ext[:, 0]
|
| 65 |
+
cur_intr = cur_intr[:, 0]
|
| 66 |
+
kf_ext = kf_ext[:, 0]
|
| 67 |
+
kf_intr = kf_intr[:, 0]
|
| 68 |
+
|
| 69 |
+
depth = cur_depth[:, 0, ::stride, ::stride, 0].to(dtype)
|
| 70 |
+
Hs, Ws = depth.shape[1], depth.shape[2]
|
| 71 |
+
|
| 72 |
+
v_coords = torch.arange(0, H, stride, device=device, dtype=dtype)
|
| 73 |
+
u_coords = torch.arange(0, W, stride, device=device, dtype=dtype)
|
| 74 |
+
v_grid, u_grid = torch.meshgrid(v_coords, u_coords, indexing='ij')
|
| 75 |
+
ones = torch.ones_like(u_grid)
|
| 76 |
+
pixel_coords = torch.stack([u_grid, v_grid, ones], dim=-1)
|
| 77 |
+
|
| 78 |
+
intr_inv = torch.inverse(cur_intr)
|
| 79 |
+
cam_coords = torch.einsum('bij,hwj->bhwi', intr_inv, pixel_coords)
|
| 80 |
+
cam_pts = cam_coords * depth.unsqueeze(-1)
|
| 81 |
+
|
| 82 |
+
c2w = torch.zeros(B, 4, 4, device=device, dtype=dtype)
|
| 83 |
+
c2w[:, :3, :] = cur_ext
|
| 84 |
+
c2w[:, 3, 3] = 1.0
|
| 85 |
+
|
| 86 |
+
ones_hw = torch.ones(B, Hs, Ws, 1, device=device, dtype=dtype)
|
| 87 |
+
cam_pts_h = torch.cat([cam_pts, ones_hw], dim=-1)
|
| 88 |
+
world_pts = torch.einsum('bij,bhwj->bhwi', c2w, cam_pts_h)[..., :3]
|
| 89 |
+
|
| 90 |
+
kf_c2w = torch.zeros(B, 4, 4, device=device, dtype=dtype)
|
| 91 |
+
kf_c2w[:, :3, :] = kf_ext
|
| 92 |
+
kf_c2w[:, 3, 3] = 1.0
|
| 93 |
+
kf_w2c = closed_form_inverse_se3(kf_c2w)
|
| 94 |
+
world_pts_h = torch.cat([world_pts, ones_hw], dim=-1)
|
| 95 |
+
kf_cam_pts = torch.einsum('bij,bhwj->bhwi', kf_w2c, world_pts_h)[..., :3]
|
| 96 |
+
|
| 97 |
+
z = kf_cam_pts[..., 2:3].clamp(min=1e-6)
|
| 98 |
+
kf_cam_norm = kf_cam_pts / z
|
| 99 |
+
kf_pixels = torch.einsum('bij,bhwj->bhwi', kf_intr, kf_cam_norm)[..., :2]
|
| 100 |
+
|
| 101 |
+
orig_pixels = torch.stack([u_grid, v_grid], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
|
| 102 |
+
|
| 103 |
+
flow = kf_pixels - orig_pixels
|
| 104 |
+
valid = (depth > 1e-6) & (kf_cam_pts[..., 2] > 1e-6)
|
| 105 |
+
|
| 106 |
+
flow_mag = flow.norm(dim=-1)
|
| 107 |
+
valid_count = valid.float().sum()
|
| 108 |
+
if valid_count < 1:
|
| 109 |
+
return 0.0
|
| 110 |
+
|
| 111 |
+
mean_mag = (flow_mag * valid.float()).sum() / valid_count
|
| 112 |
+
return mean_mag.item()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class GCTStream(GCTBase):
|
| 116 |
+
"""
|
| 117 |
+
Streaming GCT model with KV cache for efficient online inference.
|
| 118 |
+
|
| 119 |
+
Features:
|
| 120 |
+
- AggregatorStream with KV cache support (FlashInfer backend)
|
| 121 |
+
- CameraCausalHead for pose refinement
|
| 122 |
+
- Sliding window attention for memory efficiency
|
| 123 |
+
- Frame-by-frame streaming inference
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
# Architecture parameters
|
| 129 |
+
img_size: int = 518,
|
| 130 |
+
patch_size: int = 14,
|
| 131 |
+
embed_dim: int = 1024,
|
| 132 |
+
patch_embed: str = 'dinov2_vitl14_reg',
|
| 133 |
+
pretrained_path: str = '',
|
| 134 |
+
disable_global_rope: bool = False,
|
| 135 |
+
# Head configuration
|
| 136 |
+
enable_camera: bool = True,
|
| 137 |
+
enable_point: bool = True,
|
| 138 |
+
enable_local_point: bool = False,
|
| 139 |
+
enable_depth: bool = True,
|
| 140 |
+
enable_track: bool = False,
|
| 141 |
+
# Normalization
|
| 142 |
+
enable_normalize: bool = False,
|
| 143 |
+
# Prediction normalization
|
| 144 |
+
pred_normalization: bool = False,
|
| 145 |
+
# Stream-specific parameters
|
| 146 |
+
sliding_window_size: int = -1,
|
| 147 |
+
num_frame_for_scale: int = 1,
|
| 148 |
+
num_random_frames: int = 0,
|
| 149 |
+
attend_to_special_tokens: bool = False,
|
| 150 |
+
attend_to_scale_frames: bool = False,
|
| 151 |
+
enable_stream_inference: bool = True, # Default to True for streaming
|
| 152 |
+
enable_3d_rope: bool = False,
|
| 153 |
+
max_frame_num: int = 1024,
|
| 154 |
+
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
|
| 155 |
+
enable_camera_3d_rope: bool = False,
|
| 156 |
+
camera_rope_theta: float = 10000.0,
|
| 157 |
+
# Scale token configuration (kept for checkpoint compat, ignored)
|
| 158 |
+
use_scale_token: bool = True,
|
| 159 |
+
# KV cache parameters
|
| 160 |
+
kv_cache_sliding_window: int = 64,
|
| 161 |
+
kv_cache_scale_frames: int = 8,
|
| 162 |
+
kv_cache_cross_frame_special: bool = True,
|
| 163 |
+
kv_cache_include_scale_frames: bool = True,
|
| 164 |
+
kv_cache_camera_only: bool = False,
|
| 165 |
+
# Backend selection
|
| 166 |
+
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
| 167 |
+
# Gradient checkpointing
|
| 168 |
+
use_gradient_checkpoint: bool = True,
|
| 169 |
+
# Camera head iterative refinement (lower = faster inference; default 4)
|
| 170 |
+
camera_num_iterations: int = 4,
|
| 171 |
+
):
|
| 172 |
+
"""
|
| 173 |
+
Initialize GCTStream.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
img_size: Input image size
|
| 177 |
+
patch_size: Patch size for embedding
|
| 178 |
+
embed_dim: Embedding dimension
|
| 179 |
+
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
|
| 180 |
+
pretrained_path: Path to pretrained DINOv2 weights
|
| 181 |
+
disable_global_rope: Disable RoPE in global attention
|
| 182 |
+
enable_camera/point/depth/track: Enable prediction heads
|
| 183 |
+
enable_normalize: Enable normalization
|
| 184 |
+
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
| 185 |
+
num_frame_for_scale: Number of scale estimation frames
|
| 186 |
+
num_random_frames: Number of random frames for long-range dependencies
|
| 187 |
+
attend_to_special_tokens: Enable cross-frame special token attention
|
| 188 |
+
attend_to_scale_frames: Whether to attend to scale frames
|
| 189 |
+
enable_stream_inference: Enable streaming inference with KV cache
|
| 190 |
+
enable_3d_rope: Enable 3D RoPE for temporal consistency
|
| 191 |
+
max_frame_num: Maximum number of frames for 3D RoPE
|
| 192 |
+
use_scale_token: Kept for checkpoint compatibility, ignored
|
| 193 |
+
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
| 194 |
+
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
| 195 |
+
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
| 196 |
+
kv_cache_include_scale_frames: Include scale frames in KV cache
|
| 197 |
+
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
| 198 |
+
"""
|
| 199 |
+
# Store stream-specific parameters before calling super().__init__()
|
| 200 |
+
self.pretrained_path = pretrained_path
|
| 201 |
+
self.sliding_window_size = sliding_window_size
|
| 202 |
+
self.num_frame_for_scale = num_frame_for_scale
|
| 203 |
+
self.num_random_frames = num_random_frames
|
| 204 |
+
self.attend_to_special_tokens = attend_to_special_tokens
|
| 205 |
+
self.attend_to_scale_frames = attend_to_scale_frames
|
| 206 |
+
self.enable_stream_inference = enable_stream_inference
|
| 207 |
+
self.enable_3d_rope = enable_3d_rope
|
| 208 |
+
self.max_frame_num = max_frame_num
|
| 209 |
+
# Camera head 3D RoPE settings
|
| 210 |
+
self.enable_camera_3d_rope = enable_camera_3d_rope
|
| 211 |
+
self.camera_rope_theta = camera_rope_theta
|
| 212 |
+
# KV cache parameters
|
| 213 |
+
self.kv_cache_sliding_window = kv_cache_sliding_window
|
| 214 |
+
self.kv_cache_scale_frames = kv_cache_scale_frames
|
| 215 |
+
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
| 216 |
+
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
| 217 |
+
self.kv_cache_camera_only = kv_cache_camera_only
|
| 218 |
+
self.use_sdpa = use_sdpa
|
| 219 |
+
self.camera_num_iterations = camera_num_iterations
|
| 220 |
+
|
| 221 |
+
# Call base class __init__ (will call _build_aggregator)
|
| 222 |
+
super().__init__(
|
| 223 |
+
img_size=img_size,
|
| 224 |
+
patch_size=patch_size,
|
| 225 |
+
embed_dim=embed_dim,
|
| 226 |
+
patch_embed=patch_embed,
|
| 227 |
+
disable_global_rope=disable_global_rope,
|
| 228 |
+
enable_camera=enable_camera,
|
| 229 |
+
enable_point=enable_point,
|
| 230 |
+
enable_local_point=enable_local_point,
|
| 231 |
+
enable_depth=enable_depth,
|
| 232 |
+
enable_track=enable_track,
|
| 233 |
+
enable_normalize=enable_normalize,
|
| 234 |
+
pred_normalization=pred_normalization,
|
| 235 |
+
enable_3d_rope=enable_3d_rope,
|
| 236 |
+
use_gradient_checkpoint=use_gradient_checkpoint,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def _build_aggregator(self) -> nn.Module:
|
| 240 |
+
"""
|
| 241 |
+
Build streaming aggregator with KV cache support (FlashInfer backend).
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
AggregatorStream module
|
| 245 |
+
"""
|
| 246 |
+
return AggregatorStream(
|
| 247 |
+
img_size=self.img_size,
|
| 248 |
+
patch_size=self.patch_size,
|
| 249 |
+
embed_dim=self.embed_dim,
|
| 250 |
+
patch_embed=self.patch_embed,
|
| 251 |
+
pretrained_path=self.pretrained_path,
|
| 252 |
+
disable_global_rope=self.disable_global_rope,
|
| 253 |
+
sliding_window_size=self.sliding_window_size,
|
| 254 |
+
num_frame_for_scale=self.num_frame_for_scale,
|
| 255 |
+
num_random_frames=self.num_random_frames,
|
| 256 |
+
attend_to_special_tokens=self.attend_to_special_tokens,
|
| 257 |
+
attend_to_scale_frames=self.attend_to_scale_frames,
|
| 258 |
+
enable_stream_inference=self.enable_stream_inference,
|
| 259 |
+
enable_3d_rope=self.enable_3d_rope,
|
| 260 |
+
max_frame_num=self.max_frame_num,
|
| 261 |
+
# Backend: FlashInfer (default) or SDPA (fallback)
|
| 262 |
+
use_flashinfer=not self.use_sdpa,
|
| 263 |
+
use_sdpa=self.use_sdpa,
|
| 264 |
+
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
| 265 |
+
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
| 266 |
+
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
| 267 |
+
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
| 268 |
+
kv_cache_camera_only=self.kv_cache_camera_only,
|
| 269 |
+
use_gradient_checkpoint=self.use_gradient_checkpoint,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def _build_camera_head(self) -> nn.Module:
|
| 273 |
+
"""
|
| 274 |
+
Build causal camera head for streaming inference.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
CameraCausalHead module or None
|
| 278 |
+
"""
|
| 279 |
+
return CameraCausalHead(
|
| 280 |
+
dim_in=2 * self.embed_dim,
|
| 281 |
+
sliding_window_size=self.sliding_window_size,
|
| 282 |
+
attend_to_scale_frames=self.attend_to_scale_frames,
|
| 283 |
+
num_iterations=self.camera_num_iterations,
|
| 284 |
+
# KV cache parameters
|
| 285 |
+
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
| 286 |
+
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
| 287 |
+
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
| 288 |
+
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
| 289 |
+
kv_cache_camera_only=self.kv_cache_camera_only,
|
| 290 |
+
# Camera head 3D RoPE parameters
|
| 291 |
+
enable_3d_rope=self.enable_camera_3d_rope,
|
| 292 |
+
max_frame_num=self.max_frame_num,
|
| 293 |
+
rope_theta=self.camera_rope_theta,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def _aggregate_features(
|
| 297 |
+
self,
|
| 298 |
+
images: torch.Tensor,
|
| 299 |
+
num_frame_for_scale: Optional[int] = None,
|
| 300 |
+
sliding_window_size: Optional[int] = None,
|
| 301 |
+
num_frame_per_block: int = 1,
|
| 302 |
+
**kwargs,
|
| 303 |
+
) -> tuple:
|
| 304 |
+
"""
|
| 305 |
+
Run aggregator to get multi-scale features.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
images: Input images [B, S, 3, H, W]
|
| 309 |
+
num_frame_for_scale: Number of frames for scale estimation
|
| 310 |
+
sliding_window_size: Override sliding window size
|
| 311 |
+
num_frame_per_block: Number of frames per block
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
(aggregated_tokens_list, patch_start_idx)
|
| 315 |
+
"""
|
| 316 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(
|
| 317 |
+
images,
|
| 318 |
+
selected_idx=[4, 11, 17, 23],
|
| 319 |
+
num_frame_for_scale=num_frame_for_scale,
|
| 320 |
+
sliding_window_size=sliding_window_size,
|
| 321 |
+
num_frame_per_block=num_frame_per_block,
|
| 322 |
+
)
|
| 323 |
+
return aggregated_tokens_list, patch_start_idx
|
| 324 |
+
|
| 325 |
+
def clean_kv_cache(self):
|
| 326 |
+
"""
|
| 327 |
+
Clean KV cache in aggregator.
|
| 328 |
+
|
| 329 |
+
Call this method when starting a new video sequence to clear
|
| 330 |
+
cached key-value pairs from previous sequences.
|
| 331 |
+
"""
|
| 332 |
+
if hasattr(self.aggregator, 'clean_kv_cache'):
|
| 333 |
+
self.aggregator.clean_kv_cache()
|
| 334 |
+
else:
|
| 335 |
+
logger.warning("Aggregator does not support KV cache cleaning")
|
| 336 |
+
if hasattr(self.camera_head, 'kv_cache'):
|
| 337 |
+
self.camera_head.clean_kv_cache()
|
| 338 |
+
else:
|
| 339 |
+
logger.warning("Camera head does not support KV cache cleaning")
|
| 340 |
+
|
| 341 |
+
def _set_skip_append(self, skip: bool):
|
| 342 |
+
"""Set _skip_append flag on all KV caches (aggregator + camera head).
|
| 343 |
+
|
| 344 |
+
When skip=True, attention layers will attend to [cached_kv + current_kv]
|
| 345 |
+
but will NOT store the current frame's KV in cache. This is used for
|
| 346 |
+
non-keyframe processing in keyframe-based streaming inference.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
skip: If True, subsequent forward passes will not append KV to cache.
|
| 350 |
+
"""
|
| 351 |
+
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
|
| 352 |
+
self.aggregator.kv_cache["_skip_append"] = skip
|
| 353 |
+
# FlashInfer manager
|
| 354 |
+
if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
|
| 355 |
+
self.aggregator.kv_cache_manager._skip_append = skip
|
| 356 |
+
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
|
| 357 |
+
for cache_dict in self.camera_head.kv_cache:
|
| 358 |
+
cache_dict["_skip_append"] = skip
|
| 359 |
+
|
| 360 |
+
# ── Flow-based keyframe helpers ────────────────────────────────────────
|
| 361 |
+
|
| 362 |
+
def _set_defer_eviction(self, defer: bool):
|
| 363 |
+
"""Set defer-eviction flag on FlashInfer manager and SDPA caches.
|
| 364 |
+
|
| 365 |
+
While True, eviction is suppressed so that rollback can cleanly undo
|
| 366 |
+
the most recent append without having to restore evicted frames.
|
| 367 |
+
"""
|
| 368 |
+
# FlashInfer manager
|
| 369 |
+
if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
|
| 370 |
+
self.aggregator.kv_cache_manager._defer_eviction = defer
|
| 371 |
+
# SDPA aggregator cache (dict)
|
| 372 |
+
if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict):
|
| 373 |
+
self.aggregator.kv_cache["_defer_eviction"] = defer
|
| 374 |
+
# Camera head SDPA caches
|
| 375 |
+
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
|
| 376 |
+
for cache_dict in self.camera_head.kv_cache:
|
| 377 |
+
cache_dict["_defer_eviction"] = defer
|
| 378 |
+
|
| 379 |
+
def _rollback_last_frame(self):
|
| 380 |
+
"""Rollback the most recent frame from all caches.
|
| 381 |
+
|
| 382 |
+
Undoes append_frame on FlashInfer manager (all blocks), trims the
|
| 383 |
+
camera head SDPA cache, and decrements the aggregator frame counter.
|
| 384 |
+
Must be called while eviction is still deferred.
|
| 385 |
+
"""
|
| 386 |
+
# FlashInfer manager — rollback each transformer block
|
| 387 |
+
if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
|
| 388 |
+
mgr = self.aggregator.kv_cache_manager
|
| 389 |
+
for block_idx in range(mgr.num_blocks):
|
| 390 |
+
mgr.rollback_last_frame(block_idx)
|
| 391 |
+
|
| 392 |
+
# SDPA aggregator cache — trim last frame along dim=2
|
| 393 |
+
if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict):
|
| 394 |
+
kv = self.aggregator.kv_cache
|
| 395 |
+
for key in list(kv.keys()):
|
| 396 |
+
if key.startswith(("k_", "v_")) and kv[key] is not None and torch.is_tensor(kv[key]):
|
| 397 |
+
if kv[key].dim() >= 3 and kv[key].shape[2] > 1:
|
| 398 |
+
kv[key] = kv[key][:, :, :-1]
|
| 399 |
+
elif kv[key].dim() >= 3:
|
| 400 |
+
kv[key] = None
|
| 401 |
+
|
| 402 |
+
# Camera head
|
| 403 |
+
if self.camera_head is not None and hasattr(self.camera_head, 'rollback_last_frame'):
|
| 404 |
+
self.camera_head.rollback_last_frame()
|
| 405 |
+
|
| 406 |
+
# Aggregator frame counter (used for 3D RoPE temporal positions)
|
| 407 |
+
self.aggregator.total_frames_processed -= 1
|
| 408 |
+
|
| 409 |
+
def _execute_deferred_eviction(self):
|
| 410 |
+
"""Execute the eviction that was deferred during the last forward pass."""
|
| 411 |
+
# FlashInfer manager
|
| 412 |
+
if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
|
| 413 |
+
mgr = self.aggregator.kv_cache_manager
|
| 414 |
+
for block_idx in range(mgr.num_blocks):
|
| 415 |
+
mgr.execute_deferred_eviction(
|
| 416 |
+
block_idx,
|
| 417 |
+
scale_frames=self.kv_cache_scale_frames,
|
| 418 |
+
sliding_window=self.kv_cache_sliding_window,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
def get_kv_cache_info(self) -> Dict[str, Any]:
|
| 422 |
+
"""
|
| 423 |
+
Get information about current KV cache state.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
Dictionary with cache statistics:
|
| 427 |
+
- num_cached_blocks: Number of blocks with cached KV
|
| 428 |
+
- cache_memory_mb: Approximate memory usage in MB
|
| 429 |
+
"""
|
| 430 |
+
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
|
| 431 |
+
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
|
| 432 |
+
|
| 433 |
+
kv_cache = self.aggregator.kv_cache
|
| 434 |
+
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
|
| 435 |
+
|
| 436 |
+
# Estimate memory usage
|
| 437 |
+
total_elements = 0
|
| 438 |
+
for _, v in kv_cache.items():
|
| 439 |
+
if v is not None and torch.is_tensor(v):
|
| 440 |
+
total_elements += v.numel()
|
| 441 |
+
|
| 442 |
+
# Assume bfloat16 (2 bytes per element)
|
| 443 |
+
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
|
| 444 |
+
|
| 445 |
+
return {
|
| 446 |
+
"num_cached_blocks": num_cached,
|
| 447 |
+
"cache_memory_mb": round(cache_memory_mb, 2)
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
@torch.no_grad()
|
| 451 |
+
def inference_streaming(
|
| 452 |
+
self,
|
| 453 |
+
images: torch.Tensor,
|
| 454 |
+
num_scale_frames: Optional[int] = None,
|
| 455 |
+
keyframe_interval: int = 1,
|
| 456 |
+
output_device: Optional[torch.device] = None,
|
| 457 |
+
flow_threshold: float = 0.0,
|
| 458 |
+
max_non_keyframe_gap: int = 30,
|
| 459 |
+
) -> Dict[str, torch.Tensor]:
|
| 460 |
+
"""
|
| 461 |
+
Streaming inference: process scale frames first, then frame-by-frame.
|
| 462 |
+
|
| 463 |
+
This method enables efficient online inference by:
|
| 464 |
+
1. Processing initial scale frames together (bidirectional attention via scale token)
|
| 465 |
+
2. Processing remaining frames one-by-one with KV cache (causal streaming)
|
| 466 |
+
|
| 467 |
+
Keyframe mode (keyframe_interval > 1):
|
| 468 |
+
- Every keyframe_interval-th frame (after scale frames) is a keyframe
|
| 469 |
+
- Keyframes: KV is stored in cache (normal behavior)
|
| 470 |
+
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
|
| 471 |
+
- All frames produce full predictions regardless of keyframe status
|
| 472 |
+
- Reduces KV cache memory growth by ~1/keyframe_interval
|
| 473 |
+
|
| 474 |
+
Flow-based keyframe mode (flow_threshold > 0):
|
| 475 |
+
- Takes precedence over keyframe_interval
|
| 476 |
+
- Computes optical flow magnitude between current frame and last keyframe
|
| 477 |
+
- Frame becomes keyframe if flow exceeds threshold or gap exceeds max_non_keyframe_gap
|
| 478 |
+
- Uses defer-eviction + rollback for non-keyframes
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
| 482 |
+
num_scale_frames: Number of initial frames for scale estimation.
|
| 483 |
+
If None, uses self.num_frame_for_scale.
|
| 484 |
+
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
|
| 485 |
+
whose KV persists in cache. 1 = every frame is a
|
| 486 |
+
keyframe (default, same as original behavior).
|
| 487 |
+
output_device: Device to store output predictions on. If None, keeps on
|
| 488 |
+
the same device as the model. Set to torch.device('cpu')
|
| 489 |
+
to offload predictions per-frame and avoid GPU OOM on
|
| 490 |
+
long sequences.
|
| 491 |
+
flow_threshold: Mean flow magnitude threshold (pixels) for flow-based
|
| 492 |
+
keyframe selection. >0 enables flow-based mode (takes precedence
|
| 493 |
+
over keyframe_interval).
|
| 494 |
+
max_non_keyframe_gap: Max consecutive non-keyframe frames before
|
| 495 |
+
forcing a keyframe (flow mode only).
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
Dictionary containing predictions for all frames:
|
| 499 |
+
- pose_enc: [B, S, 9]
|
| 500 |
+
- depth: [B, S, H, W, 1]
|
| 501 |
+
- depth_conf: [B, S, H, W]
|
| 502 |
+
- world_points: [B, S, H, W, 3]
|
| 503 |
+
- world_points_conf: [B, S, H, W]
|
| 504 |
+
"""
|
| 505 |
+
# Normalize input shape
|
| 506 |
+
if len(images.shape) == 4:
|
| 507 |
+
images = images.unsqueeze(0)
|
| 508 |
+
B, S, C, H, W = images.shape
|
| 509 |
+
|
| 510 |
+
# Determine number of scale frames
|
| 511 |
+
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
|
| 512 |
+
scale_frames = min(scale_frames, S) # Cap to available frames
|
| 513 |
+
|
| 514 |
+
# Helper to move tensor to output device
|
| 515 |
+
def _to_out(t: torch.Tensor) -> torch.Tensor:
|
| 516 |
+
if output_device is not None:
|
| 517 |
+
return t.to(output_device)
|
| 518 |
+
return t
|
| 519 |
+
|
| 520 |
+
# Clean KV caches before starting new sequence
|
| 521 |
+
self.clean_kv_cache()
|
| 522 |
+
|
| 523 |
+
# Phase 1: Process scale frames together
|
| 524 |
+
# These frames get bidirectional attention among themselves via scale token
|
| 525 |
+
logger.info(f'Processing {scale_frames} scale frames...')
|
| 526 |
+
scale_images = images[:, :scale_frames]
|
| 527 |
+
scale_output = self.forward(
|
| 528 |
+
scale_images,
|
| 529 |
+
num_frame_for_scale=scale_frames,
|
| 530 |
+
num_frame_per_block=scale_frames, # Process all scale frames as one block
|
| 531 |
+
causal_inference=True,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Initialize output lists with scale frame predictions (offload if needed)
|
| 535 |
+
all_pose_enc = [_to_out(scale_output["pose_enc"])]
|
| 536 |
+
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
|
| 537 |
+
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
|
| 538 |
+
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
|
| 539 |
+
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
|
| 540 |
+
del scale_output
|
| 541 |
+
|
| 542 |
+
# Phase 2: Process remaining frames one-by-one
|
| 543 |
+
use_flow_keyframe = flow_threshold > 0.0
|
| 544 |
+
|
| 545 |
+
# Flow state: last keyframe = last scale frame
|
| 546 |
+
if use_flow_keyframe:
|
| 547 |
+
last_kf_pose_enc = all_pose_enc[0][:, -1:] # last scale frame
|
| 548 |
+
last_kf_idx = scale_frames - 1
|
| 549 |
+
|
| 550 |
+
pbar = tqdm(
|
| 551 |
+
range(scale_frames, S),
|
| 552 |
+
desc='Streaming inference',
|
| 553 |
+
initial=scale_frames,
|
| 554 |
+
total=S,
|
| 555 |
+
)
|
| 556 |
+
for i in pbar:
|
| 557 |
+
frame_image = images[:, i:i+1]
|
| 558 |
+
|
| 559 |
+
if use_flow_keyframe:
|
| 560 |
+
# Flow-based: defer eviction, forward, then decide
|
| 561 |
+
self._set_defer_eviction(True)
|
| 562 |
+
|
| 563 |
+
frame_output = self.forward(
|
| 564 |
+
frame_image,
|
| 565 |
+
num_frame_for_scale=scale_frames,
|
| 566 |
+
num_frame_per_block=1,
|
| 567 |
+
causal_inference=True,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
self._set_defer_eviction(False)
|
| 571 |
+
|
| 572 |
+
# Compute flow to decide keyframe
|
| 573 |
+
cur_depth = frame_output.get("depth", None)
|
| 574 |
+
if cur_depth is not None:
|
| 575 |
+
H_pred, W_pred = cur_depth.shape[2], cur_depth.shape[3]
|
| 576 |
+
flow_mag = _compute_flow_magnitude(
|
| 577 |
+
frame_output["pose_enc"], last_kf_pose_enc,
|
| 578 |
+
cur_depth, (H_pred, W_pred),
|
| 579 |
+
)
|
| 580 |
+
else:
|
| 581 |
+
flow_mag = flow_threshold + 1.0
|
| 582 |
+
|
| 583 |
+
frames_since_kf = i - last_kf_idx
|
| 584 |
+
is_keyframe = (
|
| 585 |
+
(i == scale_frames) # first streaming frame
|
| 586 |
+
or (flow_mag > flow_threshold)
|
| 587 |
+
or (frames_since_kf >= max_non_keyframe_gap)
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
if is_keyframe:
|
| 591 |
+
self._execute_deferred_eviction()
|
| 592 |
+
last_kf_pose_enc = frame_output["pose_enc"]
|
| 593 |
+
last_kf_idx = i
|
| 594 |
+
else:
|
| 595 |
+
self._rollback_last_frame()
|
| 596 |
+
else:
|
| 597 |
+
# Fixed-interval keyframe mode
|
| 598 |
+
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
|
| 599 |
+
|
| 600 |
+
if not is_keyframe:
|
| 601 |
+
self._set_skip_append(True)
|
| 602 |
+
|
| 603 |
+
frame_output = self.forward(
|
| 604 |
+
frame_image,
|
| 605 |
+
num_frame_for_scale=scale_frames,
|
| 606 |
+
num_frame_per_block=1,
|
| 607 |
+
causal_inference=True,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if not is_keyframe:
|
| 611 |
+
self._set_skip_append(False)
|
| 612 |
+
|
| 613 |
+
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
|
| 614 |
+
if "depth" in frame_output:
|
| 615 |
+
all_depth.append(_to_out(frame_output["depth"]))
|
| 616 |
+
if "depth_conf" in frame_output:
|
| 617 |
+
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
|
| 618 |
+
if "world_points" in frame_output:
|
| 619 |
+
all_world_points.append(_to_out(frame_output["world_points"]))
|
| 620 |
+
if "world_points_conf" in frame_output:
|
| 621 |
+
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
|
| 622 |
+
del frame_output
|
| 623 |
+
|
| 624 |
+
# Free GPU memory before concatenation
|
| 625 |
+
if output_device is not None:
|
| 626 |
+
# Move images to output device, then free GPU copy
|
| 627 |
+
images_out = _to_out(images)
|
| 628 |
+
del images
|
| 629 |
+
# Clean KV cache (no longer needed after inference)
|
| 630 |
+
self.clean_kv_cache()
|
| 631 |
+
if torch.cuda.is_available():
|
| 632 |
+
torch.cuda.empty_cache()
|
| 633 |
+
else:
|
| 634 |
+
images_out = images
|
| 635 |
+
|
| 636 |
+
# Concatenate all predictions along sequence dimension
|
| 637 |
+
predictions = {
|
| 638 |
+
"pose_enc": torch.cat(all_pose_enc, dim=1),
|
| 639 |
+
}
|
| 640 |
+
del all_pose_enc
|
| 641 |
+
if all_depth:
|
| 642 |
+
predictions["depth"] = torch.cat(all_depth, dim=1)
|
| 643 |
+
del all_depth
|
| 644 |
+
if all_depth_conf:
|
| 645 |
+
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
|
| 646 |
+
del all_depth_conf
|
| 647 |
+
if all_world_points:
|
| 648 |
+
predictions["world_points"] = torch.cat(all_world_points, dim=1)
|
| 649 |
+
del all_world_points
|
| 650 |
+
if all_world_points_conf:
|
| 651 |
+
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
|
| 652 |
+
del all_world_points_conf
|
| 653 |
+
|
| 654 |
+
# Store images for visualization
|
| 655 |
+
predictions["images"] = images_out
|
| 656 |
+
|
| 657 |
+
# Apply prediction normalization if enabled
|
| 658 |
+
if self.pred_normalization:
|
| 659 |
+
predictions = self._normalize_predictions(predictions)
|
| 660 |
+
|
| 661 |
+
return predictions
|
| 662 |
+
|
| 663 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 664 |
+
# Window stitching & cross-window alignment
|
| 665 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 666 |
+
|
| 667 |
+
_FRAME_AXIS_KEYS = frozenset({
|
| 668 |
+
"pose_enc", "depth", "depth_conf",
|
| 669 |
+
"world_points", "world_points_conf",
|
| 670 |
+
"frame_type", "is_keyframe",
|
| 671 |
+
})
|
| 672 |
+
|
| 673 |
+
def _stitch_windows(
|
| 674 |
+
self,
|
| 675 |
+
windows: List[Dict],
|
| 676 |
+
window_size: int,
|
| 677 |
+
overlap: int,
|
| 678 |
+
) -> Dict:
|
| 679 |
+
"""Concatenate per-window predictions while de-duplicating overlaps.
|
| 680 |
+
|
| 681 |
+
For each temporal key the method builds a slice table first — every
|
| 682 |
+
window contributes ``[0, effective_end)`` frames where
|
| 683 |
+
``effective_end = total_frames - overlap`` for non-final windows.
|
| 684 |
+
Non-temporal entries simply keep the latest available value.
|
| 685 |
+
"""
|
| 686 |
+
if len(windows) == 0:
|
| 687 |
+
return {}
|
| 688 |
+
if len(windows) == 1:
|
| 689 |
+
return windows[0]
|
| 690 |
+
|
| 691 |
+
n_win = len(windows)
|
| 692 |
+
all_keys = list(windows[0].keys())
|
| 693 |
+
stitched: Dict = {}
|
| 694 |
+
|
| 695 |
+
for key in all_keys:
|
| 696 |
+
values = [w.get(key) for w in windows]
|
| 697 |
+
if all(v is None for v in values):
|
| 698 |
+
continue
|
| 699 |
+
|
| 700 |
+
# Non-temporal entries: take latest
|
| 701 |
+
if key not in self._FRAME_AXIS_KEYS:
|
| 702 |
+
stitched[key] = next(v for v in reversed(values) if v is not None)
|
| 703 |
+
continue
|
| 704 |
+
|
| 705 |
+
# Build slice table: (start, end) for each window's contribution
|
| 706 |
+
slices = []
|
| 707 |
+
for wi, tensor in enumerate(values):
|
| 708 |
+
if tensor is None:
|
| 709 |
+
slices.append(None)
|
| 710 |
+
continue
|
| 711 |
+
total = tensor.shape[1]
|
| 712 |
+
is_last = (wi == n_win - 1)
|
| 713 |
+
end = total if is_last else max(total - overlap, 0)
|
| 714 |
+
slices.append((0, end) if end > 0 else None)
|
| 715 |
+
|
| 716 |
+
parts = [
|
| 717 |
+
values[i][:, s:e]
|
| 718 |
+
for i, s_e in enumerate(slices)
|
| 719 |
+
if s_e is not None
|
| 720 |
+
for s, e in [s_e]
|
| 721 |
+
]
|
| 722 |
+
if parts:
|
| 723 |
+
stitched[key] = torch.cat(parts, dim=1)
|
| 724 |
+
else:
|
| 725 |
+
fallback = next((v for v in reversed(values) if v is not None), None)
|
| 726 |
+
if fallback is not None:
|
| 727 |
+
stitched[key] = fallback
|
| 728 |
+
|
| 729 |
+
return stitched
|
| 730 |
+
|
| 731 |
+
@staticmethod
|
| 732 |
+
def _depth_ratio_scale(
|
| 733 |
+
anchor_depth: torch.Tensor,
|
| 734 |
+
target_depth: torch.Tensor,
|
| 735 |
+
batch_size: int,
|
| 736 |
+
device: torch.device,
|
| 737 |
+
) -> torch.Tensor:
|
| 738 |
+
"""Estimate per-batch scale as the median depth ratio anchor/target."""
|
| 739 |
+
a = anchor_depth.to(torch.float32).reshape(batch_size, -1)
|
| 740 |
+
t = target_depth.to(torch.float32).reshape(batch_size, -1)
|
| 741 |
+
ok = torch.isfinite(a) & torch.isfinite(t) & (t.abs() > torch.finfo(torch.float32).eps)
|
| 742 |
+
|
| 743 |
+
scales = []
|
| 744 |
+
for b in range(batch_size):
|
| 745 |
+
m = ok[b]
|
| 746 |
+
if m.any():
|
| 747 |
+
scales.append((a[b, m] / t[b, m]).median())
|
| 748 |
+
else:
|
| 749 |
+
scales.append(torch.tensor(1.0, device=device, dtype=torch.float32))
|
| 750 |
+
return torch.stack(scales).clamp(min=1e-3, max=1e3)
|
| 751 |
+
|
| 752 |
+
@staticmethod
|
| 753 |
+
def _pairwise_alignment(
|
| 754 |
+
prev_pred: Dict,
|
| 755 |
+
curr_pred: Dict,
|
| 756 |
+
overlap: int,
|
| 757 |
+
batch_size: int,
|
| 758 |
+
device: torch.device,
|
| 759 |
+
dtype: torch.dtype,
|
| 760 |
+
):
|
| 761 |
+
"""Compute (scale, R, t) that maps *curr* into *prev*'s coordinate frame.
|
| 762 |
+
|
| 763 |
+
Uses the first overlap frame of *curr* and the corresponding trailing
|
| 764 |
+
frame of *prev* to establish the similarity transform.
|
| 765 |
+
"""
|
| 766 |
+
unit_s = torch.ones(batch_size, device=device, dtype=dtype)
|
| 767 |
+
eye_R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1).clone()
|
| 768 |
+
zero_t = torch.zeros(batch_size, 3, device=device, dtype=dtype)
|
| 769 |
+
|
| 770 |
+
if overlap <= 0:
|
| 771 |
+
return unit_s, eye_R, zero_t
|
| 772 |
+
|
| 773 |
+
pe_prev = prev_pred.get("pose_enc")
|
| 774 |
+
pe_curr = curr_pred.get("pose_enc")
|
| 775 |
+
if pe_prev is None or pe_curr is None:
|
| 776 |
+
return unit_s, eye_R, zero_t
|
| 777 |
+
|
| 778 |
+
idx_a = max(pe_prev.shape[1] - overlap, 0)
|
| 779 |
+
|
| 780 |
+
# Decompose C2W: center ([:3]) + quaternion ([3:7])
|
| 781 |
+
Ra = quat_to_mat(pe_prev[:, idx_a, 3:7]) # (B, 3, 3)
|
| 782 |
+
ca = pe_prev[:, idx_a, :3] # (B, 3)
|
| 783 |
+
Rb = quat_to_mat(pe_curr[:, 0, 3:7])
|
| 784 |
+
cb = pe_curr[:, 0, :3]
|
| 785 |
+
|
| 786 |
+
R_ab = torch.bmm(Ra, Rb.transpose(1, 2)) # Ra = R_ab @ Rb
|
| 787 |
+
|
| 788 |
+
# Scale from depth
|
| 789 |
+
s_ab = unit_s.clone()
|
| 790 |
+
da = prev_pred.get("depth")
|
| 791 |
+
db = curr_pred.get("depth")
|
| 792 |
+
if (da is not None and db is not None
|
| 793 |
+
and da.shape[1] > idx_a and db.shape[1] > 0):
|
| 794 |
+
s_ab = GCTStream._depth_ratio_scale(
|
| 795 |
+
da[:, idx_a, ..., 0], db[:, 0, ..., 0],
|
| 796 |
+
batch_size, device,
|
| 797 |
+
).to(dtype)
|
| 798 |
+
|
| 799 |
+
# ca = s_ab * R_ab @ cb + t_ab => t_ab = ca - s_ab * R_ab @ cb
|
| 800 |
+
t_ab = ca - s_ab.unsqueeze(-1) * torch.bmm(R_ab, cb.unsqueeze(-1)).squeeze(-1)
|
| 801 |
+
|
| 802 |
+
return s_ab, R_ab.to(dtype), t_ab.to(dtype)
|
| 803 |
+
|
| 804 |
+
@staticmethod
|
| 805 |
+
def _warp_predictions(
|
| 806 |
+
pred: Dict,
|
| 807 |
+
R: torch.Tensor,
|
| 808 |
+
t: torch.Tensor,
|
| 809 |
+
s: torch.Tensor,
|
| 810 |
+
batch_size: int,
|
| 811 |
+
) -> Dict:
|
| 812 |
+
"""Apply a similarity transform (s, R, t) to one window's predictions."""
|
| 813 |
+
warped: Dict = {}
|
| 814 |
+
|
| 815 |
+
# Pose encoding: center + quaternion + intrinsics
|
| 816 |
+
pe = pred.get("pose_enc")
|
| 817 |
+
if pe is not None:
|
| 818 |
+
nf = pe.shape[1]
|
| 819 |
+
local_rot = quat_to_mat(pe[:, :, 3:7])
|
| 820 |
+
local_ctr = pe[:, :, :3]
|
| 821 |
+
|
| 822 |
+
R_exp = R[:, None].expand(-1, nf, -1, -1)
|
| 823 |
+
new_rot = torch.matmul(R_exp, local_rot)
|
| 824 |
+
new_ctr = (
|
| 825 |
+
s.view(batch_size, 1, 1) * torch.matmul(R_exp, local_ctr.unsqueeze(-1)).squeeze(-1)
|
| 826 |
+
+ t.view(batch_size, 1, 3)
|
| 827 |
+
)
|
| 828 |
+
out_pe = pe.clone()
|
| 829 |
+
out_pe[:, :, :3] = new_ctr
|
| 830 |
+
out_pe[:, :, 3:7] = mat_to_quat(new_rot)
|
| 831 |
+
warped["pose_enc"] = out_pe
|
| 832 |
+
else:
|
| 833 |
+
warped["pose_enc"] = None
|
| 834 |
+
|
| 835 |
+
# Depth: scale by s
|
| 836 |
+
d = pred.get("depth")
|
| 837 |
+
if d is not None:
|
| 838 |
+
warped["depth"] = d * s.view(batch_size, 1, 1, 1, 1)
|
| 839 |
+
else:
|
| 840 |
+
warped["depth"] = None
|
| 841 |
+
|
| 842 |
+
# World points: p_global = s * R @ p_local + t
|
| 843 |
+
wp = pred.get("world_points")
|
| 844 |
+
if wp is not None:
|
| 845 |
+
b, nf, h, w, _ = wp.shape
|
| 846 |
+
flat = wp.reshape(b, nf * h * w, 3)
|
| 847 |
+
transformed = torch.bmm(flat, R.transpose(1, 2)) * s.view(b, 1, 1)
|
| 848 |
+
transformed = transformed + t[:, None, :]
|
| 849 |
+
warped["world_points"] = transformed.reshape(b, nf, h, w, 3)
|
| 850 |
+
else:
|
| 851 |
+
warped["world_points"] = None
|
| 852 |
+
|
| 853 |
+
# Pass through all other keys untouched
|
| 854 |
+
for k, v in pred.items():
|
| 855 |
+
if k not in warped:
|
| 856 |
+
warped[k] = v
|
| 857 |
+
|
| 858 |
+
return warped
|
| 859 |
+
|
| 860 |
+
def _align_and_stitch_windows(
|
| 861 |
+
self,
|
| 862 |
+
windows: List[Dict],
|
| 863 |
+
scale_mode: str = 'median',
|
| 864 |
+
) -> Dict:
|
| 865 |
+
"""Bring all windows into the first window's coordinate frame, then stitch.
|
| 866 |
+
|
| 867 |
+
Iterates over consecutive window pairs, estimates the pairwise
|
| 868 |
+
scaled alignment, warps each window, and finally concatenates
|
| 869 |
+
via :meth:`_stitch_windows`.
|
| 870 |
+
"""
|
| 871 |
+
if len(windows) == 0:
|
| 872 |
+
return {}
|
| 873 |
+
if len(windows) == 1:
|
| 874 |
+
out = windows[0].copy()
|
| 875 |
+
out["alignment_mode"] = "scaled"
|
| 876 |
+
return out
|
| 877 |
+
|
| 878 |
+
# Discover batch / device / dtype from any available tensor
|
| 879 |
+
ref = next(
|
| 880 |
+
v
|
| 881 |
+
for w in windows
|
| 882 |
+
for k in ("pose_enc", "world_points", "depth")
|
| 883 |
+
if (v := w.get(k)) is not None
|
| 884 |
+
)
|
| 885 |
+
dev, dt, nb = ref.device, ref.dtype, ref.shape[0]
|
| 886 |
+
|
| 887 |
+
overlap = getattr(self, "_last_overlap_size", 0)
|
| 888 |
+
win_sz = getattr(self, "_last_window_size", -1)
|
| 889 |
+
|
| 890 |
+
warped_windows: List[Dict] = []
|
| 891 |
+
per_window_scales: List[torch.Tensor] = []
|
| 892 |
+
per_window_transforms: List[torch.Tensor] = []
|
| 893 |
+
|
| 894 |
+
for idx, raw in enumerate(windows):
|
| 895 |
+
if idx == 0:
|
| 896 |
+
s_rel = torch.ones(nb, device=dev, dtype=dt)
|
| 897 |
+
R_rel = torch.eye(3, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone()
|
| 898 |
+
t_rel = torch.zeros(nb, 3, device=dev, dtype=dt)
|
| 899 |
+
else:
|
| 900 |
+
s_rel, R_rel, t_rel = self._pairwise_alignment(
|
| 901 |
+
warped_windows[-1], raw, overlap, nb, dev, dt,
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
per_window_scales.append(s_rel.clone())
|
| 905 |
+
T = torch.eye(4, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone()
|
| 906 |
+
T[:, :3, :3] = R_rel
|
| 907 |
+
T[:, :3, 3] = t_rel
|
| 908 |
+
per_window_transforms.append(T)
|
| 909 |
+
|
| 910 |
+
warped_windows.append(
|
| 911 |
+
self._warp_predictions(raw, R_rel, t_rel, s_rel, nb)
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
merged = self._stitch_windows(warped_windows, win_sz, overlap)
|
| 915 |
+
|
| 916 |
+
# Attach alignment metadata
|
| 917 |
+
if per_window_scales:
|
| 918 |
+
merged["chunk_scales"] = torch.stack(per_window_scales, dim=1)
|
| 919 |
+
if per_window_transforms:
|
| 920 |
+
merged["chunk_transforms"] = torch.stack(per_window_transforms, dim=1)
|
| 921 |
+
merged["alignment_mode"] = "scaled"
|
| 922 |
+
return merged
|
| 923 |
+
|
| 924 |
+
@torch.no_grad()
|
| 925 |
+
def inference_windowed(
|
| 926 |
+
self,
|
| 927 |
+
images: torch.Tensor,
|
| 928 |
+
window_size: int = 16,
|
| 929 |
+
overlap_size: Optional[int] = None,
|
| 930 |
+
num_scale_frames: Optional[int] = None,
|
| 931 |
+
scale_mode: str = 'median',
|
| 932 |
+
output_device: Optional[torch.device] = None,
|
| 933 |
+
keyframe_interval: int = 1,
|
| 934 |
+
flow_threshold: float = 0.0,
|
| 935 |
+
max_non_keyframe_gap: int = 30,
|
| 936 |
+
) -> Dict[str, torch.Tensor]:
|
| 937 |
+
"""
|
| 938 |
+
Windowed inference with keyframe detection and cross-window alignment.
|
| 939 |
+
|
| 940 |
+
Each window is processed independently with a fresh KV cache.
|
| 941 |
+
Overlap frames between windows are the next window's scale frames
|
| 942 |
+
(bidirectional attention), ensuring the highest quality predictions
|
| 943 |
+
at alignment boundaries.
|
| 944 |
+
|
| 945 |
+
``window_size`` counts **keyframes** (frames stored in KV cache),
|
| 946 |
+
including scale frames. When ``keyframe_interval > 1``, each window
|
| 947 |
+
covers more actual frames than ``window_size``:
|
| 948 |
+
|
| 949 |
+
actual_frames = scale_frames + (window_size - scale_frames) * keyframe_interval
|
| 950 |
+
|
| 951 |
+
Args:
|
| 952 |
+
images: Input images [S, 3, H, W] or [B, S, 3, H, W] in [0, 1].
|
| 953 |
+
window_size: Number of **keyframes** per window (including scale
|
| 954 |
+
frames). Directly controls KV cache memory.
|
| 955 |
+
overlap_size: Number of overlapping frames between windows.
|
| 956 |
+
Defaults to ``num_scale_frames`` (overlap = scale frames).
|
| 957 |
+
num_scale_frames: Number of frames used as scale reference within
|
| 958 |
+
each window. Defaults to ``self.num_frame_for_scale``.
|
| 959 |
+
scale_mode: Scale estimation strategy for alignment.
|
| 960 |
+
output_device: Device to store per-window outputs.
|
| 961 |
+
keyframe_interval: Every N-th Phase 2 frame is a keyframe whose
|
| 962 |
+
KV persists in cache. 1 = every frame (default).
|
| 963 |
+
flow_threshold: Mean flow magnitude threshold (pixels) for
|
| 964 |
+
flow-based keyframe selection. >0 enables flow-based mode
|
| 965 |
+
(takes precedence over ``keyframe_interval``).
|
| 966 |
+
max_non_keyframe_gap: Max consecutive non-keyframe frames before
|
| 967 |
+
forcing a keyframe (flow mode only).
|
| 968 |
+
|
| 969 |
+
Returns:
|
| 970 |
+
Merged prediction dict with all frames.
|
| 971 |
+
"""
|
| 972 |
+
use_flow_keyframe = flow_threshold > 0.0
|
| 973 |
+
|
| 974 |
+
# Normalize input shape
|
| 975 |
+
if len(images.shape) == 4:
|
| 976 |
+
images = images.unsqueeze(0)
|
| 977 |
+
B, S, C, H, W = images.shape
|
| 978 |
+
|
| 979 |
+
ws = (num_scale_frames if num_scale_frames is not None
|
| 980 |
+
else self.num_frame_for_scale)
|
| 981 |
+
ws = min(ws, S)
|
| 982 |
+
|
| 983 |
+
# overlap = scale_frames by default
|
| 984 |
+
eff_overlap = min(overlap_size if overlap_size is not None else ws,
|
| 985 |
+
S - 1) if S > 1 else 0
|
| 986 |
+
|
| 987 |
+
def _to_out(t: torch.Tensor) -> torch.Tensor:
|
| 988 |
+
return t.to(output_device) if output_device is not None else t
|
| 989 |
+
|
| 990 |
+
def _collect_frame(out, w_lists):
|
| 991 |
+
w_lists['pose_enc'].append(_to_out(out["pose_enc"]))
|
| 992 |
+
if "depth" in out:
|
| 993 |
+
w_lists['depth'].append(_to_out(out["depth"]))
|
| 994 |
+
if "depth_conf" in out:
|
| 995 |
+
w_lists['depth_conf'].append(_to_out(out["depth_conf"]))
|
| 996 |
+
if "world_points" in out:
|
| 997 |
+
w_lists['world_points'].append(_to_out(out["world_points"]))
|
| 998 |
+
if "world_points_conf" in out:
|
| 999 |
+
w_lists['world_pts_conf'].append(_to_out(out["world_points_conf"]))
|
| 1000 |
+
|
| 1001 |
+
def _make_window_pred(w_lists):
|
| 1002 |
+
pred: Dict = {"pose_enc": torch.cat(w_lists['pose_enc'], dim=1)}
|
| 1003 |
+
if w_lists['depth']:
|
| 1004 |
+
pred["depth"] = torch.cat(w_lists['depth'], dim=1)
|
| 1005 |
+
if w_lists['depth_conf']:
|
| 1006 |
+
pred["depth_conf"] = torch.cat(w_lists['depth_conf'], dim=1)
|
| 1007 |
+
if w_lists['world_points']:
|
| 1008 |
+
pred["world_points"] = torch.cat(w_lists['world_points'], dim=1)
|
| 1009 |
+
if w_lists['world_pts_conf']:
|
| 1010 |
+
pred["world_points_conf"] = torch.cat(w_lists['world_pts_conf'], dim=1)
|
| 1011 |
+
# Frame type: 0=scale, 1=keyframe, 2=non-keyframe
|
| 1012 |
+
ft = torch.tensor(w_lists['frame_type'], dtype=torch.uint8).unsqueeze(0) # [1, T]
|
| 1013 |
+
pred["frame_type"] = ft
|
| 1014 |
+
pred["is_keyframe"] = (ft != 2) # scale + keyframe = True
|
| 1015 |
+
return pred
|
| 1016 |
+
|
| 1017 |
+
def _new_lists():
|
| 1018 |
+
return {
|
| 1019 |
+
'pose_enc': [], 'depth': [], 'depth_conf': [],
|
| 1020 |
+
'world_points': [], 'world_pts_conf': [],
|
| 1021 |
+
'frame_type': [], # list of ints: 0=scale, 1=keyframe, 2=non-keyframe
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
# ================================================================
|
| 1025 |
+
# Flow-based mode: dynamic windows (can't precompute window list)
|
| 1026 |
+
# ================================================================
|
| 1027 |
+
if use_flow_keyframe:
|
| 1028 |
+
all_window_predictions: List[Dict] = []
|
| 1029 |
+
cursor = 0
|
| 1030 |
+
window_idx = 0
|
| 1031 |
+
pbar = tqdm(total=S, desc='Windowed inference (flow)', initial=0)
|
| 1032 |
+
|
| 1033 |
+
while cursor < S:
|
| 1034 |
+
window_start = cursor
|
| 1035 |
+
window_scale = min(ws, S - cursor)
|
| 1036 |
+
|
| 1037 |
+
# Fresh KV cache
|
| 1038 |
+
self.clean_kv_cache()
|
| 1039 |
+
|
| 1040 |
+
# ---------- Phase 1: scale frames ----------
|
| 1041 |
+
scale_images = images[:, cursor:cursor + window_scale]
|
| 1042 |
+
scale_out = self.forward(
|
| 1043 |
+
scale_images,
|
| 1044 |
+
num_frame_for_scale=window_scale,
|
| 1045 |
+
num_frame_per_block=window_scale,
|
| 1046 |
+
causal_inference=True,
|
| 1047 |
+
)
|
| 1048 |
+
w_lists = _new_lists()
|
| 1049 |
+
_collect_frame(scale_out, w_lists)
|
| 1050 |
+
w_lists['frame_type'].extend([0] * window_scale) # scale frames
|
| 1051 |
+
|
| 1052 |
+
# Flow state: last keyframe = last scale frame
|
| 1053 |
+
last_kf_pose_enc = scale_out["pose_enc"][:, -1:]
|
| 1054 |
+
last_kf_local_idx = window_scale - 1
|
| 1055 |
+
del scale_out
|
| 1056 |
+
|
| 1057 |
+
cursor += window_scale
|
| 1058 |
+
pbar.update(window_scale)
|
| 1059 |
+
|
| 1060 |
+
# ---------- Phase 2: stream until enough keyframes ----------
|
| 1061 |
+
target_kf = window_size - window_scale # keyframes to collect
|
| 1062 |
+
kf_count = 0
|
| 1063 |
+
|
| 1064 |
+
while cursor < S and kf_count < target_kf:
|
| 1065 |
+
frame_image = images[:, cursor:cursor + 1]
|
| 1066 |
+
|
| 1067 |
+
self._set_defer_eviction(True)
|
| 1068 |
+
frame_out = self.forward(
|
| 1069 |
+
frame_image,
|
| 1070 |
+
num_frame_for_scale=window_scale,
|
| 1071 |
+
num_frame_per_block=1,
|
| 1072 |
+
causal_inference=True,
|
| 1073 |
+
)
|
| 1074 |
+
self._set_defer_eviction(False)
|
| 1075 |
+
|
| 1076 |
+
# Compute flow
|
| 1077 |
+
cur_depth = frame_out.get("depth", None)
|
| 1078 |
+
if cur_depth is not None:
|
| 1079 |
+
H_pred, W_pred = cur_depth.shape[2], cur_depth.shape[3]
|
| 1080 |
+
flow_mag = _compute_flow_magnitude(
|
| 1081 |
+
frame_out["pose_enc"], last_kf_pose_enc,
|
| 1082 |
+
cur_depth, (H_pred, W_pred),
|
| 1083 |
+
)
|
| 1084 |
+
else:
|
| 1085 |
+
flow_mag = flow_threshold + 1.0
|
| 1086 |
+
|
| 1087 |
+
local_idx = window_scale + (cursor - window_start - window_scale)
|
| 1088 |
+
frames_since_kf = local_idx - last_kf_local_idx
|
| 1089 |
+
is_keyframe = (
|
| 1090 |
+
(kf_count == 0) # first streaming frame
|
| 1091 |
+
or (flow_mag > flow_threshold)
|
| 1092 |
+
or (frames_since_kf >= max_non_keyframe_gap)
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
if is_keyframe:
|
| 1096 |
+
self._execute_deferred_eviction()
|
| 1097 |
+
last_kf_pose_enc = frame_out["pose_enc"]
|
| 1098 |
+
last_kf_local_idx = local_idx
|
| 1099 |
+
kf_count += 1
|
| 1100 |
+
w_lists['frame_type'].append(1) # keyframe
|
| 1101 |
+
else:
|
| 1102 |
+
self._rollback_last_frame()
|
| 1103 |
+
w_lists['frame_type'].append(2) # non-keyframe
|
| 1104 |
+
|
| 1105 |
+
_collect_frame(frame_out, w_lists)
|
| 1106 |
+
del frame_out
|
| 1107 |
+
cursor += 1
|
| 1108 |
+
pbar.update(1)
|
| 1109 |
+
|
| 1110 |
+
all_window_predictions.append(_make_window_pred(w_lists))
|
| 1111 |
+
window_idx += 1
|
| 1112 |
+
|
| 1113 |
+
# Next window starts overlap_size frames back (= scale frames)
|
| 1114 |
+
if cursor < S:
|
| 1115 |
+
cursor = max(cursor - eff_overlap, window_start + window_scale)
|
| 1116 |
+
|
| 1117 |
+
pbar.close()
|
| 1118 |
+
|
| 1119 |
+
# ================================================================
|
| 1120 |
+
# Fixed-interval / default mode: precomputable windows
|
| 1121 |
+
# ================================================================
|
| 1122 |
+
else:
|
| 1123 |
+
# Compute actual frames per window
|
| 1124 |
+
phase2_kf = max(window_size - ws, 0)
|
| 1125 |
+
kf_int = max(keyframe_interval, 1)
|
| 1126 |
+
phase2_frames = phase2_kf * kf_int
|
| 1127 |
+
actual_window_frames = ws + phase2_frames
|
| 1128 |
+
|
| 1129 |
+
eff_window = min(actual_window_frames, S)
|
| 1130 |
+
step = max(eff_window - eff_overlap, 1)
|
| 1131 |
+
|
| 1132 |
+
# Build window list
|
| 1133 |
+
if eff_window >= S:
|
| 1134 |
+
windows = [(0, S)]
|
| 1135 |
+
else:
|
| 1136 |
+
windows = []
|
| 1137 |
+
for start_idx in range(0, S, step):
|
| 1138 |
+
end_idx = min(start_idx + eff_window, S)
|
| 1139 |
+
if end_idx - start_idx >= eff_overlap or end_idx == S:
|
| 1140 |
+
windows.append((start_idx, end_idx))
|
| 1141 |
+
if end_idx == S:
|
| 1142 |
+
break
|
| 1143 |
+
|
| 1144 |
+
all_window_predictions: List[Dict] = []
|
| 1145 |
+
for start, end in tqdm(windows, desc='Windowed inference'):
|
| 1146 |
+
window_images = images[:, start:end]
|
| 1147 |
+
window_len = end - start
|
| 1148 |
+
|
| 1149 |
+
# Fresh KV cache
|
| 1150 |
+
self.clean_kv_cache()
|
| 1151 |
+
|
| 1152 |
+
window_scale = min(ws, window_len)
|
| 1153 |
+
|
| 1154 |
+
# ---------- Phase 1: scale frames ----------
|
| 1155 |
+
scale_out = self.forward(
|
| 1156 |
+
window_images[:, :window_scale],
|
| 1157 |
+
num_frame_for_scale=window_scale,
|
| 1158 |
+
num_frame_per_block=window_scale,
|
| 1159 |
+
causal_inference=True,
|
| 1160 |
+
)
|
| 1161 |
+
w_lists = _new_lists()
|
| 1162 |
+
_collect_frame(scale_out, w_lists)
|
| 1163 |
+
w_lists['frame_type'].extend([0] * window_scale) # scale frames
|
| 1164 |
+
del scale_out
|
| 1165 |
+
|
| 1166 |
+
# ---------- Phase 2: stream remaining frames ----------
|
| 1167 |
+
for i in range(window_scale, window_len):
|
| 1168 |
+
is_keyframe = (
|
| 1169 |
+
kf_int <= 1
|
| 1170 |
+
or ((i - window_scale) % kf_int == 0)
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
if not is_keyframe:
|
| 1174 |
+
self._set_skip_append(True)
|
| 1175 |
+
|
| 1176 |
+
frame_out = self.forward(
|
| 1177 |
+
window_images[:, i:i + 1],
|
| 1178 |
+
num_frame_for_scale=window_scale,
|
| 1179 |
+
num_frame_per_block=1,
|
| 1180 |
+
causal_inference=True,
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
if not is_keyframe:
|
| 1184 |
+
self._set_skip_append(False)
|
| 1185 |
+
|
| 1186 |
+
_collect_frame(frame_out, w_lists)
|
| 1187 |
+
w_lists['frame_type'].append(1 if is_keyframe else 2)
|
| 1188 |
+
del frame_out
|
| 1189 |
+
|
| 1190 |
+
all_window_predictions.append(_make_window_pred(w_lists))
|
| 1191 |
+
|
| 1192 |
+
# Store for merge helpers
|
| 1193 |
+
self._last_window_size = eff_overlap # not used directly, but kept for compat
|
| 1194 |
+
self._last_overlap_size = eff_overlap
|
| 1195 |
+
|
| 1196 |
+
# Align and stitch windows
|
| 1197 |
+
predictions = self._align_and_stitch_windows(
|
| 1198 |
+
all_window_predictions, scale_mode=scale_mode
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
predictions["images"] = _to_out(images)
|
| 1202 |
+
|
| 1203 |
+
if self.pred_normalization:
|
| 1204 |
+
predictions = self._normalize_predictions(predictions)
|
| 1205 |
+
|
| 1206 |
+
return predictions
|
lingbot_map/utils/__init__.py
ADDED
|
File without changes
|
lingbot_map/utils/geometry.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.spatial.transform import Rotation as R
|
| 11 |
+
|
| 12 |
+
from scipy.spatial.transform import Rotation
|
| 13 |
+
try:
|
| 14 |
+
from lietorch import SE3, Sim3
|
| 15 |
+
except ImportError:
|
| 16 |
+
SE3 = Sim3 = None
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from lingbot_map.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
|
| 21 |
+
except ImportError:
|
| 22 |
+
apply_distortion = iterative_undistortion = single_undistortion = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def unproject_depth_map_to_point_map(
|
| 26 |
+
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
|
| 27 |
+
) -> np.ndarray:
|
| 28 |
+
"""
|
| 29 |
+
Unproject a batch of depth maps to 3D world coordinates.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
|
| 33 |
+
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
|
| 34 |
+
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
|
| 38 |
+
"""
|
| 39 |
+
if isinstance(depth_map, torch.Tensor):
|
| 40 |
+
depth_map = depth_map.cpu().numpy()
|
| 41 |
+
if isinstance(extrinsics_cam, torch.Tensor):
|
| 42 |
+
extrinsics_cam = extrinsics_cam.cpu().numpy()
|
| 43 |
+
if isinstance(intrinsics_cam, torch.Tensor):
|
| 44 |
+
intrinsics_cam = intrinsics_cam.cpu().numpy()
|
| 45 |
+
|
| 46 |
+
world_points_list = []
|
| 47 |
+
for frame_idx in range(depth_map.shape[0]):
|
| 48 |
+
cur_world_points, _, _ = depth_to_world_coords_points(
|
| 49 |
+
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
|
| 50 |
+
)
|
| 51 |
+
world_points_list.append(cur_world_points)
|
| 52 |
+
world_points_array = np.stack(world_points_list, axis=0)
|
| 53 |
+
|
| 54 |
+
return world_points_array
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def depth_to_world_coords_points(
|
| 58 |
+
depth_map: np.ndarray,
|
| 59 |
+
extrinsic: np.ndarray,
|
| 60 |
+
intrinsic: np.ndarray,
|
| 61 |
+
eps=1e-8,
|
| 62 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 63 |
+
"""
|
| 64 |
+
Convert a depth map to world coordinates.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
| 68 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
| 69 |
+
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
|
| 73 |
+
"""
|
| 74 |
+
if depth_map is None:
|
| 75 |
+
return None, None, None
|
| 76 |
+
|
| 77 |
+
# Valid depth mask
|
| 78 |
+
point_mask = depth_map > eps
|
| 79 |
+
|
| 80 |
+
# Convert depth map to camera coordinates
|
| 81 |
+
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
|
| 82 |
+
|
| 83 |
+
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
|
| 84 |
+
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
|
| 85 |
+
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
|
| 86 |
+
|
| 87 |
+
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
|
| 88 |
+
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
|
| 89 |
+
|
| 90 |
+
# Apply the rotation and translation to the camera coordinates
|
| 91 |
+
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
|
| 92 |
+
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
|
| 93 |
+
|
| 94 |
+
return world_coords_points, cam_coords_points, point_mask
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 98 |
+
"""
|
| 99 |
+
Convert a depth map to camera coordinates.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
| 103 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
|
| 107 |
+
"""
|
| 108 |
+
H, W = depth_map.shape
|
| 109 |
+
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
|
| 110 |
+
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
|
| 111 |
+
|
| 112 |
+
# Intrinsic parameters
|
| 113 |
+
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
|
| 114 |
+
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
|
| 115 |
+
|
| 116 |
+
# Generate grid of pixel coordinates
|
| 117 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 118 |
+
|
| 119 |
+
# Unproject to camera coordinates
|
| 120 |
+
x_cam = (u - cu) * depth_map / fu
|
| 121 |
+
y_cam = (v - cv) * depth_map / fv
|
| 122 |
+
z_cam = depth_map
|
| 123 |
+
|
| 124 |
+
# Stack to form camera coordinates
|
| 125 |
+
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 126 |
+
|
| 127 |
+
return cam_coords
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def closed_form_inverse_se3(se3, R=None, T=None):
|
| 131 |
+
"""
|
| 132 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
| 133 |
+
|
| 134 |
+
If `R` and `T` are provided, they must correspond to the rotation and translation
|
| 135 |
+
components of `se3`. Otherwise, they will be extracted from `se3`.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
| 139 |
+
R (optional): Nx3x3 array or tensor of rotation matrices.
|
| 140 |
+
T (optional): Nx3x1 array or tensor of translation vectors.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Inverted SE3 matrices with the same type and device as `se3`.
|
| 144 |
+
|
| 145 |
+
Shapes:
|
| 146 |
+
se3: (N, 4, 4)
|
| 147 |
+
R: (N, 3, 3)
|
| 148 |
+
T: (N, 3, 1)
|
| 149 |
+
"""
|
| 150 |
+
# Check if se3 is a numpy array or a torch tensor
|
| 151 |
+
is_numpy = isinstance(se3, np.ndarray)
|
| 152 |
+
|
| 153 |
+
# Validate shapes
|
| 154 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
| 155 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
| 156 |
+
|
| 157 |
+
# Extract R and T if not provided
|
| 158 |
+
if R is None:
|
| 159 |
+
R = se3[:, :3, :3] # (N,3,3)
|
| 160 |
+
if T is None:
|
| 161 |
+
T = se3[:, :3, 3:] # (N,3,1)
|
| 162 |
+
|
| 163 |
+
# Transpose R
|
| 164 |
+
if is_numpy:
|
| 165 |
+
# Compute the transpose of the rotation for NumPy
|
| 166 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
| 167 |
+
# -R^T t for NumPy
|
| 168 |
+
top_right = -np.matmul(R_transposed, T)
|
| 169 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
| 170 |
+
else:
|
| 171 |
+
R_transposed = R.transpose(1, 2) # (N,3,3)
|
| 172 |
+
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
| 173 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
| 174 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
| 175 |
+
|
| 176 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
| 177 |
+
inverted_matrix[:, :3, 3:] = top_right
|
| 178 |
+
|
| 179 |
+
return inverted_matrix
|
| 180 |
+
|
| 181 |
+
def closed_form_inverse_se3_general(se3, R=None, T=None):
|
| 182 |
+
"""
|
| 183 |
+
支持任意 batch 维度的 SE3 逆运算
|
| 184 |
+
se3: (..., 4, 4) 或 (..., 3, 4)
|
| 185 |
+
"""
|
| 186 |
+
batch_shape = se3.shape[:-2]
|
| 187 |
+
if R is None:
|
| 188 |
+
R = se3[..., :3, :3]
|
| 189 |
+
if T is None:
|
| 190 |
+
T = se3[..., :3, 3:]
|
| 191 |
+
R_transposed = R.transpose(-2, -1)
|
| 192 |
+
top_right = -R_transposed @ T
|
| 193 |
+
# 构造单位阵
|
| 194 |
+
eye = torch.eye(4, 4, dtype=R.dtype, device=R.device)
|
| 195 |
+
inverted_matrix = eye.expand(*batch_shape, 4, 4).clone()
|
| 196 |
+
inverted_matrix[..., :3, :3] = R_transposed
|
| 197 |
+
inverted_matrix[..., :3, 3:] = top_right
|
| 198 |
+
return inverted_matrix
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# TODO: this code can be further cleaned up
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
|
| 205 |
+
"""
|
| 206 |
+
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
|
| 207 |
+
Args:
|
| 208 |
+
world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
|
| 209 |
+
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
|
| 210 |
+
Returns:
|
| 211 |
+
"""
|
| 212 |
+
# TODO: merge this into project_world_points_to_cam
|
| 213 |
+
|
| 214 |
+
# device = world_points.device
|
| 215 |
+
# with torch.autocast(device_type=device.type, enabled=False):
|
| 216 |
+
ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
|
| 217 |
+
world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
|
| 218 |
+
|
| 219 |
+
# extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
|
| 220 |
+
extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
|
| 221 |
+
|
| 222 |
+
# world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
|
| 223 |
+
world_points_h_exp = world_points_h.unsqueeze(-1)
|
| 224 |
+
|
| 225 |
+
# Now perform the matrix multiplication
|
| 226 |
+
# (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
|
| 227 |
+
camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
|
| 228 |
+
|
| 229 |
+
return camera_points
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def project_world_points_to_cam(
|
| 234 |
+
world_points,
|
| 235 |
+
cam_extrinsics,
|
| 236 |
+
cam_intrinsics=None,
|
| 237 |
+
distortion_params=None,
|
| 238 |
+
default=0,
|
| 239 |
+
only_points_cam=False,
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
|
| 243 |
+
Args:
|
| 244 |
+
world_points (torch.Tensor): 3D points of shape Px3.
|
| 245 |
+
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
|
| 246 |
+
cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
|
| 247 |
+
distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
|
| 248 |
+
Returns:
|
| 249 |
+
torch.Tensor: Transformed 2D points of shape BxNx2.
|
| 250 |
+
"""
|
| 251 |
+
device = world_points.device
|
| 252 |
+
# with torch.autocast(device_type=device.type, dtype=torch.double):
|
| 253 |
+
with torch.autocast(device_type=device.type, enabled=False):
|
| 254 |
+
N = world_points.shape[0] # Number of points
|
| 255 |
+
B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
|
| 256 |
+
world_points_homogeneous = torch.cat(
|
| 257 |
+
[world_points, torch.ones_like(world_points[..., 0:1])], dim=1
|
| 258 |
+
) # Nx4
|
| 259 |
+
# Reshape for batch processing
|
| 260 |
+
world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
|
| 261 |
+
B, -1, -1
|
| 262 |
+
) # BxNx4
|
| 263 |
+
|
| 264 |
+
# Step 1: Apply extrinsic parameters
|
| 265 |
+
# Transform 3D points to camera coordinate system for all cameras
|
| 266 |
+
cam_points = torch.bmm(
|
| 267 |
+
cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if only_points_cam:
|
| 271 |
+
return None, cam_points
|
| 272 |
+
|
| 273 |
+
# Step 2: Apply intrinsic parameters and (optional) distortion
|
| 274 |
+
image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
|
| 275 |
+
|
| 276 |
+
return image_points, cam_points
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
|
| 281 |
+
"""
|
| 282 |
+
Applies intrinsic parameters and optional distortion to the given 3D points.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
|
| 286 |
+
cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
|
| 287 |
+
distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
| 288 |
+
default (float, optional): Default value to replace NaNs in the output.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
# Normalized device coordinates (NDC)
|
| 295 |
+
cam_points = cam_points / cam_points[:, 2:3, :]
|
| 296 |
+
ndc_xy = cam_points[:, :2, :]
|
| 297 |
+
|
| 298 |
+
# Apply distortion if distortion_params are provided
|
| 299 |
+
if distortion_params is not None:
|
| 300 |
+
x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
|
| 301 |
+
distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
|
| 302 |
+
else:
|
| 303 |
+
distorted_xy = ndc_xy
|
| 304 |
+
|
| 305 |
+
# Prepare cam_points for batch matrix multiplication
|
| 306 |
+
cam_coords_homo = torch.cat(
|
| 307 |
+
(distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
|
| 308 |
+
) # Bx3xN
|
| 309 |
+
# Apply intrinsic parameters using batch matrix multiplication
|
| 310 |
+
pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
|
| 311 |
+
|
| 312 |
+
# Extract x and y coordinates
|
| 313 |
+
pixel_coords = pixel_coords[:, :2, :] # Bx2xN
|
| 314 |
+
|
| 315 |
+
# Replace NaNs with default value
|
| 316 |
+
pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
|
| 317 |
+
|
| 318 |
+
return pixel_coords.transpose(1, 2) # BxNx2
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def cam_from_img(pred_tracks, intrinsics, extra_params=None):
|
| 324 |
+
"""
|
| 325 |
+
Normalize predicted tracks based on camera intrinsics.
|
| 326 |
+
Args:
|
| 327 |
+
intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
|
| 328 |
+
pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
|
| 329 |
+
extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
| 330 |
+
Returns:
|
| 331 |
+
torch.Tensor: Normalized tracks tensor.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
# We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
|
| 335 |
+
# otherwise we can use something like
|
| 336 |
+
# tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
|
| 337 |
+
|
| 338 |
+
principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
|
| 339 |
+
focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
|
| 340 |
+
tracks_normalized = (pred_tracks - principal_point) / focal_length
|
| 341 |
+
|
| 342 |
+
if extra_params is not None:
|
| 343 |
+
# Apply iterative undistortion
|
| 344 |
+
try:
|
| 345 |
+
tracks_normalized = iterative_undistortion(
|
| 346 |
+
extra_params, tracks_normalized
|
| 347 |
+
)
|
| 348 |
+
except:
|
| 349 |
+
tracks_normalized = single_undistortion(
|
| 350 |
+
extra_params, tracks_normalized
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return tracks_normalized
|
| 354 |
+
|
| 355 |
+
## Droid SLAM Part
|
| 356 |
+
|
| 357 |
+
MIN_DEPTH = 0.2
|
| 358 |
+
|
| 359 |
+
def extract_intrinsics(intrinsics):
|
| 360 |
+
return intrinsics[...,None,None,:].unbind(dim=-1)
|
| 361 |
+
|
| 362 |
+
def projective_transform(
|
| 363 |
+
poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False
|
| 364 |
+
):
|
| 365 |
+
"""map points from ii->jj"""
|
| 366 |
+
|
| 367 |
+
# inverse project (pinhole)
|
| 368 |
+
X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian)
|
| 369 |
+
|
| 370 |
+
# transform
|
| 371 |
+
Gij = poses[:, jj] * poses[:, ii].inv()
|
| 372 |
+
|
| 373 |
+
# Gij.data[:, ii == jj] = torch.as_tensor(
|
| 374 |
+
# [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda"
|
| 375 |
+
# )
|
| 376 |
+
X1, Ja = actp(Gij, X0, jacobian=jacobian)
|
| 377 |
+
|
| 378 |
+
# project (pinhole)
|
| 379 |
+
x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth)
|
| 380 |
+
|
| 381 |
+
# exclude points too close to camera
|
| 382 |
+
valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float()
|
| 383 |
+
valid = valid.unsqueeze(-1)
|
| 384 |
+
|
| 385 |
+
if jacobian:
|
| 386 |
+
# Ji transforms according to dual adjoint
|
| 387 |
+
Jj = torch.matmul(Jp, Ja)
|
| 388 |
+
Ji = -Gij[:, :, None, None, None].adjT(Jj)
|
| 389 |
+
|
| 390 |
+
Jz = Gij[:, :, None, None] * Jz
|
| 391 |
+
Jz = torch.matmul(Jp, Jz.unsqueeze(-1))
|
| 392 |
+
|
| 393 |
+
return x1, valid, (Ji, Jj, Jz)
|
| 394 |
+
|
| 395 |
+
return x1, valid
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def induced_flow(poses, disps, intrinsics, ii, jj):
|
| 399 |
+
"""optical flow induced by camera motion"""
|
| 400 |
+
|
| 401 |
+
ht, wd = disps.shape[2:]
|
| 402 |
+
y, x = torch.meshgrid(
|
| 403 |
+
torch.arange(ht, device=disps.device, dtype=torch.float),
|
| 404 |
+
torch.arange(wd, device=disps.device, dtype=torch.float),
|
| 405 |
+
indexing="ij",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
coords0 = torch.stack([x, y], dim=-1)
|
| 409 |
+
coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False)
|
| 410 |
+
|
| 411 |
+
return coords1[..., :2] - coords0, valid
|
| 412 |
+
|
| 413 |
+
def all_pairs_distance_matrix(poses, beta=2.5):
|
| 414 |
+
""" compute distance matrix between all pairs of poses """
|
| 415 |
+
poses = np.array(poses, dtype=np.float32)
|
| 416 |
+
poses[:,:3] *= beta # scale to balence rot + trans
|
| 417 |
+
poses = SE3(torch.from_numpy(poses))
|
| 418 |
+
|
| 419 |
+
r = (poses[:,None].inv() * poses[None,:]).log()
|
| 420 |
+
return r.norm(dim=-1).cpu().numpy()
|
| 421 |
+
|
| 422 |
+
def pose_matrix_to_quaternion(pose):
|
| 423 |
+
""" convert 4x4 pose matrix to (t, q) """
|
| 424 |
+
q = Rotation.from_matrix(pose[..., :3, :3]).as_quat()
|
| 425 |
+
return np.concatenate([pose[..., :3, 3], q], axis=-1)
|
| 426 |
+
|
| 427 |
+
def compute_distance_matrix_flow(poses, disps, intrinsics):
|
| 428 |
+
""" compute flow magnitude between all pairs of frames """
|
| 429 |
+
if not isinstance(poses, SE3):
|
| 430 |
+
poses = torch.from_numpy(poses).float().cuda()[None]
|
| 431 |
+
poses = SE3(poses).inv()
|
| 432 |
+
|
| 433 |
+
disps = torch.from_numpy(disps).float().cuda()[None]
|
| 434 |
+
intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
|
| 435 |
+
|
| 436 |
+
N = poses.shape[1]
|
| 437 |
+
|
| 438 |
+
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
|
| 439 |
+
ii = ii.reshape(-1).cuda()
|
| 440 |
+
jj = jj.reshape(-1).cuda()
|
| 441 |
+
|
| 442 |
+
MAX_FLOW = 100.0
|
| 443 |
+
matrix = np.zeros((N, N), dtype=np.float32)
|
| 444 |
+
|
| 445 |
+
s = 2048
|
| 446 |
+
for i in range(0, ii.shape[0], s):
|
| 447 |
+
flow1, val1 = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
| 448 |
+
flow2, val2 = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
|
| 449 |
+
|
| 450 |
+
flow = torch.stack([flow1, flow2], dim=2)
|
| 451 |
+
val = torch.stack([val1, val2], dim=2)
|
| 452 |
+
|
| 453 |
+
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
|
| 454 |
+
mag = mag.view(mag.shape[1], -1)
|
| 455 |
+
val = val.view(val.shape[1], -1)
|
| 456 |
+
|
| 457 |
+
mag = (mag * val).mean(-1) / val.mean(-1)
|
| 458 |
+
mag[val.mean(-1) < 0.7] = np.inf
|
| 459 |
+
|
| 460 |
+
i1 = ii[i:i+s].cpu().numpy()
|
| 461 |
+
j1 = jj[i:i+s].cpu().numpy()
|
| 462 |
+
matrix[i1, j1] = mag.cpu().numpy()
|
| 463 |
+
|
| 464 |
+
return matrix
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
|
| 468 |
+
""" compute flow magnitude between all pairs of frames """
|
| 469 |
+
# if not isinstance(poses, SE3):
|
| 470 |
+
# poses = torch.from_numpy(poses).float().cuda()[None]
|
| 471 |
+
# poses = SE3(poses).inv()
|
| 472 |
+
|
| 473 |
+
# disps = torch.from_numpy(disps).float().cuda()[None]
|
| 474 |
+
# intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
|
| 475 |
+
|
| 476 |
+
N = poses.shape[1]
|
| 477 |
+
|
| 478 |
+
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
|
| 479 |
+
ii = ii.reshape(-1)
|
| 480 |
+
jj = jj.reshape(-1)
|
| 481 |
+
|
| 482 |
+
MAX_FLOW = 128.0
|
| 483 |
+
matrix = np.zeros((N, N), dtype=np.float32)
|
| 484 |
+
|
| 485 |
+
s = 2048
|
| 486 |
+
for i in range(0, ii.shape[0], s):
|
| 487 |
+
flow1a, val1a = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
|
| 488 |
+
flow1b, val1b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
| 489 |
+
flow2a, val2a = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
|
| 490 |
+
flow2b, val2b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
| 491 |
+
|
| 492 |
+
flow1 = flow1a + beta * flow1b
|
| 493 |
+
val1 = val1a * val2b
|
| 494 |
+
|
| 495 |
+
flow2 = flow2a + beta * flow2b
|
| 496 |
+
val2 = val2a * val2b
|
| 497 |
+
|
| 498 |
+
flow = torch.stack([flow1, flow2], dim=2)
|
| 499 |
+
val = torch.stack([val1, val2], dim=2)
|
| 500 |
+
|
| 501 |
+
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
|
| 502 |
+
mag = mag.view(mag.shape[1], -1)
|
| 503 |
+
val = val.view(val.shape[1], -1)
|
| 504 |
+
|
| 505 |
+
mag = (mag * val).mean(-1) / val.mean(-1)
|
| 506 |
+
mag[val.mean(-1) < 0.8] = np.inf
|
| 507 |
+
|
| 508 |
+
i1 = ii[i:i+s].cpu().numpy()
|
| 509 |
+
j1 = jj[i:i+s].cpu().numpy()
|
| 510 |
+
matrix[i1, j1] = mag.cpu().numpy()
|
| 511 |
+
|
| 512 |
+
return matrix
|
| 513 |
+
|
| 514 |
+
def coords_grid(ht, wd, **kwargs):
|
| 515 |
+
y, x = torch.meshgrid(
|
| 516 |
+
torch.arange(ht, dtype=torch.float, **kwargs),
|
| 517 |
+
torch.arange(wd, dtype=torch.float, **kwargs),
|
| 518 |
+
indexing="ij",
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
return torch.stack([x, y], dim=-1)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def iproj(disps, intrinsics, jacobian=False):
|
| 525 |
+
"""pinhole camera inverse projection"""
|
| 526 |
+
ht, wd = disps.shape[2:]
|
| 527 |
+
fx, fy, cx, cy = extract_intrinsics(intrinsics)
|
| 528 |
+
|
| 529 |
+
y, x = torch.meshgrid(
|
| 530 |
+
torch.arange(ht, device=disps.device, dtype=torch.float),
|
| 531 |
+
torch.arange(wd, device=disps.device, dtype=torch.float),
|
| 532 |
+
indexing="ij",
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
i = torch.ones_like(disps)
|
| 536 |
+
X = (x - cx) / fx
|
| 537 |
+
Y = (y - cy) / fy
|
| 538 |
+
pts = torch.stack([X, Y, i, disps], dim=-1)
|
| 539 |
+
|
| 540 |
+
if jacobian:
|
| 541 |
+
J = torch.zeros_like(pts)
|
| 542 |
+
J[..., -1] = 1.0
|
| 543 |
+
return pts, J
|
| 544 |
+
|
| 545 |
+
return pts, None
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def proj(Xs, intrinsics, jacobian=False, return_depth=False):
|
| 549 |
+
"""pinhole camera projection"""
|
| 550 |
+
fx, fy, cx, cy = extract_intrinsics(intrinsics)
|
| 551 |
+
X, Y, Z, D = Xs.unbind(dim=-1)
|
| 552 |
+
|
| 553 |
+
Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z)
|
| 554 |
+
d = 1.0 / Z
|
| 555 |
+
|
| 556 |
+
x = fx * (X * d) + cx
|
| 557 |
+
y = fy * (Y * d) + cy
|
| 558 |
+
if return_depth:
|
| 559 |
+
coords = torch.stack([x, y, D * d], dim=-1)
|
| 560 |
+
else:
|
| 561 |
+
coords = torch.stack([x, y], dim=-1)
|
| 562 |
+
|
| 563 |
+
if jacobian:
|
| 564 |
+
B, N, H, W = d.shape
|
| 565 |
+
o = torch.zeros_like(d)
|
| 566 |
+
proj_jac = torch.stack(
|
| 567 |
+
[
|
| 568 |
+
fx * d,
|
| 569 |
+
o,
|
| 570 |
+
-fx * X * d * d,
|
| 571 |
+
o,
|
| 572 |
+
o,
|
| 573 |
+
fy * d,
|
| 574 |
+
-fy * Y * d * d,
|
| 575 |
+
o,
|
| 576 |
+
# o, o, -D*d*d, d,
|
| 577 |
+
],
|
| 578 |
+
dim=-1,
|
| 579 |
+
).view(B, N, H, W, 2, 4)
|
| 580 |
+
|
| 581 |
+
return coords, proj_jac
|
| 582 |
+
|
| 583 |
+
return coords, None
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def actp(Gij, X0, jacobian=False):
|
| 587 |
+
"""action on point cloud"""
|
| 588 |
+
X1 = Gij[:, :, None, None] * X0
|
| 589 |
+
|
| 590 |
+
if jacobian:
|
| 591 |
+
X, Y, Z, d = X1.unbind(dim=-1)
|
| 592 |
+
o = torch.zeros_like(d)
|
| 593 |
+
B, N, H, W = d.shape
|
| 594 |
+
|
| 595 |
+
if isinstance(Gij, SE3):
|
| 596 |
+
Ja = torch.stack(
|
| 597 |
+
[
|
| 598 |
+
d,
|
| 599 |
+
o,
|
| 600 |
+
o,
|
| 601 |
+
o,
|
| 602 |
+
Z,
|
| 603 |
+
-Y,
|
| 604 |
+
o,
|
| 605 |
+
d,
|
| 606 |
+
o,
|
| 607 |
+
-Z,
|
| 608 |
+
o,
|
| 609 |
+
X,
|
| 610 |
+
o,
|
| 611 |
+
o,
|
| 612 |
+
d,
|
| 613 |
+
Y,
|
| 614 |
+
-X,
|
| 615 |
+
o,
|
| 616 |
+
o,
|
| 617 |
+
o,
|
| 618 |
+
o,
|
| 619 |
+
o,
|
| 620 |
+
o,
|
| 621 |
+
o,
|
| 622 |
+
],
|
| 623 |
+
dim=-1,
|
| 624 |
+
).view(B, N, H, W, 4, 6)
|
| 625 |
+
|
| 626 |
+
elif isinstance(Gij, Sim3):
|
| 627 |
+
Ja = torch.stack(
|
| 628 |
+
[
|
| 629 |
+
d,
|
| 630 |
+
o,
|
| 631 |
+
o,
|
| 632 |
+
o,
|
| 633 |
+
Z,
|
| 634 |
+
-Y,
|
| 635 |
+
X,
|
| 636 |
+
o,
|
| 637 |
+
d,
|
| 638 |
+
o,
|
| 639 |
+
-Z,
|
| 640 |
+
o,
|
| 641 |
+
X,
|
| 642 |
+
Y,
|
| 643 |
+
o,
|
| 644 |
+
o,
|
| 645 |
+
d,
|
| 646 |
+
Y,
|
| 647 |
+
-X,
|
| 648 |
+
o,
|
| 649 |
+
Z,
|
| 650 |
+
o,
|
| 651 |
+
o,
|
| 652 |
+
o,
|
| 653 |
+
o,
|
| 654 |
+
o,
|
| 655 |
+
o,
|
| 656 |
+
o,
|
| 657 |
+
],
|
| 658 |
+
dim=-1,
|
| 659 |
+
).view(B, N, H, W, 4, 7)
|
| 660 |
+
|
| 661 |
+
return X1, Ja
|
| 662 |
+
|
| 663 |
+
return X1, None
|
| 664 |
+
|
| 665 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 666 |
+
"""
|
| 667 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 668 |
+
but with a zero subgradient where x is 0.
|
| 669 |
+
"""
|
| 670 |
+
ret = torch.zeros_like(x)
|
| 671 |
+
positive_mask = x > 0
|
| 672 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 673 |
+
return ret
|
| 674 |
+
|
| 675 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 676 |
+
"""
|
| 677 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 681 |
+
|
| 682 |
+
Returns:
|
| 683 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 684 |
+
"""
|
| 685 |
+
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
|
| 686 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 687 |
+
|
| 688 |
+
batch_dim = matrix.shape[:-2]
|
| 689 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 690 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
q_abs = _sqrt_positive_part(
|
| 694 |
+
torch.stack(
|
| 695 |
+
[
|
| 696 |
+
1.0 + m00 + m11 + m22,
|
| 697 |
+
1.0 + m00 - m11 - m22,
|
| 698 |
+
1.0 - m00 + m11 - m22,
|
| 699 |
+
1.0 - m00 - m11 + m22,
|
| 700 |
+
],
|
| 701 |
+
dim=-1,
|
| 702 |
+
)
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
quat_by_rijk = torch.stack(
|
| 706 |
+
[
|
| 707 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 708 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 709 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 710 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 711 |
+
],
|
| 712 |
+
dim=-2,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 716 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 717 |
+
|
| 718 |
+
out = quat_candidates[
|
| 719 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
| 720 |
+
].reshape(batch_dim + (4,))
|
| 721 |
+
return standardize_quaternion(out)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 725 |
+
"""
|
| 726 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 727 |
+
part is non negative.
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
quaternions: Quaternions with real part first,
|
| 731 |
+
as tensor of shape (..., 4).
|
| 732 |
+
|
| 733 |
+
Returns:
|
| 734 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 735 |
+
"""
|
| 736 |
+
quaternions = F.normalize(quaternions, p=2, dim=-1)
|
| 737 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 738 |
+
|
| 739 |
+
def umeyama(X, Y):
|
| 740 |
+
"""
|
| 741 |
+
Estimates the Sim(3) transformation between `X` and `Y` point sets.
|
| 742 |
+
|
| 743 |
+
Estimates c, R and t such as c * R @ X + t ~ Y.
|
| 744 |
+
|
| 745 |
+
Parameters
|
| 746 |
+
----------
|
| 747 |
+
X : numpy.array
|
| 748 |
+
(m, n) shaped numpy array. m is the dimension of the points,
|
| 749 |
+
n is the number of points in the point set.
|
| 750 |
+
Y : numpy.array
|
| 751 |
+
(m, n) shaped numpy array. Indexes should be consistent with `X`.
|
| 752 |
+
That is, Y[:, i] must be the point corresponding to X[:, i].
|
| 753 |
+
|
| 754 |
+
Returns
|
| 755 |
+
-------
|
| 756 |
+
c : float
|
| 757 |
+
Scale factor.
|
| 758 |
+
R : numpy.array
|
| 759 |
+
(3, 3) shaped rotation matrix.
|
| 760 |
+
t : numpy.array
|
| 761 |
+
(3, 1) shaped translation vector.
|
| 762 |
+
"""
|
| 763 |
+
mu_x = X.mean(axis=1).reshape(-1, 1)
|
| 764 |
+
mu_y = Y.mean(axis=1).reshape(-1, 1)
|
| 765 |
+
var_x = np.square(X - mu_x).sum(axis=0).mean()
|
| 766 |
+
cov_xy = ((Y - mu_y) @ (X - mu_x).T) / X.shape[1]
|
| 767 |
+
U, D, VH = np.linalg.svd(cov_xy)
|
| 768 |
+
S = np.eye(X.shape[0])
|
| 769 |
+
if np.linalg.det(U) * np.linalg.det(VH) < 0:
|
| 770 |
+
S[-1, -1] = -1
|
| 771 |
+
c = np.trace(np.diag(D) @ S) / var_x
|
| 772 |
+
R = U @ S @ VH
|
| 773 |
+
t = mu_y - c * R @ mu_x
|
| 774 |
+
return c, R, t
|
lingbot_map/utils/load_fn.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torchvision import transforms as TF
|
| 12 |
+
from tqdm.auto import tqdm
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_and_preprocess_images_square(image_path_list, target_size=1024):
|
| 17 |
+
"""
|
| 18 |
+
Load and preprocess images by center padding to square and resizing to target size.
|
| 19 |
+
Also returns the position information of original pixels after transformation.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
image_path_list (list): List of paths to image files
|
| 23 |
+
target_size (int, optional): Target size for both width and height. Defaults to 518.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
tuple: (
|
| 27 |
+
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
|
| 28 |
+
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
ValueError: If the input list is empty
|
| 33 |
+
"""
|
| 34 |
+
# Check for empty list
|
| 35 |
+
if len(image_path_list) == 0:
|
| 36 |
+
raise ValueError("At least 1 image is required")
|
| 37 |
+
|
| 38 |
+
images = []
|
| 39 |
+
original_coords = [] # Renamed from position_info to be more descriptive
|
| 40 |
+
to_tensor = TF.ToTensor()
|
| 41 |
+
|
| 42 |
+
for image_path in image_path_list:
|
| 43 |
+
# Open image
|
| 44 |
+
img = Image.open(image_path)
|
| 45 |
+
|
| 46 |
+
# If there's an alpha channel, blend onto white background
|
| 47 |
+
if img.mode == "RGBA":
|
| 48 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
| 49 |
+
img = Image.alpha_composite(background, img)
|
| 50 |
+
|
| 51 |
+
# Convert to RGB
|
| 52 |
+
img = img.convert("RGB")
|
| 53 |
+
|
| 54 |
+
# Get original dimensions
|
| 55 |
+
width, height = img.size
|
| 56 |
+
|
| 57 |
+
# Make the image square by padding the shorter dimension
|
| 58 |
+
max_dim = max(width, height)
|
| 59 |
+
|
| 60 |
+
# Calculate padding
|
| 61 |
+
left = (max_dim - width) // 2
|
| 62 |
+
top = (max_dim - height) // 2
|
| 63 |
+
|
| 64 |
+
# Calculate scale factor for resizing
|
| 65 |
+
scale = target_size / max_dim
|
| 66 |
+
|
| 67 |
+
# Calculate final coordinates of original image in target space
|
| 68 |
+
x1 = left * scale
|
| 69 |
+
y1 = top * scale
|
| 70 |
+
x2 = (left + width) * scale
|
| 71 |
+
y2 = (top + height) * scale
|
| 72 |
+
|
| 73 |
+
# Store original image coordinates and scale
|
| 74 |
+
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
|
| 75 |
+
|
| 76 |
+
# Create a new black square image and paste original
|
| 77 |
+
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
| 78 |
+
square_img.paste(img, (left, top))
|
| 79 |
+
|
| 80 |
+
# Resize to target size
|
| 81 |
+
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
|
| 82 |
+
|
| 83 |
+
# Convert to tensor
|
| 84 |
+
img_tensor = to_tensor(square_img)
|
| 85 |
+
images.append(img_tensor)
|
| 86 |
+
|
| 87 |
+
# Stack all images
|
| 88 |
+
images = torch.stack(images)
|
| 89 |
+
original_coords = torch.from_numpy(np.array(original_coords)).float()
|
| 90 |
+
|
| 91 |
+
# Add additional dimension if single image to ensure correct shape
|
| 92 |
+
if len(image_path_list) == 1:
|
| 93 |
+
if images.dim() == 3:
|
| 94 |
+
images = images.unsqueeze(0)
|
| 95 |
+
original_coords = original_coords.unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
return images, original_coords
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
|
| 101 |
+
"""
|
| 102 |
+
A quick start function to load and preprocess images for model input.
|
| 103 |
+
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
image_path_list (list): List of paths to image files
|
| 107 |
+
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
| 108 |
+
- "crop" (default): Sets width to 518px and center crops height if needed.
|
| 109 |
+
- "pad": Preserves all pixels by making the largest dimension 518px
|
| 110 |
+
and padding the smaller dimension to reach a square shape.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
ValueError: If the input list is empty or if mode is invalid
|
| 117 |
+
|
| 118 |
+
Notes:
|
| 119 |
+
- Images with different dimensions will be padded with white (value=1.0)
|
| 120 |
+
- A warning is printed when images have different shapes
|
| 121 |
+
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
| 122 |
+
and height is center-cropped if larger than 518px
|
| 123 |
+
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
| 124 |
+
and the smaller dimension is padded to reach a square shape (518x518)
|
| 125 |
+
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
| 126 |
+
"""
|
| 127 |
+
# Check for empty list
|
| 128 |
+
if len(image_path_list) == 0:
|
| 129 |
+
raise ValueError("At least 1 image is required")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# Validate mode
|
| 134 |
+
if mode not in ["crop", "pad"]:
|
| 135 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
| 136 |
+
|
| 137 |
+
target_size = image_size
|
| 138 |
+
to_tensor = TF.ToTensor()
|
| 139 |
+
|
| 140 |
+
def _load_one(idx_path):
|
| 141 |
+
i, image_path = idx_path
|
| 142 |
+
img = Image.open(image_path)
|
| 143 |
+
if img.mode == "RGBA":
|
| 144 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
| 145 |
+
img = Image.alpha_composite(background, img)
|
| 146 |
+
img = img.convert("RGB")
|
| 147 |
+
|
| 148 |
+
width, height = img.size
|
| 149 |
+
|
| 150 |
+
fx_val = fy_val = cx_val = cy_val = None
|
| 151 |
+
if fx is not None:
|
| 152 |
+
fx_val = fx[i] * width
|
| 153 |
+
fy_val = fy[i] * height
|
| 154 |
+
cx_val = cx[i] * width
|
| 155 |
+
cy_val = cy[i] * height
|
| 156 |
+
|
| 157 |
+
if mode == "pad":
|
| 158 |
+
if width >= height:
|
| 159 |
+
new_width = target_size
|
| 160 |
+
new_height = round(height * (new_width / width) / patch_size) * patch_size
|
| 161 |
+
else:
|
| 162 |
+
new_height = target_size
|
| 163 |
+
new_width = round(width * (new_height / height) / patch_size) * patch_size
|
| 164 |
+
else: # crop
|
| 165 |
+
new_width = target_size
|
| 166 |
+
new_height = round(height * (new_width / width) / patch_size) * patch_size
|
| 167 |
+
|
| 168 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
| 169 |
+
img = to_tensor(img)
|
| 170 |
+
|
| 171 |
+
if mode == "crop" and new_height > target_size:
|
| 172 |
+
start_y = (new_height - target_size) // 2
|
| 173 |
+
img = img[:, start_y : start_y + target_size, :]
|
| 174 |
+
|
| 175 |
+
if fx is not None:
|
| 176 |
+
fx_val = fx_val * new_width / width
|
| 177 |
+
fy_val = fy_val * new_height / height
|
| 178 |
+
cx_val = img.shape[2] / 2
|
| 179 |
+
cy_val = img.shape[1] / 2
|
| 180 |
+
|
| 181 |
+
if mode == "pad":
|
| 182 |
+
h_padding = target_size - img.shape[1]
|
| 183 |
+
w_padding = target_size - img.shape[2]
|
| 184 |
+
if h_padding > 0 or w_padding > 0:
|
| 185 |
+
pad_top = h_padding // 2
|
| 186 |
+
pad_bottom = h_padding - pad_top
|
| 187 |
+
pad_left = w_padding // 2
|
| 188 |
+
pad_right = w_padding - pad_left
|
| 189 |
+
img = torch.nn.functional.pad(
|
| 190 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return i, img, (fx_val, fy_val, cx_val, cy_val)
|
| 194 |
+
|
| 195 |
+
# Parallel load with progress bar
|
| 196 |
+
num_workers = min(16, len(image_path_list))
|
| 197 |
+
results = [None] * len(image_path_list)
|
| 198 |
+
with ThreadPoolExecutor(max_workers=num_workers) as pool:
|
| 199 |
+
futures = pool.map(_load_one, enumerate(image_path_list))
|
| 200 |
+
for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"):
|
| 201 |
+
results[i] = img
|
| 202 |
+
if fx is not None:
|
| 203 |
+
fx[i], fy[i], cx[i], cy[i] = calib
|
| 204 |
+
|
| 205 |
+
images = results
|
| 206 |
+
shapes = set((img.shape[1], img.shape[2]) for img in images)
|
| 207 |
+
|
| 208 |
+
# Check if we have different shapes
|
| 209 |
+
# In theory our model can also work well with different shapes
|
| 210 |
+
if len(shapes) > 1:
|
| 211 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
| 212 |
+
# Find maximum dimensions
|
| 213 |
+
max_height = max(shape[0] for shape in shapes)
|
| 214 |
+
max_width = max(shape[1] for shape in shapes)
|
| 215 |
+
|
| 216 |
+
# Pad images if necessary
|
| 217 |
+
padded_images = []
|
| 218 |
+
for img in images:
|
| 219 |
+
h_padding = max_height - img.shape[1]
|
| 220 |
+
w_padding = max_width - img.shape[2]
|
| 221 |
+
|
| 222 |
+
if h_padding > 0 or w_padding > 0:
|
| 223 |
+
pad_top = h_padding // 2
|
| 224 |
+
pad_bottom = h_padding - pad_top
|
| 225 |
+
pad_left = w_padding // 2
|
| 226 |
+
pad_right = w_padding - pad_left
|
| 227 |
+
|
| 228 |
+
img = torch.nn.functional.pad(
|
| 229 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
| 230 |
+
)
|
| 231 |
+
padded_images.append(img)
|
| 232 |
+
images = padded_images
|
| 233 |
+
|
| 234 |
+
images = torch.stack(images) # concatenate images
|
| 235 |
+
|
| 236 |
+
# Ensure correct shape when single image
|
| 237 |
+
if len(image_path_list) == 1:
|
| 238 |
+
# Verify shape is (1, C, H, W)
|
| 239 |
+
if images.dim() == 3:
|
| 240 |
+
images = images.unsqueeze(0)
|
| 241 |
+
if fx is not None:
|
| 242 |
+
return images, fx, fy, cx, cy
|
| 243 |
+
return images
|
lingbot_map/utils/pose_enc.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from .rotation import quat_to_mat, mat_to_quat
|
| 9 |
+
import os
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import gzip
|
| 13 |
+
import json
|
| 14 |
+
import random
|
| 15 |
+
import logging
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
from lingbot_map.utils.geometry import closed_form_inverse_se3, closed_form_inverse_se3_general
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def extri_intri_to_pose_encoding(
|
| 22 |
+
extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
|
| 23 |
+
):
|
| 24 |
+
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
|
| 25 |
+
|
| 26 |
+
This function transforms camera parameters into a unified pose encoding format,
|
| 27 |
+
which can be used for various downstream tasks like pose prediction or representation.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
|
| 31 |
+
where B is batch size and S is sequence length.
|
| 32 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
|
| 33 |
+
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
|
| 34 |
+
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
|
| 35 |
+
Defined in pixels, with format:
|
| 36 |
+
[[fx, 0, cx],
|
| 37 |
+
[0, fy, cy],
|
| 38 |
+
[0, 0, 1]]
|
| 39 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point
|
| 40 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
| 41 |
+
Required for computing field of view values. For example: (256, 512).
|
| 42 |
+
pose_encoding_type (str): Type of pose encoding to use. Currently only
|
| 43 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
|
| 47 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
| 48 |
+
- [:3] = absolute translation vector T (3D)
|
| 49 |
+
- [3:7] = rotation as quaternion quat (4D)
|
| 50 |
+
- [7:] = field of view (2D)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
# extrinsics: BxSx3x4
|
| 54 |
+
# intrinsics: BxSx3x3
|
| 55 |
+
|
| 56 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 57 |
+
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
| 58 |
+
T = extrinsics[:, :, :3, 3] # BxSx3
|
| 59 |
+
|
| 60 |
+
quat = mat_to_quat(R)
|
| 61 |
+
# Note the order of h and w here
|
| 62 |
+
H, W = image_size_hw
|
| 63 |
+
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
| 64 |
+
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
| 65 |
+
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
| 66 |
+
else:
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
return pose_encoding
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def pose_encoding_to_extri_intri(
|
| 73 |
+
pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
|
| 74 |
+
):
|
| 75 |
+
"""Convert a pose encoding back to camera extrinsics and intrinsics.
|
| 76 |
+
|
| 77 |
+
This function performs the inverse operation of extri_intri_to_pose_encoding,
|
| 78 |
+
reconstructing the full camera parameters from the compact encoding.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
|
| 82 |
+
where B is batch size and S is sequence length.
|
| 83 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
| 84 |
+
- [:3] = absolute translation vector T (3D)
|
| 85 |
+
- [3:7] = rotation as quaternion quat (4D)
|
| 86 |
+
- [7:] = field of view (2D)
|
| 87 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
| 88 |
+
Required for reconstructing intrinsics from field of view values.
|
| 89 |
+
For example: (256, 512).
|
| 90 |
+
pose_encoding_type (str): Type of pose encoding used. Currently only
|
| 91 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
| 92 |
+
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
|
| 93 |
+
If False, only extrinsics are returned and intrinsics will be None.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
tuple: (extrinsics, intrinsics)
|
| 97 |
+
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
|
| 98 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
|
| 99 |
+
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
|
| 100 |
+
a 3x1 translation vector.
|
| 101 |
+
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
|
| 102 |
+
or None if build_intrinsics is False. Defined in pixels, with format:
|
| 103 |
+
[[fx, 0, cx],
|
| 104 |
+
[0, fy, cy],
|
| 105 |
+
[0, 0, 1]]
|
| 106 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point,
|
| 107 |
+
assumed to be at the center of the image (W/2, H/2).
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
intrinsics = None
|
| 111 |
+
|
| 112 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 113 |
+
T = pose_encoding[..., :3]
|
| 114 |
+
quat = pose_encoding[..., 3:7]
|
| 115 |
+
fov_h = pose_encoding[..., 7]
|
| 116 |
+
fov_w = pose_encoding[..., 8]
|
| 117 |
+
|
| 118 |
+
R = quat_to_mat(quat)
|
| 119 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
| 120 |
+
|
| 121 |
+
if build_intrinsics:
|
| 122 |
+
H, W = image_size_hw
|
| 123 |
+
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
|
| 124 |
+
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
|
| 125 |
+
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
| 126 |
+
intrinsics[..., 0, 0] = fx
|
| 127 |
+
intrinsics[..., 1, 1] = fy
|
| 128 |
+
intrinsics[..., 0, 2] = W / 2
|
| 129 |
+
intrinsics[..., 1, 2] = H / 2
|
| 130 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
| 131 |
+
elif pose_encoding_type == "absT_quaR":
|
| 132 |
+
T = pose_encoding[..., :3]
|
| 133 |
+
quat = pose_encoding[..., 3:7]
|
| 134 |
+
|
| 135 |
+
R = quat_to_mat(quat)
|
| 136 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
| 137 |
+
|
| 138 |
+
intrinsics = None
|
| 139 |
+
|
| 140 |
+
return extrinsics, intrinsics
|
| 141 |
+
|
| 142 |
+
def convert_pt3d_RT_to_opencv(Rot, Trans):
|
| 143 |
+
"""
|
| 144 |
+
Convert Point3D extrinsic matrices to OpenCV convention.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
Rot: 3D rotation matrix in Point3D format
|
| 148 |
+
Trans: 3D translation vector in Point3D format
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
extri_opencv: 3x4 extrinsic matrix in OpenCV format
|
| 152 |
+
"""
|
| 153 |
+
rot_pt3d = np.array(Rot)
|
| 154 |
+
trans_pt3d = np.array(Trans)
|
| 155 |
+
|
| 156 |
+
trans_pt3d[:2] *= -1
|
| 157 |
+
rot_pt3d[:, :2] *= -1
|
| 158 |
+
rot_pt3d = rot_pt3d.transpose(1, 0)
|
| 159 |
+
extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
|
| 160 |
+
return extri_opencv
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_pair_index(N, B=1):
|
| 164 |
+
"""
|
| 165 |
+
Build indices for all possible pairs of frames.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
N: Number of frames
|
| 169 |
+
B: Batch size
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
i1, i2: Indices for all possible pairs
|
| 173 |
+
"""
|
| 174 |
+
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
|
| 175 |
+
i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
|
| 176 |
+
return i1, i2
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
|
| 180 |
+
"""
|
| 181 |
+
Calculate rotation angle error between ground truth and predicted rotations.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
rot_gt: Ground truth rotation matrices
|
| 185 |
+
rot_pred: Predicted rotation matrices
|
| 186 |
+
batch_size: Batch size for reshaping the result
|
| 187 |
+
eps: Small value to avoid numerical issues
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Rotation angle error in degrees
|
| 191 |
+
"""
|
| 192 |
+
q_pred = mat_to_quat(rot_pred)
|
| 193 |
+
q_gt = mat_to_quat(rot_gt)
|
| 194 |
+
|
| 195 |
+
loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
|
| 196 |
+
err_q = torch.arccos(1 - 2 * loss_q)
|
| 197 |
+
|
| 198 |
+
rel_rangle_deg = err_q * 180 / np.pi
|
| 199 |
+
|
| 200 |
+
if batch_size is not None:
|
| 201 |
+
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
|
| 202 |
+
|
| 203 |
+
return rel_rangle_deg
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
|
| 207 |
+
"""
|
| 208 |
+
Calculate translation angle error between ground truth and predicted translations.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
tvec_gt: Ground truth translation vectors
|
| 212 |
+
tvec_pred: Predicted translation vectors
|
| 213 |
+
batch_size: Batch size for reshaping the result
|
| 214 |
+
ambiguity: Whether to handle direction ambiguity
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Translation angle error in degrees
|
| 218 |
+
"""
|
| 219 |
+
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
|
| 220 |
+
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
|
| 221 |
+
|
| 222 |
+
if ambiguity:
|
| 223 |
+
rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
|
| 224 |
+
|
| 225 |
+
if batch_size is not None:
|
| 226 |
+
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
|
| 227 |
+
|
| 228 |
+
return rel_tangle_deg
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
|
| 232 |
+
"""
|
| 233 |
+
Normalize the translation vectors and compute the angle between them.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
t_gt: Ground truth translation vectors
|
| 237 |
+
t: Predicted translation vectors
|
| 238 |
+
eps: Small value to avoid division by zero
|
| 239 |
+
default_err: Default error value for invalid cases
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Angular error between translation vectors in radians
|
| 243 |
+
"""
|
| 244 |
+
t_norm = torch.norm(t, dim=1, keepdim=True)
|
| 245 |
+
t = t / (t_norm + eps)
|
| 246 |
+
|
| 247 |
+
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
|
| 248 |
+
t_gt = t_gt / (t_gt_norm + eps)
|
| 249 |
+
|
| 250 |
+
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
|
| 251 |
+
err_t = torch.acos(torch.sqrt(1 - loss_t))
|
| 252 |
+
|
| 253 |
+
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
|
| 254 |
+
return err_t
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def calculate_auc_np(r_error, t_error, max_threshold=30):
|
| 258 |
+
"""
|
| 259 |
+
Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
r_error: numpy array representing R error values (Degree)
|
| 263 |
+
t_error: numpy array representing T error values (Degree)
|
| 264 |
+
max_threshold: Maximum threshold value for binning the histogram
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
AUC value and the normalized histogram
|
| 268 |
+
"""
|
| 269 |
+
error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
|
| 270 |
+
max_errors = np.max(error_matrix, axis=1)
|
| 271 |
+
bins = np.arange(max_threshold + 1)
|
| 272 |
+
histogram, _ = np.histogram(max_errors, bins=bins)
|
| 273 |
+
num_pairs = float(len(max_errors))
|
| 274 |
+
normalized_histogram = histogram.astype(float) / num_pairs
|
| 275 |
+
return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
|
| 279 |
+
"""
|
| 280 |
+
Compute rotation and translation errors between predicted and ground truth poses.
|
| 281 |
+
This function assumes the input poses are world-to-camera (w2c) transformations.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
pred_se3: Predicted SE(3) transformations (w2c), shape (N, 4, 4)
|
| 285 |
+
gt_se3: Ground truth SE(3) transformations (w2c), shape (N, 4, 4)
|
| 286 |
+
num_frames: Number of frames (N)
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Rotation and translation angle errors in degrees
|
| 290 |
+
"""
|
| 291 |
+
pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
|
| 292 |
+
|
| 293 |
+
relative_pose_gt = gt_se3[pair_idx_i1].bmm(
|
| 294 |
+
closed_form_inverse_se3(gt_se3[pair_idx_i2])
|
| 295 |
+
)
|
| 296 |
+
relative_pose_pred = pred_se3[pair_idx_i1].bmm(
|
| 297 |
+
closed_form_inverse_se3(pred_se3[pair_idx_i2])
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
rel_rangle_deg = rotation_angle(
|
| 301 |
+
relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
|
| 302 |
+
)
|
| 303 |
+
rel_tangle_deg = translation_angle(
|
| 304 |
+
relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return rel_rangle_deg, rel_tangle_deg
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def colmap_to_opencv_intrinsics(K):
|
| 311 |
+
"""
|
| 312 |
+
Modify camera intrinsics to follow a different convention.
|
| 313 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 314 |
+
- (0.5, 0.5) in Colmap
|
| 315 |
+
- (0,0) in OpenCV
|
| 316 |
+
"""
|
| 317 |
+
K = K.copy()
|
| 318 |
+
K[..., 0, 2] -= 0.5
|
| 319 |
+
K[..., 1, 2] -= 0.5
|
| 320 |
+
return K
|
| 321 |
+
|
| 322 |
+
def read_camera_parameters(filename):
|
| 323 |
+
with open(filename) as f:
|
| 324 |
+
lines = f.readlines()
|
| 325 |
+
lines = [line.rstrip() for line in lines]
|
| 326 |
+
# extrinsics: line [1,5), 4x4 matrix
|
| 327 |
+
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
| 328 |
+
# intrinsics: line [7-10), 3x3 matrix
|
| 329 |
+
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
| 330 |
+
|
| 331 |
+
return intrinsics, extrinsics
|
lingbot_map/utils/rotation.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 17 |
+
|
| 18 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 19 |
+
Args:
|
| 20 |
+
quaternions: quaternions with real part last,
|
| 21 |
+
as tensor of shape (..., 4).
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 25 |
+
"""
|
| 26 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
| 27 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 28 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 29 |
+
|
| 30 |
+
o = torch.stack(
|
| 31 |
+
(
|
| 32 |
+
1 - two_s * (j * j + k * k),
|
| 33 |
+
two_s * (i * j - k * r),
|
| 34 |
+
two_s * (i * k + j * r),
|
| 35 |
+
two_s * (i * j + k * r),
|
| 36 |
+
1 - two_s * (i * i + k * k),
|
| 37 |
+
two_s * (j * k - i * r),
|
| 38 |
+
two_s * (i * k - j * r),
|
| 39 |
+
two_s * (j * k + i * r),
|
| 40 |
+
1 - two_s * (i * i + j * j),
|
| 41 |
+
),
|
| 42 |
+
-1,
|
| 43 |
+
)
|
| 44 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 56 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 57 |
+
"""
|
| 58 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 59 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 60 |
+
|
| 61 |
+
batch_dim = matrix.shape[:-2]
|
| 62 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 63 |
+
|
| 64 |
+
q_abs = _sqrt_positive_part(
|
| 65 |
+
torch.stack(
|
| 66 |
+
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 71 |
+
quat_by_rijk = torch.stack(
|
| 72 |
+
[
|
| 73 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 74 |
+
# `int`.
|
| 75 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 76 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 77 |
+
# `int`.
|
| 78 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 79 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 80 |
+
# `int`.
|
| 81 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 82 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 83 |
+
# `int`.
|
| 84 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 85 |
+
],
|
| 86 |
+
dim=-2,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 90 |
+
# the candidate won't be picked.
|
| 91 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 92 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 93 |
+
|
| 94 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 95 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 96 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
| 97 |
+
|
| 98 |
+
# Convert from rijk to ijkr
|
| 99 |
+
out = out[..., [1, 2, 3, 0]]
|
| 100 |
+
|
| 101 |
+
out = standardize_quaternion(out)
|
| 102 |
+
|
| 103 |
+
return out
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 109 |
+
but with a zero subgradient where x is 0.
|
| 110 |
+
"""
|
| 111 |
+
ret = torch.zeros_like(x)
|
| 112 |
+
positive_mask = x > 0
|
| 113 |
+
if torch.is_grad_enabled():
|
| 114 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 115 |
+
else:
|
| 116 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 117 |
+
return ret
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
"""
|
| 122 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 123 |
+
part is non negative.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
quaternions: Quaternions with real part last,
|
| 127 |
+
as tensor of shape (..., 4).
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 131 |
+
"""
|
| 132 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
lingbot_map/vis/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
GCT Visualization Module
|
| 9 |
+
|
| 10 |
+
This module provides visualization utilities for 3D reconstruction results:
|
| 11 |
+
- PointCloudViewer: Interactive point cloud viewer with camera visualization
|
| 12 |
+
- viser_wrapper: Quick visualization wrapper for predictions
|
| 13 |
+
- predictions_to_glb: Export predictions to GLB 3D format
|
| 14 |
+
- Colorization and utility functions
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb
|
| 18 |
+
|
| 19 |
+
# Interactive visualization
|
| 20 |
+
viewer = PointCloudViewer(pred_dict=predictions, port=8080)
|
| 21 |
+
viewer.run()
|
| 22 |
+
|
| 23 |
+
# Quick visualization
|
| 24 |
+
viser_wrapper(predictions, port=8080)
|
| 25 |
+
|
| 26 |
+
# Export to GLB
|
| 27 |
+
scene = predictions_to_glb(predictions)
|
| 28 |
+
scene.export("output.glb")
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from lingbot_map.vis.point_cloud_viewer import PointCloudViewer
|
| 32 |
+
from lingbot_map.vis.viser_wrapper import viser_wrapper
|
| 33 |
+
from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar
|
| 34 |
+
from lingbot_map.vis.sky_segmentation import (
|
| 35 |
+
apply_sky_segmentation,
|
| 36 |
+
download_skyseg_model,
|
| 37 |
+
load_or_create_sky_masks,
|
| 38 |
+
segment_sky,
|
| 39 |
+
)
|
| 40 |
+
from lingbot_map.vis.glb_export import predictions_to_glb
|
| 41 |
+
|
| 42 |
+
__all__ = [
|
| 43 |
+
# Main viewer
|
| 44 |
+
"PointCloudViewer",
|
| 45 |
+
# Quick visualization
|
| 46 |
+
"viser_wrapper",
|
| 47 |
+
# GLB export
|
| 48 |
+
"predictions_to_glb",
|
| 49 |
+
# Utilities
|
| 50 |
+
"CameraState",
|
| 51 |
+
"colorize",
|
| 52 |
+
"colorize_np",
|
| 53 |
+
"get_vertical_colorbar",
|
| 54 |
+
# Sky segmentation
|
| 55 |
+
"apply_sky_segmentation",
|
| 56 |
+
"segment_sky",
|
| 57 |
+
"download_skyseg_model",
|
| 58 |
+
"load_or_create_sky_masks",
|
| 59 |
+
]
|
lingbot_map/vis/glb_export.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
GLB 3D export utilities for GCT predictions.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import copy
|
| 13 |
+
from typing import Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import cv2
|
| 17 |
+
import matplotlib
|
| 18 |
+
from scipy.spatial.transform import Rotation
|
| 19 |
+
|
| 20 |
+
from lingbot_map.vis.sky_segmentation import (
|
| 21 |
+
_SKYSEG_INPUT_SIZE,
|
| 22 |
+
_SKYSEG_SOFT_THRESHOLD,
|
| 23 |
+
_mask_to_float,
|
| 24 |
+
_mask_to_uint8,
|
| 25 |
+
_result_map_to_non_sky_conf,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import trimesh
|
| 30 |
+
except ImportError:
|
| 31 |
+
trimesh = None
|
| 32 |
+
print("trimesh not found. GLB export will not work.")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def predictions_to_glb(
|
| 36 |
+
predictions: dict,
|
| 37 |
+
conf_thres: float = 50.0,
|
| 38 |
+
filter_by_frames: str = "all",
|
| 39 |
+
mask_black_bg: bool = False,
|
| 40 |
+
mask_white_bg: bool = False,
|
| 41 |
+
show_cam: bool = True,
|
| 42 |
+
mask_sky: bool = False,
|
| 43 |
+
target_dir: Optional[str] = None,
|
| 44 |
+
prediction_mode: str = "Predicted Pointmap",
|
| 45 |
+
) -> "trimesh.Scene":
|
| 46 |
+
"""
|
| 47 |
+
Converts GCT predictions to a 3D scene represented as a GLB file.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
predictions: Dictionary containing model predictions with keys:
|
| 51 |
+
- world_points: 3D point coordinates (S, H, W, 3)
|
| 52 |
+
- world_points_conf: Confidence scores (S, H, W)
|
| 53 |
+
- images: Input images (S, H, W, 3) or (S, 3, H, W)
|
| 54 |
+
- extrinsic: Camera extrinsic matrices (S, 3, 4)
|
| 55 |
+
conf_thres: Percentage of low-confidence points to filter out
|
| 56 |
+
filter_by_frames: Frame filter specification ("all" or frame index)
|
| 57 |
+
mask_black_bg: Mask out black background pixels
|
| 58 |
+
mask_white_bg: Mask out white background pixels
|
| 59 |
+
show_cam: Include camera visualization
|
| 60 |
+
mask_sky: Apply sky segmentation mask
|
| 61 |
+
target_dir: Output directory for intermediate files
|
| 62 |
+
prediction_mode: "Predicted Pointmap" or "Predicted Depthmap"
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
trimesh.Scene: Processed 3D scene containing point cloud and cameras
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
ValueError: If input predictions structure is invalid
|
| 69 |
+
ImportError: If trimesh is not available
|
| 70 |
+
"""
|
| 71 |
+
if trimesh is None:
|
| 72 |
+
raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh")
|
| 73 |
+
|
| 74 |
+
if not isinstance(predictions, dict):
|
| 75 |
+
raise ValueError("predictions must be a dictionary")
|
| 76 |
+
|
| 77 |
+
if conf_thres is None:
|
| 78 |
+
conf_thres = 10.0
|
| 79 |
+
|
| 80 |
+
print("Building GLB scene")
|
| 81 |
+
|
| 82 |
+
# Parse frame filter
|
| 83 |
+
selected_frame_idx = None
|
| 84 |
+
if filter_by_frames != "all" and filter_by_frames != "All":
|
| 85 |
+
try:
|
| 86 |
+
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
| 87 |
+
except (ValueError, IndexError):
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
# Select prediction source
|
| 91 |
+
if "Pointmap" in prediction_mode:
|
| 92 |
+
print("Using Pointmap Branch")
|
| 93 |
+
if "world_points" in predictions:
|
| 94 |
+
pred_world_points = predictions["world_points"]
|
| 95 |
+
pred_world_points_conf = predictions.get(
|
| 96 |
+
"world_points_conf", np.ones_like(pred_world_points[..., 0])
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
print("Warning: world_points not found, falling back to depth-based points")
|
| 100 |
+
pred_world_points = predictions["world_points_from_depth"]
|
| 101 |
+
pred_world_points_conf = predictions.get(
|
| 102 |
+
"depth_conf", np.ones_like(pred_world_points[..., 0])
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
print("Using Depthmap and Camera Branch")
|
| 106 |
+
pred_world_points = predictions["world_points_from_depth"]
|
| 107 |
+
pred_world_points_conf = predictions.get(
|
| 108 |
+
"depth_conf", np.ones_like(pred_world_points[..., 0])
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
images = predictions["images"]
|
| 112 |
+
camera_matrices = predictions["extrinsic"]
|
| 113 |
+
|
| 114 |
+
# Apply sky segmentation if enabled
|
| 115 |
+
if mask_sky and target_dir is not None:
|
| 116 |
+
pred_world_points_conf = _apply_sky_mask(
|
| 117 |
+
pred_world_points_conf, target_dir, images
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Apply frame filter
|
| 121 |
+
if selected_frame_idx is not None:
|
| 122 |
+
pred_world_points = pred_world_points[selected_frame_idx][None]
|
| 123 |
+
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
|
| 124 |
+
images = images[selected_frame_idx][None]
|
| 125 |
+
camera_matrices = camera_matrices[selected_frame_idx][None]
|
| 126 |
+
|
| 127 |
+
# Prepare vertices and colors
|
| 128 |
+
vertices_3d = pred_world_points.reshape(-1, 3)
|
| 129 |
+
|
| 130 |
+
# Handle different image formats
|
| 131 |
+
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
|
| 132 |
+
colors_rgb = np.transpose(images, (0, 2, 3, 1))
|
| 133 |
+
else:
|
| 134 |
+
colors_rgb = images
|
| 135 |
+
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
| 136 |
+
|
| 137 |
+
# Apply confidence filtering
|
| 138 |
+
conf = pred_world_points_conf.reshape(-1)
|
| 139 |
+
conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0
|
| 140 |
+
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
|
| 141 |
+
|
| 142 |
+
# Apply background masking
|
| 143 |
+
if mask_black_bg:
|
| 144 |
+
black_bg_mask = colors_rgb.sum(axis=1) >= 16
|
| 145 |
+
conf_mask = conf_mask & black_bg_mask
|
| 146 |
+
|
| 147 |
+
if mask_white_bg:
|
| 148 |
+
white_bg_mask = ~(
|
| 149 |
+
(colors_rgb[:, 0] > 240) &
|
| 150 |
+
(colors_rgb[:, 1] > 240) &
|
| 151 |
+
(colors_rgb[:, 2] > 240)
|
| 152 |
+
)
|
| 153 |
+
conf_mask = conf_mask & white_bg_mask
|
| 154 |
+
|
| 155 |
+
vertices_3d = vertices_3d[conf_mask]
|
| 156 |
+
colors_rgb = colors_rgb[conf_mask]
|
| 157 |
+
|
| 158 |
+
# Handle empty point cloud
|
| 159 |
+
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
|
| 160 |
+
vertices_3d = np.array([[1, 0, 0]])
|
| 161 |
+
colors_rgb = np.array([[255, 255, 255]])
|
| 162 |
+
scene_scale = 1
|
| 163 |
+
else:
|
| 164 |
+
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
| 165 |
+
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
| 166 |
+
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
| 167 |
+
|
| 168 |
+
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
| 169 |
+
|
| 170 |
+
# Build scene
|
| 171 |
+
scene_3d = trimesh.Scene()
|
| 172 |
+
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
|
| 173 |
+
scene_3d.add_geometry(point_cloud_data)
|
| 174 |
+
|
| 175 |
+
# Prepare camera matrices
|
| 176 |
+
num_cameras = len(camera_matrices)
|
| 177 |
+
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
| 178 |
+
extrinsics_matrices[:, :3, :4] = camera_matrices
|
| 179 |
+
extrinsics_matrices[:, 3, 3] = 1
|
| 180 |
+
|
| 181 |
+
# Add cameras
|
| 182 |
+
if show_cam:
|
| 183 |
+
for i in range(num_cameras):
|
| 184 |
+
world_to_camera = extrinsics_matrices[i]
|
| 185 |
+
camera_to_world = np.linalg.inv(world_to_camera)
|
| 186 |
+
rgba_color = colormap(i / num_cameras)
|
| 187 |
+
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
| 188 |
+
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
|
| 189 |
+
|
| 190 |
+
# Align scene
|
| 191 |
+
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
| 192 |
+
|
| 193 |
+
print("GLB Scene built")
|
| 194 |
+
return scene_3d
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _apply_sky_mask(
|
| 198 |
+
conf: np.ndarray,
|
| 199 |
+
target_dir: str,
|
| 200 |
+
images: np.ndarray
|
| 201 |
+
) -> np.ndarray:
|
| 202 |
+
"""Apply sky segmentation mask to confidence scores."""
|
| 203 |
+
try:
|
| 204 |
+
import onnxruntime
|
| 205 |
+
except ImportError:
|
| 206 |
+
print("Warning: onnxruntime not available, skipping sky masking")
|
| 207 |
+
return conf
|
| 208 |
+
|
| 209 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 210 |
+
if not os.path.exists(target_dir_images):
|
| 211 |
+
print(f"Warning: Images directory not found at {target_dir_images}")
|
| 212 |
+
return conf
|
| 213 |
+
|
| 214 |
+
image_list = sorted(os.listdir(target_dir_images))
|
| 215 |
+
S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2])
|
| 216 |
+
|
| 217 |
+
skyseg_model_path = "skyseg.onnx"
|
| 218 |
+
if not os.path.exists(skyseg_model_path):
|
| 219 |
+
print("Downloading skyseg.onnx...")
|
| 220 |
+
download_file_from_url(
|
| 221 |
+
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
|
| 222 |
+
skyseg_model_path
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
| 226 |
+
sky_mask_list = []
|
| 227 |
+
|
| 228 |
+
for i, image_name in enumerate(image_list[:S]):
|
| 229 |
+
image_filepath = os.path.join(target_dir_images, image_name)
|
| 230 |
+
mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
|
| 231 |
+
|
| 232 |
+
if os.path.exists(mask_filepath):
|
| 233 |
+
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
| 234 |
+
else:
|
| 235 |
+
sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
|
| 236 |
+
|
| 237 |
+
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
|
| 238 |
+
sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 239 |
+
|
| 240 |
+
sky_mask_list.append(_mask_to_float(sky_mask))
|
| 241 |
+
|
| 242 |
+
sky_mask_array = np.array(sky_mask_list)
|
| 243 |
+
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
| 244 |
+
return conf * sky_mask_binary
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def integrate_camera_into_scene(
|
| 248 |
+
scene: "trimesh.Scene",
|
| 249 |
+
transform: np.ndarray,
|
| 250 |
+
face_colors: Tuple[int, int, int],
|
| 251 |
+
scene_scale: float,
|
| 252 |
+
frustum_thickness: float = 1.0,
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Integrates a camera mesh into the 3D scene.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
scene: The 3D scene to add the camera model
|
| 259 |
+
transform: Transformation matrix for camera positioning
|
| 260 |
+
face_colors: RGB color tuple for the camera
|
| 261 |
+
scene_scale: Scale of the scene
|
| 262 |
+
frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker)
|
| 263 |
+
"""
|
| 264 |
+
cam_width = scene_scale * 0.05
|
| 265 |
+
cam_height = scene_scale * 0.1
|
| 266 |
+
|
| 267 |
+
rot_45_degree = np.eye(4)
|
| 268 |
+
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
| 269 |
+
rot_45_degree[2, 3] = -cam_height
|
| 270 |
+
|
| 271 |
+
opengl_transform = get_opengl_conversion_matrix()
|
| 272 |
+
complete_transform = transform @ opengl_transform @ rot_45_degree
|
| 273 |
+
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
| 274 |
+
|
| 275 |
+
# Build thicker frustum by stacking rotated copies
|
| 276 |
+
slight_rotation = np.eye(4)
|
| 277 |
+
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
| 278 |
+
|
| 279 |
+
shell_scales = [1.0, 0.95]
|
| 280 |
+
shell_transforms = [np.eye(4), slight_rotation]
|
| 281 |
+
# Add extra shells for thickness
|
| 282 |
+
if frustum_thickness > 1.0:
|
| 283 |
+
n_extra = max(1, int(frustum_thickness - 1))
|
| 284 |
+
for k in range(1, n_extra + 1):
|
| 285 |
+
# Progressively rotated and scaled copies
|
| 286 |
+
angle = 2.0 + k * 2.0
|
| 287 |
+
scale = 1.0 + k * 0.02
|
| 288 |
+
rot = np.eye(4)
|
| 289 |
+
rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix()
|
| 290 |
+
shell_scales.append(scale)
|
| 291 |
+
shell_transforms.append(rot)
|
| 292 |
+
rot_neg = np.eye(4)
|
| 293 |
+
rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix()
|
| 294 |
+
shell_scales.append(scale)
|
| 295 |
+
shell_transforms.append(rot_neg)
|
| 296 |
+
|
| 297 |
+
vertices_parts = []
|
| 298 |
+
for s, t_mat in zip(shell_scales, shell_transforms):
|
| 299 |
+
vertices_parts.append(
|
| 300 |
+
transform_points(t_mat, s * camera_cone_shape.vertices)
|
| 301 |
+
)
|
| 302 |
+
vertices_combined = np.concatenate(vertices_parts)
|
| 303 |
+
vertices_transformed = transform_points(complete_transform, vertices_combined)
|
| 304 |
+
|
| 305 |
+
mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales))
|
| 306 |
+
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
|
| 307 |
+
camera_mesh.visual.face_colors[:, :3] = face_colors
|
| 308 |
+
scene.add_geometry(camera_mesh)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def apply_scene_alignment(
|
| 312 |
+
scene_3d: "trimesh.Scene",
|
| 313 |
+
extrinsics_matrices: np.ndarray
|
| 314 |
+
) -> "trimesh.Scene":
|
| 315 |
+
"""
|
| 316 |
+
Aligns the 3D scene based on the extrinsics of the first camera.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
scene_3d: The 3D scene to be aligned
|
| 320 |
+
extrinsics_matrices: Camera extrinsic matrices
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Aligned 3D scene
|
| 324 |
+
"""
|
| 325 |
+
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
| 326 |
+
|
| 327 |
+
align_rotation = np.eye(4)
|
| 328 |
+
align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
|
| 329 |
+
|
| 330 |
+
initial_transformation = (
|
| 331 |
+
np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
|
| 332 |
+
)
|
| 333 |
+
scene_3d.apply_transform(initial_transformation)
|
| 334 |
+
return scene_3d
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_opengl_conversion_matrix() -> np.ndarray:
|
| 338 |
+
"""Returns the OpenGL conversion matrix (flips Y and Z axes)."""
|
| 339 |
+
matrix = np.identity(4)
|
| 340 |
+
matrix[1, 1] = -1
|
| 341 |
+
matrix[2, 2] = -1
|
| 342 |
+
return matrix
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def transform_points(
|
| 346 |
+
transformation: np.ndarray,
|
| 347 |
+
points: np.ndarray,
|
| 348 |
+
dim: Optional[int] = None
|
| 349 |
+
) -> np.ndarray:
|
| 350 |
+
"""
|
| 351 |
+
Applies a 4x4 transformation to a set of points.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
transformation: Transformation matrix
|
| 355 |
+
points: Points to be transformed
|
| 356 |
+
dim: Dimension for reshaping the result
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
Transformed points
|
| 360 |
+
"""
|
| 361 |
+
points = np.asarray(points)
|
| 362 |
+
initial_shape = points.shape[:-1]
|
| 363 |
+
dim = dim or points.shape[-1]
|
| 364 |
+
|
| 365 |
+
transformation = transformation.swapaxes(-1, -2)
|
| 366 |
+
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
| 367 |
+
|
| 368 |
+
return points[..., :dim].reshape(*initial_shape, dim)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray:
|
| 372 |
+
"""Computes the faces for the camera mesh."""
|
| 373 |
+
faces_list = []
|
| 374 |
+
num_vertices_cone = len(cone_shape.vertices)
|
| 375 |
+
|
| 376 |
+
for face in cone_shape.faces:
|
| 377 |
+
if 0 in face:
|
| 378 |
+
continue
|
| 379 |
+
v1, v2, v3 = face
|
| 380 |
+
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
| 381 |
+
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
| 382 |
+
|
| 383 |
+
faces_list.extend([
|
| 384 |
+
(v1, v2, v2_offset),
|
| 385 |
+
(v1, v1_offset, v3),
|
| 386 |
+
(v3_offset, v2, v3),
|
| 387 |
+
(v1, v2, v2_offset_2),
|
| 388 |
+
(v1, v1_offset_2, v3),
|
| 389 |
+
(v3_offset_2, v2, v3),
|
| 390 |
+
])
|
| 391 |
+
|
| 392 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
| 393 |
+
return np.array(faces_list)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray:
|
| 397 |
+
"""Computes faces for a camera mesh with multiple shells (for thicker frustums).
|
| 398 |
+
|
| 399 |
+
Connects each consecutive pair of vertex shells to form the frustum edges.
|
| 400 |
+
"""
|
| 401 |
+
faces_list = []
|
| 402 |
+
nv = len(cone_shape.vertices)
|
| 403 |
+
|
| 404 |
+
for s in range(num_shells - 1):
|
| 405 |
+
off_a = s * nv
|
| 406 |
+
off_b = (s + 1) * nv
|
| 407 |
+
for face in cone_shape.faces:
|
| 408 |
+
if 0 in face:
|
| 409 |
+
continue
|
| 410 |
+
v1, v2, v3 = face
|
| 411 |
+
faces_list.extend([
|
| 412 |
+
(v1 + off_a, v2 + off_a, v2 + off_b),
|
| 413 |
+
(v1 + off_a, v1 + off_b, v3 + off_a),
|
| 414 |
+
(v3 + off_b, v2 + off_a, v3 + off_a),
|
| 415 |
+
])
|
| 416 |
+
|
| 417 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
| 418 |
+
return np.array(faces_list)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def segment_sky(
|
| 422 |
+
image_path: str,
|
| 423 |
+
onnx_session,
|
| 424 |
+
mask_filename: str
|
| 425 |
+
) -> np.ndarray:
|
| 426 |
+
"""
|
| 427 |
+
Segments sky from an image using an ONNX model.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
image_path: Path to input image
|
| 431 |
+
onnx_session: ONNX runtime session with loaded model
|
| 432 |
+
mask_filename: Path to save the output mask
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Continuous non-sky confidence map in [0, 1]
|
| 436 |
+
"""
|
| 437 |
+
image = cv2.imread(image_path)
|
| 438 |
+
result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image)
|
| 439 |
+
result_map_original = cv2.resize(
|
| 440 |
+
result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR
|
| 441 |
+
)
|
| 442 |
+
output_mask = _result_map_to_non_sky_conf(result_map_original)
|
| 443 |
+
|
| 444 |
+
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
| 445 |
+
cv2.imwrite(mask_filename, _mask_to_uint8(output_mask))
|
| 446 |
+
return output_mask
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def run_skyseg(
|
| 450 |
+
onnx_session,
|
| 451 |
+
input_size: Tuple[int, int],
|
| 452 |
+
image: np.ndarray
|
| 453 |
+
) -> np.ndarray:
|
| 454 |
+
"""
|
| 455 |
+
Runs sky segmentation inference using ONNX model.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
onnx_session: ONNX runtime session
|
| 459 |
+
input_size: Target size for model input (width, height)
|
| 460 |
+
image: Input image in BGR format
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
Segmentation mask
|
| 464 |
+
"""
|
| 465 |
+
temp_image = copy.deepcopy(image)
|
| 466 |
+
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
| 467 |
+
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
| 468 |
+
x = np.array(x, dtype=np.float32)
|
| 469 |
+
mean = [0.485, 0.456, 0.406]
|
| 470 |
+
std = [0.229, 0.224, 0.225]
|
| 471 |
+
x = (x / 255 - mean) / std
|
| 472 |
+
x = x.transpose(2, 0, 1)
|
| 473 |
+
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
| 474 |
+
|
| 475 |
+
input_name = onnx_session.get_inputs()[0].name
|
| 476 |
+
output_name = onnx_session.get_outputs()[0].name
|
| 477 |
+
onnx_result = onnx_session.run([output_name], {input_name: x})
|
| 478 |
+
|
| 479 |
+
onnx_result = np.array(onnx_result).squeeze()
|
| 480 |
+
min_value = np.min(onnx_result)
|
| 481 |
+
max_value = np.max(onnx_result)
|
| 482 |
+
onnx_result = (onnx_result - min_value) / (max_value - min_value)
|
| 483 |
+
onnx_result *= 255
|
| 484 |
+
return onnx_result.astype("uint8")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def download_file_from_url(url: str, filename: str):
|
| 488 |
+
"""Downloads a file from a URL, handling redirects."""
|
| 489 |
+
import requests
|
| 490 |
+
|
| 491 |
+
try:
|
| 492 |
+
response = requests.get(url, allow_redirects=False)
|
| 493 |
+
response.raise_for_status()
|
| 494 |
+
|
| 495 |
+
if response.status_code == 302:
|
| 496 |
+
redirect_url = response.headers["Location"]
|
| 497 |
+
response = requests.get(redirect_url, stream=True)
|
| 498 |
+
response.raise_for_status()
|
| 499 |
+
else:
|
| 500 |
+
print(f"Unexpected status code: {response.status_code}")
|
| 501 |
+
return
|
| 502 |
+
|
| 503 |
+
with open(filename, "wb") as f:
|
| 504 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 505 |
+
f.write(chunk)
|
| 506 |
+
print(f"Downloaded {filename} successfully.")
|
| 507 |
+
|
| 508 |
+
except requests.exceptions.RequestException as e:
|
| 509 |
+
print(f"Error downloading file: {e}")
|
lingbot_map/vis/point_cloud_viewer.py
ADDED
|
@@ -0,0 +1,1437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Interactive 3D Point Cloud Viewer using Viser.
|
| 9 |
+
|
| 10 |
+
This module provides the PointCloudViewer class for visualizing 3D reconstruction results,
|
| 11 |
+
including point clouds, camera poses, and animated playback.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
import threading
|
| 17 |
+
import subprocess
|
| 18 |
+
import tempfile
|
| 19 |
+
import shutil
|
| 20 |
+
from typing import List, Optional, Dict, Any, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import cv2
|
| 25 |
+
import matplotlib.cm as cm
|
| 26 |
+
from tqdm.auto import tqdm
|
| 27 |
+
|
| 28 |
+
import viser
|
| 29 |
+
import viser.transforms as tf
|
| 30 |
+
|
| 31 |
+
from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
|
| 32 |
+
from lingbot_map.vis.utils import CameraState
|
| 33 |
+
from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class PointCloudViewer:
|
| 37 |
+
"""
|
| 38 |
+
Interactive 3D point cloud viewer with camera visualization.
|
| 39 |
+
|
| 40 |
+
Features:
|
| 41 |
+
- Point cloud visualization with confidence-based filtering
|
| 42 |
+
- Camera frustum visualization with gradient colors
|
| 43 |
+
- Frame-by-frame playback animation (3D/4D modes)
|
| 44 |
+
- Range-based and recent-N-frames visualization modes
|
| 45 |
+
- Video export with FFmpeg
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model: Optional model for interactive inference
|
| 49 |
+
state_args: Optional state arguments
|
| 50 |
+
pc_list: List of point clouds per frame
|
| 51 |
+
color_list: List of colors per frame
|
| 52 |
+
conf_list: List of confidence scores per frame
|
| 53 |
+
cam_dict: Camera dictionary with focal, pp, R, t
|
| 54 |
+
image_mask: Optional image mask
|
| 55 |
+
edge_color_list: Optional edge colors
|
| 56 |
+
device: Device for computation
|
| 57 |
+
port: Viser server port
|
| 58 |
+
show_camera: Whether to show camera frustums
|
| 59 |
+
vis_threshold: Visibility threshold for filtering
|
| 60 |
+
size: Image size
|
| 61 |
+
downsample_factor: Point cloud downsample factor
|
| 62 |
+
point_size: Initial point size
|
| 63 |
+
pred_dict: Prediction dictionary (alternative to pc_list/color_list/conf_list)
|
| 64 |
+
init_conf_threshold: Initial confidence threshold percentage
|
| 65 |
+
use_point_map: Use point map instead of depth-based points
|
| 66 |
+
mask_sky: Apply sky segmentation
|
| 67 |
+
image_folder: Path to image folder (for sky segmentation)
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
model=None,
|
| 73 |
+
state_args=None,
|
| 74 |
+
pc_list=None,
|
| 75 |
+
color_list=None,
|
| 76 |
+
conf_list=None,
|
| 77 |
+
cam_dict=None,
|
| 78 |
+
image_mask=None,
|
| 79 |
+
edge_color_list=None,
|
| 80 |
+
device: str = "cpu",
|
| 81 |
+
port: int = 8080,
|
| 82 |
+
show_camera: bool = True,
|
| 83 |
+
vis_threshold: float = 1.0,
|
| 84 |
+
size: int = 512,
|
| 85 |
+
downsample_factor: int = 10,
|
| 86 |
+
point_size: float = 0.00001,
|
| 87 |
+
pred_dict: Optional[Dict] = None,
|
| 88 |
+
init_conf_threshold: float = 50.0,
|
| 89 |
+
use_point_map: bool = False,
|
| 90 |
+
mask_sky: bool = False,
|
| 91 |
+
image_folder: Optional[str] = None,
|
| 92 |
+
sky_mask_dir: Optional[str] = None,
|
| 93 |
+
sky_mask_visualization_dir: Optional[str] = None,
|
| 94 |
+
depth_stride: int = 1,
|
| 95 |
+
):
|
| 96 |
+
self.model = model
|
| 97 |
+
self.size = size
|
| 98 |
+
self.state_args = state_args
|
| 99 |
+
self.server = viser.ViserServer(host="0.0.0.0", port=port)
|
| 100 |
+
self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
|
| 101 |
+
self.device = device
|
| 102 |
+
self.conf_list = conf_list
|
| 103 |
+
self.vis_threshold = vis_threshold
|
| 104 |
+
self.point_size = point_size
|
| 105 |
+
self.tt = lambda x: torch.from_numpy(x).float().to(device)
|
| 106 |
+
|
| 107 |
+
# Process the prediction dictionary to create pc_list, color_list, conf_list
|
| 108 |
+
if pred_dict is not None:
|
| 109 |
+
pc_list, color_list, conf_list, cam_dict = self._process_pred_dict(
|
| 110 |
+
pred_dict, use_point_map, mask_sky, image_folder,
|
| 111 |
+
sky_mask_dir=sky_mask_dir,
|
| 112 |
+
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
| 113 |
+
depth_stride=depth_stride,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
self.original_images = []
|
| 117 |
+
|
| 118 |
+
self.pcs, self.all_steps = self.read_data(
|
| 119 |
+
pc_list, color_list, conf_list, edge_color_list
|
| 120 |
+
)
|
| 121 |
+
self.cam_dict = cam_dict
|
| 122 |
+
self.num_frames = len(self.all_steps)
|
| 123 |
+
self.image_mask = image_mask
|
| 124 |
+
self.show_camera = show_camera
|
| 125 |
+
self.on_replay = False
|
| 126 |
+
self.vis_pts_list = []
|
| 127 |
+
self.traj_list = []
|
| 128 |
+
self.orig_img_list = [x[0] for x in color_list if len(x) > 0] if color_list else []
|
| 129 |
+
self.via_points = []
|
| 130 |
+
|
| 131 |
+
self._setup_gui()
|
| 132 |
+
self.server.on_client_connect(self._connect_client)
|
| 133 |
+
|
| 134 |
+
def _process_pred_dict(
|
| 135 |
+
self,
|
| 136 |
+
pred_dict: Dict,
|
| 137 |
+
use_point_map: bool,
|
| 138 |
+
mask_sky: bool,
|
| 139 |
+
image_folder: Optional[str],
|
| 140 |
+
sky_mask_dir: Optional[str] = None,
|
| 141 |
+
sky_mask_visualization_dir: Optional[str] = None,
|
| 142 |
+
depth_stride: int = 1,
|
| 143 |
+
) -> Tuple[List, List, List, Dict]:
|
| 144 |
+
"""Process prediction dictionary to extract visualization data.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
pred_dict: Model prediction dictionary.
|
| 148 |
+
use_point_map: Use point map instead of depth-based projection.
|
| 149 |
+
mask_sky: Apply sky segmentation to filter sky points.
|
| 150 |
+
image_folder: Path to images for sky segmentation.
|
| 151 |
+
sky_mask_dir: Directory for cached sky masks.
|
| 152 |
+
sky_mask_visualization_dir: Directory for sky mask visualization images.
|
| 153 |
+
depth_stride: Only project depth to point cloud every N frames.
|
| 154 |
+
Frames not projected will have empty point clouds but still
|
| 155 |
+
show camera frustums and images. 1 = every frame (default).
|
| 156 |
+
"""
|
| 157 |
+
images = pred_dict["images"] # (S, 3, H, W)
|
| 158 |
+
|
| 159 |
+
depth_map = pred_dict.get("depth") # (S, H, W, 1)
|
| 160 |
+
depth_conf = pred_dict.get("depth_conf") # (S, H, W)
|
| 161 |
+
|
| 162 |
+
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
|
| 163 |
+
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
|
| 164 |
+
|
| 165 |
+
# Compute world points from depth if not using the precomputed point map
|
| 166 |
+
if not use_point_map:
|
| 167 |
+
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
|
| 168 |
+
conf = depth_conf
|
| 169 |
+
else:
|
| 170 |
+
world_points = pred_dict["world_points"] # (S, H, W, 3)
|
| 171 |
+
conf = pred_dict.get("world_points_conf", depth_conf) # (S, H, W)
|
| 172 |
+
|
| 173 |
+
# Apply sky segmentation if enabled
|
| 174 |
+
if mask_sky:
|
| 175 |
+
conf = apply_sky_segmentation(
|
| 176 |
+
conf, image_folder=image_folder, images=images,
|
| 177 |
+
sky_mask_dir=sky_mask_dir,
|
| 178 |
+
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Convert images from (S, 3, H, W) to (S, H, W, 3)
|
| 182 |
+
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
| 183 |
+
S = world_points.shape[0]
|
| 184 |
+
|
| 185 |
+
# Store original images for camera frustum display
|
| 186 |
+
self.original_images = []
|
| 187 |
+
for i in range(S):
|
| 188 |
+
img = images[i] # shape (3, H, W)
|
| 189 |
+
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 190 |
+
self.original_images.append(img)
|
| 191 |
+
|
| 192 |
+
# Create lists - apply depth_stride to skip frames for point projection
|
| 193 |
+
H, W = world_points.shape[1], world_points.shape[2]
|
| 194 |
+
pc_list = []
|
| 195 |
+
color_list = []
|
| 196 |
+
conf_list = []
|
| 197 |
+
skipped = 0
|
| 198 |
+
for i in range(S):
|
| 199 |
+
if depth_stride > 1 and i % depth_stride != 0:
|
| 200 |
+
# Empty point cloud for skipped frames
|
| 201 |
+
pc_list.append(np.zeros((0, 0, 3), dtype=np.float32))
|
| 202 |
+
color_list.append(np.zeros((0, 0, 3), dtype=np.float32))
|
| 203 |
+
conf_list.append(np.zeros((0, 0), dtype=np.float32))
|
| 204 |
+
skipped += 1
|
| 205 |
+
else:
|
| 206 |
+
pc_list.append(world_points[i])
|
| 207 |
+
color_list.append(colors[i])
|
| 208 |
+
if conf is not None:
|
| 209 |
+
conf_list.append(conf[i])
|
| 210 |
+
else:
|
| 211 |
+
conf_list.append(np.ones(world_points[i].shape[:2], dtype=np.float32))
|
| 212 |
+
|
| 213 |
+
if depth_stride > 1:
|
| 214 |
+
print(f' depth_stride={depth_stride}: projecting {S - skipped}/{S} frames, skipping {skipped}')
|
| 215 |
+
|
| 216 |
+
# Create camera dictionary (all frames keep cameras)
|
| 217 |
+
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
|
| 218 |
+
cam_dict = {
|
| 219 |
+
"focal": [intrinsics_cam[i, 0, 0] for i in range(S)],
|
| 220 |
+
"pp": [(intrinsics_cam[i, 0, 2], intrinsics_cam[i, 1, 2]) for i in range(S)],
|
| 221 |
+
"R": [cam_to_world_mat[i, :3, :3] for i in range(S)],
|
| 222 |
+
"t": [cam_to_world_mat[i, :3, 3] for i in range(S)],
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
return pc_list, color_list, conf_list, cam_dict
|
| 226 |
+
|
| 227 |
+
def _compute_scene_center_and_scale(self) -> Tuple[np.ndarray, float]:
|
| 228 |
+
"""Compute scene center and scale from camera positions and point clouds.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Tuple of (center as 3D array, scale as float distance).
|
| 232 |
+
"""
|
| 233 |
+
# Use camera positions as primary reference (more reliable than noisy points)
|
| 234 |
+
if self.cam_dict is not None and "t" in self.cam_dict:
|
| 235 |
+
cam_positions = np.array([self.cam_dict["t"][s] for s in self.all_steps])
|
| 236 |
+
center = np.mean(cam_positions, axis=0)
|
| 237 |
+
if len(cam_positions) > 1:
|
| 238 |
+
extent = np.ptp(cam_positions, axis=0) # range per axis
|
| 239 |
+
scale = np.linalg.norm(extent)
|
| 240 |
+
else:
|
| 241 |
+
scale = 1.0
|
| 242 |
+
else:
|
| 243 |
+
# Fallback: use point cloud data
|
| 244 |
+
all_pts = []
|
| 245 |
+
for step in self.all_steps:
|
| 246 |
+
pc = self.pcs[step]["pc"].reshape(-1, 3)
|
| 247 |
+
# subsample for speed
|
| 248 |
+
if len(pc) > 1000:
|
| 249 |
+
pc = pc[::len(pc) // 1000]
|
| 250 |
+
all_pts.append(pc)
|
| 251 |
+
all_pts = np.concatenate(all_pts, axis=0)
|
| 252 |
+
center = np.median(all_pts, axis=0)
|
| 253 |
+
extent = np.percentile(all_pts, 95, axis=0) - np.percentile(all_pts, 5, axis=0)
|
| 254 |
+
scale = np.linalg.norm(extent)
|
| 255 |
+
|
| 256 |
+
return center, max(scale, 0.1)
|
| 257 |
+
|
| 258 |
+
def _reset_view_to_direction(
|
| 259 |
+
self,
|
| 260 |
+
direction: np.ndarray,
|
| 261 |
+
up: np.ndarray = np.array([0.0, -1.0, 0.0]),
|
| 262 |
+
distance_scale: float = 1.5,
|
| 263 |
+
smooth: bool = True,
|
| 264 |
+
):
|
| 265 |
+
"""Reset the viewer camera to look at scene center from a given direction.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
direction: Unit vector pointing FROM the scene center TO the camera.
|
| 269 |
+
up: Up vector for the camera.
|
| 270 |
+
distance_scale: Multiplier on scene scale for camera distance.
|
| 271 |
+
smooth: Whether to smoothly transition.
|
| 272 |
+
"""
|
| 273 |
+
center, scale = self._compute_scene_center_and_scale()
|
| 274 |
+
distance = scale * distance_scale
|
| 275 |
+
position = center + direction * distance
|
| 276 |
+
|
| 277 |
+
for client in self.server.get_clients().values():
|
| 278 |
+
if smooth:
|
| 279 |
+
self._smooth_camera_transition(
|
| 280 |
+
client,
|
| 281 |
+
target_position=position,
|
| 282 |
+
target_look_at=center,
|
| 283 |
+
target_up=up,
|
| 284 |
+
duration=0.4,
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
client.camera.up_direction = tuple(up)
|
| 288 |
+
client.camera.position = tuple(position)
|
| 289 |
+
client.camera.look_at = tuple(center)
|
| 290 |
+
|
| 291 |
+
def _setup_gui(self):
|
| 292 |
+
"""Setup GUI controls."""
|
| 293 |
+
gui_reset_up = self.server.gui.add_button(
|
| 294 |
+
"Reset up direction",
|
| 295 |
+
hint="Set the camera control 'up' direction to the current camera's 'up'.",
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
@gui_reset_up.on_click
|
| 299 |
+
def _(event: viser.GuiEvent) -> None:
|
| 300 |
+
client = event.client
|
| 301 |
+
assert client is not None
|
| 302 |
+
client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
|
| 303 |
+
[0.0, -1.0, 0.0]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Video frame display controls — kept at top so the current frame is always visible
|
| 307 |
+
with self.server.gui.add_folder("Video Display"):
|
| 308 |
+
self.show_video_checkbox = self.server.gui.add_checkbox("Show Current Frame", initial_value=True)
|
| 309 |
+
if hasattr(self, 'original_images') and len(self.original_images) > 0:
|
| 310 |
+
self.current_frame_image = self.server.gui.add_image(
|
| 311 |
+
self.original_images[0], label="Current Frame"
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
self.current_frame_image = None
|
| 315 |
+
|
| 316 |
+
# Preset view direction buttons
|
| 317 |
+
with self.server.gui.add_folder("Reset View Direction"):
|
| 318 |
+
btn_look_at_center = self.server.gui.add_button(
|
| 319 |
+
"Look At Scene Center",
|
| 320 |
+
hint="Reset orbit center to the scene center (fixes orbit after dragging).",
|
| 321 |
+
)
|
| 322 |
+
btn_overview = self.server.gui.add_button(
|
| 323 |
+
"Overview",
|
| 324 |
+
hint="Reset to a 3/4 overview of the scene.",
|
| 325 |
+
)
|
| 326 |
+
btn_front = self.server.gui.add_button(
|
| 327 |
+
"Front (+Z)",
|
| 328 |
+
hint="View scene from the front.",
|
| 329 |
+
)
|
| 330 |
+
btn_back = self.server.gui.add_button(
|
| 331 |
+
"Back (-Z)",
|
| 332 |
+
hint="View scene from the back.",
|
| 333 |
+
)
|
| 334 |
+
btn_top = self.server.gui.add_button(
|
| 335 |
+
"Top (-Y)",
|
| 336 |
+
hint="View scene from above (bird's eye).",
|
| 337 |
+
)
|
| 338 |
+
btn_left = self.server.gui.add_button(
|
| 339 |
+
"Left (-X)",
|
| 340 |
+
hint="View scene from the left.",
|
| 341 |
+
)
|
| 342 |
+
btn_right = self.server.gui.add_button(
|
| 343 |
+
"Right (+X)",
|
| 344 |
+
hint="View scene from the right.",
|
| 345 |
+
)
|
| 346 |
+
btn_first_cam = self.server.gui.add_button(
|
| 347 |
+
"First Camera",
|
| 348 |
+
hint="Reset to the first camera's viewpoint.",
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
@btn_look_at_center.on_click
|
| 352 |
+
def _(_) -> None:
|
| 353 |
+
center, _ = self._compute_scene_center_and_scale()
|
| 354 |
+
for client in self.server.get_clients().values():
|
| 355 |
+
client.camera.look_at = tuple(center)
|
| 356 |
+
|
| 357 |
+
@btn_overview.on_click
|
| 358 |
+
def _(_) -> None:
|
| 359 |
+
d = np.array([0.5, -0.6, 0.6])
|
| 360 |
+
self._reset_view_to_direction(d / np.linalg.norm(d))
|
| 361 |
+
|
| 362 |
+
@btn_front.on_click
|
| 363 |
+
def _(_) -> None:
|
| 364 |
+
self._reset_view_to_direction(np.array([0.0, 0.0, 1.0]))
|
| 365 |
+
|
| 366 |
+
@btn_back.on_click
|
| 367 |
+
def _(_) -> None:
|
| 368 |
+
self._reset_view_to_direction(np.array([0.0, 0.0, -1.0]))
|
| 369 |
+
|
| 370 |
+
@btn_top.on_click
|
| 371 |
+
def _(_) -> None:
|
| 372 |
+
self._reset_view_to_direction(
|
| 373 |
+
np.array([0.0, -1.0, 0.0]),
|
| 374 |
+
up=np.array([0.0, 0.0, 1.0]),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
@btn_left.on_click
|
| 378 |
+
def _(_) -> None:
|
| 379 |
+
self._reset_view_to_direction(np.array([-1.0, 0.0, 0.0]))
|
| 380 |
+
|
| 381 |
+
@btn_right.on_click
|
| 382 |
+
def _(_) -> None:
|
| 383 |
+
self._reset_view_to_direction(np.array([1.0, 0.0, 0.0]))
|
| 384 |
+
|
| 385 |
+
@btn_first_cam.on_click
|
| 386 |
+
def _(_) -> None:
|
| 387 |
+
self._move_to_camera(0, smooth=True)
|
| 388 |
+
|
| 389 |
+
button3 = self.server.gui.add_button("4D (Only Show Current Frame)")
|
| 390 |
+
button4 = self.server.gui.add_button("3D (Show All Frames)")
|
| 391 |
+
self.is_render = False
|
| 392 |
+
self.fourd = False
|
| 393 |
+
|
| 394 |
+
@button3.on_click
|
| 395 |
+
def _(event: viser.GuiEvent) -> None:
|
| 396 |
+
self.fourd = True
|
| 397 |
+
|
| 398 |
+
@button4.on_click
|
| 399 |
+
def _(event: viser.GuiEvent) -> None:
|
| 400 |
+
self.fourd = False
|
| 401 |
+
|
| 402 |
+
self.focal_slider = self.server.gui.add_slider(
|
| 403 |
+
"Focal Length", min=0.1, max=99999, step=1, initial_value=533
|
| 404 |
+
)
|
| 405 |
+
self.psize_slider = self.server.gui.add_slider(
|
| 406 |
+
"Point Size", min=0.00001, max=0.1, step=0.00001, initial_value=self.point_size
|
| 407 |
+
)
|
| 408 |
+
self.camsize_slider = self.server.gui.add_slider(
|
| 409 |
+
"Camera Size", min=0.01, max=0.5, step=0.01, initial_value=0.1
|
| 410 |
+
)
|
| 411 |
+
self.downsample_slider = self.server.gui.add_slider(
|
| 412 |
+
"Downsample Factor", min=1, max=1000, step=1, initial_value=10
|
| 413 |
+
)
|
| 414 |
+
self.show_camera_checkbox = self.server.gui.add_checkbox(
|
| 415 |
+
"Show Camera", initial_value=self.show_camera
|
| 416 |
+
)
|
| 417 |
+
self.vis_threshold_slider = self.server.gui.add_slider(
|
| 418 |
+
"Visibility Threshold", min=1.0, max=5.0, step=0.01,
|
| 419 |
+
initial_value=self.vis_threshold,
|
| 420 |
+
)
|
| 421 |
+
self.camera_downsample_slider = self.server.gui.add_slider(
|
| 422 |
+
"Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Screenshot controls
|
| 426 |
+
with self.server.gui.add_folder("Screenshot"):
|
| 427 |
+
self.screenshot_button = self.server.gui.add_button("Take Screenshot")
|
| 428 |
+
self.screenshot_resolution = self.server.gui.add_dropdown(
|
| 429 |
+
"Resolution",
|
| 430 |
+
options=["1920x1080", "2560x1440", "3840x2160", "Current"],
|
| 431 |
+
initial_value="1920x1080",
|
| 432 |
+
)
|
| 433 |
+
self.screenshot_path = self.server.gui.add_text(
|
| 434 |
+
"Save Path", initial_value="screenshot.png"
|
| 435 |
+
)
|
| 436 |
+
self.screenshot_status = self.server.gui.add_text(
|
| 437 |
+
"Status", initial_value="Ready"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
@self.screenshot_button.on_click
|
| 441 |
+
def _(event: viser.GuiEvent) -> None:
|
| 442 |
+
self._take_screenshot(event.client)
|
| 443 |
+
|
| 444 |
+
# GLB export controls
|
| 445 |
+
with self.server.gui.add_folder("Export GLB"):
|
| 446 |
+
self.glb_output_path = self.server.gui.add_text(
|
| 447 |
+
"Output Path", initial_value="export.glb"
|
| 448 |
+
)
|
| 449 |
+
self.glb_show_cam_checkbox = self.server.gui.add_checkbox(
|
| 450 |
+
"Include Cameras", initial_value=True,
|
| 451 |
+
)
|
| 452 |
+
self.glb_cam_scale_slider = self.server.gui.add_slider(
|
| 453 |
+
"Camera Scale", min=0.01, max=5.0, step=0.01, initial_value=1.0,
|
| 454 |
+
hint="Scale factor for camera size in GLB.",
|
| 455 |
+
)
|
| 456 |
+
self.glb_frustum_thickness_slider = self.server.gui.add_slider(
|
| 457 |
+
"Frustum Thickness", min=1.0, max=10.0, step=0.5, initial_value=3.0,
|
| 458 |
+
hint="Thickness multiplier for camera frustum edges.",
|
| 459 |
+
)
|
| 460 |
+
self.glb_trajectory_checkbox = self.server.gui.add_checkbox(
|
| 461 |
+
"Show Trajectory", initial_value=True,
|
| 462 |
+
hint="Connect cameras with a trajectory line.",
|
| 463 |
+
)
|
| 464 |
+
self.glb_trajectory_radius_slider = self.server.gui.add_slider(
|
| 465 |
+
"Trajectory Radius", min=0.001, max=0.05, step=0.001, initial_value=0.005,
|
| 466 |
+
hint="Radius of the trajectory tube.",
|
| 467 |
+
)
|
| 468 |
+
self.glb_mode_dropdown = self.server.gui.add_dropdown(
|
| 469 |
+
"Export Mode",
|
| 470 |
+
options=["Points", "Spheres"],
|
| 471 |
+
initial_value="Points",
|
| 472 |
+
hint="Points: raw (fast). Spheres: each point becomes a small sphere (prettier, slower).",
|
| 473 |
+
)
|
| 474 |
+
self.glb_sphere_radius_slider = self.server.gui.add_slider(
|
| 475 |
+
"Sphere Radius", min=0.001, max=0.1, step=0.001, initial_value=0.005,
|
| 476 |
+
hint="Radius of each sphere in Spheres mode.",
|
| 477 |
+
disabled=True,
|
| 478 |
+
)
|
| 479 |
+
self.glb_max_sphere_pts_slider = self.server.gui.add_slider(
|
| 480 |
+
"Max Sphere Points", min=10000, max=500000, step=10000, initial_value=100000,
|
| 481 |
+
hint="Cap point count for Spheres mode to keep file size manageable.",
|
| 482 |
+
disabled=True,
|
| 483 |
+
)
|
| 484 |
+
self.glb_opacity_slider = self.server.gui.add_slider(
|
| 485 |
+
"Opacity", min=0.0, max=1.0, step=0.05, initial_value=1.0,
|
| 486 |
+
hint="Point/sphere opacity (alpha). <1.0 = semi-transparent.",
|
| 487 |
+
)
|
| 488 |
+
self.glb_saturation_slider = self.server.gui.add_slider(
|
| 489 |
+
"Saturation Boost", min=0.0, max=2.0, step=0.1, initial_value=1.0,
|
| 490 |
+
hint="Color saturation multiplier. >1 = more vivid, <1 = washed out.",
|
| 491 |
+
)
|
| 492 |
+
self.glb_brightness_slider = self.server.gui.add_slider(
|
| 493 |
+
"Brightness Boost", min=0.5, max=2.0, step=0.1, initial_value=1.0,
|
| 494 |
+
hint="Color brightness multiplier.",
|
| 495 |
+
)
|
| 496 |
+
self.glb_export_button = self.server.gui.add_button(
|
| 497 |
+
"Export GLB",
|
| 498 |
+
hint="Export current filtered point clouds and cameras as GLB.",
|
| 499 |
+
)
|
| 500 |
+
self.glb_status = self.server.gui.add_text("Status", initial_value="Ready")
|
| 501 |
+
|
| 502 |
+
@self.glb_mode_dropdown.on_update
|
| 503 |
+
def _(_) -> None:
|
| 504 |
+
is_sphere = self.glb_mode_dropdown.value == "Spheres"
|
| 505 |
+
self.glb_sphere_radius_slider.disabled = not is_sphere
|
| 506 |
+
self.glb_max_sphere_pts_slider.disabled = not is_sphere
|
| 507 |
+
|
| 508 |
+
@self.glb_export_button.on_click
|
| 509 |
+
def _(_) -> None:
|
| 510 |
+
self._export_glb()
|
| 511 |
+
|
| 512 |
+
# Video saving controls
|
| 513 |
+
with self.server.gui.add_folder("Video Saving"):
|
| 514 |
+
self.save_video_button = self.server.gui.add_button("Save Video", disabled=False)
|
| 515 |
+
self.video_output_path = self.server.gui.add_text("Output Path", initial_value="output_pointcloud.mp4")
|
| 516 |
+
self.video_save_fps = self.server.gui.add_slider("Video FPS", min=10, max=60, step=1, initial_value=30)
|
| 517 |
+
self.video_resolution = self.server.gui.add_dropdown(
|
| 518 |
+
"Resolution", options=["1920x1080", "1280x720", "3840x2160"], initial_value="1920x1080"
|
| 519 |
+
)
|
| 520 |
+
self.save_original_video_checkbox = self.server.gui.add_checkbox("Also Save Original Video", initial_value=True)
|
| 521 |
+
self.video_status = self.server.gui.add_text("Status", initial_value="Ready to save")
|
| 522 |
+
|
| 523 |
+
@self.save_video_button.on_click
|
| 524 |
+
def _(_) -> None:
|
| 525 |
+
self.save_video(
|
| 526 |
+
output_path=self.video_output_path.value,
|
| 527 |
+
fps=self.video_save_fps.value,
|
| 528 |
+
resolution=self.video_resolution.value,
|
| 529 |
+
save_original_video=self.save_original_video_checkbox.value
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
@self.show_video_checkbox.on_update
|
| 533 |
+
def _(_) -> None:
|
| 534 |
+
if self.current_frame_image is not None:
|
| 535 |
+
self.current_frame_image.visible = self.show_video_checkbox.value
|
| 536 |
+
|
| 537 |
+
self.pc_handles = []
|
| 538 |
+
self.cam_handles = []
|
| 539 |
+
|
| 540 |
+
@self.psize_slider.on_update
|
| 541 |
+
def _(_) -> None:
|
| 542 |
+
for handle in self.pc_handles:
|
| 543 |
+
handle.point_size = self.psize_slider.value
|
| 544 |
+
|
| 545 |
+
@self.camsize_slider.on_update
|
| 546 |
+
def _(_) -> None:
|
| 547 |
+
for handle in self.cam_handles:
|
| 548 |
+
handle.scale = self.camsize_slider.value
|
| 549 |
+
handle.line_thickness = 0.03 * handle.scale
|
| 550 |
+
|
| 551 |
+
@self.downsample_slider.on_update
|
| 552 |
+
def _(_) -> None:
|
| 553 |
+
self._regenerate_point_clouds()
|
| 554 |
+
|
| 555 |
+
@self.show_camera_checkbox.on_update
|
| 556 |
+
def _(_) -> None:
|
| 557 |
+
self.show_camera = self.show_camera_checkbox.value
|
| 558 |
+
if self.show_camera:
|
| 559 |
+
self._regenerate_cameras()
|
| 560 |
+
else:
|
| 561 |
+
for handle in self.cam_handles:
|
| 562 |
+
handle.visible = False
|
| 563 |
+
|
| 564 |
+
@self.vis_threshold_slider.on_update
|
| 565 |
+
def _(_) -> None:
|
| 566 |
+
self.vis_threshold = self.vis_threshold_slider.value
|
| 567 |
+
self._regenerate_point_clouds()
|
| 568 |
+
|
| 569 |
+
@self.camera_downsample_slider.on_update
|
| 570 |
+
def _(_) -> None:
|
| 571 |
+
self._regenerate_cameras()
|
| 572 |
+
|
| 573 |
+
def _regenerate_point_clouds(self):
|
| 574 |
+
"""Regenerate all point clouds with current settings."""
|
| 575 |
+
if not hasattr(self, 'frame_nodes'):
|
| 576 |
+
return
|
| 577 |
+
|
| 578 |
+
for handle in self.pc_handles:
|
| 579 |
+
try:
|
| 580 |
+
handle.remove()
|
| 581 |
+
except (KeyError, AttributeError):
|
| 582 |
+
pass
|
| 583 |
+
self.pc_handles.clear()
|
| 584 |
+
self.vis_pts_list.clear()
|
| 585 |
+
|
| 586 |
+
for i, step in enumerate(self.all_steps):
|
| 587 |
+
pc = self.pcs[step]["pc"]
|
| 588 |
+
color = self.pcs[step]["color"]
|
| 589 |
+
conf = self.pcs[step]["conf"]
|
| 590 |
+
edge_color = self.pcs[step].get("edge_color", None)
|
| 591 |
+
|
| 592 |
+
pred_pts, pc_color = self.parse_pc_data(
|
| 593 |
+
pc, color, conf, edge_color, set_border_color=True,
|
| 594 |
+
downsample_factor=self.downsample_slider.value
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
self.vis_pts_list.append(pred_pts)
|
| 598 |
+
handle = self.server.scene.add_point_cloud(
|
| 599 |
+
name=f"/frames/{step}/pred_pts",
|
| 600 |
+
points=pred_pts,
|
| 601 |
+
colors=pc_color,
|
| 602 |
+
point_size=self.psize_slider.value,
|
| 603 |
+
)
|
| 604 |
+
self.pc_handles.append(handle)
|
| 605 |
+
|
| 606 |
+
def _regenerate_cameras(self):
|
| 607 |
+
"""Regenerate camera visualizations with current settings."""
|
| 608 |
+
if not hasattr(self, 'frame_nodes'):
|
| 609 |
+
return
|
| 610 |
+
|
| 611 |
+
for handle in self.cam_handles:
|
| 612 |
+
try:
|
| 613 |
+
handle.remove()
|
| 614 |
+
except (KeyError, AttributeError):
|
| 615 |
+
pass
|
| 616 |
+
self.cam_handles.clear()
|
| 617 |
+
|
| 618 |
+
if self.show_camera:
|
| 619 |
+
downsample_factor = int(self.camera_downsample_slider.value)
|
| 620 |
+
for i, step in enumerate(self.all_steps):
|
| 621 |
+
if i % downsample_factor == 0:
|
| 622 |
+
self.add_camera(step)
|
| 623 |
+
|
| 624 |
+
def _export_glb(self):
|
| 625 |
+
"""Export current filtered point clouds and cameras as a GLB file."""
|
| 626 |
+
try:
|
| 627 |
+
import trimesh
|
| 628 |
+
except ImportError:
|
| 629 |
+
self.glb_status.value = "Error: pip install trimesh"
|
| 630 |
+
return
|
| 631 |
+
|
| 632 |
+
self.glb_status.value = "Collecting points..."
|
| 633 |
+
print("Exporting GLB...")
|
| 634 |
+
|
| 635 |
+
# Collect all currently visible, filtered points and colors
|
| 636 |
+
all_points = []
|
| 637 |
+
all_colors = []
|
| 638 |
+
for step in self.all_steps:
|
| 639 |
+
pc = self.pcs[step]["pc"]
|
| 640 |
+
color = self.pcs[step]["color"]
|
| 641 |
+
conf = self.pcs[step]["conf"]
|
| 642 |
+
edge_color = self.pcs[step].get("edge_color", None)
|
| 643 |
+
|
| 644 |
+
pts, cols = self.parse_pc_data(
|
| 645 |
+
pc, color, conf, edge_color, set_border_color=False,
|
| 646 |
+
downsample_factor=self.downsample_slider.value,
|
| 647 |
+
)
|
| 648 |
+
if len(pts) > 0:
|
| 649 |
+
all_points.append(pts)
|
| 650 |
+
if cols.dtype != np.uint8:
|
| 651 |
+
cols = (np.clip(cols, 0, 1) * 255).astype(np.uint8)
|
| 652 |
+
all_colors.append(cols)
|
| 653 |
+
|
| 654 |
+
if not all_points:
|
| 655 |
+
self.glb_status.value = "Error: no points to export"
|
| 656 |
+
return
|
| 657 |
+
|
| 658 |
+
vertices = np.concatenate(all_points, axis=0)
|
| 659 |
+
colors_rgb = np.concatenate(all_colors, axis=0)
|
| 660 |
+
|
| 661 |
+
# --- Color enhancement ---
|
| 662 |
+
colors_float = colors_rgb.astype(np.float32) / 255.0
|
| 663 |
+
|
| 664 |
+
sat_boost = self.glb_saturation_slider.value
|
| 665 |
+
if sat_boost != 1.0:
|
| 666 |
+
gray = colors_float.mean(axis=1, keepdims=True)
|
| 667 |
+
colors_float = gray + sat_boost * (colors_float - gray)
|
| 668 |
+
|
| 669 |
+
bri_boost = self.glb_brightness_slider.value
|
| 670 |
+
if bri_boost != 1.0:
|
| 671 |
+
colors_float = colors_float * bri_boost
|
| 672 |
+
|
| 673 |
+
colors_float = np.clip(colors_float, 0.0, 1.0)
|
| 674 |
+
|
| 675 |
+
# --- Opacity ---
|
| 676 |
+
# Simulate opacity by blending colors toward white (works in all viewers).
|
| 677 |
+
# For Spheres mode, also set true alpha for viewers that support it.
|
| 678 |
+
alpha = self.glb_opacity_slider.value
|
| 679 |
+
if alpha < 1.0:
|
| 680 |
+
bg = np.ones_like(colors_float) # white background
|
| 681 |
+
colors_float = colors_float * alpha + bg * (1.0 - alpha)
|
| 682 |
+
colors_float = np.clip(colors_float, 0.0, 1.0)
|
| 683 |
+
|
| 684 |
+
colors_u8 = (colors_float * 255).astype(np.uint8)
|
| 685 |
+
colors_rgba = np.concatenate([
|
| 686 |
+
colors_u8,
|
| 687 |
+
np.full((len(colors_u8), 1), int(alpha * 255), dtype=np.uint8),
|
| 688 |
+
], axis=1) # (N, 4)
|
| 689 |
+
|
| 690 |
+
# Compute scene scale for camera sizing
|
| 691 |
+
lo = np.percentile(vertices, 5, axis=0)
|
| 692 |
+
hi = np.percentile(vertices, 95, axis=0)
|
| 693 |
+
scene_scale = max(np.linalg.norm(hi - lo), 0.1)
|
| 694 |
+
|
| 695 |
+
scene_3d = trimesh.Scene()
|
| 696 |
+
|
| 697 |
+
# --- Export mode ---
|
| 698 |
+
export_mode = self.glb_mode_dropdown.value
|
| 699 |
+
if export_mode == "Spheres":
|
| 700 |
+
self.glb_status.value = "Building spheres..."
|
| 701 |
+
max_pts = int(self.glb_max_sphere_pts_slider.value)
|
| 702 |
+
radius = self.glb_sphere_radius_slider.value
|
| 703 |
+
|
| 704 |
+
# Subsample if too many points
|
| 705 |
+
if len(vertices) > max_pts:
|
| 706 |
+
idx = np.random.choice(len(vertices), max_pts, replace=False)
|
| 707 |
+
idx.sort()
|
| 708 |
+
vertices = vertices[idx]
|
| 709 |
+
colors_rgba = colors_rgba[idx]
|
| 710 |
+
|
| 711 |
+
sphere_template = trimesh.creation.icosphere(subdivisions=1, radius=radius)
|
| 712 |
+
n_verts_per = len(sphere_template.vertices)
|
| 713 |
+
n_faces_per = len(sphere_template.faces)
|
| 714 |
+
|
| 715 |
+
all_verts = np.empty((len(vertices) * n_verts_per, 3), dtype=np.float32)
|
| 716 |
+
all_faces = np.empty((len(vertices) * n_faces_per, 3), dtype=np.int64)
|
| 717 |
+
all_face_colors = np.empty((len(vertices) * n_faces_per, 4), dtype=np.uint8)
|
| 718 |
+
|
| 719 |
+
for i, (pt, rgba) in enumerate(zip(vertices, colors_rgba)):
|
| 720 |
+
v_off = i * n_verts_per
|
| 721 |
+
f_off = i * n_faces_per
|
| 722 |
+
all_verts[v_off:v_off + n_verts_per] = sphere_template.vertices + pt
|
| 723 |
+
all_faces[f_off:f_off + n_faces_per] = sphere_template.faces + v_off
|
| 724 |
+
all_face_colors[f_off:f_off + n_faces_per] = rgba
|
| 725 |
+
|
| 726 |
+
mesh = trimesh.Trimesh(vertices=all_verts, faces=all_faces)
|
| 727 |
+
mesh.visual.face_colors = all_face_colors
|
| 728 |
+
# Enable alpha blending in glTF material for true transparency
|
| 729 |
+
if alpha < 1.0:
|
| 730 |
+
mesh.visual.material.alphaMode = 'BLEND'
|
| 731 |
+
scene_3d.add_geometry(mesh)
|
| 732 |
+
print(f"Spheres mode: {len(vertices):,} spheres, {len(all_faces):,} faces")
|
| 733 |
+
else:
|
| 734 |
+
# Points mode (GLB viewers ignore alpha on points, so use blended RGB)
|
| 735 |
+
scene_3d.add_geometry(trimesh.PointCloud(vertices=vertices, colors=colors_u8))
|
| 736 |
+
|
| 737 |
+
# Add cameras and trajectory
|
| 738 |
+
if self.glb_show_cam_checkbox.value and self.cam_dict is not None:
|
| 739 |
+
from lingbot_map.vis.glb_export import integrate_camera_into_scene
|
| 740 |
+
import matplotlib
|
| 741 |
+
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
| 742 |
+
num_cameras = len(self.all_steps)
|
| 743 |
+
cam_positions = []
|
| 744 |
+
|
| 745 |
+
frustum_thickness = self.glb_frustum_thickness_slider.value
|
| 746 |
+
effective_cam_scale = scene_scale * self.glb_cam_scale_slider.value
|
| 747 |
+
|
| 748 |
+
for i, step in enumerate(self.all_steps):
|
| 749 |
+
R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3)
|
| 750 |
+
t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3)
|
| 751 |
+
|
| 752 |
+
c2w = np.eye(4)
|
| 753 |
+
c2w[:3, :3] = R
|
| 754 |
+
c2w[:3, 3] = t
|
| 755 |
+
cam_positions.append(np.array(t, dtype=np.float64))
|
| 756 |
+
|
| 757 |
+
rgba_c = colormap(i / max(num_cameras - 1, 1))
|
| 758 |
+
cam_color = tuple(int(255 * x) for x in rgba_c[:3])
|
| 759 |
+
integrate_camera_into_scene(
|
| 760 |
+
scene_3d, c2w, cam_color,
|
| 761 |
+
effective_cam_scale,
|
| 762 |
+
frustum_thickness=frustum_thickness,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
# Add trajectory line as a tube connecting camera positions
|
| 766 |
+
if self.glb_trajectory_checkbox.value and len(cam_positions) >= 2:
|
| 767 |
+
traj_pts = np.array(cam_positions)
|
| 768 |
+
traj_radius = self.glb_trajectory_radius_slider.value * self.glb_cam_scale_slider.value
|
| 769 |
+
traj_mesh = self._build_trajectory_tube(
|
| 770 |
+
traj_pts, traj_radius, colormap, num_cameras
|
| 771 |
+
)
|
| 772 |
+
if traj_mesh is not None:
|
| 773 |
+
scene_3d.add_geometry(traj_mesh)
|
| 774 |
+
|
| 775 |
+
# Align scene using first camera extrinsic
|
| 776 |
+
if self.cam_dict is not None and len(self.all_steps) > 0:
|
| 777 |
+
from lingbot_map.vis.glb_export import apply_scene_alignment
|
| 778 |
+
step0 = self.all_steps[0]
|
| 779 |
+
R0 = self.cam_dict["R"][step0] if "R" in self.cam_dict else np.eye(3)
|
| 780 |
+
t0 = self.cam_dict["t"][step0] if "t" in self.cam_dict else np.zeros(3)
|
| 781 |
+
c2w_0 = np.eye(4)
|
| 782 |
+
c2w_0[:3, :3] = R0
|
| 783 |
+
c2w_0[:3, 3] = t0
|
| 784 |
+
w2c_0 = np.linalg.inv(c2w_0)
|
| 785 |
+
extrinsics = np.expand_dims(w2c_0, 0)
|
| 786 |
+
scene_3d = apply_scene_alignment(scene_3d, extrinsics)
|
| 787 |
+
|
| 788 |
+
output_path = self.glb_output_path.value
|
| 789 |
+
scene_3d.export(output_path)
|
| 790 |
+
|
| 791 |
+
n_pts = len(vertices)
|
| 792 |
+
mode_str = f"spheres r={self.glb_sphere_radius_slider.value}" if export_mode == "Spheres" else "points"
|
| 793 |
+
self.glb_status.value = f"Saved: {output_path} ({n_pts:,} {mode_str})"
|
| 794 |
+
print(f"GLB exported to {output_path} ({n_pts:,} {mode_str})")
|
| 795 |
+
|
| 796 |
+
@staticmethod
|
| 797 |
+
def _build_trajectory_tube(positions, radius, colormap, num_cameras):
|
| 798 |
+
"""Build a tube mesh following camera trajectory with per-segment color.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
positions: (N, 3) camera positions.
|
| 802 |
+
radius: Tube radius.
|
| 803 |
+
colormap: Matplotlib colormap for gradient coloring.
|
| 804 |
+
num_cameras: Total number of cameras (for color normalization).
|
| 805 |
+
|
| 806 |
+
Returns:
|
| 807 |
+
trimesh.Trimesh or None.
|
| 808 |
+
"""
|
| 809 |
+
import trimesh
|
| 810 |
+
|
| 811 |
+
segments = []
|
| 812 |
+
for i in range(len(positions) - 1):
|
| 813 |
+
p0, p1 = positions[i], positions[i + 1]
|
| 814 |
+
seg_len = np.linalg.norm(p1 - p0)
|
| 815 |
+
if seg_len < 1e-8:
|
| 816 |
+
continue
|
| 817 |
+
|
| 818 |
+
# Create cylinder along Z, then transform
|
| 819 |
+
cyl = trimesh.creation.cylinder(radius=radius, height=seg_len, sections=8)
|
| 820 |
+
|
| 821 |
+
# Direction vector
|
| 822 |
+
direction = (p1 - p0) / seg_len
|
| 823 |
+
mid = (p0 + p1) / 2.0
|
| 824 |
+
|
| 825 |
+
# Build rotation: default cylinder is along Z
|
| 826 |
+
z_axis = np.array([0.0, 0.0, 1.0])
|
| 827 |
+
v = np.cross(z_axis, direction)
|
| 828 |
+
c = np.dot(z_axis, direction)
|
| 829 |
+
|
| 830 |
+
if np.linalg.norm(v) < 1e-8:
|
| 831 |
+
rot = np.eye(3) if c > 0 else np.diag([1, -1, -1])
|
| 832 |
+
else:
|
| 833 |
+
vx = np.array([[0, -v[2], v[1]],
|
| 834 |
+
[v[2], 0, -v[0]],
|
| 835 |
+
[-v[1], v[0], 0]])
|
| 836 |
+
rot = np.eye(3) + vx + vx @ vx / (1.0 + c)
|
| 837 |
+
|
| 838 |
+
transform = np.eye(4)
|
| 839 |
+
transform[:3, :3] = rot
|
| 840 |
+
transform[:3, 3] = mid
|
| 841 |
+
cyl.apply_transform(transform)
|
| 842 |
+
|
| 843 |
+
# Color: midpoint index
|
| 844 |
+
t_color = (i + 0.5) / max(num_cameras - 1, 1)
|
| 845 |
+
rgba = colormap(t_color)
|
| 846 |
+
color_rgb = tuple(int(255 * x) for x in rgba[:3])
|
| 847 |
+
cyl.visual.face_colors[:, :3] = color_rgb
|
| 848 |
+
segments.append(cyl)
|
| 849 |
+
|
| 850 |
+
if not segments:
|
| 851 |
+
return None
|
| 852 |
+
return trimesh.util.concatenate(segments)
|
| 853 |
+
|
| 854 |
+
def update_frame_visibility(self):
|
| 855 |
+
"""Show all frames up to the current timestep (or only the current one in 4D mode)."""
|
| 856 |
+
if not hasattr(self, 'frame_nodes') or not hasattr(self, 'gui_timestep'):
|
| 857 |
+
return
|
| 858 |
+
|
| 859 |
+
current_timestep = self.gui_timestep.value
|
| 860 |
+
for i, frame_node in enumerate(self.frame_nodes):
|
| 861 |
+
frame_node.visible = (
|
| 862 |
+
i <= current_timestep if not self.fourd else i == current_timestep
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
def _move_to_camera(self, frame_idx: int, smooth: bool = True):
|
| 866 |
+
"""Move viewer camera to match reconstructed camera at given frame."""
|
| 867 |
+
if self.cam_dict is None:
|
| 868 |
+
return
|
| 869 |
+
|
| 870 |
+
step = self.all_steps[frame_idx] if frame_idx < len(self.all_steps) else self.all_steps[-1]
|
| 871 |
+
|
| 872 |
+
R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3)
|
| 873 |
+
t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3)
|
| 874 |
+
focal = self.cam_dict["focal"][step] if "focal" in self.cam_dict else 1.0
|
| 875 |
+
pp = self.cam_dict["pp"][step] if "pp" in self.cam_dict else (1.0, 1.0)
|
| 876 |
+
|
| 877 |
+
offset = 0.5
|
| 878 |
+
viewing_dir = R[:, 2] # camera Z axis in world frame
|
| 879 |
+
position = t - viewing_dir * offset
|
| 880 |
+
look_at = t + viewing_dir * 0.5 # look slightly ahead of camera
|
| 881 |
+
|
| 882 |
+
fov = 2 * np.arctan(pp[0] / focal)
|
| 883 |
+
up = -R[:, 1] # camera -Y axis in world frame
|
| 884 |
+
|
| 885 |
+
for client in self.server.get_clients().values():
|
| 886 |
+
if smooth:
|
| 887 |
+
self._smooth_camera_transition(
|
| 888 |
+
client,
|
| 889 |
+
target_position=position,
|
| 890 |
+
target_look_at=look_at,
|
| 891 |
+
target_up=up,
|
| 892 |
+
target_fov=fov,
|
| 893 |
+
duration=0.3,
|
| 894 |
+
)
|
| 895 |
+
else:
|
| 896 |
+
client.camera.up_direction = tuple(up)
|
| 897 |
+
client.camera.position = tuple(position)
|
| 898 |
+
client.camera.look_at = tuple(look_at)
|
| 899 |
+
if fov is not None:
|
| 900 |
+
client.camera.fov = fov
|
| 901 |
+
|
| 902 |
+
def _smooth_camera_transition(
|
| 903 |
+
self,
|
| 904 |
+
client,
|
| 905 |
+
target_position,
|
| 906 |
+
target_look_at=None,
|
| 907 |
+
target_up=None,
|
| 908 |
+
target_fov=None,
|
| 909 |
+
duration=0.3,
|
| 910 |
+
):
|
| 911 |
+
"""Smoothly transition camera to target pose using look_at based control.
|
| 912 |
+
|
| 913 |
+
Args:
|
| 914 |
+
client: Viser client handle.
|
| 915 |
+
target_position: Target camera position (3,).
|
| 916 |
+
target_look_at: Target look-at point (3,). If None, keeps current.
|
| 917 |
+
target_up: Target up direction (3,). If None, keeps current.
|
| 918 |
+
target_fov: Target FOV. If None, keeps current.
|
| 919 |
+
duration: Transition duration in seconds.
|
| 920 |
+
"""
|
| 921 |
+
def interpolate():
|
| 922 |
+
num_steps = 15
|
| 923 |
+
dt = duration / num_steps
|
| 924 |
+
|
| 925 |
+
start_position = np.array(client.camera.position, dtype=np.float64)
|
| 926 |
+
start_look_at = np.array(client.camera.look_at, dtype=np.float64)
|
| 927 |
+
start_fov = client.camera.fov
|
| 928 |
+
|
| 929 |
+
end_position = np.asarray(target_position, dtype=np.float64)
|
| 930 |
+
end_look_at = np.asarray(target_look_at, dtype=np.float64) if target_look_at is not None else start_look_at
|
| 931 |
+
|
| 932 |
+
# Set up direction once at the start (not interpolated to avoid flicker)
|
| 933 |
+
if target_up is not None:
|
| 934 |
+
client.camera.up_direction = tuple(np.asarray(target_up, dtype=np.float64))
|
| 935 |
+
|
| 936 |
+
for i in range(num_steps + 1):
|
| 937 |
+
alpha = i / num_steps
|
| 938 |
+
# Smooth ease-in-out
|
| 939 |
+
alpha_smooth = alpha * alpha * (3 - 2 * alpha)
|
| 940 |
+
|
| 941 |
+
interp_pos = start_position + (end_position - start_position) * alpha_smooth
|
| 942 |
+
interp_look = start_look_at + (end_look_at - start_look_at) * alpha_smooth
|
| 943 |
+
|
| 944 |
+
# Set position first (this auto-moves look_at), then override look_at
|
| 945 |
+
client.camera.position = tuple(interp_pos)
|
| 946 |
+
client.camera.look_at = tuple(interp_look)
|
| 947 |
+
|
| 948 |
+
if target_fov is not None:
|
| 949 |
+
interp_fov = start_fov + (target_fov - start_fov) * alpha_smooth
|
| 950 |
+
client.camera.fov = interp_fov
|
| 951 |
+
|
| 952 |
+
time.sleep(dt)
|
| 953 |
+
|
| 954 |
+
thread = threading.Thread(target=interpolate, daemon=True)
|
| 955 |
+
thread.start()
|
| 956 |
+
|
| 957 |
+
def _slerp(self, q1, q2, t):
|
| 958 |
+
"""Spherical linear interpolation between quaternions."""
|
| 959 |
+
dot = np.dot(q1, q2)
|
| 960 |
+
|
| 961 |
+
if abs(dot) > 0.9995:
|
| 962 |
+
result = q1 + t * (q2 - q1)
|
| 963 |
+
return result / np.linalg.norm(result)
|
| 964 |
+
|
| 965 |
+
dot = np.clip(dot, -1.0, 1.0)
|
| 966 |
+
theta_0 = np.arccos(dot)
|
| 967 |
+
theta = theta_0 * t
|
| 968 |
+
|
| 969 |
+
q2_orthogonal = q2 - q1 * dot
|
| 970 |
+
q2_orthogonal = q2_orthogonal / np.linalg.norm(q2_orthogonal)
|
| 971 |
+
|
| 972 |
+
return q1 * np.cos(theta) + q2_orthogonal * np.sin(theta)
|
| 973 |
+
|
| 974 |
+
def get_camera_state(self, client: viser.ClientHandle) -> CameraState:
|
| 975 |
+
"""Get current camera state from client."""
|
| 976 |
+
camera = client.camera
|
| 977 |
+
c2w = np.concatenate([
|
| 978 |
+
np.concatenate([tf.SO3(camera.wxyz).as_matrix(), camera.position[:, None]], 1),
|
| 979 |
+
[[0, 0, 0, 1]],
|
| 980 |
+
], 0)
|
| 981 |
+
return CameraState(fov=camera.fov, aspect=camera.aspect, c2w=c2w)
|
| 982 |
+
|
| 983 |
+
@staticmethod
|
| 984 |
+
def generate_pseudo_intrinsics(h: int, w: int) -> np.ndarray:
|
| 985 |
+
"""Generate pseudo intrinsics from image size."""
|
| 986 |
+
focal = (h**2 + w**2) ** 0.5
|
| 987 |
+
return np.array([[focal, 0, w // 2], [0, focal, h // 2], [0, 0, 1]]).astype(np.float32)
|
| 988 |
+
|
| 989 |
+
def _connect_client(self, client: viser.ClientHandle):
|
| 990 |
+
"""Setup client connection callbacks."""
|
| 991 |
+
wxyz_panel = client.gui.add_text("wxyz:", f"{client.camera.wxyz}")
|
| 992 |
+
position_panel = client.gui.add_text("position:", f"{client.camera.position}")
|
| 993 |
+
fov_panel = client.gui.add_text(
|
| 994 |
+
"fov:", f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}"
|
| 995 |
+
)
|
| 996 |
+
aspect_panel = client.gui.add_text("aspect:", "1.0")
|
| 997 |
+
|
| 998 |
+
@client.camera.on_update
|
| 999 |
+
def _(_: viser.CameraHandle):
|
| 1000 |
+
with self.server.atomic():
|
| 1001 |
+
wxyz_panel.value = f"{client.camera.wxyz}"
|
| 1002 |
+
position_panel.value = f"{client.camera.position}"
|
| 1003 |
+
fov_panel.value = f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}"
|
| 1004 |
+
aspect_panel.value = "1.0"
|
| 1005 |
+
|
| 1006 |
+
@staticmethod
|
| 1007 |
+
def set_color_border(image, border_width=5, color=[1, 0, 0]):
|
| 1008 |
+
"""Add colored border to image."""
|
| 1009 |
+
image[:border_width, :, 0] = color[0]
|
| 1010 |
+
image[:border_width, :, 1] = color[1]
|
| 1011 |
+
image[:border_width, :, 2] = color[2]
|
| 1012 |
+
image[-border_width:, :, 0] = color[0]
|
| 1013 |
+
image[-border_width:, :, 1] = color[1]
|
| 1014 |
+
image[-border_width:, :, 2] = color[2]
|
| 1015 |
+
image[:, :border_width, 0] = color[0]
|
| 1016 |
+
image[:, :border_width, 1] = color[1]
|
| 1017 |
+
image[:, :border_width, 2] = color[2]
|
| 1018 |
+
image[:, -border_width:, 0] = color[0]
|
| 1019 |
+
image[:, -border_width:, 1] = color[1]
|
| 1020 |
+
image[:, -border_width:, 2] = color[2]
|
| 1021 |
+
return image
|
| 1022 |
+
|
| 1023 |
+
def read_data(self, pc_list, color_list, conf_list, edge_color_list=None):
|
| 1024 |
+
"""Read and organize point cloud data."""
|
| 1025 |
+
pcs = {}
|
| 1026 |
+
step_list = []
|
| 1027 |
+
for i, pc in enumerate(pc_list):
|
| 1028 |
+
step = i
|
| 1029 |
+
pcs.update({
|
| 1030 |
+
step: {
|
| 1031 |
+
"pc": pc,
|
| 1032 |
+
"color": color_list[i],
|
| 1033 |
+
"conf": conf_list[i],
|
| 1034 |
+
"edge_color": (
|
| 1035 |
+
None if edge_color_list is None or edge_color_list[i] is None
|
| 1036 |
+
else edge_color_list[i]
|
| 1037 |
+
),
|
| 1038 |
+
}
|
| 1039 |
+
})
|
| 1040 |
+
step_list.append(step)
|
| 1041 |
+
|
| 1042 |
+
# Generate camera gradient colors
|
| 1043 |
+
num_cameras = len(pc_list)
|
| 1044 |
+
if num_cameras > 1:
|
| 1045 |
+
normalized_indices = np.array(list(range(num_cameras))) / (num_cameras - 1)
|
| 1046 |
+
else:
|
| 1047 |
+
normalized_indices = np.array([0.0])
|
| 1048 |
+
cmap = cm.get_cmap('viridis')
|
| 1049 |
+
self.camera_colors = cmap(normalized_indices)
|
| 1050 |
+
return pcs, step_list
|
| 1051 |
+
|
| 1052 |
+
def parse_pc_data(
|
| 1053 |
+
self,
|
| 1054 |
+
pc,
|
| 1055 |
+
color,
|
| 1056 |
+
conf=None,
|
| 1057 |
+
edge_color=[0.251, 0.702, 0.902],
|
| 1058 |
+
set_border_color=False,
|
| 1059 |
+
downsample_factor=1,
|
| 1060 |
+
):
|
| 1061 |
+
"""Parse and filter point cloud data."""
|
| 1062 |
+
pred_pts = pc.reshape(-1, 3)
|
| 1063 |
+
|
| 1064 |
+
if set_border_color and edge_color is not None:
|
| 1065 |
+
color = self.set_color_border(color[0], color=edge_color)
|
| 1066 |
+
if np.isnan(color).any():
|
| 1067 |
+
color = np.zeros((pred_pts.shape[0], 3))
|
| 1068 |
+
color[:, 2] = 1
|
| 1069 |
+
else:
|
| 1070 |
+
color = color.reshape(-1, 3)
|
| 1071 |
+
|
| 1072 |
+
# Remove NaN / Inf points
|
| 1073 |
+
valid = np.isfinite(pred_pts).all(axis=1)
|
| 1074 |
+
if not valid.all():
|
| 1075 |
+
pred_pts = pred_pts[valid]
|
| 1076 |
+
color = color[valid]
|
| 1077 |
+
if conf is not None:
|
| 1078 |
+
conf = conf.reshape(-1)[valid]
|
| 1079 |
+
|
| 1080 |
+
# Confidence threshold filter
|
| 1081 |
+
if conf is not None:
|
| 1082 |
+
conf_flat = conf.reshape(-1) if conf.ndim > 1 else conf
|
| 1083 |
+
mask = conf_flat > self.vis_threshold
|
| 1084 |
+
pred_pts = pred_pts[mask]
|
| 1085 |
+
color = color[mask]
|
| 1086 |
+
|
| 1087 |
+
if len(pred_pts) == 0:
|
| 1088 |
+
return pred_pts, color
|
| 1089 |
+
|
| 1090 |
+
# Downsample
|
| 1091 |
+
if downsample_factor > 1 and len(pred_pts) > 0:
|
| 1092 |
+
indices = np.arange(0, len(pred_pts), downsample_factor)
|
| 1093 |
+
pred_pts = pred_pts[indices]
|
| 1094 |
+
color = color[indices]
|
| 1095 |
+
|
| 1096 |
+
return pred_pts, color
|
| 1097 |
+
|
| 1098 |
+
def add_pc(self, step):
|
| 1099 |
+
"""Add point cloud for a frame."""
|
| 1100 |
+
pc = self.pcs[step]["pc"]
|
| 1101 |
+
color = self.pcs[step]["color"]
|
| 1102 |
+
conf = self.pcs[step]["conf"]
|
| 1103 |
+
edge_color = self.pcs[step].get("edge_color", None)
|
| 1104 |
+
|
| 1105 |
+
pred_pts, color = self.parse_pc_data(
|
| 1106 |
+
pc, color, conf, edge_color, set_border_color=True,
|
| 1107 |
+
downsample_factor=self.downsample_slider.value
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
self.vis_pts_list.append(pred_pts)
|
| 1111 |
+
self.pc_handles.append(
|
| 1112 |
+
self.server.scene.add_point_cloud(
|
| 1113 |
+
name=f"/frames/{step}/pred_pts",
|
| 1114 |
+
points=pred_pts,
|
| 1115 |
+
colors=color,
|
| 1116 |
+
point_size=self.psize_slider.value,
|
| 1117 |
+
)
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
def add_camera(self, step):
|
| 1121 |
+
"""Add camera visualization for a frame."""
|
| 1122 |
+
cam = self.cam_dict
|
| 1123 |
+
focal = cam["focal"][step] if cam and "focal" in cam else 1.0
|
| 1124 |
+
pp = cam["pp"][step] if cam and "pp" in cam else (1.0, 1.0)
|
| 1125 |
+
R = cam["R"][step] if cam and "R" in cam else np.eye(3)
|
| 1126 |
+
t = cam["t"][step] if cam and "t" in cam else np.zeros(3)
|
| 1127 |
+
|
| 1128 |
+
q = tf.SO3.from_matrix(R).wxyz
|
| 1129 |
+
fov = 2 * np.arctan(pp[0] / focal)
|
| 1130 |
+
aspect = pp[0] / pp[1]
|
| 1131 |
+
self.traj_list.append((q, t))
|
| 1132 |
+
|
| 1133 |
+
step_index = self.all_steps.index(step) if step in self.all_steps else 0
|
| 1134 |
+
camera_color = self.camera_colors[step_index]
|
| 1135 |
+
camera_color_rgb = tuple((camera_color[:3] * 255).astype(int))
|
| 1136 |
+
|
| 1137 |
+
self.server.scene.add_frame(
|
| 1138 |
+
f"/frames/{step}/camera_frame",
|
| 1139 |
+
wxyz=q,
|
| 1140 |
+
position=t,
|
| 1141 |
+
axes_length=0.05,
|
| 1142 |
+
axes_radius=0.002,
|
| 1143 |
+
origin_radius=0.002,
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
frustum_handle = self.server.scene.add_camera_frustum(
|
| 1147 |
+
name=f"/frames/{step}/camera",
|
| 1148 |
+
fov=fov,
|
| 1149 |
+
aspect=aspect,
|
| 1150 |
+
wxyz=q,
|
| 1151 |
+
position=t,
|
| 1152 |
+
scale=0.03,
|
| 1153 |
+
color=camera_color_rgb,
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
@frustum_handle.on_click
|
| 1157 |
+
def _(event) -> None:
|
| 1158 |
+
look_at_pt = t + R[:, 2] * 0.5 # look ahead along camera Z
|
| 1159 |
+
up_dir = -R[:, 1]
|
| 1160 |
+
for client in self.server.get_clients().values():
|
| 1161 |
+
client.camera.up_direction = tuple(up_dir)
|
| 1162 |
+
client.camera.position = tuple(t)
|
| 1163 |
+
client.camera.look_at = tuple(look_at_pt)
|
| 1164 |
+
|
| 1165 |
+
self.cam_handles.append(frustum_handle)
|
| 1166 |
+
|
| 1167 |
+
def animate(self):
|
| 1168 |
+
"""Setup and run animation controls."""
|
| 1169 |
+
with self.server.gui.add_folder("Playback"):
|
| 1170 |
+
self.gui_timestep = self.server.gui.add_slider(
|
| 1171 |
+
"Train Step", min=0, max=self.num_frames - 1, step=1, initial_value=0, disabled=False
|
| 1172 |
+
)
|
| 1173 |
+
gui_next_frame = self.server.gui.add_button("Next Step", disabled=False)
|
| 1174 |
+
gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False)
|
| 1175 |
+
gui_playing = self.server.gui.add_checkbox("Playing", True)
|
| 1176 |
+
gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=20)
|
| 1177 |
+
gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60"))
|
| 1178 |
+
|
| 1179 |
+
@gui_next_frame.on_click
|
| 1180 |
+
def _(_) -> None:
|
| 1181 |
+
self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames
|
| 1182 |
+
|
| 1183 |
+
@gui_prev_frame.on_click
|
| 1184 |
+
def _(_) -> None:
|
| 1185 |
+
self.gui_timestep.value = (self.gui_timestep.value - 1) % self.num_frames
|
| 1186 |
+
|
| 1187 |
+
@gui_playing.on_update
|
| 1188 |
+
def _(_) -> None:
|
| 1189 |
+
self.gui_timestep.disabled = gui_playing.value
|
| 1190 |
+
gui_next_frame.disabled = gui_playing.value
|
| 1191 |
+
gui_prev_frame.disabled = gui_playing.value
|
| 1192 |
+
|
| 1193 |
+
@gui_framerate_options.on_click
|
| 1194 |
+
def _(_) -> None:
|
| 1195 |
+
gui_framerate.value = int(gui_framerate_options.value)
|
| 1196 |
+
|
| 1197 |
+
prev_timestep = self.gui_timestep.value
|
| 1198 |
+
|
| 1199 |
+
@self.gui_timestep.on_update
|
| 1200 |
+
def _(_) -> None:
|
| 1201 |
+
nonlocal prev_timestep
|
| 1202 |
+
current_timestep = self.gui_timestep.value
|
| 1203 |
+
|
| 1204 |
+
if self.current_frame_image is not None and hasattr(self, 'original_images'):
|
| 1205 |
+
if current_timestep < len(self.original_images):
|
| 1206 |
+
self.current_frame_image.image = self.original_images[current_timestep]
|
| 1207 |
+
|
| 1208 |
+
with self.server.atomic():
|
| 1209 |
+
self.frame_nodes[current_timestep].visible = True
|
| 1210 |
+
self.frame_nodes[prev_timestep].visible = False
|
| 1211 |
+
self.server.flush()
|
| 1212 |
+
|
| 1213 |
+
prev_timestep = current_timestep
|
| 1214 |
+
|
| 1215 |
+
self.server.scene.add_frame("/frames", show_axes=False)
|
| 1216 |
+
self.frame_nodes = []
|
| 1217 |
+
for i in range(self.num_frames):
|
| 1218 |
+
step = self.all_steps[i]
|
| 1219 |
+
self.frame_nodes.append(
|
| 1220 |
+
self.server.scene.add_frame(f"/frames/{step}", show_axes=False)
|
| 1221 |
+
)
|
| 1222 |
+
self.add_pc(step)
|
| 1223 |
+
if self.show_camera:
|
| 1224 |
+
downsample_factor = int(self.camera_downsample_slider.value)
|
| 1225 |
+
if i % downsample_factor == 0:
|
| 1226 |
+
self.add_camera(step)
|
| 1227 |
+
|
| 1228 |
+
prev_timestep = self.gui_timestep.value
|
| 1229 |
+
while True:
|
| 1230 |
+
if self.on_replay:
|
| 1231 |
+
pass
|
| 1232 |
+
else:
|
| 1233 |
+
if gui_playing.value:
|
| 1234 |
+
self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames
|
| 1235 |
+
self.update_frame_visibility()
|
| 1236 |
+
|
| 1237 |
+
time.sleep(1.0 / gui_framerate.value)
|
| 1238 |
+
|
| 1239 |
+
def _take_screenshot(self, client: Optional[Any] = None):
|
| 1240 |
+
"""Capture a screenshot from the current view and save to file.
|
| 1241 |
+
|
| 1242 |
+
Args:
|
| 1243 |
+
client: The viser client that triggered the action. If None,
|
| 1244 |
+
uses the first connected client.
|
| 1245 |
+
"""
|
| 1246 |
+
output_path = self.screenshot_path.value
|
| 1247 |
+
res_str = self.screenshot_resolution.value
|
| 1248 |
+
|
| 1249 |
+
# Resolve client
|
| 1250 |
+
if client is None:
|
| 1251 |
+
clients = list(self.server.get_clients().values())
|
| 1252 |
+
if not clients:
|
| 1253 |
+
self.screenshot_status.value = "Error: no client connected"
|
| 1254 |
+
return
|
| 1255 |
+
client = clients[0]
|
| 1256 |
+
|
| 1257 |
+
try:
|
| 1258 |
+
self.screenshot_status.value = "Capturing..."
|
| 1259 |
+
|
| 1260 |
+
if res_str == "Current":
|
| 1261 |
+
# Use default render size
|
| 1262 |
+
width, height = 1920, 1080
|
| 1263 |
+
else:
|
| 1264 |
+
width, height = map(int, res_str.split("x"))
|
| 1265 |
+
|
| 1266 |
+
render = client.camera.get_render(height=height, width=width)
|
| 1267 |
+
|
| 1268 |
+
if render is not None:
|
| 1269 |
+
frame = np.array(render)
|
| 1270 |
+
if frame.shape[2] == 4:
|
| 1271 |
+
frame = frame[:, :, :3]
|
| 1272 |
+
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 1273 |
+
cv2.imwrite(output_path, frame_bgr)
|
| 1274 |
+
self.screenshot_status.value = f"Saved: {output_path}"
|
| 1275 |
+
print(f"Screenshot saved to {output_path} ({width}x{height})")
|
| 1276 |
+
else:
|
| 1277 |
+
self.screenshot_status.value = "Error: render returned None"
|
| 1278 |
+
print("Screenshot failed: render returned None")
|
| 1279 |
+
|
| 1280 |
+
except Exception as e:
|
| 1281 |
+
self.screenshot_status.value = f"Error: {e}"
|
| 1282 |
+
print(f"Screenshot error: {e}")
|
| 1283 |
+
|
| 1284 |
+
def save_video(
|
| 1285 |
+
self,
|
| 1286 |
+
output_path: str = "output_pointcloud.mp4",
|
| 1287 |
+
fps: int = 30,
|
| 1288 |
+
resolution: str = "1920x1080",
|
| 1289 |
+
save_original_video: bool = True
|
| 1290 |
+
):
|
| 1291 |
+
"""Save point cloud animation as video."""
|
| 1292 |
+
try:
|
| 1293 |
+
if hasattr(self, 'video_status'):
|
| 1294 |
+
self.video_status.value = "Saving video..."
|
| 1295 |
+
print(f"Saving video to {output_path}...")
|
| 1296 |
+
|
| 1297 |
+
width, height = map(int, resolution.split('x'))
|
| 1298 |
+
temp_dir = tempfile.mkdtemp(prefix="viser_video_")
|
| 1299 |
+
print(f"Temporary directory: {temp_dir}")
|
| 1300 |
+
|
| 1301 |
+
print("Waiting for client connection...")
|
| 1302 |
+
timeout = 10
|
| 1303 |
+
start_time = time.time()
|
| 1304 |
+
while len(self.server.get_clients()) == 0:
|
| 1305 |
+
time.sleep(0.1)
|
| 1306 |
+
if time.time() - start_time > timeout:
|
| 1307 |
+
raise RuntimeError("No client connected. Please open the visualization in a browser first.")
|
| 1308 |
+
|
| 1309 |
+
print("Client connected. Starting to render frames...")
|
| 1310 |
+
clients = list(self.server.get_clients().values())
|
| 1311 |
+
client = clients[0]
|
| 1312 |
+
|
| 1313 |
+
if not hasattr(self, 'gui_timestep'):
|
| 1314 |
+
raise RuntimeError("Animation not initialized. Please ensure animate() is called before save_video().")
|
| 1315 |
+
|
| 1316 |
+
for i in tqdm(range(self.num_frames), desc="Rendering frames"):
|
| 1317 |
+
self.gui_timestep.value = i
|
| 1318 |
+
time.sleep(0.1)
|
| 1319 |
+
|
| 1320 |
+
try:
|
| 1321 |
+
screenshot = client.camera.get_render(height=height, width=width)
|
| 1322 |
+
if screenshot is not None:
|
| 1323 |
+
frame = np.array(screenshot)
|
| 1324 |
+
if frame.shape[2] == 4:
|
| 1325 |
+
frame = frame[:, :, :3]
|
| 1326 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 1327 |
+
frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
|
| 1328 |
+
cv2.imwrite(frame_path, frame)
|
| 1329 |
+
else:
|
| 1330 |
+
frame = self._render_frame_fallback(i, width, height)
|
| 1331 |
+
frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
|
| 1332 |
+
cv2.imwrite(frame_path, frame)
|
| 1333 |
+
except Exception as e:
|
| 1334 |
+
print(f"Warning: Error capturing frame {i}: {e}, using fallback")
|
| 1335 |
+
frame = self._render_frame_fallback(i, width, height)
|
| 1336 |
+
frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
|
| 1337 |
+
cv2.imwrite(frame_path, frame)
|
| 1338 |
+
|
| 1339 |
+
print("Encoding video with ffmpeg...")
|
| 1340 |
+
ffmpeg_cmd = [
|
| 1341 |
+
'ffmpeg', '-y', '-framerate', str(fps),
|
| 1342 |
+
'-i', os.path.join(temp_dir, 'frame_%06d.png'),
|
| 1343 |
+
'-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18',
|
| 1344 |
+
output_path
|
| 1345 |
+
]
|
| 1346 |
+
|
| 1347 |
+
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
|
| 1348 |
+
|
| 1349 |
+
if result.returncode == 0:
|
| 1350 |
+
print(f"Point cloud video saved successfully to {output_path}")
|
| 1351 |
+
if hasattr(self, 'video_status'):
|
| 1352 |
+
self.video_status.value = f"Saved to {output_path}"
|
| 1353 |
+
else:
|
| 1354 |
+
print(f"FFmpeg error: {result.stderr}")
|
| 1355 |
+
if hasattr(self, 'video_status'):
|
| 1356 |
+
self.video_status.value = "Error: FFmpeg failed"
|
| 1357 |
+
|
| 1358 |
+
if save_original_video and hasattr(self, 'original_images') and len(self.original_images) > 0:
|
| 1359 |
+
self._save_original_video(output_path, fps, width, height)
|
| 1360 |
+
|
| 1361 |
+
shutil.rmtree(temp_dir)
|
| 1362 |
+
print("Temporary files cleaned up")
|
| 1363 |
+
|
| 1364 |
+
except Exception as e:
|
| 1365 |
+
print(f"Error saving video: {e}")
|
| 1366 |
+
import traceback
|
| 1367 |
+
traceback.print_exc()
|
| 1368 |
+
if hasattr(self, 'video_status'):
|
| 1369 |
+
self.video_status.value = f"Error: {str(e)}"
|
| 1370 |
+
|
| 1371 |
+
def _save_original_video(self, pointcloud_video_path: str, fps: int, width: int, height: int):
|
| 1372 |
+
"""Save original images as video."""
|
| 1373 |
+
base_path = os.path.splitext(pointcloud_video_path)[0]
|
| 1374 |
+
original_video_path = f"{base_path}_original.mp4"
|
| 1375 |
+
|
| 1376 |
+
print(f"Saving original images video to {original_video_path}...")
|
| 1377 |
+
|
| 1378 |
+
try:
|
| 1379 |
+
temp_dir = tempfile.mkdtemp(prefix="original_video_")
|
| 1380 |
+
|
| 1381 |
+
for i, img in enumerate(tqdm(self.original_images, desc="Saving original frames")):
|
| 1382 |
+
frame = cv2.resize(img, (width, height))
|
| 1383 |
+
if len(frame.shape) == 3 and frame.shape[2] == 3:
|
| 1384 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 1385 |
+
frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
|
| 1386 |
+
cv2.imwrite(frame_path, frame)
|
| 1387 |
+
|
| 1388 |
+
print("Encoding original video with ffmpeg...")
|
| 1389 |
+
ffmpeg_cmd = [
|
| 1390 |
+
'ffmpeg', '-y', '-framerate', str(fps),
|
| 1391 |
+
'-i', os.path.join(temp_dir, 'frame_%06d.png'),
|
| 1392 |
+
'-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18',
|
| 1393 |
+
original_video_path
|
| 1394 |
+
]
|
| 1395 |
+
|
| 1396 |
+
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
|
| 1397 |
+
|
| 1398 |
+
if result.returncode == 0:
|
| 1399 |
+
print(f"Original video saved successfully to {original_video_path}")
|
| 1400 |
+
else:
|
| 1401 |
+
print(f"FFmpeg error for original video: {result.stderr}")
|
| 1402 |
+
|
| 1403 |
+
shutil.rmtree(temp_dir)
|
| 1404 |
+
|
| 1405 |
+
except Exception as e:
|
| 1406 |
+
print(f"Error saving original video: {e}")
|
| 1407 |
+
import traceback
|
| 1408 |
+
traceback.print_exc()
|
| 1409 |
+
|
| 1410 |
+
def _render_frame_fallback(self, frame_idx: int, width: int, height: int) -> np.ndarray:
|
| 1411 |
+
"""Fallback rendering when screenshot capture fails."""
|
| 1412 |
+
if hasattr(self, 'original_images') and frame_idx < len(self.original_images):
|
| 1413 |
+
frame = self.original_images[frame_idx].copy()
|
| 1414 |
+
frame = cv2.resize(frame, (width, height))
|
| 1415 |
+
cv2.putText(frame, f"Frame {frame_idx}", (10, 30),
|
| 1416 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
| 1417 |
+
return frame
|
| 1418 |
+
else:
|
| 1419 |
+
frame = np.zeros((height, width, 3), dtype=np.uint8)
|
| 1420 |
+
cv2.putText(frame, f"Frame {frame_idx} - No render available",
|
| 1421 |
+
(width//4, height//2),
|
| 1422 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 1423 |
+
return frame
|
| 1424 |
+
|
| 1425 |
+
def run(self, background_mode: bool = False):
|
| 1426 |
+
"""Run the viewer."""
|
| 1427 |
+
self.animate()
|
| 1428 |
+
if background_mode:
|
| 1429 |
+
def server_loop():
|
| 1430 |
+
while True:
|
| 1431 |
+
time.sleep(0.001)
|
| 1432 |
+
|
| 1433 |
+
thread = threading.Thread(target=server_loop, daemon=True)
|
| 1434 |
+
thread.start()
|
| 1435 |
+
else:
|
| 1436 |
+
while True:
|
| 1437 |
+
time.sleep(10.0)
|
lingbot_map/vis/sky_segmentation.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Sky segmentation utilities for filtering sky points from point clouds.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import glob
|
| 12 |
+
import os
|
| 13 |
+
from typing import Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import cv2
|
| 17 |
+
from tqdm.auto import tqdm
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import onnxruntime
|
| 21 |
+
except ImportError:
|
| 22 |
+
onnxruntime = None
|
| 23 |
+
print("onnxruntime not found. Sky segmentation may not work.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_SKYSEG_INPUT_SIZE = (320, 320)
|
| 27 |
+
_SKYSEG_SOFT_THRESHOLD = 0.1
|
| 28 |
+
_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _get_cache_version_path(sky_mask_dir: str) -> str:
|
| 32 |
+
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> None:
|
| 36 |
+
"""Ensure the sky mask cache directory exists and write the version stamp."""
|
| 37 |
+
if sky_mask_dir is None:
|
| 38 |
+
return
|
| 39 |
+
os.makedirs(sky_mask_dir, exist_ok=True)
|
| 40 |
+
version_path = _get_cache_version_path(sky_mask_dir)
|
| 41 |
+
if not os.path.exists(version_path):
|
| 42 |
+
with open(version_path, "w", encoding="utf-8") as f:
|
| 43 |
+
f.write(_SKYSEG_CACHE_VERSION)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def run_skyseg(
|
| 47 |
+
onnx_session,
|
| 48 |
+
input_size: Tuple[int, int],
|
| 49 |
+
image: np.ndarray,
|
| 50 |
+
) -> np.ndarray:
|
| 51 |
+
"""
|
| 52 |
+
Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
|
| 53 |
+
"""
|
| 54 |
+
resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
|
| 55 |
+
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 56 |
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 57 |
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 58 |
+
x = (x / 255.0 - mean) / std
|
| 59 |
+
x = x.transpose(2, 0, 1)
|
| 60 |
+
x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
|
| 61 |
+
|
| 62 |
+
input_name = onnx_session.get_inputs()[0].name
|
| 63 |
+
output_name = onnx_session.get_outputs()[0].name
|
| 64 |
+
onnx_result = onnx_session.run([output_name], {input_name: x})
|
| 65 |
+
|
| 66 |
+
onnx_result = np.array(onnx_result).squeeze()
|
| 67 |
+
min_value = np.min(onnx_result)
|
| 68 |
+
max_value = np.max(onnx_result)
|
| 69 |
+
denom = max(max_value - min_value, 1e-8)
|
| 70 |
+
onnx_result = (onnx_result - min_value) / denom
|
| 71 |
+
onnx_result *= 255.0
|
| 72 |
+
return onnx_result.astype(np.uint8)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _mask_to_float(mask: np.ndarray) -> np.ndarray:
|
| 76 |
+
mask = mask.astype(np.float32)
|
| 77 |
+
if mask.size == 0:
|
| 78 |
+
return mask
|
| 79 |
+
return np.clip(mask, 0.0, 1.0)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
|
| 83 |
+
mask = np.asarray(mask)
|
| 84 |
+
if mask.dtype == np.uint8:
|
| 85 |
+
return mask
|
| 86 |
+
mask = mask.astype(np.float32)
|
| 87 |
+
if mask.size > 0 and mask.max() <= 1.0:
|
| 88 |
+
mask = mask * 255.0
|
| 89 |
+
return np.clip(mask, 0.0, 255.0).astype(np.uint8)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
|
| 93 |
+
# The raw skyseg map is higher on sky and lower on non-sky.
|
| 94 |
+
return 1.0 - _mask_to_float(result_map)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def segment_sky_from_array(
|
| 98 |
+
image: np.ndarray,
|
| 99 |
+
skyseg_session,
|
| 100 |
+
target_h: int,
|
| 101 |
+
target_w: int
|
| 102 |
+
) -> np.ndarray:
|
| 103 |
+
"""
|
| 104 |
+
Segment sky from an image array using ONNX model.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
|
| 108 |
+
skyseg_session: ONNX runtime inference session
|
| 109 |
+
target_h: Target output height
|
| 110 |
+
target_w: Target output width
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Continuous non-sky confidence map in [0, 1].
|
| 114 |
+
"""
|
| 115 |
+
image_rgb = _image_to_rgb_uint8(image)
|
| 116 |
+
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
| 117 |
+
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
|
| 118 |
+
result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
| 119 |
+
return _result_map_to_non_sky_conf(result_map)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def segment_sky(
|
| 123 |
+
image_path: str,
|
| 124 |
+
skyseg_session,
|
| 125 |
+
output_path: Optional[str] = None
|
| 126 |
+
) -> np.ndarray:
|
| 127 |
+
"""
|
| 128 |
+
Segment sky from an image using ONNX model.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
image_path: Path to the input image
|
| 132 |
+
skyseg_session: ONNX runtime inference session
|
| 133 |
+
output_path: Optional path to save the mask
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Continuous non-sky confidence map in [0, 1].
|
| 137 |
+
"""
|
| 138 |
+
image = cv2.imread(image_path)
|
| 139 |
+
if image is None:
|
| 140 |
+
raise ValueError(f"Failed to read image: {image_path}")
|
| 141 |
+
|
| 142 |
+
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
|
| 143 |
+
result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 144 |
+
mask = _result_map_to_non_sky_conf(result_map)
|
| 145 |
+
|
| 146 |
+
if output_path is not None:
|
| 147 |
+
output_dir = os.path.dirname(output_path)
|
| 148 |
+
if output_dir:
|
| 149 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 150 |
+
cv2.imwrite(output_path, _mask_to_uint8(mask))
|
| 151 |
+
|
| 152 |
+
return mask
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _list_image_files(image_folder: str) -> list[str]:
|
| 156 |
+
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
|
| 157 |
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
|
| 158 |
+
return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
|
| 162 |
+
if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
|
| 163 |
+
image = image.transpose(1, 2, 0)
|
| 164 |
+
|
| 165 |
+
if image.ndim != 3 or image.shape[2] != 3:
|
| 166 |
+
raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
|
| 167 |
+
|
| 168 |
+
if image.dtype != np.uint8:
|
| 169 |
+
image = image.astype(np.float32)
|
| 170 |
+
if image.max() <= 1.0:
|
| 171 |
+
image = image * 255.0
|
| 172 |
+
image = np.clip(image, 0.0, 255.0).astype(np.uint8)
|
| 173 |
+
|
| 174 |
+
return image
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
|
| 178 |
+
if image_paths is not None and index < len(image_paths):
|
| 179 |
+
return os.path.basename(image_paths[index])
|
| 180 |
+
return f"frame_{index:06d}.png"
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _save_sky_mask_visualization(
|
| 184 |
+
image: np.ndarray,
|
| 185 |
+
sky_mask: np.ndarray,
|
| 186 |
+
output_path: str,
|
| 187 |
+
) -> None:
|
| 188 |
+
image_rgb = _image_to_rgb_uint8(image)
|
| 189 |
+
if sky_mask.shape[:2] != image_rgb.shape[:2]:
|
| 190 |
+
sky_mask = cv2.resize(
|
| 191 |
+
sky_mask,
|
| 192 |
+
(image_rgb.shape[1], image_rgb.shape[0]),
|
| 193 |
+
interpolation=cv2.INTER_NEAREST,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
mask_uint8 = _mask_to_uint8(sky_mask)
|
| 197 |
+
mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
|
| 198 |
+
overlay = image_rgb.astype(np.float32).copy()
|
| 199 |
+
sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
|
| 200 |
+
overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
|
| 201 |
+
overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
|
| 202 |
+
|
| 203 |
+
panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
|
| 204 |
+
output_dir = os.path.dirname(output_path)
|
| 205 |
+
if output_dir:
|
| 206 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 207 |
+
cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_or_create_sky_masks(
|
| 211 |
+
image_folder: Optional[str] = None,
|
| 212 |
+
image_paths: Optional[list[str]] = None,
|
| 213 |
+
images: Optional[np.ndarray] = None,
|
| 214 |
+
skyseg_model_path: str = "skyseg.onnx",
|
| 215 |
+
sky_mask_dir: Optional[str] = None,
|
| 216 |
+
sky_mask_visualization_dir: Optional[str] = None,
|
| 217 |
+
target_shape: Optional[Tuple[int, int]] = None,
|
| 218 |
+
num_frames: Optional[int] = None,
|
| 219 |
+
) -> Optional[np.ndarray]:
|
| 220 |
+
"""
|
| 221 |
+
Load cached sky masks or generate them with the ONNX model.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
image_folder: Folder containing input images.
|
| 225 |
+
image_paths: Optional explicit image file list, in the exact order to process.
|
| 226 |
+
images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
|
| 227 |
+
skyseg_model_path: Path to the sky segmentation ONNX model.
|
| 228 |
+
sky_mask_dir: Optional directory for cached raw masks.
|
| 229 |
+
sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
|
| 230 |
+
target_shape: Optional output mask shape (H, W) after resizing.
|
| 231 |
+
num_frames: Optional maximum number of frames to process.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Sky masks with shape (S, H, W), or None if sky segmentation could not run.
|
| 235 |
+
"""
|
| 236 |
+
if onnxruntime is None:
|
| 237 |
+
print("Warning: onnxruntime not available, skipping sky segmentation")
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
if image_folder is None and image_paths is None and images is None:
|
| 241 |
+
print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
if not os.path.exists(skyseg_model_path):
|
| 245 |
+
print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
|
| 246 |
+
try:
|
| 247 |
+
download_skyseg_model(skyseg_model_path)
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print(f"Warning: Failed to download sky segmentation model: {e}")
|
| 250 |
+
return None
|
| 251 |
+
|
| 252 |
+
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
| 253 |
+
sky_masks = []
|
| 254 |
+
|
| 255 |
+
if sky_mask_visualization_dir is not None:
|
| 256 |
+
os.makedirs(sky_mask_visualization_dir, exist_ok=True)
|
| 257 |
+
print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
|
| 258 |
+
|
| 259 |
+
if images is not None:
|
| 260 |
+
if image_paths is None and image_folder is not None:
|
| 261 |
+
image_paths = _list_image_files(image_folder)
|
| 262 |
+
|
| 263 |
+
num_images = images.shape[0]
|
| 264 |
+
if num_frames is not None:
|
| 265 |
+
num_images = min(num_images, num_frames)
|
| 266 |
+
if image_paths is not None:
|
| 267 |
+
image_paths = image_paths[:num_images]
|
| 268 |
+
|
| 269 |
+
if sky_mask_dir is None and image_folder is not None:
|
| 270 |
+
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
| 271 |
+
_prepare_sky_mask_cache(sky_mask_dir)
|
| 272 |
+
|
| 273 |
+
print("Generating sky masks from image array...")
|
| 274 |
+
for i in tqdm(range(num_images)):
|
| 275 |
+
image_rgb = _image_to_rgb_uint8(images[i])
|
| 276 |
+
image_h, image_w = image_rgb.shape[:2]
|
| 277 |
+
image_name = _get_mask_filename(image_paths, i)
|
| 278 |
+
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
|
| 279 |
+
|
| 280 |
+
if mask_filepath is not None and os.path.exists(mask_filepath):
|
| 281 |
+
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
| 282 |
+
if sky_mask is not None and sky_mask.shape[:2] == (image_h, image_w):
|
| 283 |
+
# Reuse cached mask
|
| 284 |
+
pass
|
| 285 |
+
else:
|
| 286 |
+
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
| 287 |
+
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
| 288 |
+
else:
|
| 289 |
+
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
| 290 |
+
if mask_filepath is not None:
|
| 291 |
+
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
| 292 |
+
|
| 293 |
+
if sky_mask_visualization_dir is not None:
|
| 294 |
+
_save_sky_mask_visualization(
|
| 295 |
+
image_rgb,
|
| 296 |
+
sky_mask,
|
| 297 |
+
os.path.join(sky_mask_visualization_dir, image_name),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
| 301 |
+
sky_mask = cv2.resize(
|
| 302 |
+
sky_mask,
|
| 303 |
+
(target_shape[1], target_shape[0]),
|
| 304 |
+
interpolation=cv2.INTER_LINEAR,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
sky_masks.append(_mask_to_float(sky_mask))
|
| 308 |
+
|
| 309 |
+
else:
|
| 310 |
+
if image_paths is None and image_folder is not None:
|
| 311 |
+
image_paths = _list_image_files(image_folder)
|
| 312 |
+
|
| 313 |
+
if images is None and image_paths is not None:
|
| 314 |
+
if len(image_paths) == 0:
|
| 315 |
+
print("Warning: No image files provided, skipping sky segmentation")
|
| 316 |
+
return None
|
| 317 |
+
|
| 318 |
+
if num_frames is not None:
|
| 319 |
+
image_paths = image_paths[:num_frames]
|
| 320 |
+
|
| 321 |
+
if sky_mask_dir is None:
|
| 322 |
+
if image_folder is None:
|
| 323 |
+
image_folder = os.path.dirname(image_paths[0])
|
| 324 |
+
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
| 325 |
+
_prepare_sky_mask_cache(sky_mask_dir)
|
| 326 |
+
|
| 327 |
+
print("Generating sky masks from image files...")
|
| 328 |
+
for image_path in tqdm(image_paths):
|
| 329 |
+
image_name = os.path.basename(image_path)
|
| 330 |
+
mask_filepath = os.path.join(sky_mask_dir, image_name)
|
| 331 |
+
|
| 332 |
+
if os.path.exists(mask_filepath):
|
| 333 |
+
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
| 334 |
+
if sky_mask is None:
|
| 335 |
+
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
| 336 |
+
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
| 337 |
+
else:
|
| 338 |
+
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
| 339 |
+
|
| 340 |
+
if sky_mask is None:
|
| 341 |
+
print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
if sky_mask_visualization_dir is not None:
|
| 345 |
+
image_bgr = cv2.imread(image_path)
|
| 346 |
+
if image_bgr is not None:
|
| 347 |
+
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 348 |
+
_save_sky_mask_visualization(
|
| 349 |
+
image_rgb,
|
| 350 |
+
sky_mask,
|
| 351 |
+
os.path.join(sky_mask_visualization_dir, image_name),
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
| 355 |
+
sky_mask = cv2.resize(
|
| 356 |
+
sky_mask,
|
| 357 |
+
(target_shape[1], target_shape[0]),
|
| 358 |
+
interpolation=cv2.INTER_LINEAR,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
sky_masks.append(_mask_to_float(sky_mask))
|
| 362 |
+
|
| 363 |
+
if len(sky_masks) == 0:
|
| 364 |
+
print("Warning: No sky masks generated, skipping sky segmentation")
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
try:
|
| 368 |
+
return np.stack(sky_masks, axis=0)
|
| 369 |
+
except ValueError:
|
| 370 |
+
return np.array(sky_masks, dtype=object)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def apply_sky_segmentation(
|
| 374 |
+
conf: np.ndarray,
|
| 375 |
+
image_folder: Optional[str] = None,
|
| 376 |
+
image_paths: Optional[list[str]] = None,
|
| 377 |
+
images: Optional[np.ndarray] = None,
|
| 378 |
+
skyseg_model_path: str = "skyseg.onnx",
|
| 379 |
+
sky_mask_dir: Optional[str] = None,
|
| 380 |
+
sky_mask_visualization_dir: Optional[str] = None,
|
| 381 |
+
) -> np.ndarray:
|
| 382 |
+
"""
|
| 383 |
+
Apply sky segmentation to confidence scores.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
conf: Confidence scores with shape (S, H, W)
|
| 387 |
+
image_folder: Path to the folder containing input images (optional if images provided)
|
| 388 |
+
image_paths: Optional explicit image file list in processing order
|
| 389 |
+
images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
|
| 390 |
+
skyseg_model_path: Path to the sky segmentation ONNX model
|
| 391 |
+
sky_mask_dir: Optional directory for cached raw masks
|
| 392 |
+
sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
Updated confidence scores with sky regions masked out
|
| 396 |
+
"""
|
| 397 |
+
S, H, W = conf.shape
|
| 398 |
+
|
| 399 |
+
sky_mask_array = load_or_create_sky_masks(
|
| 400 |
+
image_folder=image_folder,
|
| 401 |
+
image_paths=image_paths,
|
| 402 |
+
images=images,
|
| 403 |
+
skyseg_model_path=skyseg_model_path,
|
| 404 |
+
sky_mask_dir=sky_mask_dir,
|
| 405 |
+
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
| 406 |
+
target_shape=(H, W),
|
| 407 |
+
num_frames=S,
|
| 408 |
+
)
|
| 409 |
+
if sky_mask_array is None:
|
| 410 |
+
return conf
|
| 411 |
+
|
| 412 |
+
if sky_mask_array.shape[0] < S:
|
| 413 |
+
print(
|
| 414 |
+
f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
|
| 415 |
+
"leaving the remaining frames unmasked"
|
| 416 |
+
)
|
| 417 |
+
padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
|
| 418 |
+
padded[: sky_mask_array.shape[0]] = sky_mask_array
|
| 419 |
+
sky_mask_array = padded
|
| 420 |
+
elif sky_mask_array.shape[0] > S:
|
| 421 |
+
sky_mask_array = sky_mask_array[:S]
|
| 422 |
+
|
| 423 |
+
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
| 424 |
+
conf = conf * sky_mask_binary
|
| 425 |
+
|
| 426 |
+
print("Sky segmentation applied successfully")
|
| 427 |
+
return conf
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
|
| 431 |
+
"""
|
| 432 |
+
Download sky segmentation model from HuggingFace.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
output_path: Path to save the model
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
Path to the downloaded model
|
| 439 |
+
"""
|
| 440 |
+
import requests
|
| 441 |
+
|
| 442 |
+
url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
|
| 443 |
+
|
| 444 |
+
print(f"Downloading sky segmentation model from {url}...")
|
| 445 |
+
response = requests.get(url, stream=True)
|
| 446 |
+
response.raise_for_status()
|
| 447 |
+
|
| 448 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 449 |
+
|
| 450 |
+
with open(output_path, 'wb') as f:
|
| 451 |
+
with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
|
| 452 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 453 |
+
f.write(chunk)
|
| 454 |
+
pbar.update(len(chunk))
|
| 455 |
+
|
| 456 |
+
print(f"Model saved to {output_path}")
|
| 457 |
+
return output_path
|
lingbot_map/vis/utils.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Visualization utility functions for colorization and color bars.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import dataclasses
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import cv2
|
| 17 |
+
import matplotlib.cm as cm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclasses.dataclass
|
| 21 |
+
class CameraState:
|
| 22 |
+
"""Camera state for rendering."""
|
| 23 |
+
fov: float
|
| 24 |
+
aspect: float
|
| 25 |
+
c2w: np.ndarray
|
| 26 |
+
|
| 27 |
+
def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray:
|
| 28 |
+
"""Get camera intrinsic matrix from FOV and image size."""
|
| 29 |
+
W, H = img_wh
|
| 30 |
+
focal_length = H / 2.0 / np.tan(self.fov / 2.0)
|
| 31 |
+
K = np.array([
|
| 32 |
+
[focal_length, 0.0, W / 2.0],
|
| 33 |
+
[0.0, focal_length, H / 2.0],
|
| 34 |
+
[0.0, 0.0, 1.0],
|
| 35 |
+
])
|
| 36 |
+
return K
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_vertical_colorbar(
|
| 40 |
+
h: int,
|
| 41 |
+
vmin: float,
|
| 42 |
+
vmax: float,
|
| 43 |
+
cmap_name: str = "jet",
|
| 44 |
+
label: Optional[str] = None,
|
| 45 |
+
cbar_precision: int = 2
|
| 46 |
+
) -> np.ndarray:
|
| 47 |
+
"""
|
| 48 |
+
Create a vertical colorbar image.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
h: Height in pixels
|
| 52 |
+
vmin: Minimum value
|
| 53 |
+
vmax: Maximum value
|
| 54 |
+
cmap_name: Colormap name
|
| 55 |
+
label: Optional label for the colorbar
|
| 56 |
+
cbar_precision: Decimal precision for tick labels
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Colorbar image as numpy array (H, W, 3)
|
| 60 |
+
"""
|
| 61 |
+
from matplotlib.figure import Figure
|
| 62 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 63 |
+
import matplotlib as mpl
|
| 64 |
+
|
| 65 |
+
fig = Figure(figsize=(2, 8), dpi=100)
|
| 66 |
+
fig.subplots_adjust(right=1.5)
|
| 67 |
+
canvas = FigureCanvasAgg(fig)
|
| 68 |
+
|
| 69 |
+
ax = fig.add_subplot(111)
|
| 70 |
+
cmap = cm.get_cmap(cmap_name)
|
| 71 |
+
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
| 72 |
+
|
| 73 |
+
tick_cnt = 6
|
| 74 |
+
tick_loc = np.linspace(vmin, vmax, tick_cnt)
|
| 75 |
+
cb1 = mpl.colorbar.ColorbarBase(
|
| 76 |
+
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
|
| 80 |
+
if cbar_precision == 0:
|
| 81 |
+
tick_label = [x[:-2] for x in tick_label]
|
| 82 |
+
|
| 83 |
+
cb1.set_ticklabels(tick_label)
|
| 84 |
+
cb1.ax.tick_params(labelsize=18, rotation=0)
|
| 85 |
+
if label is not None:
|
| 86 |
+
cb1.set_label(label)
|
| 87 |
+
|
| 88 |
+
canvas.draw()
|
| 89 |
+
s, (width, height) = canvas.print_to_buffer()
|
| 90 |
+
|
| 91 |
+
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
|
| 92 |
+
im = im[:, :, :3].astype(np.float32) / 255.0
|
| 93 |
+
|
| 94 |
+
if h != im.shape[0]:
|
| 95 |
+
w = int(im.shape[1] / im.shape[0] * h)
|
| 96 |
+
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
|
| 97 |
+
|
| 98 |
+
return im
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def colorize_np(
|
| 102 |
+
x: np.ndarray,
|
| 103 |
+
cmap_name: str = "jet",
|
| 104 |
+
mask: Optional[np.ndarray] = None,
|
| 105 |
+
range: Optional[Tuple[float, float]] = None,
|
| 106 |
+
append_cbar: bool = False,
|
| 107 |
+
cbar_in_image: bool = False,
|
| 108 |
+
cbar_precision: int = 2,
|
| 109 |
+
) -> np.ndarray:
|
| 110 |
+
"""
|
| 111 |
+
Turn a grayscale image into a color image.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
x: Input grayscale image [H, W]
|
| 115 |
+
cmap_name: Colormap name
|
| 116 |
+
mask: Optional mask image [H, W]
|
| 117 |
+
range: Value range for scaling [min, max], automatic if None
|
| 118 |
+
append_cbar: Whether to append colorbar
|
| 119 |
+
cbar_in_image: Put colorbar inside image
|
| 120 |
+
cbar_precision: Colorbar tick precision
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Colorized image [H, W, 3]
|
| 124 |
+
"""
|
| 125 |
+
if range is not None:
|
| 126 |
+
vmin, vmax = range
|
| 127 |
+
elif mask is not None:
|
| 128 |
+
vmin = np.min(x[mask][np.nonzero(x[mask])])
|
| 129 |
+
vmax = np.max(x[mask])
|
| 130 |
+
x[np.logical_not(mask)] = vmin
|
| 131 |
+
else:
|
| 132 |
+
vmin, vmax = np.percentile(x, (1, 100))
|
| 133 |
+
vmax += 1e-6
|
| 134 |
+
|
| 135 |
+
x = np.clip(x, vmin, vmax)
|
| 136 |
+
x = (x - vmin) / (vmax - vmin)
|
| 137 |
+
|
| 138 |
+
cmap = cm.get_cmap(cmap_name)
|
| 139 |
+
x_new = cmap(x)[:, :, :3]
|
| 140 |
+
|
| 141 |
+
if mask is not None:
|
| 142 |
+
mask = np.float32(mask[:, :, np.newaxis])
|
| 143 |
+
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
|
| 144 |
+
|
| 145 |
+
cbar = get_vertical_colorbar(
|
| 146 |
+
h=x.shape[0],
|
| 147 |
+
vmin=vmin,
|
| 148 |
+
vmax=vmax,
|
| 149 |
+
cmap_name=cmap_name,
|
| 150 |
+
cbar_precision=cbar_precision,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if append_cbar:
|
| 154 |
+
if cbar_in_image:
|
| 155 |
+
x_new[:, -cbar.shape[1]:, :] = cbar
|
| 156 |
+
else:
|
| 157 |
+
x_new = np.concatenate(
|
| 158 |
+
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
|
| 159 |
+
)
|
| 160 |
+
return x_new
|
| 161 |
+
else:
|
| 162 |
+
return x_new
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def colorize(
|
| 166 |
+
x: torch.Tensor,
|
| 167 |
+
cmap_name: str = "jet",
|
| 168 |
+
mask: Optional[torch.Tensor] = None,
|
| 169 |
+
range: Optional[Tuple[float, float]] = None,
|
| 170 |
+
append_cbar: bool = False,
|
| 171 |
+
cbar_in_image: bool = False
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Turn a grayscale image into a color image (PyTorch tensor version).
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
x: Grayscale image tensor [H, W] or [B, H, W]
|
| 178 |
+
cmap_name: Colormap name
|
| 179 |
+
mask: Optional mask tensor [H, W] or [B, H, W]
|
| 180 |
+
range: Value range for scaling
|
| 181 |
+
append_cbar: Whether to append colorbar
|
| 182 |
+
cbar_in_image: Put colorbar inside image
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Colorized tensor
|
| 186 |
+
"""
|
| 187 |
+
device = x.device
|
| 188 |
+
x = x.cpu().numpy()
|
| 189 |
+
if mask is not None:
|
| 190 |
+
mask = mask.cpu().numpy() > 0.99
|
| 191 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 192 |
+
|
| 193 |
+
if x.ndim == 2:
|
| 194 |
+
x = x[None]
|
| 195 |
+
if mask is not None:
|
| 196 |
+
mask = mask[None]
|
| 197 |
+
|
| 198 |
+
out = []
|
| 199 |
+
for x_ in x:
|
| 200 |
+
if mask is not None:
|
| 201 |
+
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
|
| 202 |
+
|
| 203 |
+
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
|
| 204 |
+
out.append(torch.from_numpy(x_).to(device).float())
|
| 205 |
+
out = torch.stack(out).squeeze(0)
|
| 206 |
+
return out
|
lingbot_map/vis/viser_wrapper.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Quick visualization wrapper for GCT predictions using Viser.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import time
|
| 12 |
+
import threading
|
| 13 |
+
from typing import List, Optional
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import viser
|
| 17 |
+
import viser.transforms as tf
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
|
| 21 |
+
from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def viser_wrapper(
|
| 25 |
+
pred_dict: dict,
|
| 26 |
+
port: int = 8080,
|
| 27 |
+
init_conf_threshold: float = 50.0,
|
| 28 |
+
use_point_map: bool = False,
|
| 29 |
+
background_mode: bool = False,
|
| 30 |
+
mask_sky: bool = False,
|
| 31 |
+
image_folder: Optional[str] = None,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Visualize predicted 3D points and camera poses with viser.
|
| 35 |
+
|
| 36 |
+
This is a simplified wrapper for quick visualization without the full
|
| 37 |
+
PointCloudViewer controls.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
pred_dict: Dictionary containing predictions with keys:
|
| 41 |
+
- images: (S, 3, H, W) - Input images
|
| 42 |
+
- world_points: (S, H, W, 3)
|
| 43 |
+
- world_points_conf: (S, H, W)
|
| 44 |
+
- depth: (S, H, W, 1)
|
| 45 |
+
- depth_conf: (S, H, W)
|
| 46 |
+
- extrinsic: (S, 3, 4)
|
| 47 |
+
- intrinsic: (S, 3, 3)
|
| 48 |
+
port: Port number for the viser server
|
| 49 |
+
init_conf_threshold: Initial percentage of low-confidence points to filter out
|
| 50 |
+
use_point_map: Whether to visualize world_points or use depth-based points
|
| 51 |
+
background_mode: Whether to run the server in background thread
|
| 52 |
+
mask_sky: Whether to apply sky segmentation to filter out sky points
|
| 53 |
+
image_folder: Path to the folder containing input images (for sky segmentation)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
viser.ViserServer: The viser server instance
|
| 57 |
+
"""
|
| 58 |
+
print(f"Starting viser server on port {port}")
|
| 59 |
+
|
| 60 |
+
server = viser.ViserServer(host="0.0.0.0", port=port)
|
| 61 |
+
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
|
| 62 |
+
|
| 63 |
+
# Unpack prediction dict
|
| 64 |
+
images = pred_dict["images"] # (S, 3, H, W)
|
| 65 |
+
world_points_map = pred_dict["world_points"] # (S, H, W, 3)
|
| 66 |
+
conf_map = pred_dict["world_points_conf"] # (S, H, W)
|
| 67 |
+
|
| 68 |
+
depth_map = pred_dict["depth"] # (S, H, W, 1)
|
| 69 |
+
depth_conf = pred_dict["depth_conf"] # (S, H, W)
|
| 70 |
+
|
| 71 |
+
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
|
| 72 |
+
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
|
| 73 |
+
|
| 74 |
+
# Compute world points from depth if not using the precomputed point map
|
| 75 |
+
if not use_point_map:
|
| 76 |
+
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
|
| 77 |
+
conf = depth_conf
|
| 78 |
+
else:
|
| 79 |
+
world_points = world_points_map
|
| 80 |
+
conf = conf_map
|
| 81 |
+
|
| 82 |
+
# Apply sky segmentation if enabled
|
| 83 |
+
if mask_sky and image_folder is not None:
|
| 84 |
+
conf = apply_sky_segmentation(conf, image_folder)
|
| 85 |
+
|
| 86 |
+
# Convert images from (S, 3, H, W) to (S, H, W, 3)
|
| 87 |
+
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
| 88 |
+
shape = world_points.shape
|
| 89 |
+
S: int = shape[0]
|
| 90 |
+
H: int = shape[1]
|
| 91 |
+
W: int = shape[2]
|
| 92 |
+
|
| 93 |
+
# Flatten
|
| 94 |
+
points = world_points.reshape(-1, 3)
|
| 95 |
+
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
|
| 96 |
+
conf_flat = conf.reshape(-1)
|
| 97 |
+
|
| 98 |
+
# Random sample points if too many
|
| 99 |
+
indices = None
|
| 100 |
+
if points.shape[0] > 6000000:
|
| 101 |
+
print(f"Too many points ({points.shape[0]}), randomly sampling 6M points")
|
| 102 |
+
indices = np.random.choice(points.shape[0], size=6000000, replace=False)
|
| 103 |
+
points = points[indices]
|
| 104 |
+
colors_flat = colors_flat[indices]
|
| 105 |
+
conf_flat = conf_flat[indices]
|
| 106 |
+
|
| 107 |
+
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
|
| 108 |
+
cam_to_world = cam_to_world_mat[:, :3, :]
|
| 109 |
+
|
| 110 |
+
# Compute scene center and recenter
|
| 111 |
+
scene_center = np.mean(points, axis=0)
|
| 112 |
+
points_centered = points - scene_center
|
| 113 |
+
cam_to_world[..., -1] -= scene_center
|
| 114 |
+
|
| 115 |
+
# Store frame indices for filtering
|
| 116 |
+
frame_indices = (
|
| 117 |
+
np.repeat(np.arange(S), H * W)[indices]
|
| 118 |
+
if indices is not None
|
| 119 |
+
else np.repeat(np.arange(S), H * W)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Build the viser GUI
|
| 123 |
+
gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
|
| 124 |
+
gui_points_conf = server.gui.add_slider(
|
| 125 |
+
"Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
|
| 126 |
+
)
|
| 127 |
+
gui_frame_selector = server.gui.add_dropdown(
|
| 128 |
+
"Show Points from Frames",
|
| 129 |
+
options=["All"] + [str(i) for i in range(S)],
|
| 130 |
+
initial_value="All"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Create the main point cloud
|
| 134 |
+
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
|
| 135 |
+
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
|
| 136 |
+
point_cloud = server.scene.add_point_cloud(
|
| 137 |
+
name="viser_pcd",
|
| 138 |
+
points=points_centered[init_conf_mask],
|
| 139 |
+
colors=colors_flat[init_conf_mask],
|
| 140 |
+
point_size=0.0005,
|
| 141 |
+
point_shape="circle",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
frames: List[viser.FrameHandle] = []
|
| 145 |
+
frustums: List[viser.CameraFrustumHandle] = []
|
| 146 |
+
|
| 147 |
+
def visualize_frames(extrinsics, images_: np.ndarray) -> None:
|
| 148 |
+
"""Add camera frames and frustums to the scene."""
|
| 149 |
+
for f in frames:
|
| 150 |
+
f.remove()
|
| 151 |
+
frames.clear()
|
| 152 |
+
for fr in frustums:
|
| 153 |
+
fr.remove()
|
| 154 |
+
frustums.clear()
|
| 155 |
+
|
| 156 |
+
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
|
| 157 |
+
@frustum.on_click
|
| 158 |
+
def _(_) -> None:
|
| 159 |
+
for client in server.get_clients().values():
|
| 160 |
+
client.camera.wxyz = frame.wxyz
|
| 161 |
+
client.camera.position = frame.position
|
| 162 |
+
|
| 163 |
+
for img_id in tqdm(range(S)):
|
| 164 |
+
cam2world_3x4 = extrinsics[img_id]
|
| 165 |
+
T_world_camera = tf.SE3.from_matrix(cam2world_3x4)
|
| 166 |
+
|
| 167 |
+
frame_axis = server.scene.add_frame(
|
| 168 |
+
f"frame_{img_id}",
|
| 169 |
+
wxyz=T_world_camera.rotation().wxyz,
|
| 170 |
+
position=T_world_camera.translation(),
|
| 171 |
+
axes_length=0.05,
|
| 172 |
+
axes_radius=0.002,
|
| 173 |
+
origin_radius=0.002,
|
| 174 |
+
)
|
| 175 |
+
frames.append(frame_axis)
|
| 176 |
+
|
| 177 |
+
img = images_[img_id]
|
| 178 |
+
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 179 |
+
h, w = img.shape[:2]
|
| 180 |
+
|
| 181 |
+
fy = 1.1 * h
|
| 182 |
+
fov = 2 * np.arctan2(h / 2, fy)
|
| 183 |
+
|
| 184 |
+
frustum_cam = server.scene.add_camera_frustum(
|
| 185 |
+
f"frame_{img_id}/frustum",
|
| 186 |
+
fov=fov,
|
| 187 |
+
aspect=w / h,
|
| 188 |
+
scale=0.05,
|
| 189 |
+
image=img,
|
| 190 |
+
line_width=1.0
|
| 191 |
+
)
|
| 192 |
+
frustums.append(frustum_cam)
|
| 193 |
+
attach_callback(frustum_cam, frame_axis)
|
| 194 |
+
|
| 195 |
+
def update_point_cloud() -> None:
|
| 196 |
+
"""Update point cloud based on current GUI selections."""
|
| 197 |
+
current_percentage = gui_points_conf.value
|
| 198 |
+
threshold_val = np.percentile(conf_flat, current_percentage)
|
| 199 |
+
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
|
| 200 |
+
|
| 201 |
+
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
|
| 202 |
+
|
| 203 |
+
if gui_frame_selector.value == "All":
|
| 204 |
+
frame_mask = np.ones_like(conf_mask, dtype=bool)
|
| 205 |
+
else:
|
| 206 |
+
selected_idx = int(gui_frame_selector.value)
|
| 207 |
+
frame_mask = frame_indices == selected_idx
|
| 208 |
+
|
| 209 |
+
combined_mask = conf_mask & frame_mask
|
| 210 |
+
point_cloud.points = points_centered[combined_mask]
|
| 211 |
+
point_cloud.colors = colors_flat[combined_mask]
|
| 212 |
+
|
| 213 |
+
@gui_points_conf.on_update
|
| 214 |
+
def _(_) -> None:
|
| 215 |
+
update_point_cloud()
|
| 216 |
+
|
| 217 |
+
@gui_frame_selector.on_update
|
| 218 |
+
def _(_) -> None:
|
| 219 |
+
update_point_cloud()
|
| 220 |
+
|
| 221 |
+
@gui_show_frames.on_update
|
| 222 |
+
def _(_) -> None:
|
| 223 |
+
for f in frames:
|
| 224 |
+
f.visible = gui_show_frames.value
|
| 225 |
+
for fr in frustums:
|
| 226 |
+
fr.visible = gui_show_frames.value
|
| 227 |
+
|
| 228 |
+
# Add camera frames
|
| 229 |
+
import torch
|
| 230 |
+
if torch.is_tensor(cam_to_world):
|
| 231 |
+
cam_to_world_np = cam_to_world.cpu().numpy()
|
| 232 |
+
else:
|
| 233 |
+
cam_to_world_np = cam_to_world
|
| 234 |
+
visualize_frames(cam_to_world_np, images)
|
| 235 |
+
|
| 236 |
+
print("Starting viser server...")
|
| 237 |
+
if background_mode:
|
| 238 |
+
def server_loop():
|
| 239 |
+
while True:
|
| 240 |
+
time.sleep(0.001)
|
| 241 |
+
|
| 242 |
+
thread = threading.Thread(target=server_loop, daemon=True)
|
| 243 |
+
thread.start()
|
| 244 |
+
else:
|
| 245 |
+
while True:
|
| 246 |
+
time.sleep(0.01)
|
| 247 |
+
|
| 248 |
+
return server
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu128
|
| 2 |
+
torch==2.9.1
|
| 3 |
+
torchvision==0.24.1
|
| 4 |
+
gradio>=5.0,<6
|
| 5 |
+
spaces>=0.34.0
|
| 6 |
+
huggingface_hub>=0.30.0
|
| 7 |
+
einops>=0.8.0
|
| 8 |
+
safetensors>=0.5.0
|
| 9 |
+
opencv-python-headless>=4.10.0
|
| 10 |
+
tqdm>=4.66.0
|
| 11 |
+
scipy>=1.13.0
|
| 12 |
+
trimesh>=4.4.0
|
| 13 |
+
matplotlib>=3.8.0
|
| 14 |
+
Pillow>=10.0.0
|