dennny123 commited on
Commit
4700ca8
·
verified ·
1 Parent(s): 4800b16

Initial ZeroGPU Gradio Space for LingBot-Map

Browse files
Files changed (42) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. LICENSE.txt +201 -0
  4. README.md +45 -6
  5. app.py +630 -0
  6. assets/teaser.png +3 -0
  7. lingbot_map/__init__.py +0 -0
  8. lingbot_map/aggregator/__init__.py +2 -0
  9. lingbot_map/aggregator/base.py +608 -0
  10. lingbot_map/aggregator/stream.py +531 -0
  11. lingbot_map/heads/__init__.py +0 -0
  12. lingbot_map/heads/camera_head.py +458 -0
  13. lingbot_map/heads/dpt_head.py +679 -0
  14. lingbot_map/heads/head_act.py +125 -0
  15. lingbot_map/heads/utils.py +109 -0
  16. lingbot_map/layers/__init__.py +5 -0
  17. lingbot_map/layers/attention.py +766 -0
  18. lingbot_map/layers/block.py +514 -0
  19. lingbot_map/layers/drop_path.py +34 -0
  20. lingbot_map/layers/flashinfer_cache.py +640 -0
  21. lingbot_map/layers/layer_scale.py +22 -0
  22. lingbot_map/layers/mlp.py +40 -0
  23. lingbot_map/layers/patch_embed.py +85 -0
  24. lingbot_map/layers/rope.py +474 -0
  25. lingbot_map/layers/swiglu_ffn.py +67 -0
  26. lingbot_map/layers/vision_transformer.py +411 -0
  27. lingbot_map/models/__init__.py +0 -0
  28. lingbot_map/models/gct_base.py +359 -0
  29. lingbot_map/models/gct_stream.py +448 -0
  30. lingbot_map/models/gct_stream_window.py +1206 -0
  31. lingbot_map/utils/__init__.py +0 -0
  32. lingbot_map/utils/geometry.py +774 -0
  33. lingbot_map/utils/load_fn.py +243 -0
  34. lingbot_map/utils/pose_enc.py +331 -0
  35. lingbot_map/utils/rotation.py +132 -0
  36. lingbot_map/vis/__init__.py +59 -0
  37. lingbot_map/vis/glb_export.py +509 -0
  38. lingbot_map/vis/point_cloud_viewer.py +1437 -0
  39. lingbot_map/vis/sky_segmentation.py +457 -0
  40. lingbot_map/vis/utils.py +206 -0
  41. lingbot_map/vis/viser_wrapper.py +248 -0
  42. requirements.txt +14 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .DS_Store
3
+ .gradio/
4
+ app_output/
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,51 @@
1
  ---
2
- title: Lingbot Map Zerogpu Demo
3
- emoji: 🏢
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.12.0
8
  app_file: app.py
 
9
  pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: LingBot-Map ZeroGPU Demo
3
+ colorFrom: blue
4
+ colorTo: green
 
5
  sdk: gradio
 
6
  app_file: app.py
7
+ python_version: 3.10.13
8
  pinned: false
9
+ license: apache-2.0
10
+ startup_duration_timeout: 1h
11
+ models:
12
+ - robbyant/lingbot-map
13
+ preload_from_hub:
14
+ - robbyant/lingbot-map lingbot-map.pt,lingbot-map-long.pt
15
  ---
16
 
17
+ # LingBot-Map ZeroGPU Demo
18
+
19
+ Gradio Space wrapper around `Robbyant/lingbot-map` tuned for Hugging Face ZeroGPU:
20
+
21
+ - uses the upstream `lingbot_map` package directly
22
+ - downloads checkpoints from `robbyant/lingbot-map`
23
+ - runs the SDPA fallback path instead of FlashInfer
24
+ - caps inputs to short clips so the app fits a shared ZeroGPU workflow
25
+ - exports a browser-friendly `.glb` scene plus a zipped artifact bundle
26
+
27
+ ## Recommended Space Settings
28
+
29
+ 1. Create a new **Gradio** Space.
30
+ 2. In **Settings -> Hardware**, switch the Space to **ZeroGPU**.
31
+ 3. Keep the repo public or protected as needed.
32
+
33
+ ## Current Limits
34
+
35
+ - short demos only
36
+ - default frame cap: 24 frames
37
+ - model preview is exported as GLB, not the local `viser` server
38
+ - the app is optimized for `lingbot-map.pt` and `lingbot-map-long.pt`
39
+
40
+ ## Local Sanity Check
41
+
42
+ If you want to import the app locally without downloading the checkpoint at startup:
43
+
44
+ ```bash
45
+ LINGBOT_SPACE_SKIP_MODEL_LOAD=1 python app.py
46
+ ```
47
+
48
+ ## Upstream
49
+
50
+ - GitHub: https://github.com/Robbyant/lingbot-map
51
+ - Model: https://huggingface.co/robbyant/lingbot-map
app.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import json
4
+ import os
5
+ import shutil
6
+ import tempfile
7
+ import threading
8
+ import time
9
+ import zipfile
10
+ from pathlib import Path
11
+ from typing import Any, Iterable
12
+
13
+ import cv2
14
+ import gradio as gr
15
+ import numpy as np
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
+ from PIL import Image, ImageDraw
19
+
20
+ try:
21
+ import spaces
22
+ except ImportError:
23
+ class _SpacesShim:
24
+ @staticmethod
25
+ def GPU(*decorator_args, **decorator_kwargs):
26
+ if decorator_args and callable(decorator_args[0]) and len(decorator_args) == 1 and not decorator_kwargs:
27
+ return decorator_args[0]
28
+
29
+ def _wrap(func):
30
+ return func
31
+
32
+ return _wrap
33
+
34
+ spaces = _SpacesShim()
35
+
36
+ from lingbot_map.models.gct_stream import GCTStream
37
+ from lingbot_map.utils.geometry import closed_form_inverse_se3_general
38
+ from lingbot_map.utils.load_fn import load_and_preprocess_images
39
+ from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
40
+ from lingbot_map.vis.glb_export import predictions_to_glb
41
+
42
+
43
+ ROOT = Path(__file__).resolve().parent
44
+ OUTPUT_ROOT = ROOT / "app_output"
45
+ OUTPUT_ROOT.mkdir(exist_ok=True)
46
+
47
+ HF_MODEL_REPO = "robbyant/lingbot-map"
48
+ MODEL_FILENAMES = {
49
+ "balanced": "lingbot-map.pt",
50
+ "long": "lingbot-map-long.pt",
51
+ "stage1": "lingbot-map-stage1.pt",
52
+ }
53
+ MODEL_LABELS = {
54
+ "balanced": "Balanced",
55
+ "long": "Long",
56
+ "stage1": "Stage-1",
57
+ }
58
+
59
+ IMAGE_SIZE = 518
60
+ PATCH_SIZE = 14
61
+ DEFAULT_FPS = 8
62
+ DEFAULT_MAX_FRAMES = 24
63
+ MAX_FRAMES_HARD_LIMIT = 24
64
+ DEFAULT_SCALE_FRAMES = 4
65
+ DEFAULT_KEYFRAME_INTERVAL = 2
66
+ DEFAULT_CONF_PERCENTILE = 50.0
67
+ DEFAULT_CAMERA_ITERATIONS = 1
68
+ IS_SPACE_RUNTIME = bool(os.getenv("SPACE_ID"))
69
+ SKIP_EAGER_MODEL_LOAD = os.getenv("LINGBOT_SPACE_SKIP_MODEL_LOAD") == "1"
70
+
71
+ MODEL_CACHE: dict[str, dict[str, Any]] = {}
72
+ MODEL_CACHE_LOCK = threading.Lock()
73
+ STARTUP_NOTES: list[str] = []
74
+
75
+
76
+ def _resolve_path(file_obj: Any) -> str:
77
+ if file_obj is None:
78
+ return ""
79
+ if isinstance(file_obj, str):
80
+ return file_obj
81
+ return getattr(file_obj, "name", "")
82
+
83
+
84
+ def _cleanup_old_runs(keep_last: int = 8) -> None:
85
+ run_dirs = sorted([p for p in OUTPUT_ROOT.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime)
86
+ for stale_dir in run_dirs[:-keep_last]:
87
+ shutil.rmtree(stale_dir, ignore_errors=True)
88
+
89
+
90
+ def _pick_runtime_device() -> torch.device:
91
+ try:
92
+ torch.empty(1, device="cuda")
93
+ return torch.device("cuda")
94
+ except Exception:
95
+ return torch.device("cpu")
96
+
97
+
98
+ def _load_model_bundle(model_variant: str) -> dict[str, Any]:
99
+ with MODEL_CACHE_LOCK:
100
+ cached = MODEL_CACHE.get(model_variant)
101
+ if cached is not None:
102
+ return cached
103
+
104
+ if MODEL_CACHE:
105
+ MODEL_CACHE.clear()
106
+ gc.collect()
107
+ if torch.cuda.is_available():
108
+ torch.cuda.empty_cache()
109
+
110
+ device = _pick_runtime_device()
111
+ weight_name = MODEL_FILENAMES[model_variant]
112
+ weight_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=weight_name)
113
+
114
+ model = GCTStream(
115
+ img_size=IMAGE_SIZE,
116
+ patch_size=PATCH_SIZE,
117
+ enable_3d_rope=True,
118
+ max_frame_num=1024,
119
+ kv_cache_sliding_window=64,
120
+ kv_cache_scale_frames=8,
121
+ kv_cache_cross_frame_special=True,
122
+ kv_cache_include_scale_frames=True,
123
+ use_sdpa=True,
124
+ camera_num_iterations=DEFAULT_CAMERA_ITERATIONS,
125
+ )
126
+
127
+ checkpoint = torch.load(weight_path, map_location="cpu", weights_only=False)
128
+ state_dict = checkpoint.get("model", checkpoint)
129
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
130
+
131
+ model = model.to(device).eval()
132
+ inference_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
133
+ if device.type == "cuda" and getattr(model, "aggregator", None) is not None:
134
+ model.aggregator = model.aggregator.to(dtype=inference_dtype)
135
+
136
+ bundle = {
137
+ "model": model,
138
+ "device": device,
139
+ "dtype": inference_dtype,
140
+ "weight_name": weight_name,
141
+ "weight_path": str(weight_path),
142
+ "missing_keys": len(missing),
143
+ "unexpected_keys": len(unexpected),
144
+ }
145
+ MODEL_CACHE[model_variant] = bundle
146
+ return bundle
147
+
148
+
149
+ def _eager_load_default_model() -> None:
150
+ if not IS_SPACE_RUNTIME or SKIP_EAGER_MODEL_LOAD:
151
+ return
152
+ try:
153
+ bundle = _load_model_bundle("balanced")
154
+ STARTUP_NOTES.append(
155
+ f"Startup preload complete on `{bundle['device']}` with `{bundle['weight_name']}`."
156
+ )
157
+ except Exception as exc:
158
+ STARTUP_NOTES.append(f"Startup preload failed: {exc}")
159
+
160
+
161
+ def _copy_image_inputs(image_files: Iterable[Any], input_dir: Path, max_frames: int) -> list[str]:
162
+ paths = sorted(filter(None, (_resolve_path(item) for item in image_files)), key=lambda value: Path(value).name)
163
+ if not paths:
164
+ return []
165
+
166
+ copied = []
167
+ for idx, src_path in enumerate(paths[:max_frames]):
168
+ src = Path(src_path)
169
+ suffix = src.suffix.lower() or ".png"
170
+ dest = input_dir / f"{idx:06d}{suffix}"
171
+ shutil.copy2(src, dest)
172
+ copied.append(str(dest))
173
+ return copied
174
+
175
+
176
+ def _extract_video_frames(video_file: str, frames_dir: Path, fps: int, max_frames: int) -> tuple[list[str], dict[str, Any]]:
177
+ cap = cv2.VideoCapture(video_file)
178
+ if not cap.isOpened():
179
+ raise ValueError("Could not open the uploaded video.")
180
+
181
+ source_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
182
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
183
+ interval = max(1, round(source_fps / max(fps, 1)))
184
+
185
+ saved_paths = []
186
+ frame_idx = 0
187
+ while len(saved_paths) < max_frames:
188
+ ok, frame = cap.read()
189
+ if not ok:
190
+ break
191
+ if frame_idx % interval == 0:
192
+ output_path = frames_dir / f"{len(saved_paths):06d}.jpg"
193
+ cv2.imwrite(str(output_path), frame)
194
+ saved_paths.append(str(output_path))
195
+ frame_idx += 1
196
+
197
+ cap.release()
198
+
199
+ return saved_paths, {
200
+ "source_fps": round(source_fps, 2),
201
+ "sample_interval": interval,
202
+ "original_frame_count": total_frames,
203
+ }
204
+
205
+
206
+ def _prepare_inputs(image_files: list[Any], video_file: Any, fps: int, max_frames: int) -> tuple[torch.Tensor, list[str], Path, dict[str, Any]]:
207
+ _cleanup_old_runs()
208
+ work_dir = Path(tempfile.mkdtemp(prefix="lingbot-map-", dir=OUTPUT_ROOT))
209
+ input_dir = work_dir / "inputs"
210
+ input_dir.mkdir(parents=True, exist_ok=True)
211
+
212
+ image_paths = _copy_image_inputs(image_files or [], input_dir, max_frames=max_frames)
213
+ input_summary = {"input_mode": None}
214
+
215
+ if image_paths:
216
+ input_summary["input_mode"] = "images"
217
+ input_summary["source_fps"] = None
218
+ input_summary["sample_interval"] = None
219
+ input_summary["original_frame_count"] = len(image_paths)
220
+ else:
221
+ video_path = _resolve_path(video_file)
222
+ if not video_path:
223
+ raise ValueError("Upload either ordered images or a video.")
224
+ image_paths, video_summary = _extract_video_frames(video_path, input_dir, fps=fps, max_frames=max_frames)
225
+ input_summary["input_mode"] = "video"
226
+ input_summary.update(video_summary)
227
+
228
+ if len(image_paths) < 2:
229
+ raise ValueError("Provide at least 2 frames. The Space is tuned for short multi-frame reconstructions.")
230
+
231
+ images = load_and_preprocess_images(
232
+ image_paths,
233
+ mode="crop",
234
+ image_size=IMAGE_SIZE,
235
+ patch_size=PATCH_SIZE,
236
+ )
237
+ return images, image_paths, work_dir, input_summary
238
+
239
+
240
+ def _squeeze_single_batch(key: str, value: torch.Tensor) -> torch.Tensor:
241
+ batched_dims = {
242
+ "pose_enc": 3,
243
+ "depth": 5,
244
+ "depth_conf": 4,
245
+ "world_points": 5,
246
+ "world_points_conf": 4,
247
+ "extrinsic": 4,
248
+ "intrinsic": 4,
249
+ "images": 5,
250
+ }
251
+ expected_ndim = batched_dims.get(key)
252
+ if expected_ndim is None or value.ndim != expected_ndim or value.shape[0] != 1:
253
+ return value
254
+ return value[0]
255
+
256
+
257
+ def _postprocess_predictions(predictions: dict[str, Any], images: torch.Tensor) -> tuple[dict[str, Any], torch.Tensor]:
258
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
259
+ extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype)
260
+ extrinsic_4x4[..., :3, :4] = extrinsic
261
+ extrinsic_4x4[..., 3, 3] = 1.0
262
+ extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4)
263
+
264
+ predictions["extrinsic"] = extrinsic_4x4[..., :3, :4]
265
+ predictions["intrinsic"] = intrinsic
266
+ predictions.pop("pose_enc_list", None)
267
+ predictions.pop("images", None)
268
+
269
+ for key, value in list(predictions.items()):
270
+ if isinstance(value, torch.Tensor):
271
+ predictions[key] = _squeeze_single_batch(key, value.detach().to("cpu"))
272
+
273
+ images_cpu = images.detach().to("cpu")
274
+ if torch.cuda.is_available():
275
+ torch.cuda.synchronize()
276
+ return predictions, images_cpu
277
+
278
+
279
+ def _prepare_for_visualization(predictions: dict[str, Any], images: torch.Tensor) -> dict[str, Any]:
280
+ vis_predictions = {}
281
+ for key, value in predictions.items():
282
+ if isinstance(value, torch.Tensor):
283
+ vis_predictions[key] = _squeeze_single_batch(key, value).detach().cpu().numpy()
284
+ else:
285
+ vis_predictions[key] = value
286
+ vis_predictions["images"] = _squeeze_single_batch("images", images).detach().cpu().numpy()
287
+ return vis_predictions
288
+
289
+
290
+ def _estimate_gpu_duration(images: torch.Tensor, model_variant: str, num_scale_frames: int, keyframe_interval: int) -> int:
291
+ frame_count = int(getattr(images, "shape", [DEFAULT_MAX_FRAMES])[0])
292
+ del model_variant, num_scale_frames, keyframe_interval
293
+ return min(180, max(60, 24 + frame_count * 4))
294
+
295
+
296
+ @spaces.GPU(duration=_estimate_gpu_duration)
297
+ def _run_inference(images: torch.Tensor, model_variant: str, num_scale_frames: int, keyframe_interval: int) -> tuple[dict[str, Any], torch.Tensor, dict[str, Any]]:
298
+ bundle = _load_model_bundle(model_variant)
299
+ model = bundle["model"]
300
+ device = bundle["device"]
301
+ dtype = bundle["dtype"]
302
+
303
+ if device.type == "cuda":
304
+ torch.cuda.empty_cache()
305
+ torch.cuda.reset_peak_memory_stats()
306
+
307
+ images = images.to(device)
308
+ output_device = torch.device("cpu")
309
+ autocast_context = (
310
+ torch.amp.autocast("cuda", dtype=dtype)
311
+ if device.type == "cuda"
312
+ else contextlib.nullcontext()
313
+ )
314
+
315
+ started_at = time.time()
316
+ with torch.no_grad():
317
+ with autocast_context:
318
+ predictions = model.inference_streaming(
319
+ images,
320
+ num_scale_frames=num_scale_frames,
321
+ keyframe_interval=keyframe_interval,
322
+ output_device=output_device,
323
+ )
324
+ inference_seconds = time.time() - started_at
325
+
326
+ images_for_post = predictions["images"]
327
+ del images
328
+ if device.type == "cuda":
329
+ torch.cuda.empty_cache()
330
+
331
+ predictions, images_cpu = _postprocess_predictions(predictions, images_for_post)
332
+ return predictions, images_cpu, {
333
+ "runtime_seconds": round(inference_seconds, 2),
334
+ "device": str(device),
335
+ "dtype": str(dtype),
336
+ "weight_name": bundle["weight_name"],
337
+ "weight_path": bundle["weight_path"],
338
+ "missing_keys": bundle["missing_keys"],
339
+ "unexpected_keys": bundle["unexpected_keys"],
340
+ "peak_memory_gb": round(torch.cuda.max_memory_allocated() / 1e9, 2) if device.type == "cuda" else None,
341
+ }
342
+
343
+
344
+ def _make_preview_strip(images: torch.Tensor, output_path: Path) -> str:
345
+ frames = images.detach().cpu()
346
+ count = frames.shape[0]
347
+ indices = sorted({int(round(i)) for i in np.linspace(0, count - 1, num=min(4, count))})
348
+
349
+ tiles = []
350
+ for idx in indices:
351
+ rgb = (frames[idx].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
352
+ tile = Image.fromarray(rgb).resize((320, 220))
353
+ tiles.append(tile)
354
+
355
+ banner = Image.new("RGB", (320 * len(tiles), 260), color=(245, 240, 228))
356
+ draw = ImageDraw.Draw(banner)
357
+ draw.text((18, 14), f"LingBot-Map preview | {count} frames", fill=(31, 41, 55))
358
+ draw.text((18, 38), "ZeroGPU demo export", fill=(87, 96, 110))
359
+
360
+ x_offset = 0
361
+ for tile in tiles:
362
+ banner.paste(tile, (x_offset, 72))
363
+ x_offset += tile.width
364
+
365
+ banner.save(output_path)
366
+ return str(output_path)
367
+
368
+
369
+ def _save_predictions_npz(predictions: dict[str, Any], output_path: Path) -> str:
370
+ arrays = {}
371
+ for key, value in predictions.items():
372
+ if isinstance(value, torch.Tensor):
373
+ arrays[key] = value.detach().cpu().numpy()
374
+ np.savez_compressed(output_path, **arrays)
375
+ return str(output_path)
376
+
377
+
378
+ def _count_confident_points(vis_predictions: dict[str, Any], conf_percentile: float) -> tuple[int, float]:
379
+ conf = vis_predictions.get("world_points_conf")
380
+ if conf is None:
381
+ return 0, 0.0
382
+ conf_flat = conf.reshape(-1)
383
+ threshold = np.percentile(conf_flat, conf_percentile) if conf_percentile > 0 else 0.0
384
+ kept = int(((conf_flat >= threshold) & (conf_flat > 1e-5)).sum())
385
+ return kept, float(threshold)
386
+
387
+
388
+ def _zip_outputs(work_dir: Path, paths: list[Path], output_name: str) -> str:
389
+ zip_path = work_dir / output_name
390
+ with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
391
+ for path in paths:
392
+ if path.exists():
393
+ zip_file.write(path, arcname=path.name)
394
+ return str(zip_path)
395
+
396
+
397
+ def _export_outputs(
398
+ work_dir: Path,
399
+ image_paths: list[str],
400
+ predictions: dict[str, Any],
401
+ images_cpu: torch.Tensor,
402
+ input_summary: dict[str, Any],
403
+ runtime_summary: dict[str, Any],
404
+ model_variant: str,
405
+ num_scale_frames: int,
406
+ keyframe_interval: int,
407
+ conf_percentile: float,
408
+ ) -> tuple[str, str, dict[str, Any]]:
409
+ vis_predictions = _prepare_for_visualization(predictions, images_cpu)
410
+
411
+ glb_path = work_dir / "lingbot-map-reconstruction.glb"
412
+ scene = predictions_to_glb(
413
+ vis_predictions,
414
+ conf_thres=conf_percentile,
415
+ show_cam=True,
416
+ target_dir=str(work_dir),
417
+ mask_sky=False,
418
+ )
419
+ scene.export(glb_path)
420
+
421
+ preview_path = Path(_make_preview_strip(images_cpu, work_dir / "preview.png"))
422
+ npz_path = Path(_save_predictions_npz(predictions, work_dir / "predictions.npz"))
423
+
424
+ points_kept, conf_threshold = _count_confident_points(vis_predictions, conf_percentile)
425
+ summary = {
426
+ "model_variant": MODEL_LABELS[model_variant],
427
+ "model_filename": MODEL_FILENAMES[model_variant],
428
+ "frames_used": len(image_paths),
429
+ "num_scale_frames": num_scale_frames,
430
+ "keyframe_interval": keyframe_interval,
431
+ "confidence_percentile": conf_percentile,
432
+ "confidence_threshold": round(conf_threshold, 4),
433
+ "points_kept_for_glb": points_kept,
434
+ "input_summary": input_summary,
435
+ "runtime_summary": runtime_summary,
436
+ }
437
+
438
+ summary_path = work_dir / "summary.json"
439
+ summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
440
+
441
+ artifact_path = _zip_outputs(
442
+ work_dir,
443
+ [glb_path, preview_path, npz_path, summary_path],
444
+ output_name="lingbot-map-results.zip",
445
+ )
446
+ return str(glb_path), artifact_path, summary
447
+
448
+
449
+ def _format_status(summary: dict[str, Any]) -> str:
450
+ runtime = summary["runtime_summary"]
451
+ input_summary = summary["input_summary"]
452
+ lines = [
453
+ "## Run Complete",
454
+ f"- Model: `{summary['model_filename']}`",
455
+ f"- Frames used: `{summary['frames_used']}`",
456
+ f"- Input mode: `{input_summary['input_mode']}`",
457
+ f"- Runtime: `{runtime['runtime_seconds']}s` on `{runtime['device']}`",
458
+ f"- GLB confidence percentile: `{summary['confidence_percentile']}`",
459
+ f"- Points kept for GLB: `{summary['points_kept_for_glb']}`",
460
+ ]
461
+ if runtime.get("peak_memory_gb") is not None:
462
+ lines.append(f"- Peak GPU memory: `{runtime['peak_memory_gb']} GB`")
463
+ if input_summary.get("sample_interval"):
464
+ lines.append(f"- Video sample interval: `every {input_summary['sample_interval']} frame(s)`")
465
+ return "\n".join(lines)
466
+
467
+
468
+ def reconstruct_scene(
469
+ image_files: list[Any],
470
+ video_file: Any,
471
+ model_variant: str,
472
+ fps: int,
473
+ max_frames: int,
474
+ num_scale_frames: int,
475
+ keyframe_interval: int,
476
+ conf_percentile: float,
477
+ ):
478
+ max_frames = max(2, min(int(max_frames), MAX_FRAMES_HARD_LIMIT))
479
+ num_scale_frames = max(1, int(num_scale_frames))
480
+ keyframe_interval = max(1, int(keyframe_interval))
481
+ conf_percentile = float(conf_percentile)
482
+
483
+ images, image_paths, work_dir, input_summary = _prepare_inputs(
484
+ image_files=image_files or [],
485
+ video_file=video_file,
486
+ fps=int(fps),
487
+ max_frames=max_frames,
488
+ )
489
+
490
+ num_scale_frames = min(num_scale_frames, int(images.shape[0]))
491
+ predictions, images_cpu, runtime_summary = _run_inference(
492
+ images,
493
+ model_variant=model_variant,
494
+ num_scale_frames=num_scale_frames,
495
+ keyframe_interval=keyframe_interval,
496
+ )
497
+
498
+ glb_path, artifact_path, summary = _export_outputs(
499
+ work_dir=work_dir,
500
+ image_paths=image_paths,
501
+ predictions=predictions,
502
+ images_cpu=images_cpu,
503
+ input_summary=input_summary,
504
+ runtime_summary=runtime_summary,
505
+ model_variant=model_variant,
506
+ num_scale_frames=num_scale_frames,
507
+ keyframe_interval=keyframe_interval,
508
+ conf_percentile=conf_percentile,
509
+ )
510
+
511
+ preview_path = str(work_dir / "preview.png")
512
+ status = _format_status(summary)
513
+ return glb_path, preview_path, artifact_path, summary, status
514
+
515
+
516
+ def _build_startup_markdown() -> str:
517
+ if not STARTUP_NOTES:
518
+ return (
519
+ "Short-form LingBot-Map Space for Hugging Face ZeroGPU. "
520
+ "It uses the upstream checkpoint files, SDPA inference, and exports a GLB scene plus a zipped results bundle."
521
+ )
522
+ return "\n".join([f"- {note}" for note in STARTUP_NOTES])
523
+
524
+
525
+ CSS = """
526
+ .shell {
527
+ max-width: 1180px;
528
+ margin: 0 auto;
529
+ }
530
+ .headline {
531
+ background: linear-gradient(135deg, #f3ead7 0%, #d6e6d4 100%);
532
+ border: 1px solid #d9ccb3;
533
+ border-radius: 20px;
534
+ padding: 20px 24px;
535
+ }
536
+ .headline h1 {
537
+ margin: 0 0 8px 0;
538
+ color: #14231a;
539
+ }
540
+ .headline p {
541
+ margin: 0;
542
+ color: #304437;
543
+ }
544
+ """
545
+
546
+
547
+ _eager_load_default_model()
548
+
549
+
550
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft(primary_hue="green", secondary_hue="amber"), title="LingBot-Map ZeroGPU Demo") as demo:
551
+ gr.Markdown("<div class='shell'>")
552
+ with gr.Row():
553
+ gr.Image(value=str(ROOT / "assets" / "teaser.png"), show_label=False, interactive=False, min_width=320)
554
+ gr.Markdown(
555
+ """
556
+ <div class="headline">
557
+ <h1>LingBot-Map ZeroGPU Demo</h1>
558
+ <p>Upload ordered images or a short video. The Space samples up to 24 frames, runs the SDPA fallback path, and exports a GLB scene plus a zipped artifact bundle.</p>
559
+ </div>
560
+ """
561
+ )
562
+
563
+ gr.Markdown(_build_startup_markdown())
564
+
565
+ with gr.Row():
566
+ with gr.Column(scale=1):
567
+ image_files = gr.File(
568
+ label="Ordered images",
569
+ file_count="multiple",
570
+ file_types=["image"],
571
+ type="filepath",
572
+ )
573
+ video_file = gr.File(
574
+ label="Or upload one video",
575
+ file_types=["video"],
576
+ type="filepath",
577
+ )
578
+ model_variant = gr.Dropdown(
579
+ choices=[("Balanced", "balanced"), ("Long", "long"), ("Stage-1", "stage1")],
580
+ value="balanced",
581
+ label="Checkpoint",
582
+ )
583
+ fps = gr.Slider(minimum=1, maximum=12, step=1, value=DEFAULT_FPS, label="Video sampling FPS")
584
+ max_frames = gr.Slider(minimum=2, maximum=MAX_FRAMES_HARD_LIMIT, step=1, value=DEFAULT_MAX_FRAMES, label="Max frames")
585
+ num_scale_frames = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_SCALE_FRAMES, label="Scale frames")
586
+ keyframe_interval = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_KEYFRAME_INTERVAL, label="Keyframe interval")
587
+ conf_percentile = gr.Slider(
588
+ minimum=0,
589
+ maximum=90,
590
+ step=5,
591
+ value=DEFAULT_CONF_PERCENTILE,
592
+ label="GLB confidence percentile",
593
+ )
594
+ run_button = gr.Button("Reconstruct Scene", variant="primary")
595
+
596
+ with gr.Column(scale=1):
597
+ model_preview = gr.Model3D(label="3D preview", clear_color=[0.97, 0.94, 0.88, 1.0])
598
+ preview_image = gr.Image(label="Preview strip", interactive=False)
599
+ artifact_file = gr.File(label="Download bundle")
600
+ summary_json = gr.JSON(label="Run summary")
601
+ status_markdown = gr.Markdown()
602
+
603
+ run_button.click(
604
+ fn=reconstruct_scene,
605
+ inputs=[
606
+ image_files,
607
+ video_file,
608
+ model_variant,
609
+ fps,
610
+ max_frames,
611
+ num_scale_frames,
612
+ keyframe_interval,
613
+ conf_percentile,
614
+ ],
615
+ outputs=[
616
+ model_preview,
617
+ preview_image,
618
+ artifact_file,
619
+ summary_json,
620
+ status_markdown,
621
+ ],
622
+ show_progress="full",
623
+ )
624
+ gr.Markdown("</div>")
625
+
626
+ demo.queue(default_concurrency_limit=1)
627
+
628
+
629
+ if __name__ == "__main__":
630
+ demo.launch()
assets/teaser.png ADDED

Git LFS Details

  • SHA256: eac66d9961c46782307c784bb0bad0c45af822f817f2ba697e16e1cce50968f2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
lingbot_map/__init__.py ADDED
File without changes
lingbot_map/aggregator/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .stream import AggregatorStream
2
+ from .base import AggregatorBase
lingbot_map/aggregator/base.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AggregatorBase - Base class for all Aggregator implementations.
3
+
4
+ Provides shared functionality:
5
+ - Patch embedding (DINOv2)
6
+ - Special tokens (camera, register, scale)
7
+ - Block building
8
+ - Common forward pass structure
9
+
10
+ Subclasses implement mode-specific attention logic.
11
+ """
12
+
13
+ import logging
14
+ import torch
15
+ import torch.nn as nn
16
+ from abc import ABC, abstractmethod
17
+ from typing import Optional, Tuple, List
18
+
19
+ from lingbot_map.layers import PatchEmbed
20
+ from lingbot_map.layers.block import Block
21
+ from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
22
+ from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
27
+ _RESNET_STD = [0.229, 0.224, 0.225]
28
+
29
+
30
+ def slice_expand_and_flatten(token, B, S, first_num_frame=1):
31
+ """
32
+ Helper function to slice, expand and flatten tokens.
33
+
34
+ Args:
35
+ token: Token tensor [1, 2, N, C] where first index is for first frames
36
+ B: Batch size
37
+ S: Sequence length
38
+ first_num_frame: Number of frames to use first token for
39
+
40
+ Returns:
41
+ Flattened tokens [B*S, N, C]
42
+ """
43
+ # token shape: [1, 2, N, C]
44
+ # Expand to [B, S, N, C]
45
+ if first_num_frame > 1:
46
+ # Use first token for first first_num_frame frames, second for rest
47
+ token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C]
48
+ token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C]
49
+ token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
50
+ else:
51
+ # Use first token for first frame, second for rest
52
+ token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C]
53
+ token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C]
54
+ token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
55
+
56
+ # Flatten to [B*S, N, C]
57
+ return token_expanded.reshape(B * S, -1, token.shape[-1])
58
+
59
+
60
+ class AggregatorBase(nn.Module, ABC):
61
+ """
62
+ Base class for all Aggregator implementations.
63
+
64
+ Handles shared components:
65
+ - Patch embedding (DINOv2 or conv)
66
+ - Special tokens (camera, register, optionally scale)
67
+ - Block creation (frame + global)
68
+ - RoPE (2D rotary position embeddings)
69
+ - Common forward pass scaffolding
70
+
71
+ Subclasses must implement:
72
+ - _process_global_attention(): Mode-specific cross-frame attention logic
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ # Architecture parameters
78
+ img_size=518,
79
+ patch_size=14,
80
+ embed_dim=1024,
81
+ depth=24,
82
+ num_heads=16,
83
+ mlp_ratio=4.0,
84
+ num_register_tokens=4,
85
+ # Block configuration
86
+ block_fn=Block,
87
+ qkv_bias=True,
88
+ proj_bias=True,
89
+ ffn_bias=True,
90
+ qk_norm=True,
91
+ init_values=0.01,
92
+ # Patch embedding
93
+ patch_embed="dinov2_vitl14_reg",
94
+ pretrained_path=None,
95
+ # Attention pattern
96
+ aa_order=["frame", "global"],
97
+ aa_block_size=1,
98
+ # RoPE
99
+ rope_freq=100,
100
+ disable_global_rope=False,
101
+ # Gradient checkpointing
102
+ use_reentrant: bool = False,
103
+ use_gradient_checkpoint: bool = True,
104
+ ):
105
+ super().__init__()
106
+
107
+ # Store configuration
108
+ self.img_size = img_size
109
+ self.patch_size = patch_size
110
+ self.embed_dim = embed_dim
111
+ self.depth = depth
112
+ self.num_heads = num_heads
113
+ self.mlp_ratio = mlp_ratio
114
+ self.num_register_tokens = num_register_tokens
115
+ self.aa_order = aa_order
116
+ self.aa_block_size = aa_block_size
117
+ self.disable_global_rope = disable_global_rope
118
+ self.use_reentrant = use_reentrant
119
+ self.use_gradient_checkpoint = use_gradient_checkpoint
120
+ self.pretrained_path = pretrained_path
121
+ self.enable_ulysses_cp = False # CP disabled
122
+
123
+ print("pretrained_path:", self.pretrained_path)
124
+
125
+ # Validate depth
126
+ if self.depth % self.aa_block_size != 0:
127
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
128
+ self.aa_block_num = self.depth // self.aa_block_size
129
+
130
+ # Build patch embedding
131
+ self._build_patch_embed(
132
+ patch_embed=patch_embed,
133
+ img_size=img_size,
134
+ patch_size=patch_size,
135
+ num_register_tokens=num_register_tokens,
136
+ embed_dim=embed_dim,
137
+ pretrained_path=pretrained_path
138
+ )
139
+
140
+ # Initialize RoPE
141
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
142
+ self.position_getter = PositionGetter() if self.rope is not None else None
143
+
144
+ # Build blocks (frame + global)
145
+ self._build_blocks(
146
+ block_fn=block_fn,
147
+ depth=depth,
148
+ embed_dim=embed_dim,
149
+ num_heads=num_heads,
150
+ mlp_ratio=mlp_ratio,
151
+ qkv_bias=qkv_bias,
152
+ proj_bias=proj_bias,
153
+ ffn_bias=ffn_bias,
154
+ init_values=init_values,
155
+ qk_norm=qk_norm,
156
+ )
157
+
158
+ # Setup special tokens (camera, register, optionally scale)
159
+ self._setup_special_tokens()
160
+
161
+ # Register normalization constants
162
+ for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
163
+ self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
164
+
165
+ # Initialize from DINO checkpoint if available
166
+ if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
167
+ self._init_blocks_from_dino(self._dino_checkpoint)
168
+ del self._dino_checkpoint # Free memory
169
+
170
+ def _build_patch_embed(
171
+ self,
172
+ patch_embed: str,
173
+ img_size: int,
174
+ patch_size: int,
175
+ num_register_tokens: int,
176
+ embed_dim: int,
177
+ pretrained_path: str,
178
+ interpolate_antialias=True,
179
+ interpolate_offset=0.0,
180
+ block_chunks=0,
181
+ init_values=1.0,
182
+ ):
183
+ """
184
+ Build patch embedding layer.
185
+
186
+ Supports:
187
+ - "conv": Simple convolutional patch embedding
188
+ - "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
189
+ """
190
+ if "conv" in patch_embed:
191
+ self.patch_embed = PatchEmbed(
192
+ img_size=img_size,
193
+ patch_size=patch_size,
194
+ in_chans=3,
195
+ embed_dim=embed_dim
196
+ )
197
+ self._dino_checkpoint = None
198
+
199
+ else:
200
+ vit_models = {
201
+ "dinov2_vitl14_reg": vit_large,
202
+ "dinov2_vitb14_reg": vit_base,
203
+ "dinov2_vits14_reg": vit_small,
204
+ "dinov2_vitg2_reg": vit_giant2,
205
+ }
206
+
207
+ if patch_embed not in vit_models:
208
+ raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")
209
+
210
+ self.patch_embed = vit_models[patch_embed](
211
+ img_size=img_size,
212
+ patch_size=patch_size,
213
+ num_register_tokens=num_register_tokens,
214
+ interpolate_antialias=interpolate_antialias,
215
+ interpolate_offset=interpolate_offset,
216
+ block_chunks=block_chunks,
217
+ init_values=init_values,
218
+ )
219
+
220
+ # Load pretrained weights
221
+ try:
222
+ ckpt = torch.load(pretrained_path)
223
+ del ckpt['pos_embed']
224
+ logger.info("Loading pretrained weights for DINOv2")
225
+ missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
226
+ logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
227
+
228
+ # Store checkpoint for block initialization
229
+ self._dino_checkpoint = ckpt
230
+ except Exception as e:
231
+ logger.warning(f"Failed to load pretrained weights: {e}")
232
+ self._dino_checkpoint = None
233
+
234
+ # Disable gradients for mask token
235
+ if hasattr(self.patch_embed, "mask_token"):
236
+ self.patch_embed.mask_token.requires_grad_(False)
237
+
238
+ @abstractmethod
239
+ def _build_blocks(
240
+ self,
241
+ block_fn,
242
+ depth: int,
243
+ embed_dim: int,
244
+ num_heads: int,
245
+ mlp_ratio: float,
246
+ qkv_bias: bool,
247
+ proj_bias: bool,
248
+ ffn_bias: bool,
249
+ init_values: float,
250
+ qk_norm: bool,
251
+ ):
252
+ """
253
+ Build frame_blocks and global_blocks.
254
+
255
+ Subclasses implement mode-specific block creation.
256
+
257
+ Must create:
258
+ - self.frame_blocks: nn.ModuleList of frame attention blocks
259
+ - self.global_blocks: nn.ModuleList of global attention blocks
260
+ """
261
+ pass
262
+
263
+ @abstractmethod
264
+ def _setup_special_tokens(self):
265
+ """
266
+ Setup camera token, register tokens, and optionally scale token.
267
+
268
+ Subclasses implement mode-specific token initialization.
269
+
270
+ Must create:
271
+ - self.camera_token
272
+ - self.register_token
273
+ - self.scale_token (optional, for causal mode)
274
+ - self.patch_start_idx
275
+ - self.num_special_tokens
276
+ """
277
+ pass
278
+
279
+ def _init_blocks_from_dino(self, dino_ckpt: dict):
280
+ """
281
+ Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.
282
+
283
+ Args:
284
+ dino_ckpt: Checkpoint dictionary from DINOv2 model
285
+ """
286
+ logger.info("Initializing blocks from DINOv2 pretrained weights")
287
+
288
+ # Extract block keys
289
+ dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
290
+ if not dino_block_keys:
291
+ logger.warning("No 'blocks' found in DINO checkpoint")
292
+ return
293
+
294
+ # Get block indices
295
+ block_indices = set()
296
+ for key in dino_block_keys:
297
+ parts = key.split('.')
298
+ if len(parts) > 1 and parts[1].isdigit():
299
+ block_indices.add(int(parts[1]))
300
+
301
+ num_dino_blocks = len(block_indices)
302
+ print(f"Found {num_dino_blocks} blocks in DINO checkpoint")
303
+
304
+ # Initialize frame_blocks
305
+ for i, block in enumerate(self.frame_blocks):
306
+ dino_block_idx = i % num_dino_blocks
307
+ block_state_dict = {}
308
+ prefix = f'blocks.{dino_block_idx}.'
309
+ for key, value in dino_ckpt.items():
310
+ if key.startswith(prefix):
311
+ new_key = key[len(prefix):]
312
+ block_state_dict[new_key] = value
313
+
314
+ if block_state_dict:
315
+ missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
316
+ if i == 0: # Only log for first block to avoid spam
317
+ print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
318
+
319
+ # Initialize global_blocks
320
+ for i, block in enumerate(self.global_blocks):
321
+ dino_block_idx = i % num_dino_blocks
322
+ block_state_dict = {}
323
+ prefix = f'blocks.{dino_block_idx}.'
324
+ for key, value in dino_ckpt.items():
325
+ if key.startswith(prefix):
326
+ new_key = key[len(prefix):]
327
+ block_state_dict[new_key] = value
328
+
329
+ if block_state_dict:
330
+ missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
331
+ if i == 0: # Only log for first block to avoid spam
332
+ print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
333
+
334
+ logger.info("Successfully initialized blocks from DINOv2 weights")
335
+
336
+ def _embed_images(
337
+ self,
338
+ images: torch.Tensor,
339
+ num_frame_for_scale: Optional[int] = None,
340
+ ) -> Tuple[torch.Tensor, int, int, int, int, int]:
341
+ """
342
+ Embed images and prepare for attention processing.
343
+
344
+ Handles:
345
+ - Image normalization
346
+ - Patch embedding
347
+ - Special token concatenation
348
+ - Position embedding
349
+
350
+ Args:
351
+ images: Input images [B, S, 3, H, W] in range [0, 1]
352
+ num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)
353
+
354
+ Returns:
355
+ (tokens, B, S, S, P, C):
356
+ tokens: Embedded tokens [B*S, P, C]
357
+ B: Batch size
358
+ S: Sequence length
359
+ S: Same as above (no CP slicing)
360
+ P: Number of tokens per frame (patches + special tokens)
361
+ C: Embedding dimension
362
+ """
363
+ B, S, C_in, H, W = images.shape
364
+
365
+ if C_in != 3:
366
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
367
+
368
+ # Normalize images
369
+ images = (images - self._resnet_mean) / self._resnet_std
370
+
371
+ # No CP slicing: S_local == S_global
372
+ S_local = S
373
+ S_global = S
374
+
375
+ # Reshape for patch embedding [B*S, C, H, W]
376
+ images = images.view(B * S, C_in, H, W)
377
+
378
+ # Patch embedding
379
+ patch_tokens = self.patch_embed(images)
380
+ if isinstance(patch_tokens, dict):
381
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
382
+
383
+ _, P_patch, C = patch_tokens.shape
384
+
385
+ # Prepare special tokens
386
+ special_tokens = self._prepare_special_tokens(
387
+ B, S_local, S_global, C,
388
+ num_frame_for_scale=num_frame_for_scale
389
+ )
390
+
391
+ # Concatenate special tokens + patch tokens
392
+ tokens = torch.cat([special_tokens, patch_tokens], dim=1)
393
+
394
+ _, P, C = tokens.shape
395
+
396
+ return tokens, B, S_local, S_global, P, C
397
+
398
+ @abstractmethod
399
+ def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
400
+ """
401
+ Prepare special tokens (camera, register, optionally scale).
402
+
403
+ Subclasses implement mode-specific token preparation.
404
+
405
+ Args:
406
+ B: Batch size
407
+ S_local: Local sequence length
408
+ S_global: Global sequence length
409
+ C: Embedding dimension
410
+ **kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)
411
+
412
+ Returns:
413
+ Special tokens [B*S, N_special, C]
414
+ """
415
+ pass
416
+
417
+ def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
418
+ """
419
+ Get 2D position embeddings for RoPE.
420
+
421
+ Args:
422
+ B: Batch size
423
+ S: Sequence length
424
+ H: Image height
425
+ W: Image width
426
+ device: Device to create positions on
427
+
428
+ Returns:
429
+ Position tensor [B*S, P, 2] or None if rope is disabled
430
+ """
431
+ if self.rope is None:
432
+ return None
433
+
434
+ # Get patch positions
435
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
436
+
437
+ # Add offset for patch tokens (skip special tokens at pos=0)
438
+ if self.patch_start_idx > 0:
439
+ pos = pos + 1
440
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
441
+ pos = torch.cat([pos_special, pos], dim=1)
442
+
443
+ return pos
444
+
445
+ def _process_frame_attention(
446
+ self,
447
+ tokens: torch.Tensor,
448
+ B: int,
449
+ S: int,
450
+ P: int,
451
+ C: int,
452
+ frame_idx: int,
453
+ pos: Optional[torch.Tensor] = None,
454
+ ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
455
+ """
456
+ Process frame attention blocks.
457
+
458
+ Frame attention operates independently per frame (no cross-frame communication).
459
+ Tokens stay in shape [B*S, P, C].
460
+
461
+ Args:
462
+ tokens: Input tokens [B*S, P, C]
463
+ B: Batch size
464
+ S: Sequence length
465
+ P: Tokens per frame
466
+ C: Embedding dimension
467
+ frame_idx: Current frame block index
468
+ pos: Position embeddings [B*S, P, 2]
469
+
470
+ Returns:
471
+ (tokens, frame_idx, intermediates):
472
+ tokens: Output tokens [B*S, P, C]
473
+ frame_idx: Updated frame block index
474
+ intermediates: List of intermediate outputs [B, S, P, C]
475
+ """
476
+ # Ensure correct shape
477
+ if tokens.shape != (B * S, P, C):
478
+ tokens = tokens.view(B * S, P, C)
479
+
480
+ if pos is not None and pos.shape != (B * S, P, 2):
481
+ pos = pos.view(B * S, P, 2)
482
+
483
+ intermediates = []
484
+
485
+ # Process blocks
486
+ for i in range(self.aa_block_size):
487
+ if self.training and self.use_gradient_checkpoint:
488
+ from torch.utils.checkpoint import checkpoint
489
+ tokens = checkpoint(
490
+ self.frame_blocks[frame_idx],
491
+ tokens,
492
+ pos,
493
+ False, # enable_ulysses_cp (always False)
494
+ use_reentrant=self.use_reentrant
495
+ )
496
+ else:
497
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)
498
+
499
+ frame_idx += 1
500
+ intermediates.append(tokens.view(B, S, P, C))
501
+
502
+ return tokens, frame_idx, intermediates
503
+
504
+ @abstractmethod
505
+ def _process_global_attention(
506
+ self,
507
+ tokens: torch.Tensor,
508
+ B: int,
509
+ S_local: int,
510
+ S_global: int,
511
+ P: int,
512
+ C: int,
513
+ global_idx: int,
514
+ pos: Optional[torch.Tensor] = None,
515
+ **kwargs
516
+ ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
517
+ """
518
+ Process global (cross-frame) attention blocks.
519
+
520
+ Subclasses implement mode-specific attention logic.
521
+
522
+ Args:
523
+ tokens: Input tokens
524
+ B: Batch size
525
+ S_local: Local sequence length
526
+ S_global: Global sequence length
527
+ P: Tokens per frame
528
+ C: Embedding dimension
529
+ global_idx: Current global block index
530
+ pos: Position embeddings
531
+ **kwargs: Mode-specific parameters
532
+
533
+ Returns:
534
+ (tokens, global_idx, intermediates):
535
+ tokens: Output tokens
536
+ global_idx: Updated global block index
537
+ intermediates: List of intermediate outputs
538
+ """
539
+ pass
540
+
541
+ def forward(
542
+ self,
543
+ images: torch.Tensor,
544
+ selected_idx: Optional[List[int]] = None,
545
+ # Mode-specific parameters
546
+ num_frame_for_scale: Optional[int] = None,
547
+ sliding_window_size: Optional[int] = None,
548
+ num_frame_per_block: int = 1,
549
+ ) -> Tuple[List[torch.Tensor], int]:
550
+ """
551
+ Forward pass.
552
+
553
+ Args:
554
+ images: Input images [B, S, 3, H, W] in range [0, 1]
555
+ selected_idx: Which block indices to output (None = all)
556
+ num_frame_for_scale: Number of frames for scale estimation (causal mode)
557
+ sliding_window_size: Sliding window size in blocks (causal mode)
558
+ num_frame_per_block: Number of frames per processing block (causal mode)
559
+
560
+ Returns:
561
+ (output_list, patch_start_idx):
562
+ output_list: List of block outputs [B, S, P, 2C]
563
+ patch_start_idx: Index where patch tokens start
564
+ """
565
+ B, S_input, _, H, W = images.shape
566
+
567
+ # Embed images
568
+ tokens, B, S_local, S_global, P, C = self._embed_images(
569
+ images,
570
+ num_frame_for_scale=num_frame_for_scale,
571
+ )
572
+
573
+ # Get position embeddings
574
+ pos_local = self._get_positions(B, S_local, H, W, device=images.device)
575
+ pos_global = self._get_positions(B, S_global, H, W, device=images.device)
576
+
577
+ # Alternating attention
578
+ frame_idx = 0
579
+ global_idx = 0
580
+ output_list = []
581
+
582
+ for block_group_idx in range(self.aa_block_num):
583
+ for attn_type in self.aa_order:
584
+ if attn_type == "frame":
585
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
586
+ tokens, B, S_local, P, C, frame_idx, pos=pos_local
587
+ )
588
+ elif attn_type == "global":
589
+ tokens, global_idx, global_intermediates = self._process_global_attention(
590
+ tokens, B, S_local, S_global, P, C, global_idx,
591
+ pos=pos_global,
592
+ num_frame_for_scale=num_frame_for_scale,
593
+ sliding_window_size=sliding_window_size,
594
+ num_frame_per_block=num_frame_per_block,
595
+ image_height=H,
596
+ image_width=W,
597
+ )
598
+ else:
599
+ raise ValueError(f"Unknown attention type: {attn_type}")
600
+
601
+ # Collect outputs
602
+ if selected_idx is None or block_group_idx in selected_idx:
603
+ for i in range(len(frame_intermediates)):
604
+ # Concatenate frame and global intermediates [B, S, P, 2C]
605
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
606
+ output_list.append(concat_inter)
607
+
608
+ return output_list, self.patch_start_idx
lingbot_map/aggregator/stream.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AggregatorStream - Streaming causal aggregator with FlashInfer KV cache.
3
+
4
+ Provides:
5
+ - Temporal causal attention
6
+ - Sliding window support
7
+ - Scale token for scale estimation frames
8
+ - Streaming inference with FlashInfer paged KV cache
9
+ """
10
+
11
+ import logging
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing import Optional, Tuple, List
15
+
16
+ from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock
17
+ from lingbot_map.layers.rope import WanRotaryPosEmbed
18
+ from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class AggregatorStream(AggregatorBase):
24
+ """
25
+ Streaming causal aggregator with FlashInfer paged KV cache.
26
+
27
+ Features:
28
+ - Temporal causal attention (each frame only attends to past frames)
29
+ - Sliding window support to limit attention scope
30
+ - Scale token for scale estimation frames
31
+ - Streaming inference with FlashInfer KV cache
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ # Causal-specific parameters
37
+ sliding_window_size: int = -1,
38
+ num_frame_for_scale: int = 1,
39
+ num_random_frames: int = 0,
40
+ attend_to_special_tokens: bool = False,
41
+ attend_to_scale_frames: bool = False,
42
+ enable_3d_rope: bool = False,
43
+ max_frame_num: int = 1024,
44
+ # KV cache parameters
45
+ kv_cache_sliding_window: int = 64,
46
+ kv_cache_scale_frames: int = 8,
47
+ kv_cache_cross_frame_special: bool = True,
48
+ kv_cache_include_scale_frames: bool = True,
49
+ kv_cache_camera_only: bool = False,
50
+ # Base class parameters via **kwargs
51
+ **kwargs
52
+ ):
53
+ """
54
+ Initialize AggregatorStream.
55
+
56
+ Args:
57
+ sliding_window_size: Sliding window size in blocks (-1 for full causal)
58
+ num_frame_for_scale: Number of scale estimation frames
59
+ num_random_frames: Number of random frames for long-range dependencies
60
+ attend_to_special_tokens: Enable cross-frame special token attention
61
+ attend_to_scale_frames: Include scale frames in attention
62
+ enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache
63
+ max_frame_num: Maximum number of frames for 3D RoPE
64
+ kv_cache_sliding_window: Sliding window size for KV cache eviction
65
+ kv_cache_scale_frames: Number of scale frames to keep in KV cache
66
+ kv_cache_cross_frame_special: Keep special tokens from evicted frames
67
+ kv_cache_include_scale_frames: Include scale frames in KV cache
68
+ kv_cache_camera_only: Only keep camera tokens from evicted frames
69
+ **kwargs: Base class parameters
70
+ """
71
+ self.sliding_window_size = sliding_window_size
72
+ self.num_frame_for_scale = num_frame_for_scale
73
+ self.num_random_frames = num_random_frames
74
+ self.attend_to_special_tokens = attend_to_special_tokens
75
+ self.attend_to_scale_frames = attend_to_scale_frames
76
+ self.enable_3d_rope = enable_3d_rope
77
+ self.max_frame_num = max_frame_num
78
+ # KV cache parameters
79
+ self.kv_cache_sliding_window = kv_cache_sliding_window
80
+ self.kv_cache_scale_frames = kv_cache_scale_frames
81
+ self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
82
+ self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
83
+ self.kv_cache_camera_only = kv_cache_camera_only
84
+
85
+ # Pop kwargs that are passed but not needed by base class
86
+ kwargs.pop('enable_stream_inference', None)
87
+ use_flashinfer = kwargs.pop('use_flashinfer', True)
88
+ kwargs.pop('use_flexflash', None)
89
+ use_sdpa = kwargs.pop('use_sdpa', False)
90
+
91
+ # Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache)
92
+ self.use_sdpa = use_sdpa
93
+ self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested
94
+
95
+ # Call parent __init__
96
+ super().__init__(**kwargs)
97
+
98
+ # Initialize KV cache
99
+ self._init_kv_cache()
100
+
101
+ # Initialize 3D RoPE if enabled
102
+ if self.enable_3d_rope:
103
+ self._init_3d_rope()
104
+
105
+ def _build_blocks(
106
+ self,
107
+ block_fn,
108
+ depth: int,
109
+ embed_dim: int,
110
+ num_heads: int,
111
+ mlp_ratio: float,
112
+ qkv_bias: bool,
113
+ proj_bias: bool,
114
+ ffn_bias: bool,
115
+ init_values: float,
116
+ qk_norm: bool,
117
+ ):
118
+ """Build frame and global blocks for streaming causal mode."""
119
+ block_params = dict(
120
+ dim=embed_dim,
121
+ num_heads=num_heads,
122
+ mlp_ratio=mlp_ratio,
123
+ qkv_bias=qkv_bias,
124
+ proj_bias=proj_bias,
125
+ ffn_bias=ffn_bias,
126
+ init_values=init_values,
127
+ qk_norm=qk_norm,
128
+ )
129
+
130
+ # Frame blocks: Standard Block + RoPE
131
+ self.frame_blocks = nn.ModuleList([
132
+ block_fn(**block_params, rope=self.rope)
133
+ for _ in range(depth)
134
+ ])
135
+
136
+ # Global blocks: FlashInferBlock (default) or SDPABlock (fallback)
137
+ GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock
138
+ self.global_blocks = nn.ModuleList([
139
+ GlobalBlockCls(
140
+ **block_params,
141
+ rope=self.rope if not self.disable_global_rope else None,
142
+ kv_cache_sliding_window=self.kv_cache_sliding_window,
143
+ kv_cache_scale_frames=self.kv_cache_scale_frames,
144
+ kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
145
+ kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
146
+ kv_cache_camera_only=self.kv_cache_camera_only,
147
+ )
148
+ for _ in range(depth)
149
+ ])
150
+
151
+ def _setup_special_tokens(self):
152
+ """Setup camera, register, and scale tokens for causal mode."""
153
+ # Camera token
154
+ self.camera_token = nn.Parameter(
155
+ torch.randn(1, 2, 1, self.embed_dim)
156
+ )
157
+
158
+ # Register tokens
159
+ if self.num_register_tokens > 0:
160
+ self.register_token = nn.Parameter(
161
+ torch.randn(1, 2, self.num_register_tokens, self.embed_dim)
162
+ )
163
+
164
+ # Scale token (causal mode specific)
165
+ self.scale_token = nn.Parameter(
166
+ torch.ones(1, 2, 1, self.embed_dim)
167
+ )
168
+
169
+ # Initialize
170
+ nn.init.normal_(self.camera_token, std=1e-6)
171
+ if self.num_register_tokens > 0:
172
+ nn.init.normal_(self.register_token, std=1e-6)
173
+ nn.init.normal_(self.scale_token, std=1e-6)
174
+
175
+ # Token indexing (includes scale token)
176
+ self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale
177
+ self.num_special_tokens = 1 + self.num_register_tokens + 1
178
+
179
+ def _init_kv_cache(self):
180
+ """Initialize KV cache for streaming inference."""
181
+ self.kv_cache_manager = None # FlashInfer (lazy-initialized)
182
+ self.kv_cache = {} # Dict-based cache for SDPA
183
+ self.total_frames_processed = 0
184
+ self._cached_pos3d = None
185
+
186
+ if self.use_sdpa:
187
+ # Dict-based KV cache for SDPA
188
+ if hasattr(self, 'depth'):
189
+ for i in range(self.depth):
190
+ self.kv_cache[f"k_{i}"] = None
191
+ self.kv_cache[f"v_{i}"] = None
192
+ self.kv_cache[f"k_{i}_special"] = None
193
+ self.kv_cache[f"v_{i}_special"] = None
194
+ logger.info(f"SDPA KV cache initialized with {self.depth} blocks")
195
+ else:
196
+ logger.info("FlashInfer KV cache will be lazily initialized on first forward")
197
+
198
+ def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
199
+ """Lazily initialize FlashInferKVCacheManager on first use.
200
+
201
+ Args:
202
+ device: Device for cache tensors.
203
+ dtype: Data type for cache tensors.
204
+ tokens_per_frame: Actual number of tokens per frame (patches + specials).
205
+ If None, falls back to assuming square images of self.img_size.
206
+ """
207
+ if self.kv_cache_manager is None:
208
+ from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
209
+ num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L
210
+ head_dim = 64
211
+ if tokens_per_frame is None:
212
+ tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
213
+ # max_num_frames: scale + window + headroom
214
+ max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16
215
+ self.kv_cache_manager = FlashInferKVCacheManager(
216
+ num_blocks=self.depth,
217
+ max_num_frames=max_num_frames,
218
+ tokens_per_frame=tokens_per_frame,
219
+ num_heads=num_heads,
220
+ head_dim=head_dim,
221
+ dtype=dtype,
222
+ device=device,
223
+ num_special_tokens=self.num_special_tokens,
224
+ scale_frames=self.kv_cache_scale_frames,
225
+ sliding_window=self.kv_cache_sliding_window,
226
+ max_total_frames=self.max_frame_num + 100,
227
+ force_fp32=getattr(self, 'kv_cache_force_fp32', False),
228
+ fa3=getattr(self, 'kv_cache_fa3', False),
229
+ )
230
+ logger.info(
231
+ f"FlashInfer KV cache manager initialized: {self.depth} blocks, "
232
+ f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}"
233
+ )
234
+ return self.kv_cache_manager
235
+
236
+ def clean_kv_cache(self):
237
+ """Clean KV cache (call this when starting a new sequence)."""
238
+ if self.kv_cache_manager is not None:
239
+ self.kv_cache_manager.reset()
240
+ if self.kv_cache:
241
+ for key in list(self.kv_cache.keys()):
242
+ if key == "_skip_append":
243
+ self.kv_cache[key] = False
244
+ else:
245
+ self.kv_cache[key] = None
246
+ self.total_frames_processed = 0
247
+ self._cached_pos3d = None
248
+ logger.info("KV cache cleaned")
249
+
250
+ def _init_3d_rope(self):
251
+ """Initialize 3D RoPE for streaming inference."""
252
+ if not self.enable_3d_rope:
253
+ self.rope3d = None
254
+ return
255
+
256
+ num_heads = 16
257
+ head_dim = self.embed_dim // num_heads
258
+
259
+ self.rope3d = WanRotaryPosEmbed(
260
+ attention_head_dim=head_dim,
261
+ patch_size=(1, self.patch_size, self.patch_size),
262
+ max_seq_len=self.max_frame_num,
263
+ )
264
+ logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}")
265
+
266
+ def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end):
267
+ """
268
+ Generate 3D RoPE positions for streaming mode with correct global frame indices.
269
+
270
+ Args:
271
+ num_frames: Number of frames in current batch
272
+ H, W: Image height and width
273
+ device: Device to create positions on
274
+ f_start: Global start frame index
275
+ f_end: Global end frame index
276
+
277
+ Returns:
278
+ pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor
279
+ """
280
+ if self.rope3d is None:
281
+ return None
282
+
283
+ pph = H // self.patch_size
284
+ ppw = W // self.patch_size
285
+
286
+ pos3d = self.rope3d(
287
+ ppf=num_frames,
288
+ pph=pph,
289
+ ppw=ppw,
290
+ patch_start_idx=self.num_special_tokens,
291
+ device=device,
292
+ f_start=f_start,
293
+ f_end=f_end
294
+ )
295
+ return pos3d
296
+
297
+ def _prepare_special_tokens(
298
+ self,
299
+ B: int,
300
+ S_local: int,
301
+ S_global: int,
302
+ C: int,
303
+ num_frame_for_scale: Optional[int] = None,
304
+ ) -> torch.Tensor:
305
+ """
306
+ Prepare camera, register, and scale tokens.
307
+
308
+ Args:
309
+ B: Batch size
310
+ S_local: Local sequence length
311
+ S_global: Global sequence length
312
+ C: Embedding dimension
313
+ num_frame_for_scale: Number of frames for scale estimation
314
+
315
+ Returns:
316
+ Special tokens [B*S_global, N_special, C]
317
+ """
318
+ # Get effective num_frame_for_scale
319
+ scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale
320
+
321
+ # Check cache state for both backends
322
+ has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0
323
+ has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None
324
+
325
+ # Determine if we're in causal inference mode based on KV cache state
326
+ causal_inference = True
327
+
328
+ if causal_inference and has_flashinfer_cache:
329
+ S_cached = self.kv_cache_manager.num_frames
330
+ S_true = S_cached + S_global
331
+ elif causal_inference and has_sdpa_cache:
332
+ _, _, S_cached, _, _ = self.kv_cache["k_0"].shape
333
+ S_true = S_cached + S_global
334
+ else:
335
+ S_true = S_global
336
+
337
+ # Expand tokens based on mode
338
+ if causal_inference and S_true > S_global:
339
+ # Streaming mode: expand with S_true, then slice to get current frames
340
+ effective_scale_frames = min(scale_frames, S_true)
341
+
342
+ camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
343
+ camera_token = camera_token_full[-S_global:, :, :]
344
+
345
+ register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
346
+ register_token = register_token_full[-S_global:, :, :]
347
+ scale_token_full = slice_expand_and_flatten(
348
+ self.scale_token, B, S_true, first_num_frame=effective_scale_frames
349
+ )
350
+ scale_token = scale_token_full[-S_global:, :, :]
351
+ else:
352
+ # Batch mode or first inference: expand directly
353
+ effective_scale_frames = min(scale_frames, S_global)
354
+
355
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S_global)
356
+ register_token = slice_expand_and_flatten(self.register_token, B, S_global)
357
+ scale_token = slice_expand_and_flatten(
358
+ self.scale_token, B, S_global, first_num_frame=effective_scale_frames
359
+ )
360
+
361
+ special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1)
362
+
363
+ # Verify shape
364
+ expected_shape = (B * S_global, self.num_special_tokens, C)
365
+ assert special_tokens.shape == expected_shape, \
366
+ f"Expected {expected_shape}, got {special_tokens.shape}"
367
+
368
+ return special_tokens
369
+
370
+ def _process_global_attention(
371
+ self,
372
+ tokens: torch.Tensor,
373
+ B: int,
374
+ S_local: int,
375
+ S_global: int,
376
+ P: int,
377
+ C: int,
378
+ global_idx: int,
379
+ pos: Optional[torch.Tensor] = None,
380
+ # Mode-specific parameters
381
+ num_frame_for_scale: Optional[int] = None,
382
+ sliding_window_size: Optional[int] = None,
383
+ num_frame_per_block: int = 1,
384
+ **kwargs,
385
+ ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
386
+ """
387
+ Process causal global attention via FlashInfer streaming path.
388
+
389
+ Args:
390
+ tokens: Input tokens
391
+ B: Batch size
392
+ S_local: Local sequence length
393
+ S_global: Global sequence length
394
+ P: Tokens per frame
395
+ C: Embedding dimension
396
+ global_idx: Current global block index
397
+ pos: Position embeddings
398
+ num_frame_for_scale: Number of frames for scale estimation
399
+ sliding_window_size: Sliding window size in blocks
400
+ num_frame_per_block: Number of frames per processing block
401
+
402
+ Returns:
403
+ (tokens, global_idx, intermediates)
404
+ """
405
+ # Extract image dimensions from kwargs for 3D RoPE
406
+ image_height = kwargs.get('image_height', self.img_size)
407
+ image_width = kwargs.get('image_width', self.img_size)
408
+
409
+ return self._process_causal_stream(
410
+ tokens, B, S_local, S_global, P, C, global_idx, pos,
411
+ num_frame_per_block, sliding_window_size, num_frame_for_scale,
412
+ image_height=image_height, image_width=image_width
413
+ )
414
+
415
+ def _process_causal_stream(
416
+ self,
417
+ tokens: torch.Tensor,
418
+ B: int,
419
+ S_local: int,
420
+ S_global: int,
421
+ P: int,
422
+ C: int,
423
+ global_idx: int,
424
+ pos: Optional[torch.Tensor] = None,
425
+ num_frame_per_block: int = 1,
426
+ sliding_window_size: Optional[int] = None,
427
+ num_frame_for_scale: Optional[int] = None,
428
+ image_height: Optional[int] = None,
429
+ image_width: Optional[int] = None,
430
+ ):
431
+ """
432
+ Causal attention for streaming inference using FlashInfer KV cache.
433
+
434
+ Args:
435
+ tokens: Input tokens [B*S_local, P, C]
436
+ B: Batch size
437
+ S_local: Local sequence length
438
+ S_global: Global sequence length
439
+ P: Number of patches per frame (includes special tokens)
440
+ C: Channel dimension
441
+ global_idx: Starting block index
442
+ pos: Position embeddings [B*S_global, P, 2]
443
+ num_frame_per_block: Number of frames per block
444
+ sliding_window_size: Sliding window size in blocks
445
+ num_frame_for_scale: Number of scale frames
446
+ image_height: Image height for 3D RoPE calculation
447
+ image_width: Image width for 3D RoPE calculation
448
+
449
+ Returns:
450
+ (tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs
451
+ """
452
+ # Get effective parameters
453
+ scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale
454
+
455
+ # Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C]
456
+ if tokens.shape != (B, S_local * P, C):
457
+ tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C)
458
+
459
+ # Calculate number of frames for block mask
460
+ num_frames = S_global
461
+ num_patches = P - self.num_special_tokens
462
+
463
+ # Check if this is the first block group
464
+ is_first_block_group = (global_idx < self.aa_block_size)
465
+
466
+ if self.enable_3d_rope and self.rope3d is not None:
467
+ if is_first_block_group:
468
+ f_start = self.total_frames_processed
469
+ f_end = self.total_frames_processed + S_global
470
+
471
+ H = image_height if image_height is not None else self.img_size
472
+ W = image_width if image_width is not None else self.img_size
473
+ pos3d = self._get_3d_positions_streaming(
474
+ S_global, H, W, tokens.device, f_start, f_end
475
+ )
476
+ self._cached_pos3d = pos3d
477
+ else:
478
+ pos3d = self._cached_pos3d
479
+ pos = pos3d
480
+ else:
481
+ # Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2]
482
+ if pos is not None and pos.shape != (B, S_global * P, 2):
483
+ pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2)
484
+
485
+ intermediates = []
486
+
487
+ # Process blocks with KV cache
488
+ for _ in range(self.aa_block_size):
489
+ num_patches = P - self.num_special_tokens
490
+ if self.use_sdpa:
491
+ # SDPA: dict-based KV cache
492
+ tokens = self.global_blocks[global_idx](
493
+ tokens,
494
+ pos=pos,
495
+ enable_ulysses_cp=False,
496
+ num_patches=num_patches,
497
+ num_special=self.num_special_tokens,
498
+ num_frames=num_frames,
499
+ enable_3d_rope=self.enable_3d_rope,
500
+ kv_cache=self.kv_cache,
501
+ global_idx=global_idx,
502
+ num_frame_per_block=num_frame_per_block,
503
+ num_frame_for_scale=scale_frames,
504
+ num_register_tokens=self.num_register_tokens,
505
+ )
506
+ else:
507
+ # FlashInfer: paged KV cache manager
508
+ manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P)
509
+ tokens = self.global_blocks[global_idx](
510
+ tokens,
511
+ pos=pos,
512
+ enable_ulysses_cp=False,
513
+ num_patches=num_patches,
514
+ num_special=self.num_special_tokens,
515
+ num_frames=num_frames,
516
+ enable_3d_rope=self.enable_3d_rope,
517
+ kv_cache=manager,
518
+ global_idx=global_idx,
519
+ num_frame_per_block=num_frame_per_block,
520
+ num_frame_for_scale=scale_frames,
521
+ num_register_tokens=self.num_register_tokens,
522
+ )
523
+
524
+ global_idx += 1
525
+ intermediates.append(tokens.view(B, S_local, P, C))
526
+
527
+ # Update total frames processed counter only on the first block group
528
+ if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)):
529
+ self.total_frames_processed += S_global
530
+
531
+ return tokens, global_idx, intermediates
lingbot_map/heads/__init__.py ADDED
File without changes
lingbot_map/heads/camera_head.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from lingbot_map.layers import Mlp
15
+ from lingbot_map.layers.block import Block
16
+ from lingbot_map.layers.block import CameraBlock
17
+ from lingbot_map.heads.head_act import activate_pose
18
+ from lingbot_map.layers.rope import WanRotaryPosEmbed
19
+ from functools import partial
20
+ from torch.utils.checkpoint import checkpoint
21
+
22
+
23
+ class CameraHead(nn.Module):
24
+ """
25
+ CameraHead predicts camera parameters from token representations using iterative refinement.
26
+
27
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dim_in: int = 2048,
33
+ trunk_depth: int = 4,
34
+ pose_encoding_type: str = "absT_quaR_FoV",
35
+ num_heads: int = 16,
36
+ mlp_ratio: int = 4,
37
+ init_values: float = 0.01,
38
+ trans_act: str = "linear",
39
+ quat_act: str = "linear",
40
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
41
+ enable_ulysses_cp=False,
42
+ ):
43
+ super().__init__()
44
+
45
+ if pose_encoding_type == "absT_quaR_FoV":
46
+ self.target_dim = 9
47
+ else:
48
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
49
+
50
+ self.trans_act = trans_act
51
+ self.quat_act = quat_act
52
+ self.fl_act = fl_act
53
+ self.trunk_depth = trunk_depth
54
+
55
+ self.enable_ulysses_cp = enable_ulysses_cp
56
+
57
+ # Build the trunk using a sequence of transformer blocks.
58
+ self.trunk = nn.Sequential(
59
+ *[
60
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
61
+ for _ in range(trunk_depth)
62
+ ]
63
+ )
64
+
65
+ # Normalizations for camera token and trunk output.
66
+ self.token_norm = nn.LayerNorm(dim_in)
67
+ self.trunk_norm = nn.LayerNorm(dim_in)
68
+
69
+ # Learnable empty camera pose token.
70
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
71
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
72
+
73
+ # Module for producing modulation parameters: shift, scale, and a gate.
74
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
75
+
76
+ # Adaptive layer normalization without affine parameters.
77
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
78
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
79
+
80
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
81
+ """
82
+ Forward pass to predict camera parameters.
83
+
84
+ Args:
85
+ aggregated_tokens_list (list): List of token tensors from the network;
86
+ the last tensor is used for prediction.
87
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
88
+
89
+ Returns:
90
+ list: A list of predicted camera encodings (post-activation) from each iteration.
91
+ """
92
+ # Use tokens from the last block for camera prediction.
93
+ tokens = aggregated_tokens_list[-1]
94
+
95
+ # Extract the camera tokens
96
+ pose_tokens = tokens[:, :, 0]
97
+ pose_tokens = self.token_norm(pose_tokens)
98
+
99
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
100
+ return pred_pose_enc_list
101
+
102
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
103
+ """
104
+ Iteratively refine camera pose predictions.
105
+
106
+ Args:
107
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
108
+ num_iterations (int): Number of refinement iterations.
109
+
110
+ Returns:
111
+ list: List of activated camera encodings from each iteration.
112
+ """
113
+ B, S, C = pose_tokens.shape # S is expected to be 1.
114
+ pred_pose_enc = None
115
+ pred_pose_enc_list = []
116
+
117
+ for _ in range(num_iterations):
118
+ # Use a learned empty pose for the first iteration.
119
+ if pred_pose_enc is None:
120
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
121
+ else:
122
+ # Detach the previous prediction to avoid backprop through time.
123
+ pred_pose_enc = pred_pose_enc.detach()
124
+ module_input = self.embed_pose(pred_pose_enc)
125
+
126
+ # Generate modulation parameters and split them into shift, scale, and gate components.
127
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
128
+
129
+ # Adaptive layer normalization and modulation.
130
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
131
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
132
+
133
+ # Apply trunk blocks with enable_ulysses_cp
134
+ for block in self.trunk:
135
+ pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
136
+ # Compute the delta update for the pose encoding.
137
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
138
+
139
+ if pred_pose_enc is None:
140
+ pred_pose_enc = pred_pose_enc_delta
141
+ else:
142
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
143
+
144
+ # Apply final activation functions for translation, quaternion, and field-of-view.
145
+ activated_pose = activate_pose(
146
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
147
+ )
148
+ pred_pose_enc_list.append(activated_pose)
149
+
150
+ return pred_pose_enc_list
151
+
152
+
153
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ Modulate the input tensor using scaling and shifting parameters.
156
+ """
157
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
158
+ return x * (1 + scale) + shift
159
+
160
+
161
+ class CameraCausalHead(nn.Module):
162
+ """
163
+ CameraHead predicts camera parameters from token representations using iterative refinement.
164
+
165
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ dim_in: int = 2048,
171
+ trunk_depth: int = 4,
172
+ pose_encoding_type: str = "absT_quaR_FoV",
173
+ num_heads: int = 16,
174
+ mlp_ratio: int = 4,
175
+ init_values: float = 0.01,
176
+ trans_act: str = "linear",
177
+ quat_act: str = "linear",
178
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
179
+ num_iterations = 4,
180
+ elementwise_attn_output_gate: bool = False,
181
+ sliding_window_size: int = -1,
182
+ attend_to_scale_frames: bool = False,
183
+ num_random_frames: int = 0,
184
+ enable_ulysses_cp: bool = False,
185
+ attn_class: str = "flexflashattn_varlen",
186
+ # KV cache parameters
187
+ kv_cache_sliding_window: int = 64,
188
+ kv_cache_scale_frames: int = 8,
189
+ kv_cache_cross_frame_special: bool = True,
190
+ kv_cache_include_scale_frames: bool = True,
191
+ kv_cache_camera_only: bool = False,
192
+ # 3D RoPE parameters
193
+ enable_3d_rope: bool = False,
194
+ max_frame_num: int = 1024,
195
+ rope_theta: float = 10000.0,
196
+ ):
197
+ super().__init__()
198
+
199
+ if pose_encoding_type == "absT_quaR_FoV":
200
+ self.target_dim = 9
201
+ else:
202
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
203
+
204
+ self.trans_act = trans_act
205
+ self.quat_act = quat_act
206
+ self.fl_act = fl_act
207
+ self.trunk_depth = trunk_depth
208
+ self.sliding_window_size = sliding_window_size
209
+ self.enable_ulysses_cp = enable_ulysses_cp
210
+ self.num_heads = num_heads
211
+
212
+ # 3D RoPE for temporal position encoding
213
+ self.enable_3d_rope = enable_3d_rope
214
+ if enable_3d_rope:
215
+ head_dim = dim_in // num_heads
216
+ # For camera head: each frame has 1 token (frame_seqlen=1)
217
+ # patch_size is (max_frames, h=1, w=1) for 3D RoPE
218
+ # fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
219
+ self.rope3d = WanRotaryPosEmbed(
220
+ attention_head_dim=head_dim,
221
+ patch_size=(max_frame_num, 1, 1),
222
+ theta=rope_theta,
223
+ fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
224
+ )
225
+ else:
226
+ self.rope3d = None
227
+
228
+ # Build the trunk using a sequence of transformer blocks.
229
+ self.trunk = nn.Sequential(
230
+ *[
231
+ CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
232
+ for _ in range(trunk_depth)
233
+ ]
234
+ )
235
+
236
+ # Normalizations for camera token and trunk output.
237
+ self.token_norm = nn.LayerNorm(dim_in)
238
+ self.trunk_norm = nn.LayerNorm(dim_in)
239
+
240
+ # Learnable empty camera pose token.
241
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
242
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
243
+
244
+ # Module for producing modulation parameters: shift, scale, and a gate.
245
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
246
+
247
+ # Adaptive layer normalization without affine parameters.
248
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
249
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
250
+
251
+ self.num_iterations = num_iterations
252
+
253
+ self.kv_cache = None
254
+ self.pos_cache = None
255
+ self.frame_idx = 0
256
+ self.cp_size = 1
257
+
258
+ ## Get cp size if enable ulysses cp
259
+ if self.enable_ulysses_cp:
260
+ from torchtitan.distributed.sequence_parallel import (
261
+ init_sequence_parallel,
262
+ get_ulysses_sequence_parallel_rank,
263
+ get_ulysses_sequence_parallel_world_size,
264
+ )
265
+
266
+ self.cp_size = get_ulysses_sequence_parallel_world_size()
267
+
268
+
269
+
270
+ def clean_kv_cache(self):
271
+ del self.kv_cache
272
+ self.kv_cache = None
273
+ self.frame_idx = 0
274
+
275
+ def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = None, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
276
+ """
277
+ Forward pass to predict camera parameters.
278
+
279
+ Args:
280
+ aggregated_tokens_list (list): List of token tensors from the network;
281
+ the last tensor is used for prediction.
282
+ num_iterations (int, optional): Number of iterative refinement steps.
283
+ If None, falls back to self.num_iterations (set at construction).
284
+ sliding_window_size (int, optional): Override the sliding window size for this forward pass.
285
+ If None, use the default self.sliding_window_size.
286
+
287
+ Returns:
288
+ list: A list of predicted camera encodings (post-activation) from each iteration.
289
+ """
290
+ if num_iterations is None:
291
+ num_iterations = self.num_iterations
292
+
293
+ # Use passed sliding_window_size if provided, otherwise use default
294
+ effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
295
+
296
+ # Use tokens from the last block for camera prediction.
297
+ tokens = aggregated_tokens_list[-1]
298
+
299
+ # Extract the camera tokens
300
+ pose_tokens = tokens[:, :, 0]
301
+ pose_tokens = self.token_norm(pose_tokens)
302
+
303
+ if causal_inference:
304
+ if self.kv_cache is None:
305
+ self.kv_cache = []
306
+ for i in range(num_iterations):
307
+ self.kv_cache.append({"_skip_append": False})
308
+ for j in range(self.trunk_depth):
309
+ self.kv_cache[i][f"k_{j}"] = None
310
+ self.kv_cache[i][f"v_{j}"] = None
311
+
312
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
313
+ return pred_pose_enc_list
314
+
315
+ def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
316
+ """
317
+ Iteratively refine camera pose predictions.
318
+
319
+ Args:
320
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
321
+ num_iterations (int): Number of refinement iterations.
322
+ sliding_window_size (int, optional): Sliding window size to use.
323
+
324
+ Returns:
325
+ list: List of activated camera encodings from each iteration.
326
+ """
327
+ B, S, C = pose_tokens.shape
328
+ pred_pose_enc = None
329
+ pred_pose_enc_list = []
330
+
331
+ # Check if this is the first call (processing scale frames)
332
+ # Scale frames should use batch mode attention for numerical consistency
333
+ is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
334
+
335
+ # Generate 3D RoPE positions if enabled
336
+ pos3d = None
337
+ if self.rope3d is not None:
338
+ # For camera tokens: shape [B, S, C] where each frame has 1 token
339
+ # Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
340
+
341
+ # In streaming mode with KV cache, use frame_idx to track global position
342
+ # Otherwise, generate positions from 0
343
+ if self.kv_cache is not None:
344
+ f_start = self.frame_idx
345
+ f_end = self.frame_idx + S
346
+ else:
347
+ f_start = 0
348
+ f_end = None # Will use ppf as frame count
349
+
350
+ pos3d = self.rope3d(
351
+ ppf=S * self.cp_size, # Total frames (with CP)
352
+ pph=1, # height = 1 (camera token)
353
+ ppw=1, # width = 1 (camera token)
354
+ patch_start_idx=0, # No special tokens before
355
+ device=pose_tokens.device,
356
+ f_start=f_start,
357
+ f_end=f_end,
358
+ ) # Returns [1, 1, S*cp_size, head_dim//2] complex
359
+
360
+ for i in range(num_iterations):
361
+ # Use a learned empty pose for the first iteration.
362
+ if pred_pose_enc is None:
363
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
364
+ else:
365
+ # Detach the previous prediction to avoid backprop through time.
366
+ pred_pose_enc = pred_pose_enc.detach()
367
+ module_input = self.embed_pose(pred_pose_enc)
368
+
369
+ # Generate modulation parameters and split them into shift, scale, and gate components.
370
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
371
+
372
+ # Adaptive layer normalization and modulation.
373
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
374
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
375
+
376
+ for idx in range(self.trunk_depth):
377
+ pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
378
+ # Compute the delta update for the pose encoding.
379
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
380
+
381
+ if pred_pose_enc is None:
382
+ pred_pose_enc = pred_pose_enc_delta
383
+ else:
384
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
385
+
386
+ # Apply final activation functions for translation, quaternion, and field-of-view.
387
+ activated_pose = activate_pose(
388
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
389
+ )
390
+ pred_pose_enc_list.append(activated_pose)
391
+
392
+ # Update frame_idx for streaming mode (KV cache)
393
+ if self.kv_cache is not None:
394
+ self.frame_idx += S
395
+
396
+ return pred_pose_enc_list
397
+
398
+
399
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
400
+ """
401
+ Modulate the input tensor using scaling and shifting parameters.
402
+ """
403
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
404
+ return x * (1 + scale) + shift
405
+
406
+
407
+
408
+
409
+ class CameraDecoder(nn.Module):
410
+ def __init__(
411
+ self,
412
+ in_dim,
413
+ out_dim,
414
+ dec_embed_dim=512,
415
+ depth=5,
416
+ dec_num_heads=8,
417
+ mlp_ratio=4,
418
+ rope=None,
419
+ need_project=True,
420
+ use_checkpoint=False,
421
+ ):
422
+ super().__init__()
423
+
424
+ self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
425
+ self.use_checkpoint = use_checkpoint
426
+
427
+ self.blocks = nn.ModuleList([
428
+ Block(
429
+ dim=dec_embed_dim,
430
+ num_heads=dec_num_heads,
431
+ mlp_ratio=mlp_ratio,
432
+ qkv_bias=True,
433
+ proj_bias=True,
434
+ ffn_bias=True,
435
+ drop_path=0.0,
436
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
437
+ act_layer=nn.GELU,
438
+ ffn_layer=Mlp,
439
+ init_values=None,
440
+ qk_norm=False,
441
+ # attn_class=MemEffAttentionRope,
442
+ rope=rope
443
+ ) for _ in range(depth)])
444
+
445
+ self.linear_out = nn.Linear(dec_embed_dim, out_dim)
446
+
447
+ def forward(self, hidden, xpos=None):
448
+ hidden = self.projects(hidden)
449
+ B, V, P, C = hidden.shape
450
+ hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
451
+ for i, blk in enumerate(self.blocks):
452
+ if self.use_checkpoint and self.training:
453
+ hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
454
+ else:
455
+ hidden = blk(hidden, pos=xpos)
456
+ out = self.linear_out(hidden).view(B, V, P, -1)
457
+
458
+ return out
lingbot_map/heads/dpt_head.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [0, 1, 2, 3],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
71
+ )
72
+
73
+ # Resize layers for upsampling feature maps.
74
+ self.resize_layers = nn.ModuleList(
75
+ [
76
+ nn.ConvTranspose2d(
77
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
78
+ ),
79
+ nn.ConvTranspose2d(
80
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
81
+ ),
82
+ nn.Identity(),
83
+ nn.Conv2d(
84
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
85
+ ),
86
+ ]
87
+ )
88
+
89
+ self.scratch = _make_scratch(out_channels, features, expand=False)
90
+
91
+ # Attach additional modules to scratch.
92
+ self.scratch.stem_transpose = None
93
+ self.scratch.refinenet1 = _make_fusion_block(features)
94
+ self.scratch.refinenet2 = _make_fusion_block(features)
95
+ self.scratch.refinenet3 = _make_fusion_block(features)
96
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
97
+
98
+ head_features_1 = features
99
+ head_features_2 = 32
100
+
101
+ if feature_only:
102
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
103
+ else:
104
+ self.scratch.output_conv1 = nn.Conv2d(
105
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
106
+ )
107
+ conv2_in_channels = head_features_1 // 2
108
+
109
+ self.scratch.output_conv2 = nn.Sequential(
110
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
111
+ nn.ReLU(inplace=True),
112
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ aggregated_tokens_list: List[torch.Tensor],
118
+ images: torch.Tensor,
119
+ patch_start_idx: int,
120
+ frames_chunk_size: int = 8,
121
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
+ """
123
+ Forward pass through the DPT head, supports processing by chunking frames.
124
+ Args:
125
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
126
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
127
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
128
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
129
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
130
+ If None or larger than S, all frames are processed at once. Default: 8.
131
+
132
+ Returns:
133
+ Tensor or Tuple[Tensor, Tensor]:
134
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
135
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
136
+ """
137
+ B, _, _, H, W = images.shape
138
+
139
+ S = aggregated_tokens_list[0].shape[1]
140
+
141
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
142
+ if frames_chunk_size is None or frames_chunk_size >= S:
143
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
144
+
145
+ # Otherwise, process frames in chunks to manage memory usage
146
+ assert frames_chunk_size > 0
147
+
148
+ # Process frames in batches
149
+ all_preds = []
150
+ all_conf = []
151
+
152
+ for frames_start_idx in range(0, S, frames_chunk_size):
153
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
154
+
155
+ # Process batch of frames
156
+ if self.feature_only:
157
+ chunk_output = self._forward_impl(
158
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
159
+ )
160
+ all_preds.append(chunk_output)
161
+ else:
162
+ chunk_preds, chunk_conf = self._forward_impl(
163
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
164
+ )
165
+ all_preds.append(chunk_preds)
166
+ all_conf.append(chunk_conf)
167
+
168
+ # Concatenate results along the sequence dimension
169
+ if self.feature_only:
170
+ return torch.cat(all_preds, dim=1)
171
+ else:
172
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
173
+
174
+ def _forward_impl(
175
+ self,
176
+ aggregated_tokens_list: List[torch.Tensor],
177
+ images: torch.Tensor,
178
+ patch_start_idx: int,
179
+ frames_start_idx: int = None,
180
+ frames_end_idx: int = None,
181
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
182
+ """
183
+ Implementation of the forward pass through the DPT head.
184
+
185
+ This method processes a specific chunk of frames from the sequence.
186
+
187
+ Args:
188
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
189
+ images (Tensor): Input images with shape [B, S, 3, H, W].
190
+ patch_start_idx (int): Starting index for patch tokens.
191
+ frames_start_idx (int, optional): Starting index for frames to process.
192
+ frames_end_idx (int, optional): Ending index for frames to process.
193
+
194
+ Returns:
195
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
196
+ """
197
+
198
+ B, _, _, H, W = images.shape
199
+
200
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
201
+
202
+ out = []
203
+ dpt_idx = 0
204
+
205
+ for layer_idx in self.intermediate_layer_idx:
206
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
207
+
208
+
209
+
210
+ if frames_start_idx is not None and frames_end_idx is not None:
211
+ x = x[:, frames_start_idx:frames_end_idx]
212
+
213
+ B, S = x.shape[0], x.shape[1]
214
+
215
+ x = x.reshape(B * S, -1, x.shape[-1])
216
+
217
+ x = self.norm(x)
218
+
219
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
220
+
221
+ x = self.projects[dpt_idx](x)
222
+ if self.pos_embed:
223
+ x = self._apply_pos_embed(x, W, H)
224
+ x = self.resize_layers[dpt_idx](x)
225
+
226
+ out.append(x)
227
+ dpt_idx += 1
228
+
229
+ # Fuse features from multiple layers.
230
+ out = self.scratch_forward(out)
231
+ # Interpolate fused output to match target image resolution.
232
+ out = custom_interpolate(
233
+ out,
234
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
235
+ mode="bilinear",
236
+ align_corners=True,
237
+ )
238
+
239
+ if self.pos_embed:
240
+ out = self._apply_pos_embed(out, W, H)
241
+
242
+ if self.feature_only:
243
+ return out.view(B, S, *out.shape[1:])
244
+
245
+ out = self.scratch.output_conv2(out)
246
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
247
+
248
+ preds = preds.view(B, S, *preds.shape[1:])
249
+ conf = conf.view(B, S, *conf.shape[1:])
250
+ return preds, conf
251
+
252
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
253
+ """
254
+ Apply positional embedding to tensor x.
255
+ """
256
+ patch_w = x.shape[-1]
257
+ patch_h = x.shape[-2]
258
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
259
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
260
+ pos_embed = pos_embed * ratio
261
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
262
+ return x + pos_embed
263
+
264
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
265
+ """
266
+ Forward pass through the fusion blocks.
267
+
268
+ Args:
269
+ features (List[Tensor]): List of feature maps from different layers.
270
+
271
+ Returns:
272
+ Tensor: Fused feature map.
273
+ """
274
+ layer_1, layer_2, layer_3, layer_4 = features
275
+
276
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
277
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
278
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
279
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
280
+
281
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
282
+ del layer_4_rn, layer_4
283
+
284
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
285
+ del layer_3_rn, layer_3
286
+
287
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
288
+ del layer_2_rn, layer_2
289
+
290
+ out = self.scratch.refinenet1(out, layer_1_rn)
291
+ del layer_1_rn, layer_1
292
+
293
+ out = self.scratch.output_conv1(out)
294
+ return out
295
+
296
+
297
+ ################################################################################
298
+ # Modules
299
+ ################################################################################
300
+
301
+
302
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
303
+ return FeatureFusionBlock(
304
+ features,
305
+ nn.ReLU(inplace=True),
306
+ deconv=False,
307
+ bn=False,
308
+ expand=False,
309
+ align_corners=True,
310
+ size=size,
311
+ has_residual=has_residual,
312
+ groups=groups,
313
+ )
314
+
315
+
316
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
317
+ scratch = nn.Module()
318
+ out_shape1 = out_shape
319
+ out_shape2 = out_shape
320
+ out_shape3 = out_shape
321
+ if len(in_shape) >= 4:
322
+ out_shape4 = out_shape
323
+
324
+ if expand:
325
+ out_shape1 = out_shape
326
+ out_shape2 = out_shape * 2
327
+ out_shape3 = out_shape * 4
328
+ if len(in_shape) >= 4:
329
+ out_shape4 = out_shape * 8
330
+
331
+ scratch.layer1_rn = nn.Conv2d(
332
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
333
+ )
334
+ scratch.layer2_rn = nn.Conv2d(
335
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
336
+ )
337
+ scratch.layer3_rn = nn.Conv2d(
338
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
339
+ )
340
+ if len(in_shape) >= 4:
341
+ scratch.layer4_rn = nn.Conv2d(
342
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ return scratch
345
+
346
+
347
+ class ResidualConvUnit(nn.Module):
348
+ """Residual convolution module."""
349
+
350
+ def __init__(self, features, activation, bn, groups=1):
351
+ """Init.
352
+
353
+ Args:
354
+ features (int): number of features
355
+ """
356
+ super().__init__()
357
+
358
+ self.bn = bn
359
+ self.groups = groups
360
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
361
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
362
+
363
+ self.norm1 = None
364
+ self.norm2 = None
365
+
366
+ self.activation = activation
367
+ self.skip_add = nn.quantized.FloatFunctional()
368
+
369
+ def forward(self, x):
370
+ """Forward pass.
371
+
372
+ Args:
373
+ x (tensor): input
374
+
375
+ Returns:
376
+ tensor: output
377
+ """
378
+
379
+ out = self.activation(x)
380
+ out = self.conv1(out)
381
+ if self.norm1 is not None:
382
+ out = self.norm1(out)
383
+
384
+ out = self.activation(out)
385
+ out = self.conv2(out)
386
+ if self.norm2 is not None:
387
+ out = self.norm2(out)
388
+
389
+ return self.skip_add.add(out, x)
390
+
391
+
392
+ class FeatureFusionBlock(nn.Module):
393
+ """Feature fusion block."""
394
+
395
+ def __init__(
396
+ self,
397
+ features,
398
+ activation,
399
+ deconv=False,
400
+ bn=False,
401
+ expand=False,
402
+ align_corners=True,
403
+ size=None,
404
+ has_residual=True,
405
+ groups=1,
406
+ ):
407
+ """Init.
408
+
409
+ Args:
410
+ features (int): number of features
411
+ """
412
+ super(FeatureFusionBlock, self).__init__()
413
+
414
+ self.deconv = deconv
415
+ self.align_corners = align_corners
416
+ self.groups = groups
417
+ self.expand = expand
418
+ out_features = features
419
+ if self.expand == True:
420
+ out_features = features // 2
421
+
422
+ self.out_conv = nn.Conv2d(
423
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
424
+ )
425
+
426
+ if has_residual:
427
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
428
+
429
+ self.has_residual = has_residual
430
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
431
+
432
+ self.skip_add = nn.quantized.FloatFunctional()
433
+ self.size = size
434
+
435
+ def forward(self, *xs, size=None):
436
+ """Forward pass.
437
+
438
+ Returns:
439
+ tensor: output
440
+ """
441
+ output = xs[0]
442
+
443
+ if self.has_residual:
444
+ res = self.resConfUnit1(xs[1])
445
+ output = self.skip_add.add(output, res)
446
+
447
+ output = self.resConfUnit2(output)
448
+
449
+ if (size is None) and (self.size is None):
450
+ modifier = {"scale_factor": 2}
451
+ elif size is None:
452
+ modifier = {"size": self.size}
453
+ else:
454
+ modifier = {"size": size}
455
+
456
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
457
+ output = self.out_conv(output)
458
+
459
+ return output
460
+
461
+
462
+ def custom_interpolate(
463
+ x: torch.Tensor,
464
+ size: Tuple[int, int] = None,
465
+ scale_factor: float = None,
466
+ mode: str = "bilinear",
467
+ align_corners: bool = True,
468
+ ) -> torch.Tensor:
469
+ """
470
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
471
+ """
472
+ if size is None:
473
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
474
+
475
+ INT_MAX = 1610612736
476
+
477
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
478
+
479
+ if input_elements > INT_MAX:
480
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
481
+ interpolated_chunks = [
482
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
483
+ ]
484
+ x = torch.cat(interpolated_chunks, dim=0)
485
+ return x.contiguous()
486
+ else:
487
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
488
+
489
+ class DPTHead_Update(nn.Module):
490
+ def __init__(
491
+ self,
492
+ in_channels,
493
+ features=256,
494
+ use_bn=False,
495
+ out_channels=[256, 512, 1024, 1024],
496
+ use_clstoken=False
497
+ ):
498
+ super(DPTHead_Update, self).__init__()
499
+
500
+ self.use_clstoken = use_clstoken
501
+
502
+ self.projects = nn.ModuleList([
503
+ nn.Conv2d(
504
+ in_channels=in_channels,
505
+ out_channels=out_channel,
506
+ kernel_size=1,
507
+ stride=1,
508
+ padding=0,
509
+ ) for out_channel in out_channels
510
+ ])
511
+
512
+ self.resize_layers = nn.ModuleList([
513
+ nn.ConvTranspose2d(
514
+ in_channels=out_channels[0],
515
+ out_channels=out_channels[0],
516
+ kernel_size=4,
517
+ stride=4,
518
+ padding=0),
519
+ nn.ConvTranspose2d(
520
+ in_channels=out_channels[1],
521
+ out_channels=out_channels[1],
522
+ kernel_size=2,
523
+ stride=2,
524
+ padding=0),
525
+ nn.Identity(),
526
+ nn.Conv2d(
527
+ in_channels=out_channels[3],
528
+ out_channels=out_channels[3],
529
+ kernel_size=3,
530
+ stride=2,
531
+ padding=1)
532
+ ])
533
+
534
+ if use_clstoken:
535
+ self.readout_projects = nn.ModuleList()
536
+ for _ in range(len(self.projects)):
537
+ self.readout_projects.append(
538
+ nn.Sequential(
539
+ nn.Linear(2 * in_channels, in_channels),
540
+ nn.GELU()))
541
+
542
+ self.scratch = _make_scratch(
543
+ out_channels,
544
+ features,
545
+ groups=1,
546
+ expand=False,
547
+ )
548
+
549
+ self.scratch.stem_transpose = None
550
+
551
+ self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn)
552
+ self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn)
553
+ self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn)
554
+ self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn)
555
+
556
+ head_features_1 = features
557
+ head_features_2 = 32
558
+
559
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
560
+ self.scratch.output_conv2 = nn.Sequential(
561
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
562
+ nn.ReLU(True),
563
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
564
+ nn.ReLU(True),
565
+ nn.Identity(),
566
+ )
567
+
568
+ def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
569
+ out = []
570
+ for i, x in enumerate(out_features):
571
+ if self.use_clstoken:
572
+ x, cls_token = x[0], x[1]
573
+ readout = cls_token.unsqueeze(1).expand_as(x)
574
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
575
+
576
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
577
+
578
+ x = self.projects[i](x)
579
+ x = self.resize_layers[i](x)
580
+
581
+ out.append(x)
582
+
583
+ layer_1, layer_2, layer_3, layer_4 = out
584
+
585
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
586
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
587
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
588
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
589
+
590
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
591
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
592
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
593
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
594
+ out = self.scratch.output_conv1(path_1)
595
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
596
+ if return_intermediate:
597
+ return out, path_1, path_2, path_3, path_4
598
+ else:
599
+ out = self.scratch.output_conv2(out)
600
+ return out
601
+
602
+ def _make_fusion_block_slam(features, use_bn, size=None):
603
+ return FeatureFusionBlock_slam(
604
+ features,
605
+ nn.ReLU(False),
606
+ deconv=False,
607
+ bn=use_bn,
608
+ expand=False,
609
+ align_corners=True,
610
+ size=size,
611
+ )
612
+
613
+
614
+ class FeatureFusionBlock_slam(nn.Module):
615
+ """Feature fusion block.
616
+ """
617
+
618
+ def __init__(
619
+ self,
620
+ features,
621
+ activation,
622
+ deconv=False,
623
+ bn=False,
624
+ expand=False,
625
+ align_corners=True,
626
+ size=None
627
+ ):
628
+ """Init.
629
+
630
+ Args:
631
+ features (int): number of features
632
+ """
633
+ super(FeatureFusionBlock_slam, self).__init__()
634
+
635
+ self.deconv = deconv
636
+ self.align_corners = align_corners
637
+
638
+ self.groups=1
639
+
640
+ self.expand = expand
641
+ out_features = features
642
+ if self.expand == True:
643
+ out_features = features // 2
644
+
645
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
646
+
647
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
648
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
649
+
650
+ self.skip_add = nn.quantized.FloatFunctional()
651
+
652
+ self.size=size
653
+
654
+ def forward(self, *xs, size=None):
655
+ """Forward pass.
656
+
657
+ Returns:
658
+ tensor: output
659
+ """
660
+ output = xs[0]
661
+
662
+ if len(xs) == 2:
663
+ res = self.resConfUnit1(xs[1])
664
+ output = self.skip_add.add(output, res)
665
+
666
+ output = self.resConfUnit2(output)
667
+
668
+ if (size is None) and (self.size is None):
669
+ modifier = {"scale_factor": 2}
670
+ elif size is None:
671
+ modifier = {"size": self.size}
672
+ else:
673
+ modifier = {"size": size}
674
+
675
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
676
+
677
+ output = self.out_conv(output)
678
+
679
+ return output
lingbot_map/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
lingbot_map/heads/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ device = pos.device
49
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
50
+ omega /= embed_dim / 2.0
51
+ omega = 1.0 / omega_0**omega # (D/2,)
52
+
53
+ pos = pos.reshape(-1) # (M,)
54
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
55
+
56
+ emb_sin = torch.sin(out) # (M, D/2)
57
+ emb_cos = torch.cos(out) # (M, D/2)
58
+
59
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
60
+ return emb.float()
61
+
62
+
63
+ # Inspired by https://github.com/microsoft/moge
64
+
65
+
66
+ def create_uv_grid(
67
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
68
+ ) -> torch.Tensor:
69
+ """
70
+ Create a normalized UV grid of shape (width, height, 2).
71
+
72
+ The grid spans horizontally and vertically according to an aspect ratio,
73
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
74
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
75
+
76
+ Args:
77
+ width (int): Number of points horizontally.
78
+ height (int): Number of points vertically.
79
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
80
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
81
+ device (torch.device, optional): Device on which the tensor is created.
82
+
83
+ Returns:
84
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
85
+ """
86
+ # Derive aspect ratio if not explicitly provided
87
+ if aspect_ratio is None:
88
+ aspect_ratio = float(width) / float(height)
89
+
90
+ # Compute normalized spans for X and Y
91
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
92
+ span_x = aspect_ratio / diag_factor
93
+ span_y = 1.0 / diag_factor
94
+
95
+ # Establish the linspace boundaries
96
+ left_x = -span_x * (width - 1) / width
97
+ right_x = span_x * (width - 1) / width
98
+ top_y = -span_y * (height - 1) / height
99
+ bottom_y = span_y * (height - 1) / height
100
+
101
+ # Generate 1D coordinates
102
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
103
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
104
+
105
+ # Create 2D meshgrid (width x height) and stack into UV
106
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
107
+ uv_grid = torch.stack((uu, vv), dim=-1)
108
+
109
+ return uv_grid
lingbot_map/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from lingbot_map.layers.mlp import Mlp
2
+ from lingbot_map.layers.patch_embed import PatchEmbed
3
+ from lingbot_map.layers.block import Block
4
+ from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused
5
+ from lingbot_map.layers.attention import Attention as MemEffAttention
lingbot_map/layers/attention.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import math
13
+ import warnings
14
+ import torch
15
+
16
+ from torch import Tensor
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ from lingbot_map.layers.rope import apply_rotary_emb
21
+
22
+ from einops import rearrange
23
+
24
+ # FlashInfer imports (optional - for paged attention)
25
+ try:
26
+ import flashinfer
27
+ FLASHINFER_AVAILABLE = True
28
+ except ImportError:
29
+ FLASHINFER_AVAILABLE = False
30
+ print("flashinfer not available")
31
+
32
+ try:
33
+ from torchtitan.distributed.sequence_parallel import (
34
+ gather_seq_scatter_heads,
35
+ gather_heads_scatter_seq,
36
+ pad_tensor,
37
+ slice_input_tensor_scale_grad,
38
+ gather_outputs,
39
+ )
40
+ except ImportError:
41
+ print("torchtitan not available for ulysses cp")
42
+
43
+ def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int):
44
+ """Gather sequence dimension and scatter head dimension for Q, K, V tensors."""
45
+ q = gather_seq_scatter_heads(q, seq_dim, head_dim)
46
+ k = gather_seq_scatter_heads(k, seq_dim, head_dim)
47
+ v = gather_seq_scatter_heads(v, seq_dim, head_dim)
48
+ return q, k, v
49
+
50
+ from typing_extensions import List
51
+ from typing import Optional, Tuple
52
+
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ num_heads: int = 8,
59
+ qkv_bias: bool = True,
60
+ proj_bias: bool = True,
61
+ attn_drop: float = 0.0,
62
+ proj_drop: float = 0.0,
63
+ norm_layer: nn.Module = nn.LayerNorm,
64
+ qk_norm: bool = False,
65
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
66
+ rope=None,
67
+ ) -> None:
68
+ super().__init__()
69
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
70
+ self.num_heads = num_heads
71
+ self.head_dim = dim // num_heads
72
+ self.scale = self.head_dim**-0.5
73
+ self.fused_attn = fused_attn
74
+
75
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
77
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
78
+ self.attn_drop = nn.Dropout(attn_drop)
79
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
80
+ self.proj_drop = nn.Dropout(proj_drop)
81
+ self.rope = rope
82
+
83
+ def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
84
+ B, N, C = x.shape
85
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
86
+ q, k, v = qkv.unbind(0)
87
+ q, k = self.q_norm(q), self.k_norm(k)
88
+
89
+ if enable_ulysses_cp:
90
+ q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
91
+
92
+ if self.rope is not None:
93
+ q = self.rope(q, pos)
94
+ k = self.rope(k, pos)
95
+
96
+ if self.fused_attn:
97
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
98
+ else:
99
+ q = q * self.scale
100
+ attn = q @ k.transpose(-2, -1)
101
+ attn = attn.softmax(dim=-1)
102
+ attn = self.attn_drop(attn)
103
+ x = attn @ v
104
+
105
+ if enable_ulysses_cp:
106
+ x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
107
+
108
+ x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
109
+ x = self.proj(x)
110
+ x = self.proj_drop(x)
111
+ return x
112
+
113
+
114
+ class CausalAttention(nn.Module):
115
+ """
116
+ Causal self-attention module with KV cache support for streaming inference.
117
+ Used by CasualBlockCamera in camera_head.py.
118
+ """
119
+ def __init__(
120
+ self,
121
+ dim: int,
122
+ num_heads: int = 8,
123
+ qkv_bias: bool = True,
124
+ proj_bias: bool = True,
125
+ attn_drop: float = 0.0,
126
+ proj_drop: float = 0.0,
127
+ norm_layer: nn.Module = nn.LayerNorm,
128
+ qk_norm: bool = False,
129
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
130
+ rope=None,
131
+ elementwise_attn_output_gate=False,
132
+ # KV cache eviction parameters (matching build_attn_mask)
133
+ kv_cache_sliding_window: int =64,
134
+ kv_cache_scale_frames: int = 8,
135
+ kv_cache_cross_frame_special: bool = True,
136
+ kv_cache_include_scale_frames: bool = True,
137
+ kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token)
138
+ ) -> None:
139
+ super().__init__()
140
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
141
+ self.num_heads = num_heads
142
+ self.head_dim = dim // num_heads
143
+ self.scale = self.head_dim**-0.5
144
+ self.fused_attn = fused_attn
145
+
146
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
147
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
148
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
149
+ self.attn_drop = nn.Dropout(attn_drop)
150
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
151
+ self.proj_drop = nn.Dropout(proj_drop)
152
+ self.rope = rope
153
+
154
+ self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None
155
+
156
+ # Store KV cache eviction parameters
157
+ self.kv_cache_sliding_window = kv_cache_sliding_window
158
+ self.kv_cache_scale_frames = kv_cache_scale_frames
159
+ self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
160
+ self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
161
+ self.kv_cache_camera_only = kv_cache_camera_only
162
+
163
+ def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor:
164
+ B, N, C = x.shape
165
+
166
+ # Calculate special token indices
167
+ camera_token_idx = 0
168
+ scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale
169
+
170
+ # [3, B, num_heads, N, head_dim]
171
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
172
+ q, k, v = qkv.unbind(0)
173
+
174
+ if self.gate_proj is not None:
175
+ gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
176
+ if kv_cache is None:
177
+ q, k = self.q_norm(q), self.k_norm(k)
178
+ if enable_ulysses_cp:
179
+ q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
180
+ N = q.shape[2] # Update N after gather
181
+ if self.rope is not None and not enable_3d_rope:
182
+ q = self.rope(q, pos)
183
+ k = self.rope(k, pos)
184
+ elif enable_3d_rope and pos is not None:
185
+ q = apply_rotary_emb(q, pos)
186
+ k = apply_rotary_emb(k, pos)
187
+
188
+ with torch.no_grad():
189
+ block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]]
190
+ if block_mask.dim() == 2:
191
+ block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
192
+ block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
193
+
194
+ video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N]
195
+ video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
196
+
197
+ mask = block_mask | ~video_mask
198
+
199
+ # Apply sliding window mask if sliding_window_size > 0
200
+ # sliding_window_size is in units of num_frame_per_block
201
+ if sliding_window_size > 0 and frame_seqlen is not None:
202
+ # Create sliding window mask: each frame can only attend to frames within the window
203
+ num_frames = N // frame_seqlen
204
+ sliding_mask = torch.zeros_like(mask, dtype=torch.bool)
205
+
206
+ for i in range(num_frames):
207
+ q_start = i * frame_seqlen
208
+ q_end = (i + 1) * frame_seqlen
209
+ # Calculate the window start: sliding_window_size is in units of num_frame_per_block
210
+ # So the actual window size in frames is sliding_window_size * num_frame_per_block
211
+ window_size_in_frames = sliding_window_size * num_frame_per_block
212
+ window_start_frame = max(0, i - window_size_in_frames + 1)
213
+ k_start = window_start_frame * frame_seqlen
214
+ k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal)
215
+ sliding_mask[:, :, q_start:q_end, k_start:k_end] = True
216
+
217
+ # Combine with existing mask: both masks need to allow attention
218
+ mask = mask & sliding_mask
219
+
220
+ # If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames
221
+ if num_frame_for_scale > 0:
222
+ for i in range(num_frames):
223
+ q_start = i * frame_seqlen
224
+ q_end = (i + 1) * frame_seqlen
225
+ # Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask)
226
+ mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True
227
+
228
+ ## global attention for the first num_frame_for_scale frames
229
+ if num_frame_for_scale > 0:
230
+ mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True
231
+
232
+ if self.fused_attn:
233
+ x = F.scaled_dot_product_attention(
234
+ q,
235
+ k,
236
+ v,
237
+ dropout_p=self.attn_drop.p if self.training else 0.0,
238
+ attn_mask=mask
239
+ )
240
+ else:
241
+ # Apply RoPE to current k before caching
242
+ q, k = self.q_norm(q), self.k_norm(k)
243
+
244
+ if self.rope is not None and not enable_3d_rope:
245
+ q = self.rope(q, pos)
246
+ k = self.rope(k, pos)
247
+ elif enable_3d_rope and pos is not None:
248
+ q = apply_rotary_emb(q, pos)
249
+ k = apply_rotary_emb(k, pos)
250
+
251
+ # Check if we should skip appending to cache (non-keyframe in keyframe mode)
252
+ skip_append = kv_cache.get("_skip_append", False)
253
+
254
+ k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
255
+ v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
256
+
257
+ if not skip_append:
258
+ # KEYFRAME: store in cache (original behavior)
259
+ if kv_cache[f"k_{global_idx}"] is None:
260
+ kv_cache[f"k_{global_idx}"] = k_reshaped
261
+ kv_cache[f"v_{global_idx}"] = v_reshaped
262
+ else:
263
+ num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
264
+ k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
265
+ v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
266
+ kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
267
+ kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
268
+
269
+ # Apply sliding window eviction BEFORE attention to match causal_3drope behavior
270
+ # This ensures current frame only attends to frames within the sliding window
271
+ self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx)
272
+
273
+ # Retrieve full k, v from cache (already RoPE-applied, already evicted)
274
+ k = kv_cache[f"k_{global_idx}"].clone()
275
+ v = kv_cache[f"v_{global_idx}"].clone()
276
+ else:
277
+ # NON-KEYFRAME: attend to [cached + current] without storing in cache
278
+ if kv_cache[f"k_{global_idx}"] is not None:
279
+ k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
280
+ v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
281
+ else:
282
+ k = k_reshaped
283
+ v = v_reshaped
284
+ a, b, c, d, e = k.shape
285
+
286
+ k = k.reshape(a, b, c*d, e)
287
+ v = v.reshape(a, b, c*d, e)
288
+
289
+ # Prepend special tokens (camera + scale) from evicted frames if they exist
290
+ if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
291
+ special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D]
292
+ special_v = kv_cache[f"v_{global_idx}_special"]
293
+ sa, sb, sc, sd, se = special_k.shape
294
+ special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D]
295
+ special_v = special_v.reshape(sa, sb, sc * sd, se)
296
+
297
+ # Prepend special tokens (older frames first)
298
+ k = torch.cat([special_k, k], dim=2)
299
+ v = torch.cat([special_v, v], dim=2)
300
+
301
+ # Note: k from cache is already RoPE-applied, no need to apply again
302
+
303
+ if self.fused_attn:
304
+ # Use mask-based SDPA to ensure same kernel as batch mode
305
+ # The causal constraint is enforced by KV cache contents, not by mask
306
+ mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device)
307
+ x = F.scaled_dot_product_attention(
308
+ q,
309
+ k,
310
+ v,
311
+ dropout_p=self.attn_drop.p if self.training else 0.0,
312
+ attn_mask=mask,
313
+ )
314
+
315
+ if self.gate_proj is not None:
316
+ x = x * torch.sigmoid(gate_score)
317
+ if enable_ulysses_cp:
318
+ x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
319
+ # Use actual dimensions from attention output, not original input C
320
+ # x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim]
321
+ x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
322
+ x = self.proj(x)
323
+ x = self.proj_drop(x)
324
+ return x
325
+
326
+ def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx):
327
+ """
328
+ Apply sliding window eviction to KV cache BEFORE attention.
329
+
330
+ This ensures current frame only attends to frames within the sliding window,
331
+ matching the behavior of causal_3drope's attention mask.
332
+ """
333
+ sliding_window_frames = self.kv_cache_sliding_window
334
+ scale_frames = self.kv_cache_scale_frames
335
+
336
+ if kv_cache[f"k_{global_idx}"].shape[3] > 1:
337
+ num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
338
+
339
+ if num_cached_frames > sliding_window_frames + scale_frames:
340
+ evict_start = scale_frames
341
+ evict_end = num_cached_frames - sliding_window_frames
342
+
343
+ if evict_end > evict_start:
344
+ evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
345
+ evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
346
+
347
+ if self.kv_cache_cross_frame_special:
348
+ if self.kv_cache_camera_only:
349
+ # Only keep camera token
350
+ new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
351
+ new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
352
+ else:
353
+ # Keep ALL special tokens (camera + register + scale) to match attention_mask behavior
354
+ # Special tokens are in range [camera_token_idx, scale_token_idx+1)
355
+ new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
356
+ new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
357
+
358
+ if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
359
+ kv_cache[f"k_{global_idx}_special"] = new_special_k
360
+ kv_cache[f"v_{global_idx}_special"] = new_special_v
361
+ else:
362
+ kv_cache[f"k_{global_idx}_special"] = torch.cat(
363
+ [kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
364
+ kv_cache[f"v_{global_idx}_special"] = torch.cat(
365
+ [kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
366
+
367
+ if self.kv_cache_include_scale_frames:
368
+ kv_cache[f"k_{global_idx}"] = torch.cat([
369
+ kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
370
+ kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
371
+ ], dim=2)
372
+ kv_cache[f"v_{global_idx}"] = torch.cat([
373
+ kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
374
+ kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
375
+ ], dim=2)
376
+ else:
377
+ kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
378
+ kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
379
+
380
+
381
+ class FlashInferAttention(Attention):
382
+ """
383
+ FlashInfer variant of the GCT attention layer.
384
+ Uses FlashInferKVCacheManager for paged KV cache storage and
385
+ FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper).
386
+ Supports the same optimized token layout and KV cache streaming inference.
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ dim: int,
392
+ num_heads: int = 8,
393
+ qkv_bias: bool = True,
394
+ proj_bias: bool = True,
395
+ attn_drop: float = 0.0,
396
+ proj_drop: float = 0.0,
397
+ norm_layer: nn.Module = nn.LayerNorm,
398
+ qk_norm: bool = False,
399
+ fused_attn: bool = True,
400
+ rope=None,
401
+ # KV cache eviction parameters
402
+ kv_cache_sliding_window: int = 64,
403
+ kv_cache_scale_frames: int = 8,
404
+ kv_cache_cross_frame_special: bool = True,
405
+ kv_cache_include_scale_frames: bool = True,
406
+ kv_cache_camera_only: bool = False,
407
+ ) -> None:
408
+ if not FLASHINFER_AVAILABLE:
409
+ raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
410
+
411
+ super().__init__(
412
+ dim=dim,
413
+ num_heads=num_heads,
414
+ qkv_bias=qkv_bias,
415
+ proj_bias=proj_bias,
416
+ attn_drop=attn_drop,
417
+ proj_drop=proj_drop,
418
+ norm_layer=norm_layer,
419
+ qk_norm=qk_norm,
420
+ fused_attn=fused_attn,
421
+ rope=rope,
422
+ )
423
+
424
+ # Store KV cache eviction parameters
425
+ self.kv_cache_sliding_window = kv_cache_sliding_window
426
+ self.kv_cache_scale_frames = kv_cache_scale_frames
427
+ self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
428
+ self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
429
+ self.kv_cache_camera_only = kv_cache_camera_only
430
+
431
+ def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
432
+ """Fused pre-attention ops for single-frame streaming (Phase 2).
433
+
434
+ Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to
435
+ [tpf, H, D] format ready for append_frame + compute_attention.
436
+
437
+ Extracted as a method so torch.compile can capture all pre-attn ops as one
438
+ CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 ->
439
+ squeeze/permute/contiguous x3).
440
+ """
441
+ B, N, C = x.shape
442
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
443
+ q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
444
+ q, k = self.q_norm(q), self.k_norm(k)
445
+ if self.rope is not None and not enable_3d_rope:
446
+ q = self.rope(q, pos)
447
+ k = self.rope(k, pos)
448
+ elif self.rope is not None: # enable_3d_rope=True
449
+ q = apply_rotary_emb(q, pos)
450
+ k = apply_rotary_emb(k, pos)
451
+ # Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode)
452
+ q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous()
453
+ k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous()
454
+ v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous()
455
+ return q_nhd, k_nhd, v_nhd
456
+
457
+ def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
458
+ num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
459
+ # KV cache parameters (kv_cache is a FlashInferKVCacheManager or None)
460
+ kv_cache=None, global_idx=0, num_frame_per_block=1,
461
+ num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
462
+ """
463
+ Forward pass with FlashInfer paged KV cache and attention.
464
+
465
+ Args:
466
+ x: Input tensor [B, N, C]
467
+ kv_cache: FlashInferKVCacheManager instance or None (batch mode)
468
+ global_idx: Block index for per-block cache access
469
+ """
470
+ from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
471
+
472
+ B, N, C = x.shape
473
+
474
+ # Detect if using optimized layout
475
+ using_optimized_layout = (num_patches is not None and num_special is not None
476
+ and num_frames is not None)
477
+
478
+ # ========== Batch Mode (no KV cache manager) ==========
479
+ if not isinstance(kv_cache, FlashInferKVCacheManager):
480
+ # [3, B, num_heads, N, head_dim]
481
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
482
+ q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
483
+ q, k = self.q_norm(q), self.k_norm(k)
484
+
485
+ if enable_ulysses_cp:
486
+ if using_optimized_layout:
487
+ boundary = num_frames * num_patches
488
+ q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :]
489
+ q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :]
490
+ q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv(
491
+ q_patch, k_patch, v_patch, seq_dim=2, head_dim=1
492
+ )
493
+ q_special, k_special, v_special = gather_seq_scatter_heads_qkv(
494
+ q_special, k_special, v_special, seq_dim=2, head_dim=1
495
+ )
496
+ q = torch.cat([q_patch, q_special], dim=2)
497
+ k = torch.cat([k_patch, k_special], dim=2)
498
+ v = torch.cat([v_patch, v_special], dim=2)
499
+ else:
500
+ q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
501
+
502
+ if self.rope is not None and not enable_3d_rope:
503
+ q = self.rope(q, pos)
504
+ k = self.rope(k, pos)
505
+ elif self.rope is not None and enable_3d_rope:
506
+ q = apply_rotary_emb(q, pos)
507
+ k = apply_rotary_emb(k, pos)
508
+
509
+ # Batch mode: use SDPA for numerical consistency with SDPA variant
510
+ x = F.scaled_dot_product_attention(
511
+ q, k, v,
512
+ dropout_p=self.attn_drop.p if self.training else 0.0,
513
+ )
514
+
515
+ if enable_ulysses_cp:
516
+ if using_optimized_layout:
517
+ seq_global = x.shape[2]
518
+ seq_local = num_frames * (num_patches + num_special)
519
+ cp_size = seq_global // seq_local
520
+ boundary_global = num_frames * cp_size * num_patches
521
+ x_patch = x[:, :, :boundary_global, :]
522
+ x_special = x[:, :, boundary_global:, :]
523
+ x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1)
524
+ x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1)
525
+ x = torch.cat([x_patch, x_special], dim=2)
526
+ else:
527
+ x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
528
+
529
+ x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
530
+
531
+ # ========== Streaming Mode (with FlashInferKVCacheManager) ==========
532
+ else:
533
+ manager = kv_cache # FlashInferKVCacheManager
534
+
535
+ # Phase 1 (scale frames): num_frames > 1 — multi-frame batch
536
+ # Phase 2 (streaming): num_frames == 1 — single frame
537
+ is_multi_frame = (num_frames is not None and num_frames > 1)
538
+
539
+ if is_multi_frame:
540
+ # Phase 1: compute full self-attention via SDPA (all frames attend to each other),
541
+ # then append each frame's K/V to the paged cache one at a time.
542
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
543
+ q, k, v = qkv.unbind(0)
544
+ q, k = self.q_norm(q), self.k_norm(k)
545
+
546
+ # Apply RoPE before caching (RoPE baked into K before append)
547
+ if self.rope is not None and not enable_3d_rope:
548
+ q = self.rope(q, pos)
549
+ k = self.rope(k, pos)
550
+ elif self.rope is not None and enable_3d_rope:
551
+ q = apply_rotary_emb(q, pos)
552
+ k = apply_rotary_emb(k, pos)
553
+
554
+ x = F.scaled_dot_product_attention(
555
+ q, k, v,
556
+ dropout_p=self.attn_drop.p if self.training else 0.0,
557
+ )
558
+ x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
559
+
560
+ # Append each frame's K/V to the paged cache individually.
561
+ tpf = manager.tokens_per_frame
562
+ k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D]
563
+ v_all = v.squeeze(0).permute(1, 0, 2)
564
+ for f_idx in range(num_frames):
565
+ s = f_idx * tpf
566
+ manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous())
567
+ manager.evict_frames(
568
+ block_idx=global_idx,
569
+ scale_frames=self.kv_cache_scale_frames,
570
+ sliding_window=self.kv_cache_sliding_window,
571
+ cross_frame_special=self.kv_cache_cross_frame_special,
572
+ include_scale_frames=self.kv_cache_include_scale_frames,
573
+ camera_only=self.kv_cache_camera_only,
574
+ num_register_tokens=num_register_tokens,
575
+ )
576
+ else:
577
+ # Phase 2: single-frame streaming via FlashInfer paged attention.
578
+ q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope)
579
+
580
+ # 1. Append to paged cache
581
+ manager.append_frame(global_idx, k_nhd, v_nhd)
582
+
583
+ # 2. Apply sliding window eviction
584
+ manager.evict_frames(
585
+ block_idx=global_idx,
586
+ scale_frames=self.kv_cache_scale_frames,
587
+ sliding_window=self.kv_cache_sliding_window,
588
+ cross_frame_special=self.kv_cache_cross_frame_special,
589
+ include_scale_frames=self.kv_cache_include_scale_frames,
590
+ camera_only=self.kv_cache_camera_only,
591
+ num_register_tokens=num_register_tokens,
592
+ )
593
+
594
+ # 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper
595
+ x = manager.compute_attention(global_idx, q_nhd)
596
+
597
+ # Convert back: [tpf, H, D] -> [B, tpf, C].
598
+ x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim)
599
+
600
+ x = self.proj(x)
601
+ x = self.proj_drop(x)
602
+ return x
603
+
604
+
605
+ class SDPAAttention(Attention):
606
+ """
607
+ SDPA variant for streaming inference.
608
+ Uses F.scaled_dot_product_attention with dict-based KV cache.
609
+ No FlashInfer dependency required — works on any CUDA GPU.
610
+ """
611
+
612
+ def __init__(
613
+ self,
614
+ dim: int,
615
+ num_heads: int = 8,
616
+ qkv_bias: bool = True,
617
+ proj_bias: bool = True,
618
+ attn_drop: float = 0.0,
619
+ proj_drop: float = 0.0,
620
+ norm_layer: nn.Module = nn.LayerNorm,
621
+ qk_norm: bool = False,
622
+ fused_attn: bool = True,
623
+ rope=None,
624
+ kv_cache_sliding_window: int = 64,
625
+ kv_cache_scale_frames: int = 8,
626
+ kv_cache_cross_frame_special: bool = True,
627
+ kv_cache_include_scale_frames: bool = True,
628
+ kv_cache_camera_only: bool = False,
629
+ ) -> None:
630
+ super().__init__(
631
+ dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
632
+ attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer,
633
+ qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
634
+ )
635
+ self.kv_cache_sliding_window = kv_cache_sliding_window
636
+ self.kv_cache_scale_frames = kv_cache_scale_frames
637
+ self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
638
+ self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
639
+ self.kv_cache_camera_only = kv_cache_camera_only
640
+
641
+ def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
642
+ num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
643
+ kv_cache=None, global_idx=0, num_frame_per_block=1,
644
+ num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
645
+ B, N, C = x.shape
646
+ using_optimized_layout = (num_patches is not None and num_special is not None
647
+ and num_frames is not None)
648
+
649
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
650
+ q, k, v = qkv.unbind(0)
651
+ q, k = self.q_norm(q), self.k_norm(k)
652
+
653
+ # ========== Batch Mode (no KV cache) ==========
654
+ if kv_cache is None:
655
+ if self.rope is not None and not enable_3d_rope:
656
+ q = self.rope(q, pos)
657
+ k = self.rope(k, pos)
658
+ elif self.rope is not None and enable_3d_rope:
659
+ q = apply_rotary_emb(q, pos)
660
+ k = apply_rotary_emb(k, pos)
661
+
662
+ x = F.scaled_dot_product_attention(
663
+ q, k, v,
664
+ dropout_p=self.attn_drop.p if self.training else 0.0,
665
+ )
666
+ x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
667
+
668
+ # ========== Streaming Mode (with KV cache dict) ==========
669
+ else:
670
+ if self.rope is not None and not enable_3d_rope:
671
+ q = self.rope(q, pos)
672
+ k = self.rope(k, pos)
673
+ elif self.rope is not None and enable_3d_rope:
674
+ q = apply_rotary_emb(q, pos)
675
+ k = apply_rotary_emb(k, pos)
676
+
677
+ camera_token_idx = 0
678
+ scale_token_idx = camera_token_idx + num_register_tokens + 1
679
+
680
+ if kv_cache[f"k_{global_idx}"] is None:
681
+ kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block,
682
+ N // num_frame_per_block, self.head_dim)
683
+ kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block,
684
+ N // num_frame_per_block, self.head_dim)
685
+ else:
686
+ num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
687
+ kv_cache[f"k_{global_idx}"] = torch.cat((
688
+ kv_cache[f"k_{global_idx}"],
689
+ k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
690
+ ), dim=2)
691
+ kv_cache[f"v_{global_idx}"] = torch.cat((
692
+ kv_cache[f"v_{global_idx}"],
693
+ v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
694
+ ), dim=2)
695
+
696
+ self._apply_kv_cache_eviction(
697
+ kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens
698
+ )
699
+
700
+ k_cached = kv_cache[f"k_{global_idx}"].clone()
701
+ v_cached = kv_cache[f"v_{global_idx}"].clone()
702
+ a, b, c, d, e = k_cached.shape
703
+ k_full = k_cached.reshape(a, b, c * d, e)
704
+ v_full = v_cached.reshape(a, b, c * d, e)
705
+
706
+ if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
707
+ special_k = kv_cache[f"k_{global_idx}_special"]
708
+ special_v = kv_cache[f"v_{global_idx}_special"]
709
+ sa, sb, sc, sd, se = special_k.shape
710
+ k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2)
711
+ v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2)
712
+
713
+ q_seq_len = q.shape[2]
714
+ x = F.scaled_dot_product_attention(
715
+ q, k_full, v_full,
716
+ dropout_p=self.attn_drop.p if self.training else 0.0,
717
+ )
718
+ x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim)
719
+
720
+ x = self.proj(x)
721
+ x = self.proj_drop(x)
722
+ return x
723
+
724
+ def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens):
725
+ """Apply sliding window eviction to KV cache."""
726
+ sliding_window_frames = self.kv_cache_sliding_window
727
+ scale_frames = self.kv_cache_scale_frames
728
+
729
+ if kv_cache[f"k_{global_idx}"].shape[3] > 1:
730
+ num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
731
+ if num_cached_frames > sliding_window_frames + scale_frames:
732
+ evict_start = scale_frames
733
+ evict_end = num_cached_frames - sliding_window_frames
734
+ if evict_end > evict_start:
735
+ evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
736
+ evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
737
+
738
+ if self.kv_cache_cross_frame_special:
739
+ if self.kv_cache_camera_only:
740
+ new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
741
+ new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
742
+ else:
743
+ new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
744
+ new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
745
+
746
+ if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
747
+ kv_cache[f"k_{global_idx}_special"] = new_special_k
748
+ kv_cache[f"v_{global_idx}_special"] = new_special_v
749
+ else:
750
+ kv_cache[f"k_{global_idx}_special"] = torch.cat(
751
+ [kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
752
+ kv_cache[f"v_{global_idx}_special"] = torch.cat(
753
+ [kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
754
+
755
+ if self.kv_cache_include_scale_frames:
756
+ kv_cache[f"k_{global_idx}"] = torch.cat([
757
+ kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
758
+ kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
759
+ ], dim=2)
760
+ kv_cache[f"v_{global_idx}"] = torch.cat([
761
+ kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
762
+ kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
763
+ ], dim=2)
764
+ else:
765
+ kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
766
+ kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
lingbot_map/layers/block.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+ import math
15
+
16
+ import torch
17
+ from torch import nn, Tensor
18
+
19
+ from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
20
+ from functools import lru_cache, partial
21
+ from torch.nn.attention.flex_attention import BlockMask, create_mask
22
+ from .drop_path import DropPath
23
+ from .layer_scale import LayerScale
24
+ from .mlp import Mlp
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
71
+ )
72
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.sample_drop_ratio = drop_path
76
+
77
+ def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
78
+ num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
79
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
80
+ return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
81
+ num_patches=num_patches, num_special=num_special, num_frames=num_frames,
82
+ enable_3d_rope=enable_3d_rope))
83
+
84
+ def ffn_residual_func(x: Tensor) -> Tensor:
85
+ return self.ls2(self.mlp(self.norm2(x)))
86
+
87
+ if self.training and self.sample_drop_ratio > 0.1:
88
+ # the overhead is compensated only for a drop path rate larger than 0.1
89
+ x = drop_add_residual_stochastic_depth(
90
+ x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
91
+ )
92
+ x = drop_add_residual_stochastic_depth(
93
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
94
+ )
95
+ elif self.training and self.sample_drop_ratio > 0.0:
96
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
97
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
98
+ else:
99
+ x = x + attn_residual_func(x, pos=pos)
100
+ x = x + ffn_residual_func(x)
101
+ return x
102
+
103
+
104
+ def drop_add_residual_stochastic_depth(
105
+ x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
106
+ ) -> Tensor:
107
+ # 1) extract subset using permutation
108
+ b, n, d = x.shape
109
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
110
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
111
+ x_subset = x[brange]
112
+
113
+ # 2) apply residual_func to get residual
114
+ if pos is not None:
115
+ # if necessary, apply rope to the subset
116
+ pos = pos[brange]
117
+ residual = residual_func(x_subset, pos=pos)
118
+ else:
119
+ residual = residual_func(x_subset)
120
+
121
+ x_flat = x.flatten(1)
122
+ residual = residual.flatten(1)
123
+
124
+ residual_scale_factor = b / sample_subset_size
125
+
126
+ # 3) add the residual
127
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
128
+ return x_plus_residual.view_as(x)
129
+
130
+
131
+ def get_branges_scales(x, sample_drop_ratio=0.0):
132
+ b, n, d = x.shape
133
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
134
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
135
+ residual_scale_factor = b / sample_subset_size
136
+ return brange, residual_scale_factor
137
+
138
+
139
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
140
+ if scaling_vector is None:
141
+ x_flat = x.flatten(1)
142
+ residual = residual.flatten(1)
143
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
144
+ else:
145
+ x_plus_residual = scaled_index_add(
146
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
147
+ )
148
+ return x_plus_residual
149
+
150
+
151
+ class FlashInferBlock(nn.Module):
152
+ """
153
+ FlashInfer variant of causal block for GCT.
154
+ Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
155
+ Supports optimized token layout and KV cache streaming inference.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ dim: int,
161
+ num_heads: int,
162
+ mlp_ratio: float = 4.0,
163
+ qkv_bias: bool = True,
164
+ proj_bias: bool = True,
165
+ ffn_bias: bool = True,
166
+ drop: float = 0.0,
167
+ attn_drop: float = 0.0,
168
+ init_values=None,
169
+ drop_path: float = 0.0,
170
+ act_layer: Callable[..., nn.Module] = nn.GELU,
171
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
172
+ ffn_layer: Callable[..., nn.Module] = Mlp,
173
+ qk_norm: bool = False,
174
+ rope=None,
175
+ kv_cache_sliding_window: int = 64,
176
+ kv_cache_scale_frames: int = 8,
177
+ kv_cache_cross_frame_special: bool = True,
178
+ kv_cache_include_scale_frames: bool = True,
179
+ kv_cache_camera_only: bool = False,
180
+ ) -> None:
181
+ super().__init__()
182
+
183
+ self.norm1 = norm_layer(dim)
184
+ self.attn = FlashInferAttention(
185
+ dim=dim,
186
+ num_heads=num_heads,
187
+ qk_norm=qk_norm,
188
+ qkv_bias=qkv_bias,
189
+ proj_bias=proj_bias,
190
+ attn_drop=attn_drop,
191
+ proj_drop=drop,
192
+ rope=rope,
193
+ kv_cache_sliding_window=kv_cache_sliding_window,
194
+ kv_cache_scale_frames=kv_cache_scale_frames,
195
+ kv_cache_cross_frame_special=kv_cache_cross_frame_special,
196
+ kv_cache_include_scale_frames=kv_cache_include_scale_frames,
197
+ kv_cache_camera_only=kv_cache_camera_only,
198
+ )
199
+
200
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
201
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
202
+
203
+ self.norm2 = norm_layer(dim)
204
+ mlp_hidden_dim = int(dim * mlp_ratio)
205
+ self.mlp = ffn_layer(
206
+ in_features=dim,
207
+ hidden_features=mlp_hidden_dim,
208
+ act_layer=act_layer,
209
+ drop=drop,
210
+ bias=ffn_bias
211
+ )
212
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
213
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
214
+
215
+ self.sample_drop_ratio = drop_path
216
+
217
+ def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
218
+ """Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
219
+
220
+ Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
221
+ reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
222
+
223
+ Returns:
224
+ (q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
225
+ ready for manager.append_frame + manager.compute_attention.
226
+ """
227
+ return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
228
+
229
+ def forward(
230
+ self,
231
+ x: Tensor,
232
+ pos=None,
233
+ enable_ulysses_cp=False,
234
+ num_patches=None,
235
+ num_special=None,
236
+ num_frames=None,
237
+ enable_3d_rope=False,
238
+ kv_cache=None,
239
+ global_idx=0,
240
+ num_frame_per_block=1,
241
+ num_frame_for_scale=-1,
242
+ num_register_tokens=4,
243
+ ) -> Tensor:
244
+ # Phase 2 (streaming): single-frame FlashInfer paged attention.
245
+ # Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
246
+ is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
247
+ if is_streaming:
248
+ manager = kv_cache
249
+ # Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
250
+ q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
251
+ # Eager: write frame K/V to paged cache
252
+ manager.append_frame(global_idx, k_nhd, v_nhd)
253
+ # CPU-only: update eviction state (deque ops, no GPU kernel)
254
+ manager.evict_frames(
255
+ block_idx=global_idx,
256
+ scale_frames=self.attn.kv_cache_scale_frames,
257
+ sliding_window=self.attn.kv_cache_sliding_window,
258
+ cross_frame_special=self.attn.kv_cache_cross_frame_special,
259
+ include_scale_frames=self.attn.kv_cache_include_scale_frames,
260
+ camera_only=self.attn.kv_cache_camera_only,
261
+ num_register_tokens=num_register_tokens,
262
+ )
263
+ # Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
264
+ attn_x = manager.compute_attention(global_idx, q_nhd)
265
+ # [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
266
+ attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
267
+ self.attn.num_heads * self.attn.head_dim)
268
+ # Compiled: output projection
269
+ attn_x = self.attn.proj(attn_x)
270
+ x = x + self.ls1(attn_x)
271
+ else:
272
+ # Phase 1 (multi-frame scale pass) or non-streaming training path
273
+ x = x + self.ls1(self.attn(
274
+ self.norm1(x),
275
+ pos=pos,
276
+ enable_ulysses_cp=enable_ulysses_cp,
277
+ num_patches=num_patches,
278
+ num_special=num_special,
279
+ num_frames=num_frames,
280
+ enable_3d_rope=enable_3d_rope,
281
+ kv_cache=kv_cache,
282
+ global_idx=global_idx,
283
+ num_frame_per_block=num_frame_per_block,
284
+ num_frame_for_scale=num_frame_for_scale,
285
+ num_register_tokens=num_register_tokens,
286
+ ))
287
+ x = self.ffn_residual(x)
288
+ return x
289
+
290
+ def ffn_residual(self, x: Tensor) -> Tensor:
291
+ """FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
292
+
293
+ Includes the residual add (x + ...) so torch.compile captures the entire
294
+ ffn branch as one CUDA graph.
295
+ """
296
+ return x + self.ls2(self.mlp(self.norm2(x)))
297
+
298
+
299
+ class CameraBlock(nn.Module):
300
+ def __init__(
301
+ self,
302
+ dim: int,
303
+ num_heads: int,
304
+ mlp_ratio: float = 4.0,
305
+ qkv_bias: bool = True,
306
+ proj_bias: bool = True,
307
+ ffn_bias: bool = True,
308
+ drop: float = 0.0,
309
+ attn_drop: float = 0.0,
310
+ init_values=None,
311
+ drop_path: float = 0.0,
312
+ act_layer: Callable[..., nn.Module] = nn.GELU,
313
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
314
+ attn_class: Callable[..., nn.Module] = Attention,
315
+ ffn_layer: Callable[..., nn.Module] = Mlp,
316
+ qk_norm: bool = False,
317
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
318
+ rope=None,
319
+ elementwise_attn_output_gate: bool = False,
320
+ sliding_window_size: int = -1,
321
+ attend_to_scale_frames: bool = False,
322
+ num_random_frames: int = 0,
323
+ # KV cache parameters
324
+ kv_cache_sliding_window: int = 64,
325
+ kv_cache_scale_frames: int = 8,
326
+ kv_cache_cross_frame_special: bool = True,
327
+ kv_cache_include_scale_frames: bool = True,
328
+ kv_cache_camera_only: bool = False,
329
+ ) -> None:
330
+ super().__init__()
331
+
332
+ self.norm1 = norm_layer(dim)
333
+ self.attn = CausalAttention(dim=dim, num_heads=num_heads,
334
+ qk_norm=qk_norm, qkv_bias=qkv_bias,
335
+ rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
336
+ kv_cache_sliding_window=kv_cache_sliding_window,
337
+ kv_cache_scale_frames=kv_cache_scale_frames,
338
+ kv_cache_cross_frame_special=kv_cache_cross_frame_special,
339
+ kv_cache_include_scale_frames=kv_cache_include_scale_frames,
340
+ kv_cache_camera_only=kv_cache_camera_only)
341
+
342
+ self.sliding_window_size = sliding_window_size
343
+ self.attend_to_scale_frames = attend_to_scale_frames
344
+ self.num_random_frames = num_random_frames
345
+
346
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
347
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
348
+
349
+ self.norm2 = norm_layer(dim)
350
+ mlp_hidden_dim = int(dim * mlp_ratio)
351
+ self.mlp = ffn_layer(
352
+ in_features=dim,
353
+ hidden_features=mlp_hidden_dim,
354
+ act_layer=act_layer,
355
+ drop=drop,
356
+ bias=ffn_bias
357
+ )
358
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
359
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
360
+
361
+ self.sample_drop_ratio = drop_path
362
+ self.masks = {}
363
+
364
+ @torch.no_grad()
365
+ def _prepare_blockwise_causal_attn_mask(self,
366
+ device: torch.device | str, num_frames: int = 21,
367
+ frame_seqlen: int = 1560, num_frame_per_block=1
368
+ ) -> BlockMask:
369
+ """
370
+ we will divide the token sequence into the following format
371
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
372
+ We use flexattention to construct the attention mask
373
+ """
374
+ total_length = num_frames * frame_seqlen
375
+
376
+ # we do right padding to get to a multiple of 128
377
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
378
+
379
+ ends = torch.zeros(total_length + padded_length,
380
+ device=device, dtype=torch.long)
381
+
382
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
383
+ frame_indices = torch.arange(
384
+ start=0,
385
+ end=total_length,
386
+ step=frame_seqlen * num_frame_per_block,
387
+ device=device
388
+ )
389
+
390
+ for tmp in frame_indices:
391
+ ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
392
+ frame_seqlen * num_frame_per_block
393
+
394
+ def attention_mask(b, h, q_idx, kv_idx):
395
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
396
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
397
+
398
+ block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
399
+ KV_LEN=total_length + padded_length, device=device)
400
+
401
+ return block_mask
402
+
403
+ def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
404
+ # Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
405
+ effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
406
+
407
+ # Fast path for full attention (camera head) - skip mask computation
408
+ if full_attention:
409
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
410
+ return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
411
+
412
+ def ffn_residual_func(x: Tensor) -> Tensor:
413
+ return self.ls2(self.mlp(self.norm2(x)))
414
+
415
+ if self.training and self.sample_drop_ratio > 0.0:
416
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
417
+ x = x + self.drop_path1(ffn_residual_func(x))
418
+ else:
419
+ x = x + attn_residual_func(x, pos=pos)
420
+ x = x + ffn_residual_func(x)
421
+ return x
422
+
423
+ mask_block = self._prepare_blockwise_causal_attn_mask(
424
+ device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
425
+
426
+
427
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
428
+ return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
429
+ enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
430
+
431
+ def ffn_residual_func(x: Tensor) -> Tensor:
432
+ return self.ls2(self.mlp(self.norm2(x)))
433
+
434
+ if self.training and self.sample_drop_ratio > 0.0:
435
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
436
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
437
+ else:
438
+ x = x + attn_residual_func(x, pos=pos)
439
+ x = x + ffn_residual_func(x)
440
+ return x
441
+
442
+
443
+ class SDPABlock(nn.Module):
444
+ """
445
+ SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
446
+ with dict-based KV cache. No FlashInfer dependency required.
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ dim: int,
452
+ num_heads: int,
453
+ mlp_ratio: float = 4.0,
454
+ qkv_bias: bool = True,
455
+ proj_bias: bool = True,
456
+ ffn_bias: bool = True,
457
+ drop: float = 0.0,
458
+ attn_drop: float = 0.0,
459
+ init_values=None,
460
+ drop_path: float = 0.0,
461
+ act_layer: Callable[..., nn.Module] = nn.GELU,
462
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
463
+ ffn_layer: Callable[..., nn.Module] = Mlp,
464
+ qk_norm: bool = False,
465
+ rope=None,
466
+ kv_cache_sliding_window: int = 64,
467
+ kv_cache_scale_frames: int = 8,
468
+ kv_cache_cross_frame_special: bool = True,
469
+ kv_cache_include_scale_frames: bool = True,
470
+ kv_cache_camera_only: bool = False,
471
+ ) -> None:
472
+ super().__init__()
473
+ self.norm1 = norm_layer(dim)
474
+ self.attn = SDPAAttention(
475
+ dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
476
+ proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
477
+ kv_cache_sliding_window=kv_cache_sliding_window,
478
+ kv_cache_scale_frames=kv_cache_scale_frames,
479
+ kv_cache_cross_frame_special=kv_cache_cross_frame_special,
480
+ kv_cache_include_scale_frames=kv_cache_include_scale_frames,
481
+ kv_cache_camera_only=kv_cache_camera_only,
482
+ )
483
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
484
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
485
+ self.norm2 = norm_layer(dim)
486
+ self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
487
+ act_layer=act_layer, drop=drop, bias=ffn_bias)
488
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
489
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
490
+ self.sample_drop_ratio = drop_path
491
+
492
+ def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
493
+ num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
494
+ kv_cache=None, global_idx=0, num_frame_per_block=1,
495
+ num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
496
+ def attn_residual_func(x, pos=None):
497
+ return self.ls1(self.attn(
498
+ self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
499
+ num_patches=num_patches, num_special=num_special, num_frames=num_frames,
500
+ enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
501
+ num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
502
+ num_register_tokens=num_register_tokens,
503
+ ))
504
+
505
+ def ffn_residual_func(x):
506
+ return self.ls2(self.mlp(self.norm2(x)))
507
+
508
+ if self.training and self.sample_drop_ratio > 0.0:
509
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
510
+ x = x + self.drop_path1(ffn_residual_func(x))
511
+ else:
512
+ x = x + attn_residual_func(x, pos=pos)
513
+ x = x + ffn_residual_func(x)
514
+ return x
lingbot_map/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
lingbot_map/layers/flashinfer_cache.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlashInfer KV Cache Manager — Two-Stream Paged Design.
3
+
4
+ Two logical streams sharing one physical page pool per layer:
5
+
6
+ Patch stream (recyclable):
7
+ - page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
8
+ - Exactly 1 patch page per frame
9
+ - Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
10
+ - Recent frames → live_window_patch_pages (evicted when > sliding_window)
11
+
12
+ Special stream (append-only, never recycled):
13
+ - num_special_tokens (6) special tokens per frame
14
+ - Packed continuously: one special page holds floor(page_size/6) frames
15
+ e.g. page_size=256 → 42 frames per special page, 4 slots wasted
16
+ - Specials written for EVERY frame (including scale + window), not just evicted ones.
17
+
18
+ Physical layout per block:
19
+ kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
20
+ Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
21
+ Pages max_patch_pages .. max_pages-1: special page pool (append-only)
22
+ dim 1: 0=K 1=V
23
+
24
+ Attention computation:
25
+ visible = scale_patch_pages + live_window_patch_pages + all_special_pages
26
+ Special pages placed LAST → paged_kv_last_page_len naturally describes
27
+ the partial special-tail without a custom mask.
28
+
29
+ plan() is called ONCE per frame step (when block_idx == 0).
30
+ run() is called per layer, reusing the same plan. All layers at the
31
+ same frame step have identical page structures (same page IDs in same
32
+ positions), so reusing the plan across layers is correct.
33
+
34
+ Public API is drop-in compatible with the previous FlashInferKVCacheManager:
35
+ append_frame(block_idx, k, v)
36
+ evict_frames(block_idx, scale_frames, sliding_window, ...)
37
+ compute_attention(block_idx, q) -> out
38
+ reset()
39
+ """
40
+
41
+ import collections
42
+ import math
43
+ from typing import List
44
+
45
+ import torch
46
+ from torch import Tensor
47
+
48
+ try:
49
+ import flashinfer
50
+ FLASHINFER_AVAILABLE = True
51
+ except ImportError:
52
+ FLASHINFER_AVAILABLE = False
53
+
54
+
55
+ class FlashInferKVCacheManager:
56
+ """
57
+ Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
58
+
59
+ Args:
60
+ num_blocks: Number of Transformer blocks (one cache per block).
61
+ max_num_frames: Maximum frames held in the KV window at once
62
+ (scale_frames + sliding_window + headroom).
63
+ tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
64
+ num_heads: Number of KV heads (= QO heads; MHA assumed).
65
+ head_dim: Head dimension (64 for ViT-L).
66
+ dtype: Storage dtype (bfloat16 / float16).
67
+ device: CUDA device.
68
+ num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
69
+ scale_frames: Number of always-resident scale frames (8).
70
+ sliding_window: Sliding window size (64).
71
+ max_total_frames: Upper bound on total frames ever processed; used to
72
+ pre-allocate the special page pool (default 2048).
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ num_blocks: int,
78
+ max_num_frames: int,
79
+ tokens_per_frame: int,
80
+ num_heads: int,
81
+ head_dim: int,
82
+ dtype: torch.dtype,
83
+ device: torch.device,
84
+ num_special_tokens: int = 6,
85
+ scale_frames: int = 8,
86
+ sliding_window: int = 64,
87
+ max_total_frames: int = 2048,
88
+ force_fp32: bool = False,
89
+ fa3: bool = False,
90
+ ):
91
+ if not FLASHINFER_AVAILABLE:
92
+ raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
93
+
94
+ self.num_blocks = num_blocks
95
+ self.num_special_tokens = num_special_tokens # 6
96
+ self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
97
+ # Use exact page_size = patches_per_frame to eliminate zero-padded slots.
98
+ # FA2 (backend="fa2") supports non-power-of-2 page sizes.
99
+ # FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
100
+ p = self.patches_per_frame
101
+ if fa3:
102
+ # Round up to next power-of-2 for FA3 SM90 kernel requirement.
103
+ # e.g. 999 → 1024 (25 zero-padded slots per patch page)
104
+ self.page_size = 1 << (p - 1).bit_length()
105
+ else:
106
+ self.page_size = p # exact: no zero padding in patch pages
107
+ self.scale_frames = scale_frames # 8
108
+ self.sliding_window = sliding_window # 64
109
+ self.num_heads = num_heads
110
+ self.head_dim = head_dim
111
+ self.tokens_per_frame = tokens_per_frame
112
+
113
+ assert self.patches_per_frame > 0, (
114
+ f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
115
+ )
116
+ assert self.page_size > 0
117
+
118
+ # force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
119
+ # instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
120
+ # in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
121
+ self.force_fp32 = force_fp32
122
+ if force_fp32:
123
+ self.dtype = torch.float32
124
+ else:
125
+ if dtype == torch.float32:
126
+ dtype = torch.bfloat16
127
+ self.dtype = dtype
128
+ self.device = device
129
+
130
+ # ── Page pool sizing ─────────────────────────────────────────────────
131
+ # Patch: scale + window + 16 headroom (pages recycled → fixed count)
132
+ max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
133
+ # Special: enough for max_total_frames × 6 tokens, plus 16 headroom
134
+ max_special_pages = (
135
+ math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
136
+ )
137
+ self.max_patch_pages = max_patch_pages
138
+ self.max_num_pages = max_patch_pages + max_special_pages
139
+
140
+ # ── Physical paged KV caches ─────────────────────────────────────────
141
+ # Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
142
+ self.kv_caches: List[Tensor] = [
143
+ torch.zeros(
144
+ self.max_num_pages, 2, self.page_size, num_heads, head_dim,
145
+ dtype=dtype, device=device,
146
+ )
147
+ for _ in range(num_blocks)
148
+ ]
149
+
150
+ # ── Per-block state ──────────────────────────────────────────────────
151
+ # Patch pages (IDs 0 .. max_patch_pages-1)
152
+ self.scale_patch_pages: List[collections.deque] = [
153
+ collections.deque() for _ in range(num_blocks)
154
+ ]
155
+ self.live_window_patch_pages: List[collections.deque] = [
156
+ collections.deque() for _ in range(num_blocks)
157
+ ]
158
+ self.free_patch_pages: List[List[int]] = [
159
+ list(range(max_patch_pages)) for _ in range(num_blocks)
160
+ ]
161
+
162
+ # Special pages (IDs max_patch_pages .. max_num_pages-1)
163
+ self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
164
+ self.free_special_pages: List[List[int]] = [
165
+ list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
166
+ ]
167
+ self.special_token_count: List[int] = [0] * num_blocks
168
+
169
+ # Frame counter per block (determines scale vs window routing)
170
+ self.frame_count: List[int] = [0] * num_blocks
171
+
172
+ # Deferred eviction support for flow-based keyframe selection.
173
+ # When True, evict_frames() becomes a no-op; caller must later call
174
+ # execute_deferred_eviction() or rollback_last_frame().
175
+ self._defer_eviction: bool = False
176
+
177
+ # ── FlashInfer wrapper ───────────────────────────────────────────────
178
+ # plan() is called once per frame step (block_idx == 0).
179
+ # run() is called per layer, reusing the same aux structures.
180
+ # backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
181
+ # FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
182
+ # FlashInfer 0.2.5 at 518×378 resolution.
183
+ _fi_backend = "fa3" if fa3 else "fa2"
184
+ self.workspace_buffer = torch.zeros(
185
+ 128 * 1024 * 1024, dtype=torch.uint8, device=device
186
+ )
187
+ self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
188
+ self.workspace_buffer,
189
+ kv_layout="NHD",
190
+ backend=_fi_backend,
191
+ )
192
+
193
+ # plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
194
+ self._qo_indptr = torch.tensor(
195
+ [0, tokens_per_frame], dtype=torch.int32, device=device
196
+ )
197
+
198
+ # =========================================================================
199
+ # Public API (drop-in compatible with previous FlashInferKVCacheManager)
200
+ # =========================================================================
201
+
202
+ def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
203
+ """
204
+ Append one frame's K/V tensors to the two-stream cache.
205
+
206
+ Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
207
+ i.e. specials come first (matching stream.py's patch_start_idx convention).
208
+
209
+ Args:
210
+ block_idx: Block/layer index (0 … num_blocks-1).
211
+ k: [tokens_per_frame, H, D] NHD layout.
212
+ v: [tokens_per_frame, H, D] NHD layout.
213
+ """
214
+ n = self.num_special_tokens # 6
215
+ sp_k = k[:n].to(self.dtype) # [6, H, D]
216
+ patch_k = k[n:].to(self.dtype) # [256, H, D]
217
+ sp_v = v[:n].to(self.dtype)
218
+ patch_v = v[n:].to(self.dtype)
219
+
220
+ assert patch_k.shape[0] == self.patches_per_frame, (
221
+ f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
222
+ f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
223
+ )
224
+
225
+ self._write_patch_page(block_idx, patch_k, patch_v)
226
+ self._write_special_tokens(block_idx, sp_k, sp_v)
227
+ self.frame_count[block_idx] += 1
228
+
229
+ def evict_frames(
230
+ self,
231
+ block_idx: int,
232
+ scale_frames: int,
233
+ sliding_window: int,
234
+ cross_frame_special: bool = True,
235
+ include_scale_frames: bool = True,
236
+ camera_only: bool = False,
237
+ num_register_tokens: int = 4,
238
+ ) -> None:
239
+ """
240
+ Evict old window patch pages (recycle to free list).
241
+
242
+ Special pages are NEVER evicted.
243
+ Scale pages are NEVER evicted.
244
+ Only live_window_patch_pages beyond `sliding_window` are recycled.
245
+
246
+ When ``_defer_eviction`` is True, this method is a no-op. The caller
247
+ is expected to later call ``execute_deferred_eviction()`` (keep frame)
248
+ or ``rollback_last_frame()`` (discard frame).
249
+ """
250
+ if self._defer_eviction:
251
+ return
252
+ while len(self.live_window_patch_pages[block_idx]) > sliding_window:
253
+ old_page = self.live_window_patch_pages[block_idx].popleft()
254
+ self.free_patch_pages[block_idx].append(old_page)
255
+
256
+ def execute_deferred_eviction(
257
+ self,
258
+ block_idx: int,
259
+ scale_frames: int,
260
+ sliding_window: int,
261
+ **kwargs,
262
+ ) -> None:
263
+ """Run the eviction that was skipped while ``_defer_eviction`` was True."""
264
+ while len(self.live_window_patch_pages[block_idx]) > sliding_window:
265
+ old_page = self.live_window_patch_pages[block_idx].popleft()
266
+ self.free_patch_pages[block_idx].append(old_page)
267
+
268
+ def rollback_last_frame(self, block_idx: int) -> None:
269
+ """Undo the most recent ``append_frame()`` for *block_idx*.
270
+
271
+ This reverses all three sub-operations of ``append_frame``:
272
+ patch page allocation, special-token write, and frame_count increment.
273
+ It must be called **before** any eviction for that frame (i.e. while
274
+ ``_defer_eviction`` is True or before ``evict_frames`` is called).
275
+ """
276
+ assert self.frame_count[block_idx] > 0, (
277
+ f"block {block_idx}: cannot rollback, frame_count is 0"
278
+ )
279
+
280
+ # 1) Undo patch page ── pop from whichever deque it was routed to.
281
+ if self.frame_count[block_idx] > self.scale_frames:
282
+ page_id = self.live_window_patch_pages[block_idx].pop()
283
+ else:
284
+ page_id = self.scale_patch_pages[block_idx].pop()
285
+ self.free_patch_pages[block_idx].append(page_id)
286
+
287
+ # 2) Undo special tokens
288
+ n = self.num_special_tokens
289
+ new_count = self.special_token_count[block_idx] - n
290
+ assert new_count >= 0, (
291
+ f"block {block_idx}: special_token_count underflow "
292
+ f"({self.special_token_count[block_idx]} - {n})"
293
+ )
294
+ new_num_pages = math.ceil(new_count / self.page_size) if new_count > 0 else 0
295
+ while len(self.all_special_pages[block_idx]) > new_num_pages:
296
+ freed = self.all_special_pages[block_idx].pop()
297
+ self.free_special_pages[block_idx].append(freed)
298
+ self.special_token_count[block_idx] = new_count
299
+
300
+ # 3) Decrement frame count
301
+ self.frame_count[block_idx] -= 1
302
+
303
+ def _gather_kv(self, block_idx: int):
304
+ """
305
+ Gather all visible K and V tokens from the paged cache into dense tensors.
306
+
307
+ Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
308
+ supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
309
+
310
+ Returns:
311
+ k_flat: [kv_len, H, D] — all visible K tokens concatenated
312
+ v_flat: [kv_len, H, D] — all visible V tokens concatenated
313
+ """
314
+ visible = self.build_visible_page_table(block_idx)
315
+ last_len = self.compute_last_page_len(block_idx)
316
+ P = self.page_size
317
+
318
+ parts_k, parts_v = [], []
319
+ for i, pid in enumerate(visible):
320
+ n = last_len if (i == len(visible) - 1) else P
321
+ parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
322
+ parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
323
+
324
+ k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
325
+ v_flat = torch.cat(parts_v, dim=0)
326
+ return k_flat, v_flat
327
+
328
+ def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
329
+ """
330
+ Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
331
+
332
+ When self.force_fp32 is True, gathers all visible K/V into dense tensors
333
+ and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
334
+ This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
335
+
336
+ plan() is called once per frame step (when block_idx == 0).
337
+ All layers at the same step share the same visible page structure,
338
+ so the plan is reused by calling run() with each layer's kv_cache.
339
+
340
+ Args:
341
+ block_idx: Block/layer index.
342
+ q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
343
+
344
+ Returns:
345
+ out: [q_len, H, D]
346
+ """
347
+ if self.frame_count[block_idx] == 0:
348
+ # No KV present yet (should not occur in normal usage after append_frame)
349
+ return torch.zeros_like(q)
350
+
351
+ if self.force_fp32:
352
+ # ── fp32 gather+SDPA path ─────────────────────────────────────────
353
+ # Gather visible K/V from paged cache and run SDPA in fp32.
354
+ # This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
355
+ # q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
356
+ import torch.nn.functional as F_nn
357
+ k_flat, v_flat = self._gather_kv(block_idx)
358
+ q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
359
+ k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
360
+ v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
361
+ out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
362
+ return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
363
+
364
+ if block_idx == 0:
365
+ # ── Plan once per frame step ──────────────────────────────────────
366
+ # Build visible page table from block 0's state.
367
+ # All blocks have identical page structures, so this plan is valid
368
+ # for all subsequent run() calls (block_idx = 1, 2, ...).
369
+ visible = self.build_visible_page_table(0)
370
+ last_len = self.compute_last_page_len(0)
371
+
372
+ assert visible, "visible page table is empty after append_frame"
373
+ assert 1 <= last_len <= self.page_size, (
374
+ f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
375
+ )
376
+
377
+ paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
378
+ paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
379
+ paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
380
+
381
+ self.prefill_wrapper.plan(
382
+ self._qo_indptr,
383
+ paged_kv_indptr,
384
+ paged_kv_indices,
385
+ paged_kv_last_page_len,
386
+ num_qo_heads = self.num_heads,
387
+ num_kv_heads = self.num_heads,
388
+ head_dim_qk = self.head_dim,
389
+ page_size = self.page_size,
390
+ causal = False, # custom page ordering; no causal mask
391
+ pos_encoding_mode = "NONE", # RoPE applied externally before append
392
+ q_data_type = self.dtype,
393
+ )
394
+
395
+ # ── Run attention for this layer ──────────────────────────────────────
396
+ # Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
397
+ return self.prefill_wrapper.run(
398
+ q = q.to(self.dtype).contiguous(),
399
+ paged_kv_cache = self.kv_caches[block_idx],
400
+ ) # → [q_len, H, D]
401
+
402
+ def reset(self) -> None:
403
+ """Reset all per-block state for a new sequence."""
404
+ for i in range(self.num_blocks):
405
+ self.scale_patch_pages[i].clear()
406
+ self.live_window_patch_pages[i].clear()
407
+ self.all_special_pages[i].clear()
408
+ self.free_patch_pages[i] = list(range(self.max_patch_pages))
409
+ self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
410
+ self.special_token_count[i] = 0
411
+ self.frame_count[i] = 0
412
+
413
+ # =========================================================================
414
+ # Helper methods
415
+ # =========================================================================
416
+
417
+ def build_visible_page_table(self, block_idx: int) -> List[int]:
418
+ """
419
+ Return page IDs in strict order: scale → window → special.
420
+
421
+ Placing special pages last means only the final page may be partially
422
+ full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
423
+ without a custom attention mask.
424
+ """
425
+ return (
426
+ list(self.scale_patch_pages[block_idx]) +
427
+ list(self.live_window_patch_pages[block_idx]) +
428
+ list(self.all_special_pages[block_idx])
429
+ )
430
+
431
+ def compute_last_page_len(self, block_idx: int) -> int:
432
+ """
433
+ Valid token count in the last page of the visible sequence.
434
+
435
+ - No special pages → last page is a patch page.
436
+ Returns patches_per_frame (real tokens written),
437
+ which may be < page_size when page_size was rounded
438
+ up to a power of 2.
439
+ - Special tail partial → special_token_count % page_size.
440
+ - Special tail exactly full → page_size.
441
+ """
442
+ if not self.all_special_pages[block_idx]:
443
+ # Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
444
+ # positions P..page_size-1 are zero padding. Tell FlashInfer the true
445
+ # valid count so it doesn't read beyond the real tokens.
446
+ return self.patches_per_frame
447
+
448
+ tail = self.special_token_count[block_idx] % self.page_size
449
+ return self.page_size if tail == 0 else tail
450
+
451
+ # ── Internal write helpers ────────────────────────────────────────────────
452
+
453
+ def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
454
+ """
455
+ Allocate one free patch page and write patches_per_frame patch tokens.
456
+
457
+ Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
458
+ the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
459
+ kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
460
+ patch_k/v fill exactly one full page (patches_per_frame == page_size).
461
+
462
+ Routes to scale_patch_pages if still filling scale quota,
463
+ otherwise to live_window_patch_pages.
464
+
465
+ Returns:
466
+ page_id: Physical page index used.
467
+ """
468
+ assert self.free_patch_pages[block_idx], (
469
+ f"block {block_idx}: patch page pool exhausted — "
470
+ f"scale={len(self.scale_patch_pages[block_idx])}, "
471
+ f"window={len(self.live_window_patch_pages[block_idx])}, "
472
+ f"free={len(self.free_patch_pages[block_idx])}"
473
+ )
474
+
475
+ page_id = self.free_patch_pages[block_idx].pop()
476
+
477
+ # Direct slice write: positions 0..patches_per_frame-1.
478
+ # When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
479
+ # this is equivalent to a full-page write. When page_size > patches_per_frame
480
+ # (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
481
+ # positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
482
+ P = self.patches_per_frame
483
+ self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
484
+ self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
485
+
486
+ if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
487
+ self.scale_patch_pages[block_idx].append(page_id)
488
+ else:
489
+ self.live_window_patch_pages[block_idx].append(page_id)
490
+
491
+ return page_id
492
+
493
+ def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
494
+ """
495
+ Append num_special_tokens (6) special tokens to the special stream.
496
+
497
+ Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
498
+ tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
499
+ overhead of flashinfer.page.append_paged_kv_cache.
500
+
501
+ Handles page-boundary crossing: if 6 tokens straddle two pages, performs
502
+ two slice writes (rare — page_size=256 >> 6).
503
+ """
504
+ remaining = self.num_special_tokens # 6
505
+ written = 0
506
+
507
+ while remaining > 0:
508
+ tail_offset = self.special_token_count[block_idx] % self.page_size
509
+
510
+ if tail_offset == 0:
511
+ # Current tail page is full (or no page exists) — allocate a new one
512
+ assert self.free_special_pages[block_idx], (
513
+ f"block {block_idx}: special page pool exhausted at "
514
+ f"special_token_count={self.special_token_count[block_idx]}. "
515
+ f"Increase max_total_frames."
516
+ )
517
+ new_page = self.free_special_pages[block_idx].pop()
518
+ self.all_special_pages[block_idx].append(new_page)
519
+
520
+ tail_page = self.all_special_pages[block_idx][-1]
521
+ space = self.page_size - tail_offset # free slots in tail page
522
+ write_n = min(remaining, space)
523
+
524
+ # Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
525
+ # shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
526
+ end = tail_offset + write_n
527
+ self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
528
+ self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
529
+
530
+ self.special_token_count[block_idx] += write_n
531
+ written += write_n
532
+ remaining -= write_n
533
+
534
+ # ── Legacy property (used by stream.py) ──────────────────────────────────
535
+
536
+ @property
537
+ def num_frames(self) -> int:
538
+ """Number of frames appended to block 0 (representative)."""
539
+ return self.frame_count[0] if self.frame_count else 0
540
+
541
+
542
+ # =============================================================================
543
+ # Sanity check
544
+ # =============================================================================
545
+
546
+ def _sanity_check():
547
+ """
548
+ Minimal smoke test.
549
+ Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
550
+ """
551
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
552
+ if not torch.cuda.is_available():
553
+ print("[sanity_check] CUDA not available — skipping.")
554
+ return
555
+
556
+ tokens_per_frame = 262 # 256 patch + 6 special (224×224)
557
+ num_special = 6
558
+ patches_per_frame = tokens_per_frame - num_special # 256
559
+ page_size = patches_per_frame # 256
560
+
561
+ mgr = FlashInferKVCacheManager(
562
+ num_blocks = 2,
563
+ max_num_frames = 88,
564
+ tokens_per_frame = tokens_per_frame,
565
+ num_heads = 16,
566
+ head_dim = 64,
567
+ dtype = torch.bfloat16,
568
+ device = device,
569
+ num_special_tokens = num_special,
570
+ scale_frames = 8,
571
+ sliding_window = 64,
572
+ max_total_frames = 200,
573
+ )
574
+
575
+ def make_kv():
576
+ k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
577
+ v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
578
+ return k, v
579
+
580
+ def make_q():
581
+ return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
582
+
583
+ for block in range(2):
584
+ for t in range(100):
585
+ k, v = make_kv()
586
+ mgr.append_frame(block, k, v)
587
+ mgr.evict_frames(block, scale_frames=8, sliding_window=64)
588
+
589
+ # ── Page count checks ───────────────────────────────────────────────
590
+ n_scale = len(mgr.scale_patch_pages[block])
591
+ n_window = len(mgr.live_window_patch_pages[block])
592
+ n_spec = len(mgr.all_special_pages[block])
593
+ sp_count = mgr.special_token_count[block]
594
+
595
+ assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
596
+ assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
597
+ # 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
598
+ expected_spec_pages = math.ceil(100 * num_special / page_size)
599
+ assert n_spec == expected_spec_pages, (
600
+ f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
601
+ )
602
+ assert sp_count == 100 * num_special, (
603
+ f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
604
+ )
605
+
606
+ # ── last_page_len ────────────────────────────────────────────────────
607
+ last_len = mgr.compute_last_page_len(block)
608
+ tail = sp_count % page_size
609
+ expected_len = page_size if tail == 0 else tail
610
+ assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
611
+
612
+ # ── visible page table order ─────────────────────────────────────────
613
+ visible = mgr.build_visible_page_table(block)
614
+ assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
615
+ for pid in visible[:n_scale + n_window]:
616
+ assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
617
+ for pid in visible[n_scale + n_window:]:
618
+ assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
619
+
620
+ # ── forward pass: plan() once for block 0, run() for both blocks ─────
621
+ if block == 1:
622
+ # Simulate the actual calling pattern: plan on block 0, run on both
623
+ q0 = make_q()
624
+ out0 = mgr.compute_attention(0, q0) # triggers plan()
625
+ q1 = make_q()
626
+ out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
627
+ assert out0.shape == (tokens_per_frame, 16, 64)
628
+ assert out1.shape == (tokens_per_frame, 16, 64)
629
+
630
+ print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
631
+ f"special_pages={n_spec}, special_tokens={sp_count}, "
632
+ f"last_page_len={last_len}")
633
+
634
+ mgr.reset()
635
+ assert mgr.frame_count[0] == 0
636
+ print("\n[sanity_check] All assertions passed.")
637
+
638
+
639
+ if __name__ == "__main__":
640
+ _sanity_check()
lingbot_map/layers/layer_scale.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
17
+ super().__init__()
18
+ self.inplace = inplace
19
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
lingbot_map/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
lingbot_map/layers/patch_embed.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
51
+
52
+ self.img_size = image_HW
53
+ self.patch_size = patch_HW
54
+ self.patches_resolution = patch_grid_size
55
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
56
+
57
+ self.in_chans = in_chans
58
+ self.embed_dim = embed_dim
59
+
60
+ self.flatten_embedding = flatten_embedding
61
+
62
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
63
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
64
+
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ _, _, H, W = x.shape
67
+ patch_H, patch_W = self.patch_size
68
+
69
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
70
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
71
+
72
+ x = self.proj(x) # B C H W
73
+ H, W = x.size(2), x.size(3)
74
+ x = x.flatten(2).transpose(1, 2) # B HW C
75
+ x = self.norm(x)
76
+ if not self.flatten_embedding:
77
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
78
+ return x
79
+
80
+ def flops(self) -> float:
81
+ Ho, Wo = self.patches_resolution
82
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
83
+ if self.norm is not None:
84
+ flops += Ho * Wo * self.embed_dim
85
+ return flops
lingbot_map/layers/rope.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Tuple
22
+
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+
26
+ class PositionGetter:
27
+ """Generates and caches 2D spatial positions for patches in a grid.
28
+
29
+ This class efficiently manages the generation of spatial coordinates for patches
30
+ in a 2D grid, caching results to avoid redundant computations.
31
+
32
+ Attributes:
33
+ position_cache: Dictionary storing precomputed position tensors for different
34
+ grid dimensions.
35
+ """
36
+
37
+ def __init__(self):
38
+ """Initializes the position generator with an empty cache."""
39
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
40
+
41
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
42
+ """Generates spatial positions for a batch of patches.
43
+
44
+ Args:
45
+ batch_size: Number of samples in the batch.
46
+ height: Height of the grid in patches.
47
+ width: Width of the grid in patches.
48
+ device: Target device for the position tensor.
49
+
50
+ Returns:
51
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
52
+ for each position in the grid, repeated for each batch item.
53
+ """
54
+ if (height, width) not in self.position_cache:
55
+ y_coords = torch.arange(height, device=device)
56
+ x_coords = torch.arange(width, device=device)
57
+ positions = torch.cartesian_prod(y_coords, x_coords)
58
+ self.position_cache[height, width] = positions
59
+
60
+ cached_positions = self.position_cache[height, width]
61
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
62
+
63
+
64
+ class RotaryPositionEmbedding2D(nn.Module):
65
+ """2D Rotary Position Embedding implementation.
66
+
67
+ This module applies rotary position embeddings to input tokens based on their
68
+ 2D spatial positions. It handles the position-dependent rotation of features
69
+ separately for vertical and horizontal dimensions.
70
+
71
+ Args:
72
+ frequency: Base frequency for the position embeddings. Default: 100.0
73
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
74
+
75
+ Attributes:
76
+ base_frequency: Base frequency for computing position embeddings.
77
+ scaling_factor: Factor to scale the computed frequencies.
78
+ frequency_cache: Cache for storing precomputed frequency components.
79
+ """
80
+
81
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
82
+ """Initializes the 2D RoPE module."""
83
+ super().__init__()
84
+ self.base_frequency = frequency
85
+ self.scaling_factor = scaling_factor
86
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
87
+
88
+ def _compute_frequency_components(
89
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
90
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ """Computes frequency components for rotary embeddings.
92
+
93
+ Args:
94
+ dim: Feature dimension (must be even).
95
+ seq_len: Maximum sequence length.
96
+ device: Target device for computations.
97
+ dtype: Data type for the computed tensors.
98
+
99
+ Returns:
100
+ Tuple of (cosine, sine) tensors for frequency components.
101
+ """
102
+ cache_key = (dim, seq_len, device, dtype)
103
+ if cache_key not in self.frequency_cache:
104
+ # Compute frequency bands
105
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
106
+ inv_freq = 1.0 / (self.base_frequency**exponents)
107
+
108
+ # Generate position-dependent frequencies
109
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
110
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
111
+
112
+ # Compute and cache frequency components
113
+ angles = angles.to(dtype)
114
+ angles = torch.cat((angles, angles), dim=-1)
115
+ cos_components = angles.cos().to(dtype)
116
+ sin_components = angles.sin().to(dtype)
117
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
118
+
119
+ return self.frequency_cache[cache_key]
120
+
121
+ @staticmethod
122
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
123
+ """Performs feature rotation by splitting and recombining feature dimensions.
124
+
125
+ Args:
126
+ x: Input tensor to rotate.
127
+
128
+ Returns:
129
+ Rotated feature tensor.
130
+ """
131
+ feature_dim = x.shape[-1]
132
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+ def _apply_1d_rope(
136
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
137
+ ) -> torch.Tensor:
138
+ """Applies 1D rotary position embeddings along one dimension.
139
+
140
+ Args:
141
+ tokens: Input token features.
142
+ positions: Position indices.
143
+ cos_comp: Cosine components for rotation.
144
+ sin_comp: Sine components for rotation.
145
+
146
+ Returns:
147
+ Tokens with applied rotary position embeddings.
148
+ """
149
+ # Embed positions with frequency components
150
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
151
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
152
+
153
+ # Apply rotation
154
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
155
+
156
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
157
+ """Applies 2D rotary position embeddings to input tokens.
158
+
159
+ Args:
160
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
161
+ The feature dimension (dim) must be divisible by 4.
162
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
163
+ the y and x coordinates for each token.
164
+
165
+ Returns:
166
+ Tensor of same shape as input with applied 2D rotary position embeddings.
167
+
168
+ Raises:
169
+ AssertionError: If input dimensions are invalid or positions are malformed.
170
+ """
171
+ # Validate inputs
172
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
173
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
174
+
175
+ # Compute feature dimension for each spatial direction
176
+ feature_dim = tokens.size(-1) // 2
177
+
178
+ # Get frequency components
179
+ max_position = int(positions.max()) + 1
180
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
181
+
182
+ # Split features for vertical and horizontal processing
183
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
184
+
185
+ # Apply RoPE separately for each dimension
186
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
187
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
188
+
189
+ # Combine processed features
190
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
191
+
192
+
193
+
194
+ def get_1d_rotary_pos_embed(
195
+ dim: int,
196
+ pos: Union[np.ndarray, int],
197
+ theta: float = 10000.0,
198
+ use_real=False,
199
+ linear_factor=1.0,
200
+ ntk_factor=1.0,
201
+ repeat_interleave_real=True,
202
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
203
+ ):
204
+ """
205
+ 计算1D旋转位置编码(RoPE)的频率张量。
206
+
207
+ RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
208
+ 公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
209
+
210
+ Args:
211
+ dim: 特征维度,必须是偶数(因为要成对处理)
212
+ pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
213
+ theta: 基础频率,控制位置编码的周期性(默认10000)
214
+ use_real: 是否返回实数形式(cos和sin分开)还是复数形式
215
+ linear_factor: 线性缩放因子,用于上下文扩展
216
+ ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
217
+ repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
218
+ freqs_dtype: 频率张量的数据类型
219
+
220
+ Returns:
221
+ 复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
222
+ 实数形式:两个 [S, D] 的张量(cos和sin)
223
+ """
224
+ # 确保维度是偶数(RoPE需要成对处理维度)
225
+ assert dim % 2 == 0
226
+
227
+ # 将位置转换为torch张量
228
+ if isinstance(pos, int):
229
+ pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
230
+ if isinstance(pos, np.ndarray):
231
+ pos = torch.from_numpy(pos) # [S]
232
+
233
+ # 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
234
+ theta = theta * ntk_factor
235
+
236
+ # 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
237
+ # 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
238
+ # 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
239
+ freqs = (
240
+ 1.0
241
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
242
+ / linear_factor
243
+ ) # [D/2],每个频率对应一个维度对
244
+
245
+ # 步骤2:计算位置-频率矩阵
246
+ # ���用外积:pos[m] * freqs[i] = m * θ_i
247
+ # 结果:每个位置m和每个频率i的组合
248
+ freqs = torch.outer(pos, freqs) # [S, D/2]
249
+
250
+ # 步骤3:根据返回格式转换
251
+ if use_real and repeat_interleave_real:
252
+ # 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
253
+ # 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
254
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
255
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
256
+ return freqs_cos, freqs_sin
257
+ elif use_real:
258
+ # 方式2:拼接重复(用于stable audio, allegro等模型)
259
+ # 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
260
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
261
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
262
+ return freqs_cos, freqs_sin
263
+ else:
264
+ # 方式3:复数形式(用于lumina等模型)
265
+ # 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
266
+ # torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
267
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
268
+ return freqs_cis
269
+
270
+
271
+ class WanRotaryPosEmbed(nn.Module):
272
+ """
273
+ 3D旋转位置编码(3D RoPE)模块
274
+
275
+ 核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
276
+ 每个维度(t, h, w)独立使用RoPE,然后拼接起来。
277
+
278
+ 公式:
279
+ 对于3D位置 (f, h, w)(帧、高度、宽度):
280
+ - 帧维度使用 dim_f 个特征维度
281
+ - 高度维度使用 dim_h 个特征维度
282
+ - 宽度维度使用 dim_w 个特征维度
283
+ 其中 dim_f + dim_h + dim_w = attention_head_dim
284
+ """
285
+ def __init__(
286
+ self,
287
+ attention_head_dim: int,
288
+ patch_size: Tuple[int, int, int],
289
+ max_seq_len: int = 1024,
290
+ theta: float = 10000.0,
291
+ fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
292
+ ):
293
+ super().__init__()
294
+
295
+ self.attention_head_dim = attention_head_dim # 注意力头的总维度
296
+ self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
297
+ self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
298
+
299
+ # 步骤1:分配维度给三个空间维度
300
+ if fhw_dim is not None:
301
+ # 如果指定了维度分配,使用指定的
302
+ assert attention_head_dim == sum(
303
+ fhw_dim
304
+ ), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
305
+ t_dim, h_dim, w_dim = fhw_dim
306
+ else:
307
+ # 否则自动分配:h和w各占1/3,t占剩余
308
+ # 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
309
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
310
+ t_dim = attention_head_dim - h_dim - w_dim
311
+
312
+ # 保存维度分配以便在forward中使用
313
+ self.fhw_dim = (t_dim, h_dim, w_dim)
314
+
315
+ # 步骤2:为每个维度预计算频率
316
+ # 分别计算时间、高度、宽度三个维度的RoPE频率
317
+ freqs = []
318
+ for dim in [t_dim, h_dim, w_dim]:
319
+ # 每个维度独立调用1D RoPE
320
+ # 返回复数形式的频率: [max_seq_len, dim//2]
321
+ freq = get_1d_rotary_pos_embed(
322
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
323
+ )
324
+ freqs.append(freq)
325
+ # 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
326
+ self.freqs = torch.cat(freqs, dim=1)
327
+
328
+ def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
329
+ """
330
+ 前向传播:为3D输入(视频帧+patch)生成旋转位置编码
331
+
332
+ 参数:
333
+ - ppf (int): 帧数(patches per frame),当f_end为None时使用
334
+ - pph (int): 每帧的patch高度数量
335
+ - ppw (int): 每帧的patch宽度数量
336
+ - patch_start_idx (int): 每帧的特殊token数量(在patches之前)
337
+ - device: 计算设备(CPU/GPU)
338
+ - f_start (int): 起始帧索引(用于causal模式),默认为0
339
+ - f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
340
+
341
+ 返回:
342
+ - freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
343
+
344
+ Token排列顺序:
345
+ [frame0_special_token_0, ..., frame0_special_token_N,
346
+ frame0_patch_0, ..., frame0_patch_M,
347
+ frame1_special_token_0, ..., frame1_special_token_N,
348
+ frame1_patch_0, ..., frame1_patch_M,
349
+ ...]
350
+
351
+ 模���:
352
+ - 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
353
+ - Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
354
+ """
355
+
356
+ # 步骤1:将预计算的频率移到目标设备,并分割成三个维度
357
+ self.freqs = self.freqs.to(device)
358
+ # 获取实际的维度分配
359
+ if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
360
+ t_dim, h_dim, w_dim = self.fhw_dim
361
+ else:
362
+ # 自动分配的情况
363
+ h_dim = w_dim = 2 * (self.attention_head_dim // 6)
364
+ t_dim = self.attention_head_dim - h_dim - w_dim
365
+
366
+ # 使用正确的split sizes(每个维度的一半)
367
+ freqs = self.freqs.split_with_sizes(
368
+ [
369
+ t_dim // 2, # 时间维度
370
+ h_dim // 2, # 高度维度
371
+ w_dim // 2, # 宽度维度
372
+ ],
373
+ dim=1,
374
+ )
375
+
376
+ # 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
377
+ if f_end is not None:
378
+ ppf = f_end - f_start
379
+ frame_slice = slice(f_start, f_end)
380
+ else:
381
+ # 非causal模式:使用从0开始的ppf个帧
382
+ frame_slice = slice(0, ppf)
383
+
384
+ # 步骤2:处理特殊token(如果存在)
385
+ ## For other tokens
386
+ if patch_start_idx > 0:
387
+ # 2.1 为特殊token生成位置编码
388
+ # 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
389
+ # camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
390
+ # Shape: (ppf, patch_start_idx, dim)
391
+ freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
392
+ freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
393
+ freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
394
+ freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
395
+ freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
396
+
397
+ # 2.2 为图像patch生成位置编码
398
+ # Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
399
+ # 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
400
+ # Shape: (ppf, pph, ppw, dim)
401
+ freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
402
+ freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
403
+ freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
404
+ freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
405
+ freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
406
+
407
+ # 步骤3:按照正确的顺序组合特殊token和patches
408
+ # 每帧内部顺序:[特殊tokens, patches]
409
+ # Concatenate special tokens and patches for each frame along the second dimension
410
+ # Shape: (ppf, patch_start_idx + pph * ppw, dim)
411
+ freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
412
+
413
+ # 步骤4:展平为最终形状并添加batch和head维度
414
+ # Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
415
+ freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
416
+ freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
417
+ return freqs
418
+
419
+ # 如果没有特殊token(patch_start_idx == 0),只处理图像patches
420
+ # 所有patches位于 (f, 0:pph, 0:ppw)
421
+ freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
422
+ freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
423
+ freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
424
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
425
+ return freqs
426
+
427
+ def apply_rotary_emb(x, freqs):
428
+ """
429
+ 应用旋转位置编码到输入特征
430
+
431
+ 核心思想:使用复数乘法实现���征旋转,保持相对位置信息
432
+
433
+ 数学原理:
434
+ 对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
435
+ (x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
436
+ = (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
437
+
438
+ 这等价于旋转矩阵:
439
+ [cos(θ) -sin(θ)] [x1]
440
+ [sin(θ) cos(θ)] [x2]
441
+
442
+ 参数:
443
+ - x: 输入特征 [batch, heads, seq_len, head_dim]
444
+ - freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
445
+
446
+ 返回:
447
+ - x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
448
+
449
+ 实现步骤:
450
+ 1. 将x的每两个连续特征看作一个复数 (real, imag)
451
+ 2. 与预计算的复数频率 e^(iθ) 相乘
452
+ 3. 转回实数表示
453
+ """
454
+ # 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
455
+ # 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
456
+ x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
457
+
458
+ # 步骤2:转换为复数表示 [b, h, seq, 32]
459
+ # 每个元素是 real + imag*i
460
+ x_complex = torch.view_as_complex(x_reshaped)
461
+
462
+ # 步骤3:复数乘法实现旋转
463
+ # x_complex * freqs 相当于将每对特征旋转θ角度
464
+ # freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
465
+ x_rotated = x_complex * freqs
466
+
467
+ # 步骤4:转回实数表示 [b, h, seq, 32, 2]
468
+ x_real = torch.view_as_real(x_rotated)
469
+
470
+ # 步骤5:展平最后两维 [b, h, seq, 64]
471
+ x_out = x_real.flatten(3)
472
+
473
+ # 步骤6:转回原始数据类型
474
+ return x_out.to(x.dtype)
lingbot_map/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ # try:
39
+ # if XFORMERS_ENABLED:
40
+ # from xformers.ops import SwiGLU
41
+
42
+ # XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ # else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ # raise ImportError
47
+ # except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
lingbot_map/layers/vision_transformer.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block
20
+
21
+ # TODO: Check this
22
+ # We replace NestedTensorBlock with Block
23
+ from .block import Block
24
+
25
+ logger = logging.getLogger("dinov2")
26
+
27
+
28
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
29
+ if not depth_first and include_root:
30
+ fn(module=module, name=name)
31
+ for child_name, child_module in module.named_children():
32
+ child_name = ".".join((name, child_name)) if name else child_name
33
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
34
+ if depth_first and include_root:
35
+ fn(module=module, name=name)
36
+ return module
37
+
38
+
39
+ class BlockChunk(nn.ModuleList):
40
+ def forward(self, x):
41
+ for b in self:
42
+ x = b(x)
43
+ return x
44
+
45
+
46
+ class DinoVisionTransformer(nn.Module):
47
+ def __init__(
48
+ self,
49
+ img_size=224,
50
+ patch_size=16,
51
+ in_chans=3,
52
+ embed_dim=768,
53
+ depth=12,
54
+ num_heads=12,
55
+ mlp_ratio=4.0,
56
+ qkv_bias=True,
57
+ ffn_bias=True,
58
+ proj_bias=True,
59
+ drop_path_rate=0.0,
60
+ drop_path_uniform=False,
61
+ init_values=None, # for layerscale: None or 0 => no layerscale
62
+ embed_layer=PatchEmbed,
63
+ act_layer=nn.GELU,
64
+ block_fn=Block,
65
+ ffn_layer="mlp",
66
+ block_chunks=1,
67
+ num_register_tokens=0,
68
+ interpolate_antialias=False,
69
+ interpolate_offset=0.1,
70
+ drop_cls_token=False,
71
+ qk_norm=False,
72
+ ):
73
+ """
74
+ Args:
75
+ img_size (int, tuple): input image size
76
+ patch_size (int, tuple): patch size
77
+ in_chans (int): number of input channels
78
+ embed_dim (int): embedding dimension
79
+ depth (int): depth of transformer
80
+ num_heads (int): number of attention heads
81
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
82
+ qkv_bias (bool): enable bias for qkv if True
83
+ proj_bias (bool): enable bias for proj in attn if True
84
+ ffn_bias (bool): enable bias for ffn if True
85
+ drop_path_rate (float): stochastic depth rate
86
+ drop_path_uniform (bool): apply uniform drop rate across blocks
87
+ weight_init (str): weight init scheme
88
+ init_values (float): layer-scale init values
89
+ embed_layer (nn.Module): patch embedding layer
90
+ act_layer (nn.Module): MLP activation layer
91
+ block_fn (nn.Module): transformer block class
92
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
93
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
94
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
95
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
96
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
97
+ """
98
+ super().__init__()
99
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
100
+
101
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
102
+ self.num_tokens = 1 if not drop_cls_token else 0
103
+ self.n_blocks = depth
104
+ self.num_heads = num_heads
105
+ self.patch_size = patch_size
106
+ self.num_register_tokens = num_register_tokens
107
+ self.interpolate_antialias = interpolate_antialias
108
+ self.interpolate_offset = interpolate_offset
109
+ self.use_reentrant = False # hardcoded to False
110
+
111
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
112
+ num_patches = self.patch_embed.num_patches
113
+
114
+ self.drop_cls_token = drop_cls_token
115
+
116
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None
117
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
118
+ assert num_register_tokens >= 0
119
+ self.register_tokens = (
120
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
121
+ )
122
+
123
+ if drop_path_uniform is True:
124
+ dpr = [drop_path_rate] * depth
125
+ else:
126
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
127
+
128
+ if ffn_layer == "mlp":
129
+ logger.info("using MLP layer as FFN")
130
+ ffn_layer = Mlp
131
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
132
+ logger.info("using SwiGLU layer as FFN")
133
+ ffn_layer = SwiGLUFFNFused
134
+ elif ffn_layer == "identity":
135
+ logger.info("using Identity layer as FFN")
136
+
137
+ def f(*args, **kwargs):
138
+ return nn.Identity()
139
+
140
+ ffn_layer = f
141
+ else:
142
+ raise NotImplementedError
143
+
144
+ blocks_list = [
145
+ block_fn(
146
+ dim=embed_dim,
147
+ num_heads=num_heads,
148
+ mlp_ratio=mlp_ratio,
149
+ qkv_bias=qkv_bias,
150
+ proj_bias=proj_bias,
151
+ ffn_bias=ffn_bias,
152
+ drop_path=dpr[i],
153
+ norm_layer=norm_layer,
154
+ act_layer=act_layer,
155
+ ffn_layer=ffn_layer,
156
+ init_values=init_values,
157
+ qk_norm=qk_norm,
158
+ )
159
+ for i in range(depth)
160
+ ]
161
+ if block_chunks > 0:
162
+ self.chunked_blocks = True
163
+ chunked_blocks = []
164
+ chunksize = depth // block_chunks
165
+ for i in range(0, depth, chunksize):
166
+ # this is to keep the block index consistent if we chunk the block list
167
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
168
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
169
+ else:
170
+ self.chunked_blocks = False
171
+ self.blocks = nn.ModuleList(blocks_list)
172
+
173
+ self.norm = norm_layer(embed_dim)
174
+ self.head = nn.Identity()
175
+
176
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
177
+
178
+ self.init_weights()
179
+
180
+ def init_weights(self):
181
+ trunc_normal_(self.pos_embed, std=0.02)
182
+ nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None
183
+ if self.register_tokens is not None:
184
+ nn.init.normal_(self.register_tokens, std=1e-6)
185
+ named_apply(init_weights_vit_timm, self)
186
+
187
+ def interpolate_pos_encoding(self, x, w, h):
188
+ previous_dtype = x.dtype
189
+ npatch = x.shape[1] - 1
190
+ N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1]
191
+ if npatch == N and w == h:
192
+ return self.pos_embed
193
+ pos_embed = self.pos_embed.float()
194
+ if not self.drop_cls_token:
195
+ class_pos_embed = pos_embed[:, 0]
196
+ patch_pos_embed = pos_embed[:, 1:]
197
+ else:
198
+ patch_pos_embed = pos_embed
199
+ dim = x.shape[-1]
200
+ w0 = w // self.patch_size
201
+ h0 = h // self.patch_size
202
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
203
+ assert N == M * M
204
+ kwargs = {}
205
+ if self.interpolate_offset:
206
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
207
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
208
+ sx = float(w0 + self.interpolate_offset) / M
209
+ sy = float(h0 + self.interpolate_offset) / M
210
+ kwargs["scale_factor"] = (sx, sy)
211
+ else:
212
+ # Simply specify an output size instead of a scale factor
213
+ kwargs["size"] = (w0, h0)
214
+ patch_pos_embed = nn.functional.interpolate(
215
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
216
+ mode="bicubic",
217
+ antialias=self.interpolate_antialias,
218
+ **kwargs,
219
+ )
220
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
221
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
222
+ if not self.drop_cls_token:
223
+ x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
224
+ else:
225
+ x = patch_pos_embed
226
+ return x.to(previous_dtype)
227
+
228
+ def prepare_tokens_with_masks(self, x, masks=None):
229
+ B, nc, w, h = x.shape
230
+ x = self.patch_embed(x)
231
+ if masks is not None:
232
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
233
+
234
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x
235
+ x = x + self.interpolate_pos_encoding(x, w, h)
236
+
237
+ if self.register_tokens is not None:
238
+ x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
239
+
240
+ return x
241
+
242
+ def forward_features_list(self, x_list, masks_list):
243
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
244
+
245
+ for blk in self.blocks:
246
+ if self.training:
247
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
248
+ else:
249
+ x = blk(x)
250
+
251
+ all_x = x
252
+ output = []
253
+ for x, masks in zip(all_x, masks_list):
254
+ x_norm = self.norm(x)
255
+ output.append(
256
+ {
257
+ "x_norm_clstoken": x_norm[:, 0],
258
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
259
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
260
+ "x_prenorm": x,
261
+ "masks": masks,
262
+ }
263
+ )
264
+ return output
265
+
266
+ def forward_features(self, x, masks=None):
267
+ if isinstance(x, list):
268
+ return self.forward_features_list(x, masks)
269
+
270
+ x = self.prepare_tokens_with_masks(x, masks)
271
+
272
+ for blk in self.blocks:
273
+ if self.training:
274
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
275
+ else:
276
+ x = blk(x)
277
+
278
+ x_norm = self.norm(x)
279
+ return {
280
+ "x_norm_clstoken": x_norm[:, 0],
281
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
282
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
283
+ "x_prenorm": x,
284
+ "masks": masks,
285
+ }
286
+
287
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
288
+ x = self.prepare_tokens_with_masks(x)
289
+ # If n is an int, take the n last blocks. If it's a list, take them
290
+ output, total_block_len = [], len(self.blocks)
291
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
292
+ for i, blk in enumerate(self.blocks):
293
+ x = blk(x)
294
+ if i in blocks_to_take:
295
+ output.append(x)
296
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
297
+ return output
298
+
299
+ def _get_intermediate_layers_chunked(self, x, n=1):
300
+ x = self.prepare_tokens_with_masks(x)
301
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
302
+ # If n is an int, take the n last blocks. If it's a list, take them
303
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
304
+ for block_chunk in self.blocks:
305
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
306
+ x = blk(x)
307
+ if i in blocks_to_take:
308
+ output.append(x)
309
+ i += 1
310
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
311
+ return output
312
+
313
+ def get_intermediate_layers(
314
+ self,
315
+ x: torch.Tensor,
316
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
317
+ reshape: bool = False,
318
+ return_class_token: bool = False,
319
+ norm=True,
320
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
321
+ if self.chunked_blocks:
322
+ outputs = self._get_intermediate_layers_chunked(x, n)
323
+ else:
324
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
325
+ if norm:
326
+ outputs = [self.norm(out) for out in outputs]
327
+ class_tokens = [out[:, 0] for out in outputs]
328
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
329
+ if reshape:
330
+ B, _, w, h = x.shape
331
+ outputs = [
332
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
333
+ for out in outputs
334
+ ]
335
+ if return_class_token:
336
+ return tuple(zip(outputs, class_tokens))
337
+ return tuple(outputs)
338
+
339
+ def forward(self, *args, is_training=True, **kwargs):
340
+ ret = self.forward_features(*args, **kwargs)
341
+ if is_training:
342
+ return ret
343
+ else:
344
+ return self.head(ret["x_norm_clstoken"])
345
+
346
+
347
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
348
+ """ViT weight initialization, original timm impl (for reproducibility)"""
349
+ if isinstance(module, nn.Linear):
350
+ trunc_normal_(module.weight, std=0.02)
351
+ if module.bias is not None:
352
+ nn.init.zeros_(module.bias)
353
+
354
+
355
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
356
+ model = DinoVisionTransformer(
357
+ patch_size=patch_size,
358
+ embed_dim=384,
359
+ depth=12,
360
+ num_heads=6,
361
+ mlp_ratio=4,
362
+ block_fn=partial(Block, attn_class=MemEffAttention),
363
+ num_register_tokens=num_register_tokens,
364
+ **kwargs,
365
+ )
366
+ return model
367
+
368
+
369
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
370
+ model = DinoVisionTransformer(
371
+ patch_size=patch_size,
372
+ embed_dim=768,
373
+ depth=12,
374
+ num_heads=12,
375
+ mlp_ratio=4,
376
+ block_fn=partial(Block, attn_class=MemEffAttention),
377
+ num_register_tokens=num_register_tokens,
378
+ **kwargs,
379
+ )
380
+ return model
381
+
382
+
383
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
384
+ model = DinoVisionTransformer(
385
+ patch_size=patch_size,
386
+ embed_dim=1024,
387
+ depth=24,
388
+ num_heads=16,
389
+ mlp_ratio=4,
390
+ block_fn=partial(Block, attn_class=MemEffAttention),
391
+ num_register_tokens=num_register_tokens,
392
+ **kwargs,
393
+ )
394
+ return model
395
+
396
+
397
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
398
+ """
399
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
400
+ """
401
+ model = DinoVisionTransformer(
402
+ patch_size=patch_size,
403
+ embed_dim=1536,
404
+ depth=40,
405
+ num_heads=24,
406
+ mlp_ratio=4,
407
+ block_fn=partial(Block, attn_class=MemEffAttention),
408
+ num_register_tokens=num_register_tokens,
409
+ **kwargs,
410
+ )
411
+ return model
lingbot_map/models/__init__.py ADDED
File without changes
lingbot_map/models/gct_base.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GCTBase - Base class for GCT model implementations.
3
+
4
+ Provides shared functionality:
5
+ - Prediction heads (camera, depth, point)
6
+ - Forward pass structure
7
+ - Model hub mixin (PyTorchModelHubMixin)
8
+ """
9
+
10
+ import logging
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from abc import ABC, abstractmethod
15
+ from typing import Optional, Dict, Any, List, Union
16
+ from huggingface_hub import PyTorchModelHubMixin
17
+
18
+ from lingbot_map.heads.dpt_head import DPTHead
19
+ from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
20
+ from lingbot_map.utils.geometry import closed_form_inverse_se3
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class GCTBase(nn.Module, PyTorchModelHubMixin, ABC):
26
+ """
27
+ Base class for GCT model implementations.
28
+
29
+ Handles shared components:
30
+ - Prediction heads (camera, depth, point)
31
+ - Forward pass structure
32
+ - Input normalization
33
+
34
+ Subclasses must implement:
35
+ - _build_aggregator(): Create mode-specific aggregator
36
+ - _build_camera_head(): Create mode-specific camera head
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ # Architecture parameters
42
+ img_size: int = 518,
43
+ patch_size: int = 14,
44
+ embed_dim: int = 1024,
45
+ patch_embed: str = 'dinov2_vitl14_reg',
46
+ disable_global_rope: bool = False,
47
+ # Head configuration
48
+ enable_camera: bool = True,
49
+ enable_point: bool = True,
50
+ enable_local_point: bool = False,
51
+ enable_depth: bool = True,
52
+ enable_track: bool = False,
53
+ # Camera head sliding window
54
+ enable_camera_sliding_window: bool = False,
55
+ # 3D RoPE
56
+ enable_3d_rope: bool = False,
57
+ # Context Parallelism (kept for checkpoint compatibility but not used)
58
+ enable_ulysses_cp: bool = False,
59
+ # Normalization
60
+ enable_normalize: bool = False,
61
+ # Prediction normalization
62
+ pred_normalization: bool = False,
63
+ pred_normalization_detach_scale: bool = False,
64
+ # Gradient checkpointing
65
+ use_gradient_checkpoint: bool = True,
66
+ ):
67
+ super().__init__()
68
+
69
+ # Store configuration
70
+ self.img_size = img_size
71
+ self.patch_size = patch_size
72
+ self.embed_dim = embed_dim
73
+ self.patch_embed = patch_embed
74
+ self.disable_global_rope = disable_global_rope
75
+
76
+ self.enable_ulysses_cp = False # CP disabled in standalone package
77
+ self.enable_normalize = enable_normalize
78
+ self.pred_normalization = pred_normalization
79
+ self.pred_normalization_detach_scale = pred_normalization_detach_scale
80
+ self.use_gradient_checkpoint = use_gradient_checkpoint
81
+
82
+ # Head flags
83
+ self.enable_camera = enable_camera
84
+ self.enable_point = enable_point
85
+ self.enable_local_point = enable_local_point
86
+ self.enable_depth = enable_depth
87
+ self.enable_track = enable_track
88
+ self.enable_camera_sliding_window = enable_camera_sliding_window
89
+ self.enable_3d_rope = enable_3d_rope
90
+
91
+ # Build aggregator (subclass-specific)
92
+ self.aggregator = self._build_aggregator()
93
+
94
+ # Build prediction heads (subclass-specific)
95
+ self.camera_head = self._build_camera_head() if enable_camera else None
96
+ self.point_head = self._build_point_head() if enable_point else None
97
+ self.local_point_head = self._build_local_point_head() if enable_local_point else None
98
+ self.depth_head = self._build_depth_head() if enable_depth else None
99
+
100
+ @abstractmethod
101
+ def _build_aggregator(self) -> nn.Module:
102
+ pass
103
+
104
+ @abstractmethod
105
+ def _build_camera_head(self) -> nn.Module:
106
+ pass
107
+
108
+ def _build_depth_head(self) -> nn.Module:
109
+ return DPTHead(
110
+ dim_in=2 * self.embed_dim,
111
+ patch_size=self.patch_size,
112
+ output_dim=2,
113
+ activation="exp",
114
+ conf_activation="expp1"
115
+ )
116
+
117
+ def _build_point_head(self) -> nn.Module:
118
+ return DPTHead(
119
+ dim_in=2 * self.embed_dim,
120
+ patch_size=self.patch_size,
121
+ output_dim=4,
122
+ activation="inv_log",
123
+ conf_activation="expp1"
124
+ )
125
+
126
+ def _build_local_point_head(self) -> nn.Module:
127
+ return DPTHead(
128
+ dim_in=2 * self.embed_dim,
129
+ patch_size=self.patch_size,
130
+ output_dim=4,
131
+ activation="inv_log",
132
+ conf_activation="expp1"
133
+ )
134
+
135
+ def _normalize_input(self, images: torch.Tensor, query_points=None):
136
+ if len(images.shape) == 4:
137
+ images = images.unsqueeze(0)
138
+ if query_points is not None and len(query_points.shape) == 2:
139
+ query_points = query_points.unsqueeze(0)
140
+ return images, query_points
141
+
142
+ @abstractmethod
143
+ def _aggregate_features(
144
+ self,
145
+ images: torch.Tensor,
146
+ num_frame_for_scale: Optional[int] = None,
147
+ sliding_window_size: Optional[int] = None,
148
+ num_frame_per_block: int = 1,
149
+ view_graphs: Optional[torch.Tensor] = None,
150
+ causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None,
151
+ ordered_video: Optional[torch.Tensor] = None,
152
+ is_cp_sliced: bool = False,
153
+ ) -> tuple:
154
+ pass
155
+
156
+ def _predict_camera(
157
+ self,
158
+ aggregated_tokens_list: list,
159
+ mask: Optional[torch.Tensor] = None,
160
+ causal_inference: bool = False,
161
+ num_frame_for_scale: Optional[int] = None,
162
+ sliding_window_size: Optional[int] = None,
163
+ num_frame_per_block: int = 1,
164
+ gather_outputs: bool = True,
165
+ ) -> Dict[str, torch.Tensor]:
166
+ if self.camera_head is None:
167
+ return {}
168
+
169
+ aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
170
+
171
+ camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1
172
+
173
+ with torch.amp.autocast('cuda', enabled=False):
174
+ pose_enc_list = self.camera_head(
175
+ aggregated_tokens_list_fp32,
176
+ mask=mask,
177
+ causal_inference=causal_inference,
178
+ num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1,
179
+ sliding_window_size=camera_sliding_window,
180
+ num_frame_per_block=num_frame_per_block,
181
+ )
182
+
183
+ return {
184
+ "pose_enc": pose_enc_list[-1],
185
+ "pose_enc_list": pose_enc_list,
186
+ }
187
+
188
+ def _predict_depth(
189
+ self,
190
+ aggregated_tokens_list: list,
191
+ images: torch.Tensor,
192
+ patch_start_idx: int,
193
+ gather_outputs: bool = True,
194
+ ) -> Dict[str, torch.Tensor]:
195
+ if self.depth_head is None:
196
+ return {}
197
+
198
+ aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
199
+ images_fp32 = images.float()
200
+
201
+ with torch.amp.autocast('cuda', enabled=False):
202
+ depth, depth_conf = self.depth_head(
203
+ aggregated_tokens_list_fp32,
204
+ images=images_fp32,
205
+ patch_start_idx=patch_start_idx
206
+ )
207
+
208
+ return {"depth": depth, "depth_conf": depth_conf}
209
+
210
+ def _predict_points(
211
+ self,
212
+ aggregated_tokens_list: list,
213
+ images: torch.Tensor,
214
+ patch_start_idx: int,
215
+ gather_outputs: bool = True,
216
+ ) -> Dict[str, torch.Tensor]:
217
+ if self.point_head is None:
218
+ return {}
219
+
220
+ aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
221
+ images_fp32 = images.float()
222
+
223
+ with torch.amp.autocast('cuda', enabled=False):
224
+ pts3d, pts3d_conf = self.point_head(
225
+ aggregated_tokens_list_fp32,
226
+ images=images_fp32,
227
+ patch_start_idx=patch_start_idx
228
+ )
229
+
230
+ return {"world_points": pts3d, "world_points_conf": pts3d_conf}
231
+
232
+ def _predict_local_points(
233
+ self,
234
+ aggregated_tokens_list: list,
235
+ images: torch.Tensor,
236
+ patch_start_idx: int,
237
+ gather_outputs: bool = True,
238
+ ) -> Dict[str, torch.Tensor]:
239
+ if self.local_point_head is None:
240
+ return {}
241
+
242
+ aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
243
+ images_fp32 = images.float()
244
+
245
+ with torch.amp.autocast('cuda', enabled=False):
246
+ pts3d, pts3d_conf = self.local_point_head(
247
+ aggregated_tokens_list_fp32,
248
+ images=images_fp32,
249
+ patch_start_idx=patch_start_idx
250
+ )
251
+
252
+ return {"cam_points": pts3d, "cam_points_conf": pts3d_conf}
253
+
254
+ def _unproject_depth_to_world(
255
+ self,
256
+ depth: torch.Tensor,
257
+ pose_enc: torch.Tensor,
258
+ ) -> torch.Tensor:
259
+ B, S, H, W, _ = depth.shape
260
+ device = depth.device
261
+ dtype = depth.dtype
262
+
263
+ image_size_hw = (H, W)
264
+ extrinsics, intrinsics = pose_encoding_to_extri_intri(
265
+ pose_enc, image_size_hw=image_size_hw, build_intrinsics=True
266
+ )
267
+
268
+ extrinsics_flat = extrinsics.view(B * S, 3, 4)
269
+ extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype)
270
+ extrinsics_4x4[:, :3, :] = extrinsics_flat
271
+ extrinsics_4x4[:, 3, 3] = 1.0
272
+ c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4)
273
+
274
+ y_grid, x_grid = torch.meshgrid(
275
+ torch.arange(H, device=device, dtype=dtype),
276
+ torch.arange(W, device=device, dtype=dtype),
277
+ indexing='ij'
278
+ )
279
+ pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1)
280
+
281
+ intrinsics_inv = torch.inverse(intrinsics)
282
+ camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords)
283
+ camera_points = camera_coords * depth
284
+
285
+ ones = torch.ones_like(camera_points[..., :1])
286
+ camera_points_h = torch.cat([camera_points, ones], dim=-1)
287
+ world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h)
288
+
289
+ return world_points_h[..., :3]
290
+
291
+ def forward(
292
+ self,
293
+ images: torch.Tensor,
294
+ query_points: Optional[torch.Tensor] = None,
295
+ num_frame_for_scale: Optional[int] = None,
296
+ sliding_window_size: Optional[int] = None,
297
+ num_frame_per_block: int = 1,
298
+ mask: Optional[torch.Tensor] = None,
299
+ causal_inference: bool = False,
300
+ ordered_video: Optional[torch.Tensor] = None,
301
+ gather_outputs: bool = True,
302
+ point_masks: Optional[torch.Tensor] = None,
303
+ **kwargs,
304
+ ) -> Dict[str, torch.Tensor]:
305
+ """
306
+ Forward pass of the GCT model.
307
+
308
+ Args:
309
+ images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
310
+ query_points: Optional query points [N, 2] or [B, N, 2]
311
+
312
+ Returns:
313
+ Dictionary containing predictions:
314
+ - pose_enc: Camera pose encoding [B, S, 9]
315
+ - depth: Depth maps [B, S, H, W, 1]
316
+ - depth_conf: Depth confidence [B, S, H, W]
317
+ - world_points: 3D world coordinates [B, S, H, W, 3]
318
+ - world_points_conf: Point confidence [B, S, H, W]
319
+ """
320
+ images, query_points = self._normalize_input(images, query_points)
321
+
322
+ aggregated_tokens_list, patch_start_idx = self._aggregate_features(
323
+ images,
324
+ num_frame_for_scale=num_frame_for_scale,
325
+ sliding_window_size=sliding_window_size,
326
+ num_frame_per_block=num_frame_per_block,
327
+ )
328
+
329
+ predictions = {}
330
+
331
+ predictions.update(self._predict_camera(
332
+ aggregated_tokens_list,
333
+ mask=ordered_video,
334
+ causal_inference=causal_inference,
335
+ num_frame_for_scale=num_frame_for_scale,
336
+ sliding_window_size=sliding_window_size,
337
+ num_frame_per_block=num_frame_per_block,
338
+ gather_outputs=gather_outputs,
339
+ ))
340
+
341
+ predictions.update(self._predict_depth(
342
+ aggregated_tokens_list, images, patch_start_idx,
343
+ gather_outputs=gather_outputs,
344
+ ))
345
+
346
+ predictions.update(self._predict_points(
347
+ aggregated_tokens_list, images, patch_start_idx,
348
+ gather_outputs=gather_outputs,
349
+ ))
350
+
351
+ predictions.update(self._predict_local_points(
352
+ aggregated_tokens_list, images, patch_start_idx,
353
+ gather_outputs=gather_outputs,
354
+ ))
355
+
356
+ if not self.training:
357
+ predictions["images"] = images
358
+
359
+ return predictions
lingbot_map/models/gct_stream.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GCTStream - Streaming GCT with KV cache for online inference.
3
+
4
+ Provides streaming inference functionality:
5
+ - Temporal causal attention with KV cache
6
+ - Sliding window support
7
+ - Efficient frame-by-frame processing
8
+ - 3D RoPE support for temporal consistency
9
+ """
10
+
11
+ import logging
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing import Optional, Dict, Any, List
15
+ from tqdm.auto import tqdm
16
+
17
+ from lingbot_map.heads.camera_head import CameraCausalHead
18
+ from lingbot_map.models.gct_base import GCTBase
19
+ from lingbot_map.aggregator.stream import AggregatorStream
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class GCTStream(GCTBase):
25
+ """
26
+ Streaming GCT model with KV cache for efficient online inference.
27
+
28
+ Features:
29
+ - AggregatorStream with KV cache support (FlashInfer backend)
30
+ - CameraCausalHead for pose refinement
31
+ - Sliding window attention for memory efficiency
32
+ - Frame-by-frame streaming inference
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ # Architecture parameters
38
+ img_size: int = 518,
39
+ patch_size: int = 14,
40
+ embed_dim: int = 1024,
41
+ patch_embed: str = 'dinov2_vitl14_reg',
42
+ pretrained_path: str = '',
43
+ disable_global_rope: bool = False,
44
+ # Head configuration
45
+ enable_camera: bool = True,
46
+ enable_point: bool = True,
47
+ enable_local_point: bool = False,
48
+ enable_depth: bool = True,
49
+ enable_track: bool = False,
50
+ # Normalization
51
+ enable_normalize: bool = False,
52
+ # Prediction normalization
53
+ pred_normalization: bool = False,
54
+ # Stream-specific parameters
55
+ sliding_window_size: int = -1,
56
+ num_frame_for_scale: int = 1,
57
+ num_random_frames: int = 0,
58
+ attend_to_special_tokens: bool = False,
59
+ attend_to_scale_frames: bool = False,
60
+ enable_stream_inference: bool = True, # Default to True for streaming
61
+ enable_3d_rope: bool = False,
62
+ max_frame_num: int = 1024,
63
+ # Camera head 3D RoPE (separate from aggregator 3D RoPE)
64
+ enable_camera_3d_rope: bool = False,
65
+ camera_rope_theta: float = 10000.0,
66
+ # Scale token configuration (kept for checkpoint compat, ignored)
67
+ use_scale_token: bool = True,
68
+ # KV cache parameters
69
+ kv_cache_sliding_window: int = 64,
70
+ kv_cache_scale_frames: int = 8,
71
+ kv_cache_cross_frame_special: bool = True,
72
+ kv_cache_include_scale_frames: bool = True,
73
+ kv_cache_camera_only: bool = False,
74
+ # Backend selection
75
+ use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
76
+ # Gradient checkpointing
77
+ use_gradient_checkpoint: bool = True,
78
+ # Camera head iterative refinement (lower = faster inference; default 4)
79
+ camera_num_iterations: int = 4,
80
+ ):
81
+ """
82
+ Initialize GCTStream.
83
+
84
+ Args:
85
+ img_size: Input image size
86
+ patch_size: Patch size for embedding
87
+ embed_dim: Embedding dimension
88
+ patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
89
+ pretrained_path: Path to pretrained DINOv2 weights
90
+ disable_global_rope: Disable RoPE in global attention
91
+ enable_camera/point/depth/track: Enable prediction heads
92
+ enable_normalize: Enable normalization
93
+ sliding_window_size: Sliding window size in blocks (-1 for full causal)
94
+ num_frame_for_scale: Number of scale estimation frames
95
+ num_random_frames: Number of random frames for long-range dependencies
96
+ attend_to_special_tokens: Enable cross-frame special token attention
97
+ attend_to_scale_frames: Whether to attend to scale frames
98
+ enable_stream_inference: Enable streaming inference with KV cache
99
+ enable_3d_rope: Enable 3D RoPE for temporal consistency
100
+ max_frame_num: Maximum number of frames for 3D RoPE
101
+ use_scale_token: Kept for checkpoint compatibility, ignored
102
+ kv_cache_sliding_window: Sliding window size for KV cache eviction
103
+ kv_cache_scale_frames: Number of scale frames to keep in KV cache
104
+ kv_cache_cross_frame_special: Keep special tokens from evicted frames
105
+ kv_cache_include_scale_frames: Include scale frames in KV cache
106
+ kv_cache_camera_only: Only keep camera tokens from evicted frames
107
+ """
108
+ # Store stream-specific parameters before calling super().__init__()
109
+ self.pretrained_path = pretrained_path
110
+ self.sliding_window_size = sliding_window_size
111
+ self.num_frame_for_scale = num_frame_for_scale
112
+ self.num_random_frames = num_random_frames
113
+ self.attend_to_special_tokens = attend_to_special_tokens
114
+ self.attend_to_scale_frames = attend_to_scale_frames
115
+ self.enable_stream_inference = enable_stream_inference
116
+ self.enable_3d_rope = enable_3d_rope
117
+ self.max_frame_num = max_frame_num
118
+ # Camera head 3D RoPE settings
119
+ self.enable_camera_3d_rope = enable_camera_3d_rope
120
+ self.camera_rope_theta = camera_rope_theta
121
+ # KV cache parameters
122
+ self.kv_cache_sliding_window = kv_cache_sliding_window
123
+ self.kv_cache_scale_frames = kv_cache_scale_frames
124
+ self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
125
+ self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
126
+ self.kv_cache_camera_only = kv_cache_camera_only
127
+ self.use_sdpa = use_sdpa
128
+ self.camera_num_iterations = camera_num_iterations
129
+
130
+ # Call base class __init__ (will call _build_aggregator)
131
+ super().__init__(
132
+ img_size=img_size,
133
+ patch_size=patch_size,
134
+ embed_dim=embed_dim,
135
+ patch_embed=patch_embed,
136
+ disable_global_rope=disable_global_rope,
137
+ enable_camera=enable_camera,
138
+ enable_point=enable_point,
139
+ enable_local_point=enable_local_point,
140
+ enable_depth=enable_depth,
141
+ enable_track=enable_track,
142
+ enable_normalize=enable_normalize,
143
+ pred_normalization=pred_normalization,
144
+ enable_3d_rope=enable_3d_rope,
145
+ use_gradient_checkpoint=use_gradient_checkpoint,
146
+ )
147
+
148
+ def _build_aggregator(self) -> nn.Module:
149
+ """
150
+ Build streaming aggregator with KV cache support (FlashInfer backend).
151
+
152
+ Returns:
153
+ AggregatorStream module
154
+ """
155
+ return AggregatorStream(
156
+ img_size=self.img_size,
157
+ patch_size=self.patch_size,
158
+ embed_dim=self.embed_dim,
159
+ patch_embed=self.patch_embed,
160
+ pretrained_path=self.pretrained_path,
161
+ disable_global_rope=self.disable_global_rope,
162
+ sliding_window_size=self.sliding_window_size,
163
+ num_frame_for_scale=self.num_frame_for_scale,
164
+ num_random_frames=self.num_random_frames,
165
+ attend_to_special_tokens=self.attend_to_special_tokens,
166
+ attend_to_scale_frames=self.attend_to_scale_frames,
167
+ enable_stream_inference=self.enable_stream_inference,
168
+ enable_3d_rope=self.enable_3d_rope,
169
+ max_frame_num=self.max_frame_num,
170
+ # Backend: FlashInfer (default) or SDPA (fallback)
171
+ use_flashinfer=not self.use_sdpa,
172
+ use_sdpa=self.use_sdpa,
173
+ kv_cache_sliding_window=self.kv_cache_sliding_window,
174
+ kv_cache_scale_frames=self.kv_cache_scale_frames,
175
+ kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
176
+ kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
177
+ kv_cache_camera_only=self.kv_cache_camera_only,
178
+ use_gradient_checkpoint=self.use_gradient_checkpoint,
179
+ )
180
+
181
+ def _build_camera_head(self) -> nn.Module:
182
+ """
183
+ Build causal camera head for streaming inference.
184
+
185
+ Returns:
186
+ CameraCausalHead module or None
187
+ """
188
+ return CameraCausalHead(
189
+ dim_in=2 * self.embed_dim,
190
+ sliding_window_size=self.sliding_window_size,
191
+ attend_to_scale_frames=self.attend_to_scale_frames,
192
+ num_iterations=self.camera_num_iterations,
193
+ # KV cache parameters
194
+ kv_cache_sliding_window=self.kv_cache_sliding_window,
195
+ kv_cache_scale_frames=self.kv_cache_scale_frames,
196
+ kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
197
+ kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
198
+ kv_cache_camera_only=self.kv_cache_camera_only,
199
+ # Camera head 3D RoPE parameters
200
+ enable_3d_rope=self.enable_camera_3d_rope,
201
+ max_frame_num=self.max_frame_num,
202
+ rope_theta=self.camera_rope_theta,
203
+ )
204
+
205
+ def _aggregate_features(
206
+ self,
207
+ images: torch.Tensor,
208
+ num_frame_for_scale: Optional[int] = None,
209
+ sliding_window_size: Optional[int] = None,
210
+ num_frame_per_block: int = 1,
211
+ **kwargs,
212
+ ) -> tuple:
213
+ """
214
+ Run aggregator to get multi-scale features.
215
+
216
+ Args:
217
+ images: Input images [B, S, 3, H, W]
218
+ num_frame_for_scale: Number of frames for scale estimation
219
+ sliding_window_size: Override sliding window size
220
+ num_frame_per_block: Number of frames per block
221
+
222
+ Returns:
223
+ (aggregated_tokens_list, patch_start_idx)
224
+ """
225
+ aggregated_tokens_list, patch_start_idx = self.aggregator(
226
+ images,
227
+ selected_idx=[4, 11, 17, 23],
228
+ num_frame_for_scale=num_frame_for_scale,
229
+ sliding_window_size=sliding_window_size,
230
+ num_frame_per_block=num_frame_per_block,
231
+ )
232
+ return aggregated_tokens_list, patch_start_idx
233
+
234
+ def clean_kv_cache(self):
235
+ """
236
+ Clean KV cache in aggregator.
237
+
238
+ Call this method when starting a new video sequence to clear
239
+ cached key-value pairs from previous sequences.
240
+ """
241
+ if hasattr(self.aggregator, 'clean_kv_cache'):
242
+ self.aggregator.clean_kv_cache()
243
+ else:
244
+ logger.warning("Aggregator does not support KV cache cleaning")
245
+ if hasattr(self.camera_head, 'kv_cache'):
246
+ self.camera_head.clean_kv_cache()
247
+ else:
248
+ logger.warning("Camera head does not support KV cache cleaning")
249
+
250
+ def _set_skip_append(self, skip: bool):
251
+ """Set _skip_append flag on all KV caches (aggregator + camera head).
252
+
253
+ When skip=True, attention layers will attend to [cached_kv + current_kv]
254
+ but will NOT store the current frame's KV in cache. This is used for
255
+ non-keyframe processing in keyframe-based streaming inference.
256
+
257
+ Args:
258
+ skip: If True, subsequent forward passes will not append KV to cache.
259
+ """
260
+ if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
261
+ self.aggregator.kv_cache["_skip_append"] = skip
262
+ if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
263
+ for cache_dict in self.camera_head.kv_cache:
264
+ cache_dict["_skip_append"] = skip
265
+
266
+ def get_kv_cache_info(self) -> Dict[str, Any]:
267
+ """
268
+ Get information about current KV cache state.
269
+
270
+ Returns:
271
+ Dictionary with cache statistics:
272
+ - num_cached_blocks: Number of blocks with cached KV
273
+ - cache_memory_mb: Approximate memory usage in MB
274
+ """
275
+ if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
276
+ return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
277
+
278
+ kv_cache = self.aggregator.kv_cache
279
+ num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
280
+
281
+ # Estimate memory usage
282
+ total_elements = 0
283
+ for _, v in kv_cache.items():
284
+ if v is not None and torch.is_tensor(v):
285
+ total_elements += v.numel()
286
+
287
+ # Assume bfloat16 (2 bytes per element)
288
+ cache_memory_mb = (total_elements * 2) / (1024 * 1024)
289
+
290
+ return {
291
+ "num_cached_blocks": num_cached,
292
+ "cache_memory_mb": round(cache_memory_mb, 2)
293
+ }
294
+
295
+ @torch.no_grad()
296
+ def inference_streaming(
297
+ self,
298
+ images: torch.Tensor,
299
+ num_scale_frames: Optional[int] = None,
300
+ keyframe_interval: int = 1,
301
+ output_device: Optional[torch.device] = None,
302
+ ) -> Dict[str, torch.Tensor]:
303
+ """
304
+ Streaming inference: process scale frames first, then frame-by-frame.
305
+
306
+ This method enables efficient online inference by:
307
+ 1. Processing initial scale frames together (bidirectional attention via scale token)
308
+ 2. Processing remaining frames one-by-one with KV cache (causal streaming)
309
+
310
+ Keyframe mode (keyframe_interval > 1):
311
+ - Every keyframe_interval-th frame (after scale frames) is a keyframe
312
+ - Keyframes: KV is stored in cache (normal behavior)
313
+ - Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
314
+ - All frames produce full predictions regardless of keyframe status
315
+ - Reduces KV cache memory growth by ~1/keyframe_interval
316
+
317
+ Args:
318
+ images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
319
+ num_scale_frames: Number of initial frames for scale estimation.
320
+ If None, uses self.num_frame_for_scale.
321
+ keyframe_interval: Every N-th frame (after scale frames) is a keyframe
322
+ whose KV persists in cache. 1 = every frame is a
323
+ keyframe (default, same as original behavior).
324
+ output_device: Device to store output predictions on. If None, keeps on
325
+ the same device as the model. Set to torch.device('cpu')
326
+ to offload predictions per-frame and avoid GPU OOM on
327
+ long sequences.
328
+
329
+ Returns:
330
+ Dictionary containing predictions for all frames:
331
+ - pose_enc: [B, S, 9]
332
+ - depth: [B, S, H, W, 1]
333
+ - depth_conf: [B, S, H, W]
334
+ - world_points: [B, S, H, W, 3]
335
+ - world_points_conf: [B, S, H, W]
336
+ """
337
+ # Normalize input shape
338
+ if len(images.shape) == 4:
339
+ images = images.unsqueeze(0)
340
+ B, S, C, H, W = images.shape
341
+
342
+ # Determine number of scale frames
343
+ scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
344
+ scale_frames = min(scale_frames, S) # Cap to available frames
345
+
346
+ # Helper to move tensor to output device
347
+ def _to_out(t: torch.Tensor) -> torch.Tensor:
348
+ if output_device is not None:
349
+ return t.to(output_device)
350
+ return t
351
+
352
+ # Clean KV caches before starting new sequence
353
+ self.clean_kv_cache()
354
+
355
+ # Phase 1: Process scale frames together
356
+ # These frames get bidirectional attention among themselves via scale token
357
+ logger.info(f'Processing {scale_frames} scale frames...')
358
+ scale_images = images[:, :scale_frames]
359
+ scale_output = self.forward(
360
+ scale_images,
361
+ num_frame_for_scale=scale_frames,
362
+ num_frame_per_block=scale_frames, # Process all scale frames as one block
363
+ causal_inference=True,
364
+ )
365
+
366
+ # Initialize output lists with scale frame predictions (offload if needed)
367
+ all_pose_enc = [_to_out(scale_output["pose_enc"])]
368
+ all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
369
+ all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
370
+ all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
371
+ all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
372
+ del scale_output
373
+
374
+ # Phase 2: Process remaining frames one-by-one
375
+ pbar = tqdm(
376
+ range(scale_frames, S),
377
+ desc='Streaming inference',
378
+ initial=scale_frames,
379
+ total=S,
380
+ )
381
+ for i in pbar:
382
+ frame_image = images[:, i:i+1]
383
+
384
+ # Determine if this frame is a keyframe
385
+ is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
386
+
387
+ if not is_keyframe:
388
+ self._set_skip_append(True)
389
+
390
+ frame_output = self.forward(
391
+ frame_image,
392
+ num_frame_for_scale=scale_frames, # Keep same for scale token logic
393
+ num_frame_per_block=1, # Single frame per block
394
+ causal_inference=True,
395
+ )
396
+
397
+ if not is_keyframe:
398
+ self._set_skip_append(False)
399
+
400
+ all_pose_enc.append(_to_out(frame_output["pose_enc"]))
401
+ if "depth" in frame_output:
402
+ all_depth.append(_to_out(frame_output["depth"]))
403
+ if "depth_conf" in frame_output:
404
+ all_depth_conf.append(_to_out(frame_output["depth_conf"]))
405
+ if "world_points" in frame_output:
406
+ all_world_points.append(_to_out(frame_output["world_points"]))
407
+ if "world_points_conf" in frame_output:
408
+ all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
409
+ del frame_output
410
+
411
+ # Free GPU memory before concatenation
412
+ if output_device is not None:
413
+ # Move images to output device, then free GPU copy
414
+ images_out = _to_out(images)
415
+ del images
416
+ # Clean KV cache (no longer needed after inference)
417
+ self.clean_kv_cache()
418
+ if torch.cuda.is_available():
419
+ torch.cuda.empty_cache()
420
+ else:
421
+ images_out = images
422
+
423
+ # Concatenate all predictions along sequence dimension
424
+ predictions = {
425
+ "pose_enc": torch.cat(all_pose_enc, dim=1),
426
+ }
427
+ del all_pose_enc
428
+ if all_depth:
429
+ predictions["depth"] = torch.cat(all_depth, dim=1)
430
+ del all_depth
431
+ if all_depth_conf:
432
+ predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
433
+ del all_depth_conf
434
+ if all_world_points:
435
+ predictions["world_points"] = torch.cat(all_world_points, dim=1)
436
+ del all_world_points
437
+ if all_world_points_conf:
438
+ predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
439
+ del all_world_points_conf
440
+
441
+ # Store images for visualization
442
+ predictions["images"] = images_out
443
+
444
+ # Apply prediction normalization if enabled
445
+ if self.pred_normalization:
446
+ predictions = self._normalize_predictions(predictions)
447
+
448
+ return predictions
lingbot_map/models/gct_stream_window.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GCTStream - Streaming GCT with KV cache for online inference.
3
+
4
+ Provides streaming inference functionality:
5
+ - Temporal causal attention with KV cache
6
+ - Sliding window support
7
+ - Efficient frame-by-frame processing
8
+ - 3D RoPE support for temporal consistency
9
+ """
10
+
11
+ import logging
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing import Optional, Dict, Any, List
15
+ from tqdm.auto import tqdm
16
+
17
+ from lingbot_map.utils.rotation import quat_to_mat, mat_to_quat
18
+
19
+ from lingbot_map.heads.camera_head import CameraCausalHead
20
+ from lingbot_map.models.gct_base import GCTBase
21
+ from lingbot_map.aggregator.stream import AggregatorStream
22
+ from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
23
+ from lingbot_map.utils.geometry import closed_form_inverse_se3
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @torch.no_grad()
29
+ def _compute_flow_magnitude(
30
+ cur_pose_enc: torch.Tensor,
31
+ kf_pose_enc: torch.Tensor,
32
+ cur_depth: torch.Tensor,
33
+ image_size_hw: tuple,
34
+ stride: int = 8,
35
+ ) -> float:
36
+ """Compute mean optical flow magnitude induced by camera motion.
37
+
38
+ Projects current frame pixels into the last keyframe camera using the
39
+ current depth map and both frames' poses, then returns the average
40
+ pixel displacement (L2 norm of flow) over valid pixels.
41
+
42
+ Args:
43
+ cur_pose_enc: Current frame pose encoding [B, 1, 9].
44
+ kf_pose_enc: Last keyframe pose encoding [B, 1, 9].
45
+ cur_depth: Current frame depth map [B, 1, H, W, 1].
46
+ image_size_hw: (H, W) of the depth map.
47
+ stride: Subsampling stride for efficiency.
48
+
49
+ Returns:
50
+ Mean flow magnitude in pixels (scalar float).
51
+ """
52
+ H, W = image_size_hw
53
+ device = cur_pose_enc.device
54
+ dtype = cur_depth.dtype
55
+
56
+ cur_ext, cur_intr = pose_encoding_to_extri_intri(
57
+ cur_pose_enc, image_size_hw=image_size_hw
58
+ )
59
+ kf_ext, kf_intr = pose_encoding_to_extri_intri(
60
+ kf_pose_enc, image_size_hw=image_size_hw
61
+ )
62
+ B = cur_ext.shape[0]
63
+
64
+ cur_ext = cur_ext[:, 0]
65
+ cur_intr = cur_intr[:, 0]
66
+ kf_ext = kf_ext[:, 0]
67
+ kf_intr = kf_intr[:, 0]
68
+
69
+ depth = cur_depth[:, 0, ::stride, ::stride, 0].to(dtype)
70
+ Hs, Ws = depth.shape[1], depth.shape[2]
71
+
72
+ v_coords = torch.arange(0, H, stride, device=device, dtype=dtype)
73
+ u_coords = torch.arange(0, W, stride, device=device, dtype=dtype)
74
+ v_grid, u_grid = torch.meshgrid(v_coords, u_coords, indexing='ij')
75
+ ones = torch.ones_like(u_grid)
76
+ pixel_coords = torch.stack([u_grid, v_grid, ones], dim=-1)
77
+
78
+ intr_inv = torch.inverse(cur_intr)
79
+ cam_coords = torch.einsum('bij,hwj->bhwi', intr_inv, pixel_coords)
80
+ cam_pts = cam_coords * depth.unsqueeze(-1)
81
+
82
+ c2w = torch.zeros(B, 4, 4, device=device, dtype=dtype)
83
+ c2w[:, :3, :] = cur_ext
84
+ c2w[:, 3, 3] = 1.0
85
+
86
+ ones_hw = torch.ones(B, Hs, Ws, 1, device=device, dtype=dtype)
87
+ cam_pts_h = torch.cat([cam_pts, ones_hw], dim=-1)
88
+ world_pts = torch.einsum('bij,bhwj->bhwi', c2w, cam_pts_h)[..., :3]
89
+
90
+ kf_c2w = torch.zeros(B, 4, 4, device=device, dtype=dtype)
91
+ kf_c2w[:, :3, :] = kf_ext
92
+ kf_c2w[:, 3, 3] = 1.0
93
+ kf_w2c = closed_form_inverse_se3(kf_c2w)
94
+ world_pts_h = torch.cat([world_pts, ones_hw], dim=-1)
95
+ kf_cam_pts = torch.einsum('bij,bhwj->bhwi', kf_w2c, world_pts_h)[..., :3]
96
+
97
+ z = kf_cam_pts[..., 2:3].clamp(min=1e-6)
98
+ kf_cam_norm = kf_cam_pts / z
99
+ kf_pixels = torch.einsum('bij,bhwj->bhwi', kf_intr, kf_cam_norm)[..., :2]
100
+
101
+ orig_pixels = torch.stack([u_grid, v_grid], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
102
+
103
+ flow = kf_pixels - orig_pixels
104
+ valid = (depth > 1e-6) & (kf_cam_pts[..., 2] > 1e-6)
105
+
106
+ flow_mag = flow.norm(dim=-1)
107
+ valid_count = valid.float().sum()
108
+ if valid_count < 1:
109
+ return 0.0
110
+
111
+ mean_mag = (flow_mag * valid.float()).sum() / valid_count
112
+ return mean_mag.item()
113
+
114
+
115
+ class GCTStream(GCTBase):
116
+ """
117
+ Streaming GCT model with KV cache for efficient online inference.
118
+
119
+ Features:
120
+ - AggregatorStream with KV cache support (FlashInfer backend)
121
+ - CameraCausalHead for pose refinement
122
+ - Sliding window attention for memory efficiency
123
+ - Frame-by-frame streaming inference
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ # Architecture parameters
129
+ img_size: int = 518,
130
+ patch_size: int = 14,
131
+ embed_dim: int = 1024,
132
+ patch_embed: str = 'dinov2_vitl14_reg',
133
+ pretrained_path: str = '',
134
+ disable_global_rope: bool = False,
135
+ # Head configuration
136
+ enable_camera: bool = True,
137
+ enable_point: bool = True,
138
+ enable_local_point: bool = False,
139
+ enable_depth: bool = True,
140
+ enable_track: bool = False,
141
+ # Normalization
142
+ enable_normalize: bool = False,
143
+ # Prediction normalization
144
+ pred_normalization: bool = False,
145
+ # Stream-specific parameters
146
+ sliding_window_size: int = -1,
147
+ num_frame_for_scale: int = 1,
148
+ num_random_frames: int = 0,
149
+ attend_to_special_tokens: bool = False,
150
+ attend_to_scale_frames: bool = False,
151
+ enable_stream_inference: bool = True, # Default to True for streaming
152
+ enable_3d_rope: bool = False,
153
+ max_frame_num: int = 1024,
154
+ # Camera head 3D RoPE (separate from aggregator 3D RoPE)
155
+ enable_camera_3d_rope: bool = False,
156
+ camera_rope_theta: float = 10000.0,
157
+ # Scale token configuration (kept for checkpoint compat, ignored)
158
+ use_scale_token: bool = True,
159
+ # KV cache parameters
160
+ kv_cache_sliding_window: int = 64,
161
+ kv_cache_scale_frames: int = 8,
162
+ kv_cache_cross_frame_special: bool = True,
163
+ kv_cache_include_scale_frames: bool = True,
164
+ kv_cache_camera_only: bool = False,
165
+ # Backend selection
166
+ use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
167
+ # Gradient checkpointing
168
+ use_gradient_checkpoint: bool = True,
169
+ # Camera head iterative refinement (lower = faster inference; default 4)
170
+ camera_num_iterations: int = 4,
171
+ ):
172
+ """
173
+ Initialize GCTStream.
174
+
175
+ Args:
176
+ img_size: Input image size
177
+ patch_size: Patch size for embedding
178
+ embed_dim: Embedding dimension
179
+ patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
180
+ pretrained_path: Path to pretrained DINOv2 weights
181
+ disable_global_rope: Disable RoPE in global attention
182
+ enable_camera/point/depth/track: Enable prediction heads
183
+ enable_normalize: Enable normalization
184
+ sliding_window_size: Sliding window size in blocks (-1 for full causal)
185
+ num_frame_for_scale: Number of scale estimation frames
186
+ num_random_frames: Number of random frames for long-range dependencies
187
+ attend_to_special_tokens: Enable cross-frame special token attention
188
+ attend_to_scale_frames: Whether to attend to scale frames
189
+ enable_stream_inference: Enable streaming inference with KV cache
190
+ enable_3d_rope: Enable 3D RoPE for temporal consistency
191
+ max_frame_num: Maximum number of frames for 3D RoPE
192
+ use_scale_token: Kept for checkpoint compatibility, ignored
193
+ kv_cache_sliding_window: Sliding window size for KV cache eviction
194
+ kv_cache_scale_frames: Number of scale frames to keep in KV cache
195
+ kv_cache_cross_frame_special: Keep special tokens from evicted frames
196
+ kv_cache_include_scale_frames: Include scale frames in KV cache
197
+ kv_cache_camera_only: Only keep camera tokens from evicted frames
198
+ """
199
+ # Store stream-specific parameters before calling super().__init__()
200
+ self.pretrained_path = pretrained_path
201
+ self.sliding_window_size = sliding_window_size
202
+ self.num_frame_for_scale = num_frame_for_scale
203
+ self.num_random_frames = num_random_frames
204
+ self.attend_to_special_tokens = attend_to_special_tokens
205
+ self.attend_to_scale_frames = attend_to_scale_frames
206
+ self.enable_stream_inference = enable_stream_inference
207
+ self.enable_3d_rope = enable_3d_rope
208
+ self.max_frame_num = max_frame_num
209
+ # Camera head 3D RoPE settings
210
+ self.enable_camera_3d_rope = enable_camera_3d_rope
211
+ self.camera_rope_theta = camera_rope_theta
212
+ # KV cache parameters
213
+ self.kv_cache_sliding_window = kv_cache_sliding_window
214
+ self.kv_cache_scale_frames = kv_cache_scale_frames
215
+ self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
216
+ self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
217
+ self.kv_cache_camera_only = kv_cache_camera_only
218
+ self.use_sdpa = use_sdpa
219
+ self.camera_num_iterations = camera_num_iterations
220
+
221
+ # Call base class __init__ (will call _build_aggregator)
222
+ super().__init__(
223
+ img_size=img_size,
224
+ patch_size=patch_size,
225
+ embed_dim=embed_dim,
226
+ patch_embed=patch_embed,
227
+ disable_global_rope=disable_global_rope,
228
+ enable_camera=enable_camera,
229
+ enable_point=enable_point,
230
+ enable_local_point=enable_local_point,
231
+ enable_depth=enable_depth,
232
+ enable_track=enable_track,
233
+ enable_normalize=enable_normalize,
234
+ pred_normalization=pred_normalization,
235
+ enable_3d_rope=enable_3d_rope,
236
+ use_gradient_checkpoint=use_gradient_checkpoint,
237
+ )
238
+
239
+ def _build_aggregator(self) -> nn.Module:
240
+ """
241
+ Build streaming aggregator with KV cache support (FlashInfer backend).
242
+
243
+ Returns:
244
+ AggregatorStream module
245
+ """
246
+ return AggregatorStream(
247
+ img_size=self.img_size,
248
+ patch_size=self.patch_size,
249
+ embed_dim=self.embed_dim,
250
+ patch_embed=self.patch_embed,
251
+ pretrained_path=self.pretrained_path,
252
+ disable_global_rope=self.disable_global_rope,
253
+ sliding_window_size=self.sliding_window_size,
254
+ num_frame_for_scale=self.num_frame_for_scale,
255
+ num_random_frames=self.num_random_frames,
256
+ attend_to_special_tokens=self.attend_to_special_tokens,
257
+ attend_to_scale_frames=self.attend_to_scale_frames,
258
+ enable_stream_inference=self.enable_stream_inference,
259
+ enable_3d_rope=self.enable_3d_rope,
260
+ max_frame_num=self.max_frame_num,
261
+ # Backend: FlashInfer (default) or SDPA (fallback)
262
+ use_flashinfer=not self.use_sdpa,
263
+ use_sdpa=self.use_sdpa,
264
+ kv_cache_sliding_window=self.kv_cache_sliding_window,
265
+ kv_cache_scale_frames=self.kv_cache_scale_frames,
266
+ kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
267
+ kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
268
+ kv_cache_camera_only=self.kv_cache_camera_only,
269
+ use_gradient_checkpoint=self.use_gradient_checkpoint,
270
+ )
271
+
272
+ def _build_camera_head(self) -> nn.Module:
273
+ """
274
+ Build causal camera head for streaming inference.
275
+
276
+ Returns:
277
+ CameraCausalHead module or None
278
+ """
279
+ return CameraCausalHead(
280
+ dim_in=2 * self.embed_dim,
281
+ sliding_window_size=self.sliding_window_size,
282
+ attend_to_scale_frames=self.attend_to_scale_frames,
283
+ num_iterations=self.camera_num_iterations,
284
+ # KV cache parameters
285
+ kv_cache_sliding_window=self.kv_cache_sliding_window,
286
+ kv_cache_scale_frames=self.kv_cache_scale_frames,
287
+ kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
288
+ kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
289
+ kv_cache_camera_only=self.kv_cache_camera_only,
290
+ # Camera head 3D RoPE parameters
291
+ enable_3d_rope=self.enable_camera_3d_rope,
292
+ max_frame_num=self.max_frame_num,
293
+ rope_theta=self.camera_rope_theta,
294
+ )
295
+
296
+ def _aggregate_features(
297
+ self,
298
+ images: torch.Tensor,
299
+ num_frame_for_scale: Optional[int] = None,
300
+ sliding_window_size: Optional[int] = None,
301
+ num_frame_per_block: int = 1,
302
+ **kwargs,
303
+ ) -> tuple:
304
+ """
305
+ Run aggregator to get multi-scale features.
306
+
307
+ Args:
308
+ images: Input images [B, S, 3, H, W]
309
+ num_frame_for_scale: Number of frames for scale estimation
310
+ sliding_window_size: Override sliding window size
311
+ num_frame_per_block: Number of frames per block
312
+
313
+ Returns:
314
+ (aggregated_tokens_list, patch_start_idx)
315
+ """
316
+ aggregated_tokens_list, patch_start_idx = self.aggregator(
317
+ images,
318
+ selected_idx=[4, 11, 17, 23],
319
+ num_frame_for_scale=num_frame_for_scale,
320
+ sliding_window_size=sliding_window_size,
321
+ num_frame_per_block=num_frame_per_block,
322
+ )
323
+ return aggregated_tokens_list, patch_start_idx
324
+
325
+ def clean_kv_cache(self):
326
+ """
327
+ Clean KV cache in aggregator.
328
+
329
+ Call this method when starting a new video sequence to clear
330
+ cached key-value pairs from previous sequences.
331
+ """
332
+ if hasattr(self.aggregator, 'clean_kv_cache'):
333
+ self.aggregator.clean_kv_cache()
334
+ else:
335
+ logger.warning("Aggregator does not support KV cache cleaning")
336
+ if hasattr(self.camera_head, 'kv_cache'):
337
+ self.camera_head.clean_kv_cache()
338
+ else:
339
+ logger.warning("Camera head does not support KV cache cleaning")
340
+
341
+ def _set_skip_append(self, skip: bool):
342
+ """Set _skip_append flag on all KV caches (aggregator + camera head).
343
+
344
+ When skip=True, attention layers will attend to [cached_kv + current_kv]
345
+ but will NOT store the current frame's KV in cache. This is used for
346
+ non-keyframe processing in keyframe-based streaming inference.
347
+
348
+ Args:
349
+ skip: If True, subsequent forward passes will not append KV to cache.
350
+ """
351
+ if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
352
+ self.aggregator.kv_cache["_skip_append"] = skip
353
+ # FlashInfer manager
354
+ if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
355
+ self.aggregator.kv_cache_manager._skip_append = skip
356
+ if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
357
+ for cache_dict in self.camera_head.kv_cache:
358
+ cache_dict["_skip_append"] = skip
359
+
360
+ # ── Flow-based keyframe helpers ────────────────────────────────────────
361
+
362
+ def _set_defer_eviction(self, defer: bool):
363
+ """Set defer-eviction flag on FlashInfer manager and SDPA caches.
364
+
365
+ While True, eviction is suppressed so that rollback can cleanly undo
366
+ the most recent append without having to restore evicted frames.
367
+ """
368
+ # FlashInfer manager
369
+ if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
370
+ self.aggregator.kv_cache_manager._defer_eviction = defer
371
+ # SDPA aggregator cache (dict)
372
+ if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict):
373
+ self.aggregator.kv_cache["_defer_eviction"] = defer
374
+ # Camera head SDPA caches
375
+ if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
376
+ for cache_dict in self.camera_head.kv_cache:
377
+ cache_dict["_defer_eviction"] = defer
378
+
379
+ def _rollback_last_frame(self):
380
+ """Rollback the most recent frame from all caches.
381
+
382
+ Undoes append_frame on FlashInfer manager (all blocks), trims the
383
+ camera head SDPA cache, and decrements the aggregator frame counter.
384
+ Must be called while eviction is still deferred.
385
+ """
386
+ # FlashInfer manager — rollback each transformer block
387
+ if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
388
+ mgr = self.aggregator.kv_cache_manager
389
+ for block_idx in range(mgr.num_blocks):
390
+ mgr.rollback_last_frame(block_idx)
391
+
392
+ # SDPA aggregator cache — trim last frame along dim=2
393
+ if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict):
394
+ kv = self.aggregator.kv_cache
395
+ for key in list(kv.keys()):
396
+ if key.startswith(("k_", "v_")) and kv[key] is not None and torch.is_tensor(kv[key]):
397
+ if kv[key].dim() >= 3 and kv[key].shape[2] > 1:
398
+ kv[key] = kv[key][:, :, :-1]
399
+ elif kv[key].dim() >= 3:
400
+ kv[key] = None
401
+
402
+ # Camera head
403
+ if self.camera_head is not None and hasattr(self.camera_head, 'rollback_last_frame'):
404
+ self.camera_head.rollback_last_frame()
405
+
406
+ # Aggregator frame counter (used for 3D RoPE temporal positions)
407
+ self.aggregator.total_frames_processed -= 1
408
+
409
+ def _execute_deferred_eviction(self):
410
+ """Execute the eviction that was deferred during the last forward pass."""
411
+ # FlashInfer manager
412
+ if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None:
413
+ mgr = self.aggregator.kv_cache_manager
414
+ for block_idx in range(mgr.num_blocks):
415
+ mgr.execute_deferred_eviction(
416
+ block_idx,
417
+ scale_frames=self.kv_cache_scale_frames,
418
+ sliding_window=self.kv_cache_sliding_window,
419
+ )
420
+
421
+ def get_kv_cache_info(self) -> Dict[str, Any]:
422
+ """
423
+ Get information about current KV cache state.
424
+
425
+ Returns:
426
+ Dictionary with cache statistics:
427
+ - num_cached_blocks: Number of blocks with cached KV
428
+ - cache_memory_mb: Approximate memory usage in MB
429
+ """
430
+ if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
431
+ return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
432
+
433
+ kv_cache = self.aggregator.kv_cache
434
+ num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
435
+
436
+ # Estimate memory usage
437
+ total_elements = 0
438
+ for _, v in kv_cache.items():
439
+ if v is not None and torch.is_tensor(v):
440
+ total_elements += v.numel()
441
+
442
+ # Assume bfloat16 (2 bytes per element)
443
+ cache_memory_mb = (total_elements * 2) / (1024 * 1024)
444
+
445
+ return {
446
+ "num_cached_blocks": num_cached,
447
+ "cache_memory_mb": round(cache_memory_mb, 2)
448
+ }
449
+
450
+ @torch.no_grad()
451
+ def inference_streaming(
452
+ self,
453
+ images: torch.Tensor,
454
+ num_scale_frames: Optional[int] = None,
455
+ keyframe_interval: int = 1,
456
+ output_device: Optional[torch.device] = None,
457
+ flow_threshold: float = 0.0,
458
+ max_non_keyframe_gap: int = 30,
459
+ ) -> Dict[str, torch.Tensor]:
460
+ """
461
+ Streaming inference: process scale frames first, then frame-by-frame.
462
+
463
+ This method enables efficient online inference by:
464
+ 1. Processing initial scale frames together (bidirectional attention via scale token)
465
+ 2. Processing remaining frames one-by-one with KV cache (causal streaming)
466
+
467
+ Keyframe mode (keyframe_interval > 1):
468
+ - Every keyframe_interval-th frame (after scale frames) is a keyframe
469
+ - Keyframes: KV is stored in cache (normal behavior)
470
+ - Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
471
+ - All frames produce full predictions regardless of keyframe status
472
+ - Reduces KV cache memory growth by ~1/keyframe_interval
473
+
474
+ Flow-based keyframe mode (flow_threshold > 0):
475
+ - Takes precedence over keyframe_interval
476
+ - Computes optical flow magnitude between current frame and last keyframe
477
+ - Frame becomes keyframe if flow exceeds threshold or gap exceeds max_non_keyframe_gap
478
+ - Uses defer-eviction + rollback for non-keyframes
479
+
480
+ Args:
481
+ images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
482
+ num_scale_frames: Number of initial frames for scale estimation.
483
+ If None, uses self.num_frame_for_scale.
484
+ keyframe_interval: Every N-th frame (after scale frames) is a keyframe
485
+ whose KV persists in cache. 1 = every frame is a
486
+ keyframe (default, same as original behavior).
487
+ output_device: Device to store output predictions on. If None, keeps on
488
+ the same device as the model. Set to torch.device('cpu')
489
+ to offload predictions per-frame and avoid GPU OOM on
490
+ long sequences.
491
+ flow_threshold: Mean flow magnitude threshold (pixels) for flow-based
492
+ keyframe selection. >0 enables flow-based mode (takes precedence
493
+ over keyframe_interval).
494
+ max_non_keyframe_gap: Max consecutive non-keyframe frames before
495
+ forcing a keyframe (flow mode only).
496
+
497
+ Returns:
498
+ Dictionary containing predictions for all frames:
499
+ - pose_enc: [B, S, 9]
500
+ - depth: [B, S, H, W, 1]
501
+ - depth_conf: [B, S, H, W]
502
+ - world_points: [B, S, H, W, 3]
503
+ - world_points_conf: [B, S, H, W]
504
+ """
505
+ # Normalize input shape
506
+ if len(images.shape) == 4:
507
+ images = images.unsqueeze(0)
508
+ B, S, C, H, W = images.shape
509
+
510
+ # Determine number of scale frames
511
+ scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
512
+ scale_frames = min(scale_frames, S) # Cap to available frames
513
+
514
+ # Helper to move tensor to output device
515
+ def _to_out(t: torch.Tensor) -> torch.Tensor:
516
+ if output_device is not None:
517
+ return t.to(output_device)
518
+ return t
519
+
520
+ # Clean KV caches before starting new sequence
521
+ self.clean_kv_cache()
522
+
523
+ # Phase 1: Process scale frames together
524
+ # These frames get bidirectional attention among themselves via scale token
525
+ logger.info(f'Processing {scale_frames} scale frames...')
526
+ scale_images = images[:, :scale_frames]
527
+ scale_output = self.forward(
528
+ scale_images,
529
+ num_frame_for_scale=scale_frames,
530
+ num_frame_per_block=scale_frames, # Process all scale frames as one block
531
+ causal_inference=True,
532
+ )
533
+
534
+ # Initialize output lists with scale frame predictions (offload if needed)
535
+ all_pose_enc = [_to_out(scale_output["pose_enc"])]
536
+ all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
537
+ all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
538
+ all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
539
+ all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
540
+ del scale_output
541
+
542
+ # Phase 2: Process remaining frames one-by-one
543
+ use_flow_keyframe = flow_threshold > 0.0
544
+
545
+ # Flow state: last keyframe = last scale frame
546
+ if use_flow_keyframe:
547
+ last_kf_pose_enc = all_pose_enc[0][:, -1:] # last scale frame
548
+ last_kf_idx = scale_frames - 1
549
+
550
+ pbar = tqdm(
551
+ range(scale_frames, S),
552
+ desc='Streaming inference',
553
+ initial=scale_frames,
554
+ total=S,
555
+ )
556
+ for i in pbar:
557
+ frame_image = images[:, i:i+1]
558
+
559
+ if use_flow_keyframe:
560
+ # Flow-based: defer eviction, forward, then decide
561
+ self._set_defer_eviction(True)
562
+
563
+ frame_output = self.forward(
564
+ frame_image,
565
+ num_frame_for_scale=scale_frames,
566
+ num_frame_per_block=1,
567
+ causal_inference=True,
568
+ )
569
+
570
+ self._set_defer_eviction(False)
571
+
572
+ # Compute flow to decide keyframe
573
+ cur_depth = frame_output.get("depth", None)
574
+ if cur_depth is not None:
575
+ H_pred, W_pred = cur_depth.shape[2], cur_depth.shape[3]
576
+ flow_mag = _compute_flow_magnitude(
577
+ frame_output["pose_enc"], last_kf_pose_enc,
578
+ cur_depth, (H_pred, W_pred),
579
+ )
580
+ else:
581
+ flow_mag = flow_threshold + 1.0
582
+
583
+ frames_since_kf = i - last_kf_idx
584
+ is_keyframe = (
585
+ (i == scale_frames) # first streaming frame
586
+ or (flow_mag > flow_threshold)
587
+ or (frames_since_kf >= max_non_keyframe_gap)
588
+ )
589
+
590
+ if is_keyframe:
591
+ self._execute_deferred_eviction()
592
+ last_kf_pose_enc = frame_output["pose_enc"]
593
+ last_kf_idx = i
594
+ else:
595
+ self._rollback_last_frame()
596
+ else:
597
+ # Fixed-interval keyframe mode
598
+ is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
599
+
600
+ if not is_keyframe:
601
+ self._set_skip_append(True)
602
+
603
+ frame_output = self.forward(
604
+ frame_image,
605
+ num_frame_for_scale=scale_frames,
606
+ num_frame_per_block=1,
607
+ causal_inference=True,
608
+ )
609
+
610
+ if not is_keyframe:
611
+ self._set_skip_append(False)
612
+
613
+ all_pose_enc.append(_to_out(frame_output["pose_enc"]))
614
+ if "depth" in frame_output:
615
+ all_depth.append(_to_out(frame_output["depth"]))
616
+ if "depth_conf" in frame_output:
617
+ all_depth_conf.append(_to_out(frame_output["depth_conf"]))
618
+ if "world_points" in frame_output:
619
+ all_world_points.append(_to_out(frame_output["world_points"]))
620
+ if "world_points_conf" in frame_output:
621
+ all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
622
+ del frame_output
623
+
624
+ # Free GPU memory before concatenation
625
+ if output_device is not None:
626
+ # Move images to output device, then free GPU copy
627
+ images_out = _to_out(images)
628
+ del images
629
+ # Clean KV cache (no longer needed after inference)
630
+ self.clean_kv_cache()
631
+ if torch.cuda.is_available():
632
+ torch.cuda.empty_cache()
633
+ else:
634
+ images_out = images
635
+
636
+ # Concatenate all predictions along sequence dimension
637
+ predictions = {
638
+ "pose_enc": torch.cat(all_pose_enc, dim=1),
639
+ }
640
+ del all_pose_enc
641
+ if all_depth:
642
+ predictions["depth"] = torch.cat(all_depth, dim=1)
643
+ del all_depth
644
+ if all_depth_conf:
645
+ predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
646
+ del all_depth_conf
647
+ if all_world_points:
648
+ predictions["world_points"] = torch.cat(all_world_points, dim=1)
649
+ del all_world_points
650
+ if all_world_points_conf:
651
+ predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
652
+ del all_world_points_conf
653
+
654
+ # Store images for visualization
655
+ predictions["images"] = images_out
656
+
657
+ # Apply prediction normalization if enabled
658
+ if self.pred_normalization:
659
+ predictions = self._normalize_predictions(predictions)
660
+
661
+ return predictions
662
+
663
+ # ══════════════════════════════════════════════════════════════════════
664
+ # Window stitching & cross-window alignment
665
+ # ══════════════════════════════════════════════════════════════════════
666
+
667
+ _FRAME_AXIS_KEYS = frozenset({
668
+ "pose_enc", "depth", "depth_conf",
669
+ "world_points", "world_points_conf",
670
+ "frame_type", "is_keyframe",
671
+ })
672
+
673
+ def _stitch_windows(
674
+ self,
675
+ windows: List[Dict],
676
+ window_size: int,
677
+ overlap: int,
678
+ ) -> Dict:
679
+ """Concatenate per-window predictions while de-duplicating overlaps.
680
+
681
+ For each temporal key the method builds a slice table first — every
682
+ window contributes ``[0, effective_end)`` frames where
683
+ ``effective_end = total_frames - overlap`` for non-final windows.
684
+ Non-temporal entries simply keep the latest available value.
685
+ """
686
+ if len(windows) == 0:
687
+ return {}
688
+ if len(windows) == 1:
689
+ return windows[0]
690
+
691
+ n_win = len(windows)
692
+ all_keys = list(windows[0].keys())
693
+ stitched: Dict = {}
694
+
695
+ for key in all_keys:
696
+ values = [w.get(key) for w in windows]
697
+ if all(v is None for v in values):
698
+ continue
699
+
700
+ # Non-temporal entries: take latest
701
+ if key not in self._FRAME_AXIS_KEYS:
702
+ stitched[key] = next(v for v in reversed(values) if v is not None)
703
+ continue
704
+
705
+ # Build slice table: (start, end) for each window's contribution
706
+ slices = []
707
+ for wi, tensor in enumerate(values):
708
+ if tensor is None:
709
+ slices.append(None)
710
+ continue
711
+ total = tensor.shape[1]
712
+ is_last = (wi == n_win - 1)
713
+ end = total if is_last else max(total - overlap, 0)
714
+ slices.append((0, end) if end > 0 else None)
715
+
716
+ parts = [
717
+ values[i][:, s:e]
718
+ for i, s_e in enumerate(slices)
719
+ if s_e is not None
720
+ for s, e in [s_e]
721
+ ]
722
+ if parts:
723
+ stitched[key] = torch.cat(parts, dim=1)
724
+ else:
725
+ fallback = next((v for v in reversed(values) if v is not None), None)
726
+ if fallback is not None:
727
+ stitched[key] = fallback
728
+
729
+ return stitched
730
+
731
+ @staticmethod
732
+ def _depth_ratio_scale(
733
+ anchor_depth: torch.Tensor,
734
+ target_depth: torch.Tensor,
735
+ batch_size: int,
736
+ device: torch.device,
737
+ ) -> torch.Tensor:
738
+ """Estimate per-batch scale as the median depth ratio anchor/target."""
739
+ a = anchor_depth.to(torch.float32).reshape(batch_size, -1)
740
+ t = target_depth.to(torch.float32).reshape(batch_size, -1)
741
+ ok = torch.isfinite(a) & torch.isfinite(t) & (t.abs() > torch.finfo(torch.float32).eps)
742
+
743
+ scales = []
744
+ for b in range(batch_size):
745
+ m = ok[b]
746
+ if m.any():
747
+ scales.append((a[b, m] / t[b, m]).median())
748
+ else:
749
+ scales.append(torch.tensor(1.0, device=device, dtype=torch.float32))
750
+ return torch.stack(scales).clamp(min=1e-3, max=1e3)
751
+
752
+ @staticmethod
753
+ def _pairwise_alignment(
754
+ prev_pred: Dict,
755
+ curr_pred: Dict,
756
+ overlap: int,
757
+ batch_size: int,
758
+ device: torch.device,
759
+ dtype: torch.dtype,
760
+ ):
761
+ """Compute (scale, R, t) that maps *curr* into *prev*'s coordinate frame.
762
+
763
+ Uses the first overlap frame of *curr* and the corresponding trailing
764
+ frame of *prev* to establish the similarity transform.
765
+ """
766
+ unit_s = torch.ones(batch_size, device=device, dtype=dtype)
767
+ eye_R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1).clone()
768
+ zero_t = torch.zeros(batch_size, 3, device=device, dtype=dtype)
769
+
770
+ if overlap <= 0:
771
+ return unit_s, eye_R, zero_t
772
+
773
+ pe_prev = prev_pred.get("pose_enc")
774
+ pe_curr = curr_pred.get("pose_enc")
775
+ if pe_prev is None or pe_curr is None:
776
+ return unit_s, eye_R, zero_t
777
+
778
+ idx_a = max(pe_prev.shape[1] - overlap, 0)
779
+
780
+ # Decompose C2W: center ([:3]) + quaternion ([3:7])
781
+ Ra = quat_to_mat(pe_prev[:, idx_a, 3:7]) # (B, 3, 3)
782
+ ca = pe_prev[:, idx_a, :3] # (B, 3)
783
+ Rb = quat_to_mat(pe_curr[:, 0, 3:7])
784
+ cb = pe_curr[:, 0, :3]
785
+
786
+ R_ab = torch.bmm(Ra, Rb.transpose(1, 2)) # Ra = R_ab @ Rb
787
+
788
+ # Scale from depth
789
+ s_ab = unit_s.clone()
790
+ da = prev_pred.get("depth")
791
+ db = curr_pred.get("depth")
792
+ if (da is not None and db is not None
793
+ and da.shape[1] > idx_a and db.shape[1] > 0):
794
+ s_ab = GCTStream._depth_ratio_scale(
795
+ da[:, idx_a, ..., 0], db[:, 0, ..., 0],
796
+ batch_size, device,
797
+ ).to(dtype)
798
+
799
+ # ca = s_ab * R_ab @ cb + t_ab => t_ab = ca - s_ab * R_ab @ cb
800
+ t_ab = ca - s_ab.unsqueeze(-1) * torch.bmm(R_ab, cb.unsqueeze(-1)).squeeze(-1)
801
+
802
+ return s_ab, R_ab.to(dtype), t_ab.to(dtype)
803
+
804
+ @staticmethod
805
+ def _warp_predictions(
806
+ pred: Dict,
807
+ R: torch.Tensor,
808
+ t: torch.Tensor,
809
+ s: torch.Tensor,
810
+ batch_size: int,
811
+ ) -> Dict:
812
+ """Apply a similarity transform (s, R, t) to one window's predictions."""
813
+ warped: Dict = {}
814
+
815
+ # Pose encoding: center + quaternion + intrinsics
816
+ pe = pred.get("pose_enc")
817
+ if pe is not None:
818
+ nf = pe.shape[1]
819
+ local_rot = quat_to_mat(pe[:, :, 3:7])
820
+ local_ctr = pe[:, :, :3]
821
+
822
+ R_exp = R[:, None].expand(-1, nf, -1, -1)
823
+ new_rot = torch.matmul(R_exp, local_rot)
824
+ new_ctr = (
825
+ s.view(batch_size, 1, 1) * torch.matmul(R_exp, local_ctr.unsqueeze(-1)).squeeze(-1)
826
+ + t.view(batch_size, 1, 3)
827
+ )
828
+ out_pe = pe.clone()
829
+ out_pe[:, :, :3] = new_ctr
830
+ out_pe[:, :, 3:7] = mat_to_quat(new_rot)
831
+ warped["pose_enc"] = out_pe
832
+ else:
833
+ warped["pose_enc"] = None
834
+
835
+ # Depth: scale by s
836
+ d = pred.get("depth")
837
+ if d is not None:
838
+ warped["depth"] = d * s.view(batch_size, 1, 1, 1, 1)
839
+ else:
840
+ warped["depth"] = None
841
+
842
+ # World points: p_global = s * R @ p_local + t
843
+ wp = pred.get("world_points")
844
+ if wp is not None:
845
+ b, nf, h, w, _ = wp.shape
846
+ flat = wp.reshape(b, nf * h * w, 3)
847
+ transformed = torch.bmm(flat, R.transpose(1, 2)) * s.view(b, 1, 1)
848
+ transformed = transformed + t[:, None, :]
849
+ warped["world_points"] = transformed.reshape(b, nf, h, w, 3)
850
+ else:
851
+ warped["world_points"] = None
852
+
853
+ # Pass through all other keys untouched
854
+ for k, v in pred.items():
855
+ if k not in warped:
856
+ warped[k] = v
857
+
858
+ return warped
859
+
860
+ def _align_and_stitch_windows(
861
+ self,
862
+ windows: List[Dict],
863
+ scale_mode: str = 'median',
864
+ ) -> Dict:
865
+ """Bring all windows into the first window's coordinate frame, then stitch.
866
+
867
+ Iterates over consecutive window pairs, estimates the pairwise
868
+ scaled alignment, warps each window, and finally concatenates
869
+ via :meth:`_stitch_windows`.
870
+ """
871
+ if len(windows) == 0:
872
+ return {}
873
+ if len(windows) == 1:
874
+ out = windows[0].copy()
875
+ out["alignment_mode"] = "scaled"
876
+ return out
877
+
878
+ # Discover batch / device / dtype from any available tensor
879
+ ref = next(
880
+ v
881
+ for w in windows
882
+ for k in ("pose_enc", "world_points", "depth")
883
+ if (v := w.get(k)) is not None
884
+ )
885
+ dev, dt, nb = ref.device, ref.dtype, ref.shape[0]
886
+
887
+ overlap = getattr(self, "_last_overlap_size", 0)
888
+ win_sz = getattr(self, "_last_window_size", -1)
889
+
890
+ warped_windows: List[Dict] = []
891
+ per_window_scales: List[torch.Tensor] = []
892
+ per_window_transforms: List[torch.Tensor] = []
893
+
894
+ for idx, raw in enumerate(windows):
895
+ if idx == 0:
896
+ s_rel = torch.ones(nb, device=dev, dtype=dt)
897
+ R_rel = torch.eye(3, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone()
898
+ t_rel = torch.zeros(nb, 3, device=dev, dtype=dt)
899
+ else:
900
+ s_rel, R_rel, t_rel = self._pairwise_alignment(
901
+ warped_windows[-1], raw, overlap, nb, dev, dt,
902
+ )
903
+
904
+ per_window_scales.append(s_rel.clone())
905
+ T = torch.eye(4, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone()
906
+ T[:, :3, :3] = R_rel
907
+ T[:, :3, 3] = t_rel
908
+ per_window_transforms.append(T)
909
+
910
+ warped_windows.append(
911
+ self._warp_predictions(raw, R_rel, t_rel, s_rel, nb)
912
+ )
913
+
914
+ merged = self._stitch_windows(warped_windows, win_sz, overlap)
915
+
916
+ # Attach alignment metadata
917
+ if per_window_scales:
918
+ merged["chunk_scales"] = torch.stack(per_window_scales, dim=1)
919
+ if per_window_transforms:
920
+ merged["chunk_transforms"] = torch.stack(per_window_transforms, dim=1)
921
+ merged["alignment_mode"] = "scaled"
922
+ return merged
923
+
924
+ @torch.no_grad()
925
+ def inference_windowed(
926
+ self,
927
+ images: torch.Tensor,
928
+ window_size: int = 16,
929
+ overlap_size: Optional[int] = None,
930
+ num_scale_frames: Optional[int] = None,
931
+ scale_mode: str = 'median',
932
+ output_device: Optional[torch.device] = None,
933
+ keyframe_interval: int = 1,
934
+ flow_threshold: float = 0.0,
935
+ max_non_keyframe_gap: int = 30,
936
+ ) -> Dict[str, torch.Tensor]:
937
+ """
938
+ Windowed inference with keyframe detection and cross-window alignment.
939
+
940
+ Each window is processed independently with a fresh KV cache.
941
+ Overlap frames between windows are the next window's scale frames
942
+ (bidirectional attention), ensuring the highest quality predictions
943
+ at alignment boundaries.
944
+
945
+ ``window_size`` counts **keyframes** (frames stored in KV cache),
946
+ including scale frames. When ``keyframe_interval > 1``, each window
947
+ covers more actual frames than ``window_size``:
948
+
949
+ actual_frames = scale_frames + (window_size - scale_frames) * keyframe_interval
950
+
951
+ Args:
952
+ images: Input images [S, 3, H, W] or [B, S, 3, H, W] in [0, 1].
953
+ window_size: Number of **keyframes** per window (including scale
954
+ frames). Directly controls KV cache memory.
955
+ overlap_size: Number of overlapping frames between windows.
956
+ Defaults to ``num_scale_frames`` (overlap = scale frames).
957
+ num_scale_frames: Number of frames used as scale reference within
958
+ each window. Defaults to ``self.num_frame_for_scale``.
959
+ scale_mode: Scale estimation strategy for alignment.
960
+ output_device: Device to store per-window outputs.
961
+ keyframe_interval: Every N-th Phase 2 frame is a keyframe whose
962
+ KV persists in cache. 1 = every frame (default).
963
+ flow_threshold: Mean flow magnitude threshold (pixels) for
964
+ flow-based keyframe selection. >0 enables flow-based mode
965
+ (takes precedence over ``keyframe_interval``).
966
+ max_non_keyframe_gap: Max consecutive non-keyframe frames before
967
+ forcing a keyframe (flow mode only).
968
+
969
+ Returns:
970
+ Merged prediction dict with all frames.
971
+ """
972
+ use_flow_keyframe = flow_threshold > 0.0
973
+
974
+ # Normalize input shape
975
+ if len(images.shape) == 4:
976
+ images = images.unsqueeze(0)
977
+ B, S, C, H, W = images.shape
978
+
979
+ ws = (num_scale_frames if num_scale_frames is not None
980
+ else self.num_frame_for_scale)
981
+ ws = min(ws, S)
982
+
983
+ # overlap = scale_frames by default
984
+ eff_overlap = min(overlap_size if overlap_size is not None else ws,
985
+ S - 1) if S > 1 else 0
986
+
987
+ def _to_out(t: torch.Tensor) -> torch.Tensor:
988
+ return t.to(output_device) if output_device is not None else t
989
+
990
+ def _collect_frame(out, w_lists):
991
+ w_lists['pose_enc'].append(_to_out(out["pose_enc"]))
992
+ if "depth" in out:
993
+ w_lists['depth'].append(_to_out(out["depth"]))
994
+ if "depth_conf" in out:
995
+ w_lists['depth_conf'].append(_to_out(out["depth_conf"]))
996
+ if "world_points" in out:
997
+ w_lists['world_points'].append(_to_out(out["world_points"]))
998
+ if "world_points_conf" in out:
999
+ w_lists['world_pts_conf'].append(_to_out(out["world_points_conf"]))
1000
+
1001
+ def _make_window_pred(w_lists):
1002
+ pred: Dict = {"pose_enc": torch.cat(w_lists['pose_enc'], dim=1)}
1003
+ if w_lists['depth']:
1004
+ pred["depth"] = torch.cat(w_lists['depth'], dim=1)
1005
+ if w_lists['depth_conf']:
1006
+ pred["depth_conf"] = torch.cat(w_lists['depth_conf'], dim=1)
1007
+ if w_lists['world_points']:
1008
+ pred["world_points"] = torch.cat(w_lists['world_points'], dim=1)
1009
+ if w_lists['world_pts_conf']:
1010
+ pred["world_points_conf"] = torch.cat(w_lists['world_pts_conf'], dim=1)
1011
+ # Frame type: 0=scale, 1=keyframe, 2=non-keyframe
1012
+ ft = torch.tensor(w_lists['frame_type'], dtype=torch.uint8).unsqueeze(0) # [1, T]
1013
+ pred["frame_type"] = ft
1014
+ pred["is_keyframe"] = (ft != 2) # scale + keyframe = True
1015
+ return pred
1016
+
1017
+ def _new_lists():
1018
+ return {
1019
+ 'pose_enc': [], 'depth': [], 'depth_conf': [],
1020
+ 'world_points': [], 'world_pts_conf': [],
1021
+ 'frame_type': [], # list of ints: 0=scale, 1=keyframe, 2=non-keyframe
1022
+ }
1023
+
1024
+ # ================================================================
1025
+ # Flow-based mode: dynamic windows (can't precompute window list)
1026
+ # ================================================================
1027
+ if use_flow_keyframe:
1028
+ all_window_predictions: List[Dict] = []
1029
+ cursor = 0
1030
+ window_idx = 0
1031
+ pbar = tqdm(total=S, desc='Windowed inference (flow)', initial=0)
1032
+
1033
+ while cursor < S:
1034
+ window_start = cursor
1035
+ window_scale = min(ws, S - cursor)
1036
+
1037
+ # Fresh KV cache
1038
+ self.clean_kv_cache()
1039
+
1040
+ # ---------- Phase 1: scale frames ----------
1041
+ scale_images = images[:, cursor:cursor + window_scale]
1042
+ scale_out = self.forward(
1043
+ scale_images,
1044
+ num_frame_for_scale=window_scale,
1045
+ num_frame_per_block=window_scale,
1046
+ causal_inference=True,
1047
+ )
1048
+ w_lists = _new_lists()
1049
+ _collect_frame(scale_out, w_lists)
1050
+ w_lists['frame_type'].extend([0] * window_scale) # scale frames
1051
+
1052
+ # Flow state: last keyframe = last scale frame
1053
+ last_kf_pose_enc = scale_out["pose_enc"][:, -1:]
1054
+ last_kf_local_idx = window_scale - 1
1055
+ del scale_out
1056
+
1057
+ cursor += window_scale
1058
+ pbar.update(window_scale)
1059
+
1060
+ # ---------- Phase 2: stream until enough keyframes ----------
1061
+ target_kf = window_size - window_scale # keyframes to collect
1062
+ kf_count = 0
1063
+
1064
+ while cursor < S and kf_count < target_kf:
1065
+ frame_image = images[:, cursor:cursor + 1]
1066
+
1067
+ self._set_defer_eviction(True)
1068
+ frame_out = self.forward(
1069
+ frame_image,
1070
+ num_frame_for_scale=window_scale,
1071
+ num_frame_per_block=1,
1072
+ causal_inference=True,
1073
+ )
1074
+ self._set_defer_eviction(False)
1075
+
1076
+ # Compute flow
1077
+ cur_depth = frame_out.get("depth", None)
1078
+ if cur_depth is not None:
1079
+ H_pred, W_pred = cur_depth.shape[2], cur_depth.shape[3]
1080
+ flow_mag = _compute_flow_magnitude(
1081
+ frame_out["pose_enc"], last_kf_pose_enc,
1082
+ cur_depth, (H_pred, W_pred),
1083
+ )
1084
+ else:
1085
+ flow_mag = flow_threshold + 1.0
1086
+
1087
+ local_idx = window_scale + (cursor - window_start - window_scale)
1088
+ frames_since_kf = local_idx - last_kf_local_idx
1089
+ is_keyframe = (
1090
+ (kf_count == 0) # first streaming frame
1091
+ or (flow_mag > flow_threshold)
1092
+ or (frames_since_kf >= max_non_keyframe_gap)
1093
+ )
1094
+
1095
+ if is_keyframe:
1096
+ self._execute_deferred_eviction()
1097
+ last_kf_pose_enc = frame_out["pose_enc"]
1098
+ last_kf_local_idx = local_idx
1099
+ kf_count += 1
1100
+ w_lists['frame_type'].append(1) # keyframe
1101
+ else:
1102
+ self._rollback_last_frame()
1103
+ w_lists['frame_type'].append(2) # non-keyframe
1104
+
1105
+ _collect_frame(frame_out, w_lists)
1106
+ del frame_out
1107
+ cursor += 1
1108
+ pbar.update(1)
1109
+
1110
+ all_window_predictions.append(_make_window_pred(w_lists))
1111
+ window_idx += 1
1112
+
1113
+ # Next window starts overlap_size frames back (= scale frames)
1114
+ if cursor < S:
1115
+ cursor = max(cursor - eff_overlap, window_start + window_scale)
1116
+
1117
+ pbar.close()
1118
+
1119
+ # ================================================================
1120
+ # Fixed-interval / default mode: precomputable windows
1121
+ # ================================================================
1122
+ else:
1123
+ # Compute actual frames per window
1124
+ phase2_kf = max(window_size - ws, 0)
1125
+ kf_int = max(keyframe_interval, 1)
1126
+ phase2_frames = phase2_kf * kf_int
1127
+ actual_window_frames = ws + phase2_frames
1128
+
1129
+ eff_window = min(actual_window_frames, S)
1130
+ step = max(eff_window - eff_overlap, 1)
1131
+
1132
+ # Build window list
1133
+ if eff_window >= S:
1134
+ windows = [(0, S)]
1135
+ else:
1136
+ windows = []
1137
+ for start_idx in range(0, S, step):
1138
+ end_idx = min(start_idx + eff_window, S)
1139
+ if end_idx - start_idx >= eff_overlap or end_idx == S:
1140
+ windows.append((start_idx, end_idx))
1141
+ if end_idx == S:
1142
+ break
1143
+
1144
+ all_window_predictions: List[Dict] = []
1145
+ for start, end in tqdm(windows, desc='Windowed inference'):
1146
+ window_images = images[:, start:end]
1147
+ window_len = end - start
1148
+
1149
+ # Fresh KV cache
1150
+ self.clean_kv_cache()
1151
+
1152
+ window_scale = min(ws, window_len)
1153
+
1154
+ # ---------- Phase 1: scale frames ----------
1155
+ scale_out = self.forward(
1156
+ window_images[:, :window_scale],
1157
+ num_frame_for_scale=window_scale,
1158
+ num_frame_per_block=window_scale,
1159
+ causal_inference=True,
1160
+ )
1161
+ w_lists = _new_lists()
1162
+ _collect_frame(scale_out, w_lists)
1163
+ w_lists['frame_type'].extend([0] * window_scale) # scale frames
1164
+ del scale_out
1165
+
1166
+ # ---------- Phase 2: stream remaining frames ----------
1167
+ for i in range(window_scale, window_len):
1168
+ is_keyframe = (
1169
+ kf_int <= 1
1170
+ or ((i - window_scale) % kf_int == 0)
1171
+ )
1172
+
1173
+ if not is_keyframe:
1174
+ self._set_skip_append(True)
1175
+
1176
+ frame_out = self.forward(
1177
+ window_images[:, i:i + 1],
1178
+ num_frame_for_scale=window_scale,
1179
+ num_frame_per_block=1,
1180
+ causal_inference=True,
1181
+ )
1182
+
1183
+ if not is_keyframe:
1184
+ self._set_skip_append(False)
1185
+
1186
+ _collect_frame(frame_out, w_lists)
1187
+ w_lists['frame_type'].append(1 if is_keyframe else 2)
1188
+ del frame_out
1189
+
1190
+ all_window_predictions.append(_make_window_pred(w_lists))
1191
+
1192
+ # Store for merge helpers
1193
+ self._last_window_size = eff_overlap # not used directly, but kept for compat
1194
+ self._last_overlap_size = eff_overlap
1195
+
1196
+ # Align and stitch windows
1197
+ predictions = self._align_and_stitch_windows(
1198
+ all_window_predictions, scale_mode=scale_mode
1199
+ )
1200
+
1201
+ predictions["images"] = _to_out(images)
1202
+
1203
+ if self.pred_normalization:
1204
+ predictions = self._normalize_predictions(predictions)
1205
+
1206
+ return predictions
lingbot_map/utils/__init__.py ADDED
File without changes
lingbot_map/utils/geometry.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+ from scipy.spatial.transform import Rotation as R
11
+
12
+ from scipy.spatial.transform import Rotation
13
+ try:
14
+ from lietorch import SE3, Sim3
15
+ except ImportError:
16
+ SE3 = Sim3 = None
17
+ import torch.nn.functional as F
18
+
19
+ try:
20
+ from lingbot_map.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
21
+ except ImportError:
22
+ apply_distortion = iterative_undistortion = single_undistortion = None
23
+
24
+
25
+ def unproject_depth_map_to_point_map(
26
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
27
+ ) -> np.ndarray:
28
+ """
29
+ Unproject a batch of depth maps to 3D world coordinates.
30
+
31
+ Args:
32
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
33
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
34
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
35
+
36
+ Returns:
37
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
38
+ """
39
+ if isinstance(depth_map, torch.Tensor):
40
+ depth_map = depth_map.cpu().numpy()
41
+ if isinstance(extrinsics_cam, torch.Tensor):
42
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
43
+ if isinstance(intrinsics_cam, torch.Tensor):
44
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
45
+
46
+ world_points_list = []
47
+ for frame_idx in range(depth_map.shape[0]):
48
+ cur_world_points, _, _ = depth_to_world_coords_points(
49
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
50
+ )
51
+ world_points_list.append(cur_world_points)
52
+ world_points_array = np.stack(world_points_list, axis=0)
53
+
54
+ return world_points_array
55
+
56
+
57
+ def depth_to_world_coords_points(
58
+ depth_map: np.ndarray,
59
+ extrinsic: np.ndarray,
60
+ intrinsic: np.ndarray,
61
+ eps=1e-8,
62
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
63
+ """
64
+ Convert a depth map to world coordinates.
65
+
66
+ Args:
67
+ depth_map (np.ndarray): Depth map of shape (H, W).
68
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
69
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
70
+
71
+ Returns:
72
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
73
+ """
74
+ if depth_map is None:
75
+ return None, None, None
76
+
77
+ # Valid depth mask
78
+ point_mask = depth_map > eps
79
+
80
+ # Convert depth map to camera coordinates
81
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
82
+
83
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
84
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
85
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
86
+
87
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
88
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
89
+
90
+ # Apply the rotation and translation to the camera coordinates
91
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
92
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
93
+
94
+ return world_coords_points, cam_coords_points, point_mask
95
+
96
+
97
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
98
+ """
99
+ Convert a depth map to camera coordinates.
100
+
101
+ Args:
102
+ depth_map (np.ndarray): Depth map of shape (H, W).
103
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
104
+
105
+ Returns:
106
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
107
+ """
108
+ H, W = depth_map.shape
109
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
110
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
111
+
112
+ # Intrinsic parameters
113
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
114
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
115
+
116
+ # Generate grid of pixel coordinates
117
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
118
+
119
+ # Unproject to camera coordinates
120
+ x_cam = (u - cu) * depth_map / fu
121
+ y_cam = (v - cv) * depth_map / fv
122
+ z_cam = depth_map
123
+
124
+ # Stack to form camera coordinates
125
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
126
+
127
+ return cam_coords
128
+
129
+
130
+ def closed_form_inverse_se3(se3, R=None, T=None):
131
+ """
132
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
133
+
134
+ If `R` and `T` are provided, they must correspond to the rotation and translation
135
+ components of `se3`. Otherwise, they will be extracted from `se3`.
136
+
137
+ Args:
138
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
139
+ R (optional): Nx3x3 array or tensor of rotation matrices.
140
+ T (optional): Nx3x1 array or tensor of translation vectors.
141
+
142
+ Returns:
143
+ Inverted SE3 matrices with the same type and device as `se3`.
144
+
145
+ Shapes:
146
+ se3: (N, 4, 4)
147
+ R: (N, 3, 3)
148
+ T: (N, 3, 1)
149
+ """
150
+ # Check if se3 is a numpy array or a torch tensor
151
+ is_numpy = isinstance(se3, np.ndarray)
152
+
153
+ # Validate shapes
154
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
155
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
156
+
157
+ # Extract R and T if not provided
158
+ if R is None:
159
+ R = se3[:, :3, :3] # (N,3,3)
160
+ if T is None:
161
+ T = se3[:, :3, 3:] # (N,3,1)
162
+
163
+ # Transpose R
164
+ if is_numpy:
165
+ # Compute the transpose of the rotation for NumPy
166
+ R_transposed = np.transpose(R, (0, 2, 1))
167
+ # -R^T t for NumPy
168
+ top_right = -np.matmul(R_transposed, T)
169
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
170
+ else:
171
+ R_transposed = R.transpose(1, 2) # (N,3,3)
172
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
173
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
174
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
175
+
176
+ inverted_matrix[:, :3, :3] = R_transposed
177
+ inverted_matrix[:, :3, 3:] = top_right
178
+
179
+ return inverted_matrix
180
+
181
+ def closed_form_inverse_se3_general(se3, R=None, T=None):
182
+ """
183
+ 支持任意 batch 维度的 SE3 逆运算
184
+ se3: (..., 4, 4) 或 (..., 3, 4)
185
+ """
186
+ batch_shape = se3.shape[:-2]
187
+ if R is None:
188
+ R = se3[..., :3, :3]
189
+ if T is None:
190
+ T = se3[..., :3, 3:]
191
+ R_transposed = R.transpose(-2, -1)
192
+ top_right = -R_transposed @ T
193
+ # 构造单位阵
194
+ eye = torch.eye(4, 4, dtype=R.dtype, device=R.device)
195
+ inverted_matrix = eye.expand(*batch_shape, 4, 4).clone()
196
+ inverted_matrix[..., :3, :3] = R_transposed
197
+ inverted_matrix[..., :3, 3:] = top_right
198
+ return inverted_matrix
199
+
200
+
201
+ # TODO: this code can be further cleaned up
202
+
203
+
204
+ def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
205
+ """
206
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
207
+ Args:
208
+ world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
209
+ cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
210
+ Returns:
211
+ """
212
+ # TODO: merge this into project_world_points_to_cam
213
+
214
+ # device = world_points.device
215
+ # with torch.autocast(device_type=device.type, enabled=False):
216
+ ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
217
+ world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
218
+
219
+ # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
220
+ extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
221
+
222
+ # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
223
+ world_points_h_exp = world_points_h.unsqueeze(-1)
224
+
225
+ # Now perform the matrix multiplication
226
+ # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
227
+ camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
228
+
229
+ return camera_points
230
+
231
+
232
+
233
+ def project_world_points_to_cam(
234
+ world_points,
235
+ cam_extrinsics,
236
+ cam_intrinsics=None,
237
+ distortion_params=None,
238
+ default=0,
239
+ only_points_cam=False,
240
+ ):
241
+ """
242
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
243
+ Args:
244
+ world_points (torch.Tensor): 3D points of shape Px3.
245
+ cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
246
+ cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
247
+ distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
248
+ Returns:
249
+ torch.Tensor: Transformed 2D points of shape BxNx2.
250
+ """
251
+ device = world_points.device
252
+ # with torch.autocast(device_type=device.type, dtype=torch.double):
253
+ with torch.autocast(device_type=device.type, enabled=False):
254
+ N = world_points.shape[0] # Number of points
255
+ B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
256
+ world_points_homogeneous = torch.cat(
257
+ [world_points, torch.ones_like(world_points[..., 0:1])], dim=1
258
+ ) # Nx4
259
+ # Reshape for batch processing
260
+ world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
261
+ B, -1, -1
262
+ ) # BxNx4
263
+
264
+ # Step 1: Apply extrinsic parameters
265
+ # Transform 3D points to camera coordinate system for all cameras
266
+ cam_points = torch.bmm(
267
+ cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
268
+ )
269
+
270
+ if only_points_cam:
271
+ return None, cam_points
272
+
273
+ # Step 2: Apply intrinsic parameters and (optional) distortion
274
+ image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
275
+
276
+ return image_points, cam_points
277
+
278
+
279
+
280
+ def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
281
+ """
282
+ Applies intrinsic parameters and optional distortion to the given 3D points.
283
+
284
+ Args:
285
+ cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
286
+ cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
287
+ distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
288
+ default (float, optional): Default value to replace NaNs in the output.
289
+
290
+ Returns:
291
+ pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
292
+ """
293
+
294
+ # Normalized device coordinates (NDC)
295
+ cam_points = cam_points / cam_points[:, 2:3, :]
296
+ ndc_xy = cam_points[:, :2, :]
297
+
298
+ # Apply distortion if distortion_params are provided
299
+ if distortion_params is not None:
300
+ x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
301
+ distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
302
+ else:
303
+ distorted_xy = ndc_xy
304
+
305
+ # Prepare cam_points for batch matrix multiplication
306
+ cam_coords_homo = torch.cat(
307
+ (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
308
+ ) # Bx3xN
309
+ # Apply intrinsic parameters using batch matrix multiplication
310
+ pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
311
+
312
+ # Extract x and y coordinates
313
+ pixel_coords = pixel_coords[:, :2, :] # Bx2xN
314
+
315
+ # Replace NaNs with default value
316
+ pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
317
+
318
+ return pixel_coords.transpose(1, 2) # BxNx2
319
+
320
+
321
+
322
+
323
+ def cam_from_img(pred_tracks, intrinsics, extra_params=None):
324
+ """
325
+ Normalize predicted tracks based on camera intrinsics.
326
+ Args:
327
+ intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
328
+ pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
329
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
330
+ Returns:
331
+ torch.Tensor: Normalized tracks tensor.
332
+ """
333
+
334
+ # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
335
+ # otherwise we can use something like
336
+ # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
337
+
338
+ principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
339
+ focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
340
+ tracks_normalized = (pred_tracks - principal_point) / focal_length
341
+
342
+ if extra_params is not None:
343
+ # Apply iterative undistortion
344
+ try:
345
+ tracks_normalized = iterative_undistortion(
346
+ extra_params, tracks_normalized
347
+ )
348
+ except:
349
+ tracks_normalized = single_undistortion(
350
+ extra_params, tracks_normalized
351
+ )
352
+
353
+ return tracks_normalized
354
+
355
+ ## Droid SLAM Part
356
+
357
+ MIN_DEPTH = 0.2
358
+
359
+ def extract_intrinsics(intrinsics):
360
+ return intrinsics[...,None,None,:].unbind(dim=-1)
361
+
362
+ def projective_transform(
363
+ poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False
364
+ ):
365
+ """map points from ii->jj"""
366
+
367
+ # inverse project (pinhole)
368
+ X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian)
369
+
370
+ # transform
371
+ Gij = poses[:, jj] * poses[:, ii].inv()
372
+
373
+ # Gij.data[:, ii == jj] = torch.as_tensor(
374
+ # [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda"
375
+ # )
376
+ X1, Ja = actp(Gij, X0, jacobian=jacobian)
377
+
378
+ # project (pinhole)
379
+ x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth)
380
+
381
+ # exclude points too close to camera
382
+ valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float()
383
+ valid = valid.unsqueeze(-1)
384
+
385
+ if jacobian:
386
+ # Ji transforms according to dual adjoint
387
+ Jj = torch.matmul(Jp, Ja)
388
+ Ji = -Gij[:, :, None, None, None].adjT(Jj)
389
+
390
+ Jz = Gij[:, :, None, None] * Jz
391
+ Jz = torch.matmul(Jp, Jz.unsqueeze(-1))
392
+
393
+ return x1, valid, (Ji, Jj, Jz)
394
+
395
+ return x1, valid
396
+
397
+
398
+ def induced_flow(poses, disps, intrinsics, ii, jj):
399
+ """optical flow induced by camera motion"""
400
+
401
+ ht, wd = disps.shape[2:]
402
+ y, x = torch.meshgrid(
403
+ torch.arange(ht, device=disps.device, dtype=torch.float),
404
+ torch.arange(wd, device=disps.device, dtype=torch.float),
405
+ indexing="ij",
406
+ )
407
+
408
+ coords0 = torch.stack([x, y], dim=-1)
409
+ coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False)
410
+
411
+ return coords1[..., :2] - coords0, valid
412
+
413
+ def all_pairs_distance_matrix(poses, beta=2.5):
414
+ """ compute distance matrix between all pairs of poses """
415
+ poses = np.array(poses, dtype=np.float32)
416
+ poses[:,:3] *= beta # scale to balence rot + trans
417
+ poses = SE3(torch.from_numpy(poses))
418
+
419
+ r = (poses[:,None].inv() * poses[None,:]).log()
420
+ return r.norm(dim=-1).cpu().numpy()
421
+
422
+ def pose_matrix_to_quaternion(pose):
423
+ """ convert 4x4 pose matrix to (t, q) """
424
+ q = Rotation.from_matrix(pose[..., :3, :3]).as_quat()
425
+ return np.concatenate([pose[..., :3, 3], q], axis=-1)
426
+
427
+ def compute_distance_matrix_flow(poses, disps, intrinsics):
428
+ """ compute flow magnitude between all pairs of frames """
429
+ if not isinstance(poses, SE3):
430
+ poses = torch.from_numpy(poses).float().cuda()[None]
431
+ poses = SE3(poses).inv()
432
+
433
+ disps = torch.from_numpy(disps).float().cuda()[None]
434
+ intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
435
+
436
+ N = poses.shape[1]
437
+
438
+ ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
439
+ ii = ii.reshape(-1).cuda()
440
+ jj = jj.reshape(-1).cuda()
441
+
442
+ MAX_FLOW = 100.0
443
+ matrix = np.zeros((N, N), dtype=np.float32)
444
+
445
+ s = 2048
446
+ for i in range(0, ii.shape[0], s):
447
+ flow1, val1 = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
448
+ flow2, val2 = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
449
+
450
+ flow = torch.stack([flow1, flow2], dim=2)
451
+ val = torch.stack([val1, val2], dim=2)
452
+
453
+ mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
454
+ mag = mag.view(mag.shape[1], -1)
455
+ val = val.view(val.shape[1], -1)
456
+
457
+ mag = (mag * val).mean(-1) / val.mean(-1)
458
+ mag[val.mean(-1) < 0.7] = np.inf
459
+
460
+ i1 = ii[i:i+s].cpu().numpy()
461
+ j1 = jj[i:i+s].cpu().numpy()
462
+ matrix[i1, j1] = mag.cpu().numpy()
463
+
464
+ return matrix
465
+
466
+
467
+ def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
468
+ """ compute flow magnitude between all pairs of frames """
469
+ # if not isinstance(poses, SE3):
470
+ # poses = torch.from_numpy(poses).float().cuda()[None]
471
+ # poses = SE3(poses).inv()
472
+
473
+ # disps = torch.from_numpy(disps).float().cuda()[None]
474
+ # intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
475
+
476
+ N = poses.shape[1]
477
+
478
+ ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
479
+ ii = ii.reshape(-1)
480
+ jj = jj.reshape(-1)
481
+
482
+ MAX_FLOW = 128.0
483
+ matrix = np.zeros((N, N), dtype=np.float32)
484
+
485
+ s = 2048
486
+ for i in range(0, ii.shape[0], s):
487
+ flow1a, val1a = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
488
+ flow1b, val1b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
489
+ flow2a, val2a = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
490
+ flow2b, val2b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
491
+
492
+ flow1 = flow1a + beta * flow1b
493
+ val1 = val1a * val2b
494
+
495
+ flow2 = flow2a + beta * flow2b
496
+ val2 = val2a * val2b
497
+
498
+ flow = torch.stack([flow1, flow2], dim=2)
499
+ val = torch.stack([val1, val2], dim=2)
500
+
501
+ mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
502
+ mag = mag.view(mag.shape[1], -1)
503
+ val = val.view(val.shape[1], -1)
504
+
505
+ mag = (mag * val).mean(-1) / val.mean(-1)
506
+ mag[val.mean(-1) < 0.8] = np.inf
507
+
508
+ i1 = ii[i:i+s].cpu().numpy()
509
+ j1 = jj[i:i+s].cpu().numpy()
510
+ matrix[i1, j1] = mag.cpu().numpy()
511
+
512
+ return matrix
513
+
514
+ def coords_grid(ht, wd, **kwargs):
515
+ y, x = torch.meshgrid(
516
+ torch.arange(ht, dtype=torch.float, **kwargs),
517
+ torch.arange(wd, dtype=torch.float, **kwargs),
518
+ indexing="ij",
519
+ )
520
+
521
+ return torch.stack([x, y], dim=-1)
522
+
523
+
524
+ def iproj(disps, intrinsics, jacobian=False):
525
+ """pinhole camera inverse projection"""
526
+ ht, wd = disps.shape[2:]
527
+ fx, fy, cx, cy = extract_intrinsics(intrinsics)
528
+
529
+ y, x = torch.meshgrid(
530
+ torch.arange(ht, device=disps.device, dtype=torch.float),
531
+ torch.arange(wd, device=disps.device, dtype=torch.float),
532
+ indexing="ij",
533
+ )
534
+
535
+ i = torch.ones_like(disps)
536
+ X = (x - cx) / fx
537
+ Y = (y - cy) / fy
538
+ pts = torch.stack([X, Y, i, disps], dim=-1)
539
+
540
+ if jacobian:
541
+ J = torch.zeros_like(pts)
542
+ J[..., -1] = 1.0
543
+ return pts, J
544
+
545
+ return pts, None
546
+
547
+
548
+ def proj(Xs, intrinsics, jacobian=False, return_depth=False):
549
+ """pinhole camera projection"""
550
+ fx, fy, cx, cy = extract_intrinsics(intrinsics)
551
+ X, Y, Z, D = Xs.unbind(dim=-1)
552
+
553
+ Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z)
554
+ d = 1.0 / Z
555
+
556
+ x = fx * (X * d) + cx
557
+ y = fy * (Y * d) + cy
558
+ if return_depth:
559
+ coords = torch.stack([x, y, D * d], dim=-1)
560
+ else:
561
+ coords = torch.stack([x, y], dim=-1)
562
+
563
+ if jacobian:
564
+ B, N, H, W = d.shape
565
+ o = torch.zeros_like(d)
566
+ proj_jac = torch.stack(
567
+ [
568
+ fx * d,
569
+ o,
570
+ -fx * X * d * d,
571
+ o,
572
+ o,
573
+ fy * d,
574
+ -fy * Y * d * d,
575
+ o,
576
+ # o, o, -D*d*d, d,
577
+ ],
578
+ dim=-1,
579
+ ).view(B, N, H, W, 2, 4)
580
+
581
+ return coords, proj_jac
582
+
583
+ return coords, None
584
+
585
+
586
+ def actp(Gij, X0, jacobian=False):
587
+ """action on point cloud"""
588
+ X1 = Gij[:, :, None, None] * X0
589
+
590
+ if jacobian:
591
+ X, Y, Z, d = X1.unbind(dim=-1)
592
+ o = torch.zeros_like(d)
593
+ B, N, H, W = d.shape
594
+
595
+ if isinstance(Gij, SE3):
596
+ Ja = torch.stack(
597
+ [
598
+ d,
599
+ o,
600
+ o,
601
+ o,
602
+ Z,
603
+ -Y,
604
+ o,
605
+ d,
606
+ o,
607
+ -Z,
608
+ o,
609
+ X,
610
+ o,
611
+ o,
612
+ d,
613
+ Y,
614
+ -X,
615
+ o,
616
+ o,
617
+ o,
618
+ o,
619
+ o,
620
+ o,
621
+ o,
622
+ ],
623
+ dim=-1,
624
+ ).view(B, N, H, W, 4, 6)
625
+
626
+ elif isinstance(Gij, Sim3):
627
+ Ja = torch.stack(
628
+ [
629
+ d,
630
+ o,
631
+ o,
632
+ o,
633
+ Z,
634
+ -Y,
635
+ X,
636
+ o,
637
+ d,
638
+ o,
639
+ -Z,
640
+ o,
641
+ X,
642
+ Y,
643
+ o,
644
+ o,
645
+ d,
646
+ Y,
647
+ -X,
648
+ o,
649
+ Z,
650
+ o,
651
+ o,
652
+ o,
653
+ o,
654
+ o,
655
+ o,
656
+ o,
657
+ ],
658
+ dim=-1,
659
+ ).view(B, N, H, W, 4, 7)
660
+
661
+ return X1, Ja
662
+
663
+ return X1, None
664
+
665
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
666
+ """
667
+ Returns torch.sqrt(torch.max(0, x))
668
+ but with a zero subgradient where x is 0.
669
+ """
670
+ ret = torch.zeros_like(x)
671
+ positive_mask = x > 0
672
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
673
+ return ret
674
+
675
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
676
+ """
677
+ Convert rotations given as rotation matrices to quaternions.
678
+
679
+ Args:
680
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
681
+
682
+ Returns:
683
+ quaternions with real part first, as tensor of shape (..., 4).
684
+ """
685
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
686
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
687
+
688
+ batch_dim = matrix.shape[:-2]
689
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
690
+ matrix.reshape(batch_dim + (9,)), dim=-1
691
+ )
692
+
693
+ q_abs = _sqrt_positive_part(
694
+ torch.stack(
695
+ [
696
+ 1.0 + m00 + m11 + m22,
697
+ 1.0 + m00 - m11 - m22,
698
+ 1.0 - m00 + m11 - m22,
699
+ 1.0 - m00 - m11 + m22,
700
+ ],
701
+ dim=-1,
702
+ )
703
+ )
704
+
705
+ quat_by_rijk = torch.stack(
706
+ [
707
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
708
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
709
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
710
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
711
+ ],
712
+ dim=-2,
713
+ )
714
+
715
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
716
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
717
+
718
+ out = quat_candidates[
719
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
720
+ ].reshape(batch_dim + (4,))
721
+ return standardize_quaternion(out)
722
+
723
+
724
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
725
+ """
726
+ Convert a unit quaternion to a standard form: one in which the real
727
+ part is non negative.
728
+
729
+ Args:
730
+ quaternions: Quaternions with real part first,
731
+ as tensor of shape (..., 4).
732
+
733
+ Returns:
734
+ Standardized quaternions as tensor of shape (..., 4).
735
+ """
736
+ quaternions = F.normalize(quaternions, p=2, dim=-1)
737
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
738
+
739
+ def umeyama(X, Y):
740
+ """
741
+ Estimates the Sim(3) transformation between `X` and `Y` point sets.
742
+
743
+ Estimates c, R and t such as c * R @ X + t ~ Y.
744
+
745
+ Parameters
746
+ ----------
747
+ X : numpy.array
748
+ (m, n) shaped numpy array. m is the dimension of the points,
749
+ n is the number of points in the point set.
750
+ Y : numpy.array
751
+ (m, n) shaped numpy array. Indexes should be consistent with `X`.
752
+ That is, Y[:, i] must be the point corresponding to X[:, i].
753
+
754
+ Returns
755
+ -------
756
+ c : float
757
+ Scale factor.
758
+ R : numpy.array
759
+ (3, 3) shaped rotation matrix.
760
+ t : numpy.array
761
+ (3, 1) shaped translation vector.
762
+ """
763
+ mu_x = X.mean(axis=1).reshape(-1, 1)
764
+ mu_y = Y.mean(axis=1).reshape(-1, 1)
765
+ var_x = np.square(X - mu_x).sum(axis=0).mean()
766
+ cov_xy = ((Y - mu_y) @ (X - mu_x).T) / X.shape[1]
767
+ U, D, VH = np.linalg.svd(cov_xy)
768
+ S = np.eye(X.shape[0])
769
+ if np.linalg.det(U) * np.linalg.det(VH) < 0:
770
+ S[-1, -1] = -1
771
+ c = np.trace(np.diag(D) @ S) / var_x
772
+ R = U @ S @ VH
773
+ t = mu_y - c * R @ mu_x
774
+ return c, R, t
lingbot_map/utils/load_fn.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ import torch
10
+ from PIL import Image
11
+ from torchvision import transforms as TF
12
+ from tqdm.auto import tqdm
13
+ import numpy as np
14
+
15
+
16
+ def load_and_preprocess_images_square(image_path_list, target_size=1024):
17
+ """
18
+ Load and preprocess images by center padding to square and resizing to target size.
19
+ Also returns the position information of original pixels after transformation.
20
+
21
+ Args:
22
+ image_path_list (list): List of paths to image files
23
+ target_size (int, optional): Target size for both width and height. Defaults to 518.
24
+
25
+ Returns:
26
+ tuple: (
27
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
28
+ torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
29
+ )
30
+
31
+ Raises:
32
+ ValueError: If the input list is empty
33
+ """
34
+ # Check for empty list
35
+ if len(image_path_list) == 0:
36
+ raise ValueError("At least 1 image is required")
37
+
38
+ images = []
39
+ original_coords = [] # Renamed from position_info to be more descriptive
40
+ to_tensor = TF.ToTensor()
41
+
42
+ for image_path in image_path_list:
43
+ # Open image
44
+ img = Image.open(image_path)
45
+
46
+ # If there's an alpha channel, blend onto white background
47
+ if img.mode == "RGBA":
48
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
49
+ img = Image.alpha_composite(background, img)
50
+
51
+ # Convert to RGB
52
+ img = img.convert("RGB")
53
+
54
+ # Get original dimensions
55
+ width, height = img.size
56
+
57
+ # Make the image square by padding the shorter dimension
58
+ max_dim = max(width, height)
59
+
60
+ # Calculate padding
61
+ left = (max_dim - width) // 2
62
+ top = (max_dim - height) // 2
63
+
64
+ # Calculate scale factor for resizing
65
+ scale = target_size / max_dim
66
+
67
+ # Calculate final coordinates of original image in target space
68
+ x1 = left * scale
69
+ y1 = top * scale
70
+ x2 = (left + width) * scale
71
+ y2 = (top + height) * scale
72
+
73
+ # Store original image coordinates and scale
74
+ original_coords.append(np.array([x1, y1, x2, y2, width, height]))
75
+
76
+ # Create a new black square image and paste original
77
+ square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
78
+ square_img.paste(img, (left, top))
79
+
80
+ # Resize to target size
81
+ square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
82
+
83
+ # Convert to tensor
84
+ img_tensor = to_tensor(square_img)
85
+ images.append(img_tensor)
86
+
87
+ # Stack all images
88
+ images = torch.stack(images)
89
+ original_coords = torch.from_numpy(np.array(original_coords)).float()
90
+
91
+ # Add additional dimension if single image to ensure correct shape
92
+ if len(image_path_list) == 1:
93
+ if images.dim() == 3:
94
+ images = images.unsqueeze(0)
95
+ original_coords = original_coords.unsqueeze(0)
96
+
97
+ return images, original_coords
98
+
99
+
100
+ def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
101
+ """
102
+ A quick start function to load and preprocess images for model input.
103
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
104
+
105
+ Args:
106
+ image_path_list (list): List of paths to image files
107
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
108
+ - "crop" (default): Sets width to 518px and center crops height if needed.
109
+ - "pad": Preserves all pixels by making the largest dimension 518px
110
+ and padding the smaller dimension to reach a square shape.
111
+
112
+ Returns:
113
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
114
+
115
+ Raises:
116
+ ValueError: If the input list is empty or if mode is invalid
117
+
118
+ Notes:
119
+ - Images with different dimensions will be padded with white (value=1.0)
120
+ - A warning is printed when images have different shapes
121
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
122
+ and height is center-cropped if larger than 518px
123
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
124
+ and the smaller dimension is padded to reach a square shape (518x518)
125
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
126
+ """
127
+ # Check for empty list
128
+ if len(image_path_list) == 0:
129
+ raise ValueError("At least 1 image is required")
130
+
131
+
132
+
133
+ # Validate mode
134
+ if mode not in ["crop", "pad"]:
135
+ raise ValueError("Mode must be either 'crop' or 'pad'")
136
+
137
+ target_size = image_size
138
+ to_tensor = TF.ToTensor()
139
+
140
+ def _load_one(idx_path):
141
+ i, image_path = idx_path
142
+ img = Image.open(image_path)
143
+ if img.mode == "RGBA":
144
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
145
+ img = Image.alpha_composite(background, img)
146
+ img = img.convert("RGB")
147
+
148
+ width, height = img.size
149
+
150
+ fx_val = fy_val = cx_val = cy_val = None
151
+ if fx is not None:
152
+ fx_val = fx[i] * width
153
+ fy_val = fy[i] * height
154
+ cx_val = cx[i] * width
155
+ cy_val = cy[i] * height
156
+
157
+ if mode == "pad":
158
+ if width >= height:
159
+ new_width = target_size
160
+ new_height = round(height * (new_width / width) / patch_size) * patch_size
161
+ else:
162
+ new_height = target_size
163
+ new_width = round(width * (new_height / height) / patch_size) * patch_size
164
+ else: # crop
165
+ new_width = target_size
166
+ new_height = round(height * (new_width / width) / patch_size) * patch_size
167
+
168
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
169
+ img = to_tensor(img)
170
+
171
+ if mode == "crop" and new_height > target_size:
172
+ start_y = (new_height - target_size) // 2
173
+ img = img[:, start_y : start_y + target_size, :]
174
+
175
+ if fx is not None:
176
+ fx_val = fx_val * new_width / width
177
+ fy_val = fy_val * new_height / height
178
+ cx_val = img.shape[2] / 2
179
+ cy_val = img.shape[1] / 2
180
+
181
+ if mode == "pad":
182
+ h_padding = target_size - img.shape[1]
183
+ w_padding = target_size - img.shape[2]
184
+ if h_padding > 0 or w_padding > 0:
185
+ pad_top = h_padding // 2
186
+ pad_bottom = h_padding - pad_top
187
+ pad_left = w_padding // 2
188
+ pad_right = w_padding - pad_left
189
+ img = torch.nn.functional.pad(
190
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
191
+ )
192
+
193
+ return i, img, (fx_val, fy_val, cx_val, cy_val)
194
+
195
+ # Parallel load with progress bar
196
+ num_workers = min(16, len(image_path_list))
197
+ results = [None] * len(image_path_list)
198
+ with ThreadPoolExecutor(max_workers=num_workers) as pool:
199
+ futures = pool.map(_load_one, enumerate(image_path_list))
200
+ for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"):
201
+ results[i] = img
202
+ if fx is not None:
203
+ fx[i], fy[i], cx[i], cy[i] = calib
204
+
205
+ images = results
206
+ shapes = set((img.shape[1], img.shape[2]) for img in images)
207
+
208
+ # Check if we have different shapes
209
+ # In theory our model can also work well with different shapes
210
+ if len(shapes) > 1:
211
+ print(f"Warning: Found images with different shapes: {shapes}")
212
+ # Find maximum dimensions
213
+ max_height = max(shape[0] for shape in shapes)
214
+ max_width = max(shape[1] for shape in shapes)
215
+
216
+ # Pad images if necessary
217
+ padded_images = []
218
+ for img in images:
219
+ h_padding = max_height - img.shape[1]
220
+ w_padding = max_width - img.shape[2]
221
+
222
+ if h_padding > 0 or w_padding > 0:
223
+ pad_top = h_padding // 2
224
+ pad_bottom = h_padding - pad_top
225
+ pad_left = w_padding // 2
226
+ pad_right = w_padding - pad_left
227
+
228
+ img = torch.nn.functional.pad(
229
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
230
+ )
231
+ padded_images.append(img)
232
+ images = padded_images
233
+
234
+ images = torch.stack(images) # concatenate images
235
+
236
+ # Ensure correct shape when single image
237
+ if len(image_path_list) == 1:
238
+ # Verify shape is (1, C, H, W)
239
+ if images.dim() == 3:
240
+ images = images.unsqueeze(0)
241
+ if fx is not None:
242
+ return images, fx, fy, cx, cy
243
+ return images
lingbot_map/utils/pose_enc.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+ import os
10
+ import torch
11
+ import numpy as np
12
+ import gzip
13
+ import json
14
+ import random
15
+ import logging
16
+ import warnings
17
+
18
+ from lingbot_map.utils.geometry import closed_form_inverse_se3, closed_form_inverse_se3_general
19
+
20
+
21
+ def extri_intri_to_pose_encoding(
22
+ extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
23
+ ):
24
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
25
+
26
+ This function transforms camera parameters into a unified pose encoding format,
27
+ which can be used for various downstream tasks like pose prediction or representation.
28
+
29
+ Args:
30
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
31
+ where B is batch size and S is sequence length.
32
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
33
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
34
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
35
+ Defined in pixels, with format:
36
+ [[fx, 0, cx],
37
+ [0, fy, cy],
38
+ [0, 0, 1]]
39
+ where fx, fy are focal lengths and (cx, cy) is the principal point
40
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
41
+ Required for computing field of view values. For example: (256, 512).
42
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
43
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
44
+
45
+ Returns:
46
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
47
+ For "absT_quaR_FoV" type, the 9 dimensions are:
48
+ - [:3] = absolute translation vector T (3D)
49
+ - [3:7] = rotation as quaternion quat (4D)
50
+ - [7:] = field of view (2D)
51
+ """
52
+
53
+ # extrinsics: BxSx3x4
54
+ # intrinsics: BxSx3x3
55
+
56
+ if pose_encoding_type == "absT_quaR_FoV":
57
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
58
+ T = extrinsics[:, :, :3, 3] # BxSx3
59
+
60
+ quat = mat_to_quat(R)
61
+ # Note the order of h and w here
62
+ H, W = image_size_hw
63
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
64
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
65
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
66
+ else:
67
+ raise NotImplementedError
68
+
69
+ return pose_encoding
70
+
71
+
72
+ def pose_encoding_to_extri_intri(
73
+ pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
74
+ ):
75
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
76
+
77
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
78
+ reconstructing the full camera parameters from the compact encoding.
79
+
80
+ Args:
81
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
82
+ where B is batch size and S is sequence length.
83
+ For "absT_quaR_FoV" type, the 9 dimensions are:
84
+ - [:3] = absolute translation vector T (3D)
85
+ - [3:7] = rotation as quaternion quat (4D)
86
+ - [7:] = field of view (2D)
87
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
88
+ Required for reconstructing intrinsics from field of view values.
89
+ For example: (256, 512).
90
+ pose_encoding_type (str): Type of pose encoding used. Currently only
91
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
92
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
93
+ If False, only extrinsics are returned and intrinsics will be None.
94
+
95
+ Returns:
96
+ tuple: (extrinsics, intrinsics)
97
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
98
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
99
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
100
+ a 3x1 translation vector.
101
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
102
+ or None if build_intrinsics is False. Defined in pixels, with format:
103
+ [[fx, 0, cx],
104
+ [0, fy, cy],
105
+ [0, 0, 1]]
106
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
107
+ assumed to be at the center of the image (W/2, H/2).
108
+ """
109
+
110
+ intrinsics = None
111
+
112
+ if pose_encoding_type == "absT_quaR_FoV":
113
+ T = pose_encoding[..., :3]
114
+ quat = pose_encoding[..., 3:7]
115
+ fov_h = pose_encoding[..., 7]
116
+ fov_w = pose_encoding[..., 8]
117
+
118
+ R = quat_to_mat(quat)
119
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
120
+
121
+ if build_intrinsics:
122
+ H, W = image_size_hw
123
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
124
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
125
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
126
+ intrinsics[..., 0, 0] = fx
127
+ intrinsics[..., 1, 1] = fy
128
+ intrinsics[..., 0, 2] = W / 2
129
+ intrinsics[..., 1, 2] = H / 2
130
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
131
+ elif pose_encoding_type == "absT_quaR":
132
+ T = pose_encoding[..., :3]
133
+ quat = pose_encoding[..., 3:7]
134
+
135
+ R = quat_to_mat(quat)
136
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
137
+
138
+ intrinsics = None
139
+
140
+ return extrinsics, intrinsics
141
+
142
+ def convert_pt3d_RT_to_opencv(Rot, Trans):
143
+ """
144
+ Convert Point3D extrinsic matrices to OpenCV convention.
145
+
146
+ Args:
147
+ Rot: 3D rotation matrix in Point3D format
148
+ Trans: 3D translation vector in Point3D format
149
+
150
+ Returns:
151
+ extri_opencv: 3x4 extrinsic matrix in OpenCV format
152
+ """
153
+ rot_pt3d = np.array(Rot)
154
+ trans_pt3d = np.array(Trans)
155
+
156
+ trans_pt3d[:2] *= -1
157
+ rot_pt3d[:, :2] *= -1
158
+ rot_pt3d = rot_pt3d.transpose(1, 0)
159
+ extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
160
+ return extri_opencv
161
+
162
+
163
+ def build_pair_index(N, B=1):
164
+ """
165
+ Build indices for all possible pairs of frames.
166
+
167
+ Args:
168
+ N: Number of frames
169
+ B: Batch size
170
+
171
+ Returns:
172
+ i1, i2: Indices for all possible pairs
173
+ """
174
+ i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
175
+ i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
176
+ return i1, i2
177
+
178
+
179
+ def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
180
+ """
181
+ Calculate rotation angle error between ground truth and predicted rotations.
182
+
183
+ Args:
184
+ rot_gt: Ground truth rotation matrices
185
+ rot_pred: Predicted rotation matrices
186
+ batch_size: Batch size for reshaping the result
187
+ eps: Small value to avoid numerical issues
188
+
189
+ Returns:
190
+ Rotation angle error in degrees
191
+ """
192
+ q_pred = mat_to_quat(rot_pred)
193
+ q_gt = mat_to_quat(rot_gt)
194
+
195
+ loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
196
+ err_q = torch.arccos(1 - 2 * loss_q)
197
+
198
+ rel_rangle_deg = err_q * 180 / np.pi
199
+
200
+ if batch_size is not None:
201
+ rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
202
+
203
+ return rel_rangle_deg
204
+
205
+
206
+ def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
207
+ """
208
+ Calculate translation angle error between ground truth and predicted translations.
209
+
210
+ Args:
211
+ tvec_gt: Ground truth translation vectors
212
+ tvec_pred: Predicted translation vectors
213
+ batch_size: Batch size for reshaping the result
214
+ ambiguity: Whether to handle direction ambiguity
215
+
216
+ Returns:
217
+ Translation angle error in degrees
218
+ """
219
+ rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
220
+ rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
221
+
222
+ if ambiguity:
223
+ rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
224
+
225
+ if batch_size is not None:
226
+ rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
227
+
228
+ return rel_tangle_deg
229
+
230
+
231
+ def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
232
+ """
233
+ Normalize the translation vectors and compute the angle between them.
234
+
235
+ Args:
236
+ t_gt: Ground truth translation vectors
237
+ t: Predicted translation vectors
238
+ eps: Small value to avoid division by zero
239
+ default_err: Default error value for invalid cases
240
+
241
+ Returns:
242
+ Angular error between translation vectors in radians
243
+ """
244
+ t_norm = torch.norm(t, dim=1, keepdim=True)
245
+ t = t / (t_norm + eps)
246
+
247
+ t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
248
+ t_gt = t_gt / (t_gt_norm + eps)
249
+
250
+ loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
251
+ err_t = torch.acos(torch.sqrt(1 - loss_t))
252
+
253
+ err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
254
+ return err_t
255
+
256
+
257
+ def calculate_auc_np(r_error, t_error, max_threshold=30):
258
+ """
259
+ Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
260
+
261
+ Args:
262
+ r_error: numpy array representing R error values (Degree)
263
+ t_error: numpy array representing T error values (Degree)
264
+ max_threshold: Maximum threshold value for binning the histogram
265
+
266
+ Returns:
267
+ AUC value and the normalized histogram
268
+ """
269
+ error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
270
+ max_errors = np.max(error_matrix, axis=1)
271
+ bins = np.arange(max_threshold + 1)
272
+ histogram, _ = np.histogram(max_errors, bins=bins)
273
+ num_pairs = float(len(max_errors))
274
+ normalized_histogram = histogram.astype(float) / num_pairs
275
+ return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
276
+
277
+
278
+ def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
279
+ """
280
+ Compute rotation and translation errors between predicted and ground truth poses.
281
+ This function assumes the input poses are world-to-camera (w2c) transformations.
282
+
283
+ Args:
284
+ pred_se3: Predicted SE(3) transformations (w2c), shape (N, 4, 4)
285
+ gt_se3: Ground truth SE(3) transformations (w2c), shape (N, 4, 4)
286
+ num_frames: Number of frames (N)
287
+
288
+ Returns:
289
+ Rotation and translation angle errors in degrees
290
+ """
291
+ pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
292
+
293
+ relative_pose_gt = gt_se3[pair_idx_i1].bmm(
294
+ closed_form_inverse_se3(gt_se3[pair_idx_i2])
295
+ )
296
+ relative_pose_pred = pred_se3[pair_idx_i1].bmm(
297
+ closed_form_inverse_se3(pred_se3[pair_idx_i2])
298
+ )
299
+
300
+ rel_rangle_deg = rotation_angle(
301
+ relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
302
+ )
303
+ rel_tangle_deg = translation_angle(
304
+ relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
305
+ )
306
+
307
+ return rel_rangle_deg, rel_tangle_deg
308
+
309
+
310
+ def colmap_to_opencv_intrinsics(K):
311
+ """
312
+ Modify camera intrinsics to follow a different convention.
313
+ Coordinates of the center of the top-left pixels are by default:
314
+ - (0.5, 0.5) in Colmap
315
+ - (0,0) in OpenCV
316
+ """
317
+ K = K.copy()
318
+ K[..., 0, 2] -= 0.5
319
+ K[..., 1, 2] -= 0.5
320
+ return K
321
+
322
+ def read_camera_parameters(filename):
323
+ with open(filename) as f:
324
+ lines = f.readlines()
325
+ lines = [line.rstrip() for line in lines]
326
+ # extrinsics: line [1,5), 4x4 matrix
327
+ extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
328
+ # intrinsics: line [7-10), 3x3 matrix
329
+ intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
330
+
331
+ return intrinsics, extrinsics
lingbot_map/utils/rotation.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
63
+
64
+ q_abs = _sqrt_positive_part(
65
+ torch.stack(
66
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
67
+ )
68
+ )
69
+
70
+ # we produce the desired quaternion multiplied by each of r, i, j, k
71
+ quat_by_rijk = torch.stack(
72
+ [
73
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
74
+ # `int`.
75
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
76
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
77
+ # `int`.
78
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
79
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
80
+ # `int`.
81
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
82
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
83
+ # `int`.
84
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
85
+ ],
86
+ dim=-2,
87
+ )
88
+
89
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
90
+ # the candidate won't be picked.
91
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
92
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
93
+
94
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
95
+ # forall i; we pick the best-conditioned one (with the largest denominator)
96
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
97
+
98
+ # Convert from rijk to ijkr
99
+ out = out[..., [1, 2, 3, 0]]
100
+
101
+ out = standardize_quaternion(out)
102
+
103
+ return out
104
+
105
+
106
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Returns torch.sqrt(torch.max(0, x))
109
+ but with a zero subgradient where x is 0.
110
+ """
111
+ ret = torch.zeros_like(x)
112
+ positive_mask = x > 0
113
+ if torch.is_grad_enabled():
114
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
115
+ else:
116
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
117
+ return ret
118
+
119
+
120
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
121
+ """
122
+ Convert a unit quaternion to a standard form: one in which the real
123
+ part is non negative.
124
+
125
+ Args:
126
+ quaternions: Quaternions with real part last,
127
+ as tensor of shape (..., 4).
128
+
129
+ Returns:
130
+ Standardized quaternions as tensor of shape (..., 4).
131
+ """
132
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
lingbot_map/vis/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ GCT Visualization Module
9
+
10
+ This module provides visualization utilities for 3D reconstruction results:
11
+ - PointCloudViewer: Interactive point cloud viewer with camera visualization
12
+ - viser_wrapper: Quick visualization wrapper for predictions
13
+ - predictions_to_glb: Export predictions to GLB 3D format
14
+ - Colorization and utility functions
15
+
16
+ Usage:
17
+ from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb
18
+
19
+ # Interactive visualization
20
+ viewer = PointCloudViewer(pred_dict=predictions, port=8080)
21
+ viewer.run()
22
+
23
+ # Quick visualization
24
+ viser_wrapper(predictions, port=8080)
25
+
26
+ # Export to GLB
27
+ scene = predictions_to_glb(predictions)
28
+ scene.export("output.glb")
29
+ """
30
+
31
+ from lingbot_map.vis.point_cloud_viewer import PointCloudViewer
32
+ from lingbot_map.vis.viser_wrapper import viser_wrapper
33
+ from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar
34
+ from lingbot_map.vis.sky_segmentation import (
35
+ apply_sky_segmentation,
36
+ download_skyseg_model,
37
+ load_or_create_sky_masks,
38
+ segment_sky,
39
+ )
40
+ from lingbot_map.vis.glb_export import predictions_to_glb
41
+
42
+ __all__ = [
43
+ # Main viewer
44
+ "PointCloudViewer",
45
+ # Quick visualization
46
+ "viser_wrapper",
47
+ # GLB export
48
+ "predictions_to_glb",
49
+ # Utilities
50
+ "CameraState",
51
+ "colorize",
52
+ "colorize_np",
53
+ "get_vertical_colorbar",
54
+ # Sky segmentation
55
+ "apply_sky_segmentation",
56
+ "segment_sky",
57
+ "download_skyseg_model",
58
+ "load_or_create_sky_masks",
59
+ ]
lingbot_map/vis/glb_export.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ GLB 3D export utilities for GCT predictions.
9
+ """
10
+
11
+ import os
12
+ import copy
13
+ from typing import Optional, Tuple
14
+
15
+ import numpy as np
16
+ import cv2
17
+ import matplotlib
18
+ from scipy.spatial.transform import Rotation
19
+
20
+ from lingbot_map.vis.sky_segmentation import (
21
+ _SKYSEG_INPUT_SIZE,
22
+ _SKYSEG_SOFT_THRESHOLD,
23
+ _mask_to_float,
24
+ _mask_to_uint8,
25
+ _result_map_to_non_sky_conf,
26
+ )
27
+
28
+ try:
29
+ import trimesh
30
+ except ImportError:
31
+ trimesh = None
32
+ print("trimesh not found. GLB export will not work.")
33
+
34
+
35
+ def predictions_to_glb(
36
+ predictions: dict,
37
+ conf_thres: float = 50.0,
38
+ filter_by_frames: str = "all",
39
+ mask_black_bg: bool = False,
40
+ mask_white_bg: bool = False,
41
+ show_cam: bool = True,
42
+ mask_sky: bool = False,
43
+ target_dir: Optional[str] = None,
44
+ prediction_mode: str = "Predicted Pointmap",
45
+ ) -> "trimesh.Scene":
46
+ """
47
+ Converts GCT predictions to a 3D scene represented as a GLB file.
48
+
49
+ Args:
50
+ predictions: Dictionary containing model predictions with keys:
51
+ - world_points: 3D point coordinates (S, H, W, 3)
52
+ - world_points_conf: Confidence scores (S, H, W)
53
+ - images: Input images (S, H, W, 3) or (S, 3, H, W)
54
+ - extrinsic: Camera extrinsic matrices (S, 3, 4)
55
+ conf_thres: Percentage of low-confidence points to filter out
56
+ filter_by_frames: Frame filter specification ("all" or frame index)
57
+ mask_black_bg: Mask out black background pixels
58
+ mask_white_bg: Mask out white background pixels
59
+ show_cam: Include camera visualization
60
+ mask_sky: Apply sky segmentation mask
61
+ target_dir: Output directory for intermediate files
62
+ prediction_mode: "Predicted Pointmap" or "Predicted Depthmap"
63
+
64
+ Returns:
65
+ trimesh.Scene: Processed 3D scene containing point cloud and cameras
66
+
67
+ Raises:
68
+ ValueError: If input predictions structure is invalid
69
+ ImportError: If trimesh is not available
70
+ """
71
+ if trimesh is None:
72
+ raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh")
73
+
74
+ if not isinstance(predictions, dict):
75
+ raise ValueError("predictions must be a dictionary")
76
+
77
+ if conf_thres is None:
78
+ conf_thres = 10.0
79
+
80
+ print("Building GLB scene")
81
+
82
+ # Parse frame filter
83
+ selected_frame_idx = None
84
+ if filter_by_frames != "all" and filter_by_frames != "All":
85
+ try:
86
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
87
+ except (ValueError, IndexError):
88
+ pass
89
+
90
+ # Select prediction source
91
+ if "Pointmap" in prediction_mode:
92
+ print("Using Pointmap Branch")
93
+ if "world_points" in predictions:
94
+ pred_world_points = predictions["world_points"]
95
+ pred_world_points_conf = predictions.get(
96
+ "world_points_conf", np.ones_like(pred_world_points[..., 0])
97
+ )
98
+ else:
99
+ print("Warning: world_points not found, falling back to depth-based points")
100
+ pred_world_points = predictions["world_points_from_depth"]
101
+ pred_world_points_conf = predictions.get(
102
+ "depth_conf", np.ones_like(pred_world_points[..., 0])
103
+ )
104
+ else:
105
+ print("Using Depthmap and Camera Branch")
106
+ pred_world_points = predictions["world_points_from_depth"]
107
+ pred_world_points_conf = predictions.get(
108
+ "depth_conf", np.ones_like(pred_world_points[..., 0])
109
+ )
110
+
111
+ images = predictions["images"]
112
+ camera_matrices = predictions["extrinsic"]
113
+
114
+ # Apply sky segmentation if enabled
115
+ if mask_sky and target_dir is not None:
116
+ pred_world_points_conf = _apply_sky_mask(
117
+ pred_world_points_conf, target_dir, images
118
+ )
119
+
120
+ # Apply frame filter
121
+ if selected_frame_idx is not None:
122
+ pred_world_points = pred_world_points[selected_frame_idx][None]
123
+ pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
124
+ images = images[selected_frame_idx][None]
125
+ camera_matrices = camera_matrices[selected_frame_idx][None]
126
+
127
+ # Prepare vertices and colors
128
+ vertices_3d = pred_world_points.reshape(-1, 3)
129
+
130
+ # Handle different image formats
131
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
132
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
133
+ else:
134
+ colors_rgb = images
135
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
136
+
137
+ # Apply confidence filtering
138
+ conf = pred_world_points_conf.reshape(-1)
139
+ conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0
140
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
141
+
142
+ # Apply background masking
143
+ if mask_black_bg:
144
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
145
+ conf_mask = conf_mask & black_bg_mask
146
+
147
+ if mask_white_bg:
148
+ white_bg_mask = ~(
149
+ (colors_rgb[:, 0] > 240) &
150
+ (colors_rgb[:, 1] > 240) &
151
+ (colors_rgb[:, 2] > 240)
152
+ )
153
+ conf_mask = conf_mask & white_bg_mask
154
+
155
+ vertices_3d = vertices_3d[conf_mask]
156
+ colors_rgb = colors_rgb[conf_mask]
157
+
158
+ # Handle empty point cloud
159
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
160
+ vertices_3d = np.array([[1, 0, 0]])
161
+ colors_rgb = np.array([[255, 255, 255]])
162
+ scene_scale = 1
163
+ else:
164
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
165
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
166
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
167
+
168
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
169
+
170
+ # Build scene
171
+ scene_3d = trimesh.Scene()
172
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
173
+ scene_3d.add_geometry(point_cloud_data)
174
+
175
+ # Prepare camera matrices
176
+ num_cameras = len(camera_matrices)
177
+ extrinsics_matrices = np.zeros((num_cameras, 4, 4))
178
+ extrinsics_matrices[:, :3, :4] = camera_matrices
179
+ extrinsics_matrices[:, 3, 3] = 1
180
+
181
+ # Add cameras
182
+ if show_cam:
183
+ for i in range(num_cameras):
184
+ world_to_camera = extrinsics_matrices[i]
185
+ camera_to_world = np.linalg.inv(world_to_camera)
186
+ rgba_color = colormap(i / num_cameras)
187
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
188
+ integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
189
+
190
+ # Align scene
191
+ scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
192
+
193
+ print("GLB Scene built")
194
+ return scene_3d
195
+
196
+
197
+ def _apply_sky_mask(
198
+ conf: np.ndarray,
199
+ target_dir: str,
200
+ images: np.ndarray
201
+ ) -> np.ndarray:
202
+ """Apply sky segmentation mask to confidence scores."""
203
+ try:
204
+ import onnxruntime
205
+ except ImportError:
206
+ print("Warning: onnxruntime not available, skipping sky masking")
207
+ return conf
208
+
209
+ target_dir_images = os.path.join(target_dir, "images")
210
+ if not os.path.exists(target_dir_images):
211
+ print(f"Warning: Images directory not found at {target_dir_images}")
212
+ return conf
213
+
214
+ image_list = sorted(os.listdir(target_dir_images))
215
+ S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2])
216
+
217
+ skyseg_model_path = "skyseg.onnx"
218
+ if not os.path.exists(skyseg_model_path):
219
+ print("Downloading skyseg.onnx...")
220
+ download_file_from_url(
221
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
222
+ skyseg_model_path
223
+ )
224
+
225
+ skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
226
+ sky_mask_list = []
227
+
228
+ for i, image_name in enumerate(image_list[:S]):
229
+ image_filepath = os.path.join(target_dir_images, image_name)
230
+ mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
231
+
232
+ if os.path.exists(mask_filepath):
233
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
234
+ else:
235
+ sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
236
+
237
+ if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
238
+ sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR)
239
+
240
+ sky_mask_list.append(_mask_to_float(sky_mask))
241
+
242
+ sky_mask_array = np.array(sky_mask_list)
243
+ sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
244
+ return conf * sky_mask_binary
245
+
246
+
247
+ def integrate_camera_into_scene(
248
+ scene: "trimesh.Scene",
249
+ transform: np.ndarray,
250
+ face_colors: Tuple[int, int, int],
251
+ scene_scale: float,
252
+ frustum_thickness: float = 1.0,
253
+ ):
254
+ """
255
+ Integrates a camera mesh into the 3D scene.
256
+
257
+ Args:
258
+ scene: The 3D scene to add the camera model
259
+ transform: Transformation matrix for camera positioning
260
+ face_colors: RGB color tuple for the camera
261
+ scene_scale: Scale of the scene
262
+ frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker)
263
+ """
264
+ cam_width = scene_scale * 0.05
265
+ cam_height = scene_scale * 0.1
266
+
267
+ rot_45_degree = np.eye(4)
268
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
269
+ rot_45_degree[2, 3] = -cam_height
270
+
271
+ opengl_transform = get_opengl_conversion_matrix()
272
+ complete_transform = transform @ opengl_transform @ rot_45_degree
273
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
274
+
275
+ # Build thicker frustum by stacking rotated copies
276
+ slight_rotation = np.eye(4)
277
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
278
+
279
+ shell_scales = [1.0, 0.95]
280
+ shell_transforms = [np.eye(4), slight_rotation]
281
+ # Add extra shells for thickness
282
+ if frustum_thickness > 1.0:
283
+ n_extra = max(1, int(frustum_thickness - 1))
284
+ for k in range(1, n_extra + 1):
285
+ # Progressively rotated and scaled copies
286
+ angle = 2.0 + k * 2.0
287
+ scale = 1.0 + k * 0.02
288
+ rot = np.eye(4)
289
+ rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix()
290
+ shell_scales.append(scale)
291
+ shell_transforms.append(rot)
292
+ rot_neg = np.eye(4)
293
+ rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix()
294
+ shell_scales.append(scale)
295
+ shell_transforms.append(rot_neg)
296
+
297
+ vertices_parts = []
298
+ for s, t_mat in zip(shell_scales, shell_transforms):
299
+ vertices_parts.append(
300
+ transform_points(t_mat, s * camera_cone_shape.vertices)
301
+ )
302
+ vertices_combined = np.concatenate(vertices_parts)
303
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
304
+
305
+ mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales))
306
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
307
+ camera_mesh.visual.face_colors[:, :3] = face_colors
308
+ scene.add_geometry(camera_mesh)
309
+
310
+
311
+ def apply_scene_alignment(
312
+ scene_3d: "trimesh.Scene",
313
+ extrinsics_matrices: np.ndarray
314
+ ) -> "trimesh.Scene":
315
+ """
316
+ Aligns the 3D scene based on the extrinsics of the first camera.
317
+
318
+ Args:
319
+ scene_3d: The 3D scene to be aligned
320
+ extrinsics_matrices: Camera extrinsic matrices
321
+
322
+ Returns:
323
+ Aligned 3D scene
324
+ """
325
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
326
+
327
+ align_rotation = np.eye(4)
328
+ align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
329
+
330
+ initial_transformation = (
331
+ np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
332
+ )
333
+ scene_3d.apply_transform(initial_transformation)
334
+ return scene_3d
335
+
336
+
337
+ def get_opengl_conversion_matrix() -> np.ndarray:
338
+ """Returns the OpenGL conversion matrix (flips Y and Z axes)."""
339
+ matrix = np.identity(4)
340
+ matrix[1, 1] = -1
341
+ matrix[2, 2] = -1
342
+ return matrix
343
+
344
+
345
+ def transform_points(
346
+ transformation: np.ndarray,
347
+ points: np.ndarray,
348
+ dim: Optional[int] = None
349
+ ) -> np.ndarray:
350
+ """
351
+ Applies a 4x4 transformation to a set of points.
352
+
353
+ Args:
354
+ transformation: Transformation matrix
355
+ points: Points to be transformed
356
+ dim: Dimension for reshaping the result
357
+
358
+ Returns:
359
+ Transformed points
360
+ """
361
+ points = np.asarray(points)
362
+ initial_shape = points.shape[:-1]
363
+ dim = dim or points.shape[-1]
364
+
365
+ transformation = transformation.swapaxes(-1, -2)
366
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
367
+
368
+ return points[..., :dim].reshape(*initial_shape, dim)
369
+
370
+
371
+ def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray:
372
+ """Computes the faces for the camera mesh."""
373
+ faces_list = []
374
+ num_vertices_cone = len(cone_shape.vertices)
375
+
376
+ for face in cone_shape.faces:
377
+ if 0 in face:
378
+ continue
379
+ v1, v2, v3 = face
380
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
381
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
382
+
383
+ faces_list.extend([
384
+ (v1, v2, v2_offset),
385
+ (v1, v1_offset, v3),
386
+ (v3_offset, v2, v3),
387
+ (v1, v2, v2_offset_2),
388
+ (v1, v1_offset_2, v3),
389
+ (v3_offset_2, v2, v3),
390
+ ])
391
+
392
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
393
+ return np.array(faces_list)
394
+
395
+
396
+ def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray:
397
+ """Computes faces for a camera mesh with multiple shells (for thicker frustums).
398
+
399
+ Connects each consecutive pair of vertex shells to form the frustum edges.
400
+ """
401
+ faces_list = []
402
+ nv = len(cone_shape.vertices)
403
+
404
+ for s in range(num_shells - 1):
405
+ off_a = s * nv
406
+ off_b = (s + 1) * nv
407
+ for face in cone_shape.faces:
408
+ if 0 in face:
409
+ continue
410
+ v1, v2, v3 = face
411
+ faces_list.extend([
412
+ (v1 + off_a, v2 + off_a, v2 + off_b),
413
+ (v1 + off_a, v1 + off_b, v3 + off_a),
414
+ (v3 + off_b, v2 + off_a, v3 + off_a),
415
+ ])
416
+
417
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
418
+ return np.array(faces_list)
419
+
420
+
421
+ def segment_sky(
422
+ image_path: str,
423
+ onnx_session,
424
+ mask_filename: str
425
+ ) -> np.ndarray:
426
+ """
427
+ Segments sky from an image using an ONNX model.
428
+
429
+ Args:
430
+ image_path: Path to input image
431
+ onnx_session: ONNX runtime session with loaded model
432
+ mask_filename: Path to save the output mask
433
+
434
+ Returns:
435
+ Continuous non-sky confidence map in [0, 1]
436
+ """
437
+ image = cv2.imread(image_path)
438
+ result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image)
439
+ result_map_original = cv2.resize(
440
+ result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR
441
+ )
442
+ output_mask = _result_map_to_non_sky_conf(result_map_original)
443
+
444
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
445
+ cv2.imwrite(mask_filename, _mask_to_uint8(output_mask))
446
+ return output_mask
447
+
448
+
449
+ def run_skyseg(
450
+ onnx_session,
451
+ input_size: Tuple[int, int],
452
+ image: np.ndarray
453
+ ) -> np.ndarray:
454
+ """
455
+ Runs sky segmentation inference using ONNX model.
456
+
457
+ Args:
458
+ onnx_session: ONNX runtime session
459
+ input_size: Target size for model input (width, height)
460
+ image: Input image in BGR format
461
+
462
+ Returns:
463
+ Segmentation mask
464
+ """
465
+ temp_image = copy.deepcopy(image)
466
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
467
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
468
+ x = np.array(x, dtype=np.float32)
469
+ mean = [0.485, 0.456, 0.406]
470
+ std = [0.229, 0.224, 0.225]
471
+ x = (x / 255 - mean) / std
472
+ x = x.transpose(2, 0, 1)
473
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
474
+
475
+ input_name = onnx_session.get_inputs()[0].name
476
+ output_name = onnx_session.get_outputs()[0].name
477
+ onnx_result = onnx_session.run([output_name], {input_name: x})
478
+
479
+ onnx_result = np.array(onnx_result).squeeze()
480
+ min_value = np.min(onnx_result)
481
+ max_value = np.max(onnx_result)
482
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
483
+ onnx_result *= 255
484
+ return onnx_result.astype("uint8")
485
+
486
+
487
+ def download_file_from_url(url: str, filename: str):
488
+ """Downloads a file from a URL, handling redirects."""
489
+ import requests
490
+
491
+ try:
492
+ response = requests.get(url, allow_redirects=False)
493
+ response.raise_for_status()
494
+
495
+ if response.status_code == 302:
496
+ redirect_url = response.headers["Location"]
497
+ response = requests.get(redirect_url, stream=True)
498
+ response.raise_for_status()
499
+ else:
500
+ print(f"Unexpected status code: {response.status_code}")
501
+ return
502
+
503
+ with open(filename, "wb") as f:
504
+ for chunk in response.iter_content(chunk_size=8192):
505
+ f.write(chunk)
506
+ print(f"Downloaded {filename} successfully.")
507
+
508
+ except requests.exceptions.RequestException as e:
509
+ print(f"Error downloading file: {e}")
lingbot_map/vis/point_cloud_viewer.py ADDED
@@ -0,0 +1,1437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Interactive 3D Point Cloud Viewer using Viser.
9
+
10
+ This module provides the PointCloudViewer class for visualizing 3D reconstruction results,
11
+ including point clouds, camera poses, and animated playback.
12
+ """
13
+
14
+ import os
15
+ import time
16
+ import threading
17
+ import subprocess
18
+ import tempfile
19
+ import shutil
20
+ from typing import List, Optional, Dict, Any, Tuple
21
+
22
+ import numpy as np
23
+ import torch
24
+ import cv2
25
+ import matplotlib.cm as cm
26
+ from tqdm.auto import tqdm
27
+
28
+ import viser
29
+ import viser.transforms as tf
30
+
31
+ from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
32
+ from lingbot_map.vis.utils import CameraState
33
+ from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
34
+
35
+
36
+ class PointCloudViewer:
37
+ """
38
+ Interactive 3D point cloud viewer with camera visualization.
39
+
40
+ Features:
41
+ - Point cloud visualization with confidence-based filtering
42
+ - Camera frustum visualization with gradient colors
43
+ - Frame-by-frame playback animation (3D/4D modes)
44
+ - Range-based and recent-N-frames visualization modes
45
+ - Video export with FFmpeg
46
+
47
+ Args:
48
+ model: Optional model for interactive inference
49
+ state_args: Optional state arguments
50
+ pc_list: List of point clouds per frame
51
+ color_list: List of colors per frame
52
+ conf_list: List of confidence scores per frame
53
+ cam_dict: Camera dictionary with focal, pp, R, t
54
+ image_mask: Optional image mask
55
+ edge_color_list: Optional edge colors
56
+ device: Device for computation
57
+ port: Viser server port
58
+ show_camera: Whether to show camera frustums
59
+ vis_threshold: Visibility threshold for filtering
60
+ size: Image size
61
+ downsample_factor: Point cloud downsample factor
62
+ point_size: Initial point size
63
+ pred_dict: Prediction dictionary (alternative to pc_list/color_list/conf_list)
64
+ init_conf_threshold: Initial confidence threshold percentage
65
+ use_point_map: Use point map instead of depth-based points
66
+ mask_sky: Apply sky segmentation
67
+ image_folder: Path to image folder (for sky segmentation)
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ model=None,
73
+ state_args=None,
74
+ pc_list=None,
75
+ color_list=None,
76
+ conf_list=None,
77
+ cam_dict=None,
78
+ image_mask=None,
79
+ edge_color_list=None,
80
+ device: str = "cpu",
81
+ port: int = 8080,
82
+ show_camera: bool = True,
83
+ vis_threshold: float = 1.0,
84
+ size: int = 512,
85
+ downsample_factor: int = 10,
86
+ point_size: float = 0.00001,
87
+ pred_dict: Optional[Dict] = None,
88
+ init_conf_threshold: float = 50.0,
89
+ use_point_map: bool = False,
90
+ mask_sky: bool = False,
91
+ image_folder: Optional[str] = None,
92
+ sky_mask_dir: Optional[str] = None,
93
+ sky_mask_visualization_dir: Optional[str] = None,
94
+ depth_stride: int = 1,
95
+ ):
96
+ self.model = model
97
+ self.size = size
98
+ self.state_args = state_args
99
+ self.server = viser.ViserServer(host="0.0.0.0", port=port)
100
+ self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
101
+ self.device = device
102
+ self.conf_list = conf_list
103
+ self.vis_threshold = vis_threshold
104
+ self.point_size = point_size
105
+ self.tt = lambda x: torch.from_numpy(x).float().to(device)
106
+
107
+ # Process the prediction dictionary to create pc_list, color_list, conf_list
108
+ if pred_dict is not None:
109
+ pc_list, color_list, conf_list, cam_dict = self._process_pred_dict(
110
+ pred_dict, use_point_map, mask_sky, image_folder,
111
+ sky_mask_dir=sky_mask_dir,
112
+ sky_mask_visualization_dir=sky_mask_visualization_dir,
113
+ depth_stride=depth_stride,
114
+ )
115
+ else:
116
+ self.original_images = []
117
+
118
+ self.pcs, self.all_steps = self.read_data(
119
+ pc_list, color_list, conf_list, edge_color_list
120
+ )
121
+ self.cam_dict = cam_dict
122
+ self.num_frames = len(self.all_steps)
123
+ self.image_mask = image_mask
124
+ self.show_camera = show_camera
125
+ self.on_replay = False
126
+ self.vis_pts_list = []
127
+ self.traj_list = []
128
+ self.orig_img_list = [x[0] for x in color_list if len(x) > 0] if color_list else []
129
+ self.via_points = []
130
+
131
+ self._setup_gui()
132
+ self.server.on_client_connect(self._connect_client)
133
+
134
+ def _process_pred_dict(
135
+ self,
136
+ pred_dict: Dict,
137
+ use_point_map: bool,
138
+ mask_sky: bool,
139
+ image_folder: Optional[str],
140
+ sky_mask_dir: Optional[str] = None,
141
+ sky_mask_visualization_dir: Optional[str] = None,
142
+ depth_stride: int = 1,
143
+ ) -> Tuple[List, List, List, Dict]:
144
+ """Process prediction dictionary to extract visualization data.
145
+
146
+ Args:
147
+ pred_dict: Model prediction dictionary.
148
+ use_point_map: Use point map instead of depth-based projection.
149
+ mask_sky: Apply sky segmentation to filter sky points.
150
+ image_folder: Path to images for sky segmentation.
151
+ sky_mask_dir: Directory for cached sky masks.
152
+ sky_mask_visualization_dir: Directory for sky mask visualization images.
153
+ depth_stride: Only project depth to point cloud every N frames.
154
+ Frames not projected will have empty point clouds but still
155
+ show camera frustums and images. 1 = every frame (default).
156
+ """
157
+ images = pred_dict["images"] # (S, 3, H, W)
158
+
159
+ depth_map = pred_dict.get("depth") # (S, H, W, 1)
160
+ depth_conf = pred_dict.get("depth_conf") # (S, H, W)
161
+
162
+ extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
163
+ intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
164
+
165
+ # Compute world points from depth if not using the precomputed point map
166
+ if not use_point_map:
167
+ world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
168
+ conf = depth_conf
169
+ else:
170
+ world_points = pred_dict["world_points"] # (S, H, W, 3)
171
+ conf = pred_dict.get("world_points_conf", depth_conf) # (S, H, W)
172
+
173
+ # Apply sky segmentation if enabled
174
+ if mask_sky:
175
+ conf = apply_sky_segmentation(
176
+ conf, image_folder=image_folder, images=images,
177
+ sky_mask_dir=sky_mask_dir,
178
+ sky_mask_visualization_dir=sky_mask_visualization_dir,
179
+ )
180
+
181
+ # Convert images from (S, 3, H, W) to (S, H, W, 3)
182
+ colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
183
+ S = world_points.shape[0]
184
+
185
+ # Store original images for camera frustum display
186
+ self.original_images = []
187
+ for i in range(S):
188
+ img = images[i] # shape (3, H, W)
189
+ img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
190
+ self.original_images.append(img)
191
+
192
+ # Create lists - apply depth_stride to skip frames for point projection
193
+ H, W = world_points.shape[1], world_points.shape[2]
194
+ pc_list = []
195
+ color_list = []
196
+ conf_list = []
197
+ skipped = 0
198
+ for i in range(S):
199
+ if depth_stride > 1 and i % depth_stride != 0:
200
+ # Empty point cloud for skipped frames
201
+ pc_list.append(np.zeros((0, 0, 3), dtype=np.float32))
202
+ color_list.append(np.zeros((0, 0, 3), dtype=np.float32))
203
+ conf_list.append(np.zeros((0, 0), dtype=np.float32))
204
+ skipped += 1
205
+ else:
206
+ pc_list.append(world_points[i])
207
+ color_list.append(colors[i])
208
+ if conf is not None:
209
+ conf_list.append(conf[i])
210
+ else:
211
+ conf_list.append(np.ones(world_points[i].shape[:2], dtype=np.float32))
212
+
213
+ if depth_stride > 1:
214
+ print(f' depth_stride={depth_stride}: projecting {S - skipped}/{S} frames, skipping {skipped}')
215
+
216
+ # Create camera dictionary (all frames keep cameras)
217
+ cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
218
+ cam_dict = {
219
+ "focal": [intrinsics_cam[i, 0, 0] for i in range(S)],
220
+ "pp": [(intrinsics_cam[i, 0, 2], intrinsics_cam[i, 1, 2]) for i in range(S)],
221
+ "R": [cam_to_world_mat[i, :3, :3] for i in range(S)],
222
+ "t": [cam_to_world_mat[i, :3, 3] for i in range(S)],
223
+ }
224
+
225
+ return pc_list, color_list, conf_list, cam_dict
226
+
227
+ def _compute_scene_center_and_scale(self) -> Tuple[np.ndarray, float]:
228
+ """Compute scene center and scale from camera positions and point clouds.
229
+
230
+ Returns:
231
+ Tuple of (center as 3D array, scale as float distance).
232
+ """
233
+ # Use camera positions as primary reference (more reliable than noisy points)
234
+ if self.cam_dict is not None and "t" in self.cam_dict:
235
+ cam_positions = np.array([self.cam_dict["t"][s] for s in self.all_steps])
236
+ center = np.mean(cam_positions, axis=0)
237
+ if len(cam_positions) > 1:
238
+ extent = np.ptp(cam_positions, axis=0) # range per axis
239
+ scale = np.linalg.norm(extent)
240
+ else:
241
+ scale = 1.0
242
+ else:
243
+ # Fallback: use point cloud data
244
+ all_pts = []
245
+ for step in self.all_steps:
246
+ pc = self.pcs[step]["pc"].reshape(-1, 3)
247
+ # subsample for speed
248
+ if len(pc) > 1000:
249
+ pc = pc[::len(pc) // 1000]
250
+ all_pts.append(pc)
251
+ all_pts = np.concatenate(all_pts, axis=0)
252
+ center = np.median(all_pts, axis=0)
253
+ extent = np.percentile(all_pts, 95, axis=0) - np.percentile(all_pts, 5, axis=0)
254
+ scale = np.linalg.norm(extent)
255
+
256
+ return center, max(scale, 0.1)
257
+
258
+ def _reset_view_to_direction(
259
+ self,
260
+ direction: np.ndarray,
261
+ up: np.ndarray = np.array([0.0, -1.0, 0.0]),
262
+ distance_scale: float = 1.5,
263
+ smooth: bool = True,
264
+ ):
265
+ """Reset the viewer camera to look at scene center from a given direction.
266
+
267
+ Args:
268
+ direction: Unit vector pointing FROM the scene center TO the camera.
269
+ up: Up vector for the camera.
270
+ distance_scale: Multiplier on scene scale for camera distance.
271
+ smooth: Whether to smoothly transition.
272
+ """
273
+ center, scale = self._compute_scene_center_and_scale()
274
+ distance = scale * distance_scale
275
+ position = center + direction * distance
276
+
277
+ for client in self.server.get_clients().values():
278
+ if smooth:
279
+ self._smooth_camera_transition(
280
+ client,
281
+ target_position=position,
282
+ target_look_at=center,
283
+ target_up=up,
284
+ duration=0.4,
285
+ )
286
+ else:
287
+ client.camera.up_direction = tuple(up)
288
+ client.camera.position = tuple(position)
289
+ client.camera.look_at = tuple(center)
290
+
291
+ def _setup_gui(self):
292
+ """Setup GUI controls."""
293
+ gui_reset_up = self.server.gui.add_button(
294
+ "Reset up direction",
295
+ hint="Set the camera control 'up' direction to the current camera's 'up'.",
296
+ )
297
+
298
+ @gui_reset_up.on_click
299
+ def _(event: viser.GuiEvent) -> None:
300
+ client = event.client
301
+ assert client is not None
302
+ client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
303
+ [0.0, -1.0, 0.0]
304
+ )
305
+
306
+ # Video frame display controls — kept at top so the current frame is always visible
307
+ with self.server.gui.add_folder("Video Display"):
308
+ self.show_video_checkbox = self.server.gui.add_checkbox("Show Current Frame", initial_value=True)
309
+ if hasattr(self, 'original_images') and len(self.original_images) > 0:
310
+ self.current_frame_image = self.server.gui.add_image(
311
+ self.original_images[0], label="Current Frame"
312
+ )
313
+ else:
314
+ self.current_frame_image = None
315
+
316
+ # Preset view direction buttons
317
+ with self.server.gui.add_folder("Reset View Direction"):
318
+ btn_look_at_center = self.server.gui.add_button(
319
+ "Look At Scene Center",
320
+ hint="Reset orbit center to the scene center (fixes orbit after dragging).",
321
+ )
322
+ btn_overview = self.server.gui.add_button(
323
+ "Overview",
324
+ hint="Reset to a 3/4 overview of the scene.",
325
+ )
326
+ btn_front = self.server.gui.add_button(
327
+ "Front (+Z)",
328
+ hint="View scene from the front.",
329
+ )
330
+ btn_back = self.server.gui.add_button(
331
+ "Back (-Z)",
332
+ hint="View scene from the back.",
333
+ )
334
+ btn_top = self.server.gui.add_button(
335
+ "Top (-Y)",
336
+ hint="View scene from above (bird's eye).",
337
+ )
338
+ btn_left = self.server.gui.add_button(
339
+ "Left (-X)",
340
+ hint="View scene from the left.",
341
+ )
342
+ btn_right = self.server.gui.add_button(
343
+ "Right (+X)",
344
+ hint="View scene from the right.",
345
+ )
346
+ btn_first_cam = self.server.gui.add_button(
347
+ "First Camera",
348
+ hint="Reset to the first camera's viewpoint.",
349
+ )
350
+
351
+ @btn_look_at_center.on_click
352
+ def _(_) -> None:
353
+ center, _ = self._compute_scene_center_and_scale()
354
+ for client in self.server.get_clients().values():
355
+ client.camera.look_at = tuple(center)
356
+
357
+ @btn_overview.on_click
358
+ def _(_) -> None:
359
+ d = np.array([0.5, -0.6, 0.6])
360
+ self._reset_view_to_direction(d / np.linalg.norm(d))
361
+
362
+ @btn_front.on_click
363
+ def _(_) -> None:
364
+ self._reset_view_to_direction(np.array([0.0, 0.0, 1.0]))
365
+
366
+ @btn_back.on_click
367
+ def _(_) -> None:
368
+ self._reset_view_to_direction(np.array([0.0, 0.0, -1.0]))
369
+
370
+ @btn_top.on_click
371
+ def _(_) -> None:
372
+ self._reset_view_to_direction(
373
+ np.array([0.0, -1.0, 0.0]),
374
+ up=np.array([0.0, 0.0, 1.0]),
375
+ )
376
+
377
+ @btn_left.on_click
378
+ def _(_) -> None:
379
+ self._reset_view_to_direction(np.array([-1.0, 0.0, 0.0]))
380
+
381
+ @btn_right.on_click
382
+ def _(_) -> None:
383
+ self._reset_view_to_direction(np.array([1.0, 0.0, 0.0]))
384
+
385
+ @btn_first_cam.on_click
386
+ def _(_) -> None:
387
+ self._move_to_camera(0, smooth=True)
388
+
389
+ button3 = self.server.gui.add_button("4D (Only Show Current Frame)")
390
+ button4 = self.server.gui.add_button("3D (Show All Frames)")
391
+ self.is_render = False
392
+ self.fourd = False
393
+
394
+ @button3.on_click
395
+ def _(event: viser.GuiEvent) -> None:
396
+ self.fourd = True
397
+
398
+ @button4.on_click
399
+ def _(event: viser.GuiEvent) -> None:
400
+ self.fourd = False
401
+
402
+ self.focal_slider = self.server.gui.add_slider(
403
+ "Focal Length", min=0.1, max=99999, step=1, initial_value=533
404
+ )
405
+ self.psize_slider = self.server.gui.add_slider(
406
+ "Point Size", min=0.00001, max=0.1, step=0.00001, initial_value=self.point_size
407
+ )
408
+ self.camsize_slider = self.server.gui.add_slider(
409
+ "Camera Size", min=0.01, max=0.5, step=0.01, initial_value=0.1
410
+ )
411
+ self.downsample_slider = self.server.gui.add_slider(
412
+ "Downsample Factor", min=1, max=1000, step=1, initial_value=10
413
+ )
414
+ self.show_camera_checkbox = self.server.gui.add_checkbox(
415
+ "Show Camera", initial_value=self.show_camera
416
+ )
417
+ self.vis_threshold_slider = self.server.gui.add_slider(
418
+ "Visibility Threshold", min=1.0, max=5.0, step=0.01,
419
+ initial_value=self.vis_threshold,
420
+ )
421
+ self.camera_downsample_slider = self.server.gui.add_slider(
422
+ "Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
423
+ )
424
+
425
+ # Screenshot controls
426
+ with self.server.gui.add_folder("Screenshot"):
427
+ self.screenshot_button = self.server.gui.add_button("Take Screenshot")
428
+ self.screenshot_resolution = self.server.gui.add_dropdown(
429
+ "Resolution",
430
+ options=["1920x1080", "2560x1440", "3840x2160", "Current"],
431
+ initial_value="1920x1080",
432
+ )
433
+ self.screenshot_path = self.server.gui.add_text(
434
+ "Save Path", initial_value="screenshot.png"
435
+ )
436
+ self.screenshot_status = self.server.gui.add_text(
437
+ "Status", initial_value="Ready"
438
+ )
439
+
440
+ @self.screenshot_button.on_click
441
+ def _(event: viser.GuiEvent) -> None:
442
+ self._take_screenshot(event.client)
443
+
444
+ # GLB export controls
445
+ with self.server.gui.add_folder("Export GLB"):
446
+ self.glb_output_path = self.server.gui.add_text(
447
+ "Output Path", initial_value="export.glb"
448
+ )
449
+ self.glb_show_cam_checkbox = self.server.gui.add_checkbox(
450
+ "Include Cameras", initial_value=True,
451
+ )
452
+ self.glb_cam_scale_slider = self.server.gui.add_slider(
453
+ "Camera Scale", min=0.01, max=5.0, step=0.01, initial_value=1.0,
454
+ hint="Scale factor for camera size in GLB.",
455
+ )
456
+ self.glb_frustum_thickness_slider = self.server.gui.add_slider(
457
+ "Frustum Thickness", min=1.0, max=10.0, step=0.5, initial_value=3.0,
458
+ hint="Thickness multiplier for camera frustum edges.",
459
+ )
460
+ self.glb_trajectory_checkbox = self.server.gui.add_checkbox(
461
+ "Show Trajectory", initial_value=True,
462
+ hint="Connect cameras with a trajectory line.",
463
+ )
464
+ self.glb_trajectory_radius_slider = self.server.gui.add_slider(
465
+ "Trajectory Radius", min=0.001, max=0.05, step=0.001, initial_value=0.005,
466
+ hint="Radius of the trajectory tube.",
467
+ )
468
+ self.glb_mode_dropdown = self.server.gui.add_dropdown(
469
+ "Export Mode",
470
+ options=["Points", "Spheres"],
471
+ initial_value="Points",
472
+ hint="Points: raw (fast). Spheres: each point becomes a small sphere (prettier, slower).",
473
+ )
474
+ self.glb_sphere_radius_slider = self.server.gui.add_slider(
475
+ "Sphere Radius", min=0.001, max=0.1, step=0.001, initial_value=0.005,
476
+ hint="Radius of each sphere in Spheres mode.",
477
+ disabled=True,
478
+ )
479
+ self.glb_max_sphere_pts_slider = self.server.gui.add_slider(
480
+ "Max Sphere Points", min=10000, max=500000, step=10000, initial_value=100000,
481
+ hint="Cap point count for Spheres mode to keep file size manageable.",
482
+ disabled=True,
483
+ )
484
+ self.glb_opacity_slider = self.server.gui.add_slider(
485
+ "Opacity", min=0.0, max=1.0, step=0.05, initial_value=1.0,
486
+ hint="Point/sphere opacity (alpha). <1.0 = semi-transparent.",
487
+ )
488
+ self.glb_saturation_slider = self.server.gui.add_slider(
489
+ "Saturation Boost", min=0.0, max=2.0, step=0.1, initial_value=1.0,
490
+ hint="Color saturation multiplier. >1 = more vivid, <1 = washed out.",
491
+ )
492
+ self.glb_brightness_slider = self.server.gui.add_slider(
493
+ "Brightness Boost", min=0.5, max=2.0, step=0.1, initial_value=1.0,
494
+ hint="Color brightness multiplier.",
495
+ )
496
+ self.glb_export_button = self.server.gui.add_button(
497
+ "Export GLB",
498
+ hint="Export current filtered point clouds and cameras as GLB.",
499
+ )
500
+ self.glb_status = self.server.gui.add_text("Status", initial_value="Ready")
501
+
502
+ @self.glb_mode_dropdown.on_update
503
+ def _(_) -> None:
504
+ is_sphere = self.glb_mode_dropdown.value == "Spheres"
505
+ self.glb_sphere_radius_slider.disabled = not is_sphere
506
+ self.glb_max_sphere_pts_slider.disabled = not is_sphere
507
+
508
+ @self.glb_export_button.on_click
509
+ def _(_) -> None:
510
+ self._export_glb()
511
+
512
+ # Video saving controls
513
+ with self.server.gui.add_folder("Video Saving"):
514
+ self.save_video_button = self.server.gui.add_button("Save Video", disabled=False)
515
+ self.video_output_path = self.server.gui.add_text("Output Path", initial_value="output_pointcloud.mp4")
516
+ self.video_save_fps = self.server.gui.add_slider("Video FPS", min=10, max=60, step=1, initial_value=30)
517
+ self.video_resolution = self.server.gui.add_dropdown(
518
+ "Resolution", options=["1920x1080", "1280x720", "3840x2160"], initial_value="1920x1080"
519
+ )
520
+ self.save_original_video_checkbox = self.server.gui.add_checkbox("Also Save Original Video", initial_value=True)
521
+ self.video_status = self.server.gui.add_text("Status", initial_value="Ready to save")
522
+
523
+ @self.save_video_button.on_click
524
+ def _(_) -> None:
525
+ self.save_video(
526
+ output_path=self.video_output_path.value,
527
+ fps=self.video_save_fps.value,
528
+ resolution=self.video_resolution.value,
529
+ save_original_video=self.save_original_video_checkbox.value
530
+ )
531
+
532
+ @self.show_video_checkbox.on_update
533
+ def _(_) -> None:
534
+ if self.current_frame_image is not None:
535
+ self.current_frame_image.visible = self.show_video_checkbox.value
536
+
537
+ self.pc_handles = []
538
+ self.cam_handles = []
539
+
540
+ @self.psize_slider.on_update
541
+ def _(_) -> None:
542
+ for handle in self.pc_handles:
543
+ handle.point_size = self.psize_slider.value
544
+
545
+ @self.camsize_slider.on_update
546
+ def _(_) -> None:
547
+ for handle in self.cam_handles:
548
+ handle.scale = self.camsize_slider.value
549
+ handle.line_thickness = 0.03 * handle.scale
550
+
551
+ @self.downsample_slider.on_update
552
+ def _(_) -> None:
553
+ self._regenerate_point_clouds()
554
+
555
+ @self.show_camera_checkbox.on_update
556
+ def _(_) -> None:
557
+ self.show_camera = self.show_camera_checkbox.value
558
+ if self.show_camera:
559
+ self._regenerate_cameras()
560
+ else:
561
+ for handle in self.cam_handles:
562
+ handle.visible = False
563
+
564
+ @self.vis_threshold_slider.on_update
565
+ def _(_) -> None:
566
+ self.vis_threshold = self.vis_threshold_slider.value
567
+ self._regenerate_point_clouds()
568
+
569
+ @self.camera_downsample_slider.on_update
570
+ def _(_) -> None:
571
+ self._regenerate_cameras()
572
+
573
+ def _regenerate_point_clouds(self):
574
+ """Regenerate all point clouds with current settings."""
575
+ if not hasattr(self, 'frame_nodes'):
576
+ return
577
+
578
+ for handle in self.pc_handles:
579
+ try:
580
+ handle.remove()
581
+ except (KeyError, AttributeError):
582
+ pass
583
+ self.pc_handles.clear()
584
+ self.vis_pts_list.clear()
585
+
586
+ for i, step in enumerate(self.all_steps):
587
+ pc = self.pcs[step]["pc"]
588
+ color = self.pcs[step]["color"]
589
+ conf = self.pcs[step]["conf"]
590
+ edge_color = self.pcs[step].get("edge_color", None)
591
+
592
+ pred_pts, pc_color = self.parse_pc_data(
593
+ pc, color, conf, edge_color, set_border_color=True,
594
+ downsample_factor=self.downsample_slider.value
595
+ )
596
+
597
+ self.vis_pts_list.append(pred_pts)
598
+ handle = self.server.scene.add_point_cloud(
599
+ name=f"/frames/{step}/pred_pts",
600
+ points=pred_pts,
601
+ colors=pc_color,
602
+ point_size=self.psize_slider.value,
603
+ )
604
+ self.pc_handles.append(handle)
605
+
606
+ def _regenerate_cameras(self):
607
+ """Regenerate camera visualizations with current settings."""
608
+ if not hasattr(self, 'frame_nodes'):
609
+ return
610
+
611
+ for handle in self.cam_handles:
612
+ try:
613
+ handle.remove()
614
+ except (KeyError, AttributeError):
615
+ pass
616
+ self.cam_handles.clear()
617
+
618
+ if self.show_camera:
619
+ downsample_factor = int(self.camera_downsample_slider.value)
620
+ for i, step in enumerate(self.all_steps):
621
+ if i % downsample_factor == 0:
622
+ self.add_camera(step)
623
+
624
+ def _export_glb(self):
625
+ """Export current filtered point clouds and cameras as a GLB file."""
626
+ try:
627
+ import trimesh
628
+ except ImportError:
629
+ self.glb_status.value = "Error: pip install trimesh"
630
+ return
631
+
632
+ self.glb_status.value = "Collecting points..."
633
+ print("Exporting GLB...")
634
+
635
+ # Collect all currently visible, filtered points and colors
636
+ all_points = []
637
+ all_colors = []
638
+ for step in self.all_steps:
639
+ pc = self.pcs[step]["pc"]
640
+ color = self.pcs[step]["color"]
641
+ conf = self.pcs[step]["conf"]
642
+ edge_color = self.pcs[step].get("edge_color", None)
643
+
644
+ pts, cols = self.parse_pc_data(
645
+ pc, color, conf, edge_color, set_border_color=False,
646
+ downsample_factor=self.downsample_slider.value,
647
+ )
648
+ if len(pts) > 0:
649
+ all_points.append(pts)
650
+ if cols.dtype != np.uint8:
651
+ cols = (np.clip(cols, 0, 1) * 255).astype(np.uint8)
652
+ all_colors.append(cols)
653
+
654
+ if not all_points:
655
+ self.glb_status.value = "Error: no points to export"
656
+ return
657
+
658
+ vertices = np.concatenate(all_points, axis=0)
659
+ colors_rgb = np.concatenate(all_colors, axis=0)
660
+
661
+ # --- Color enhancement ---
662
+ colors_float = colors_rgb.astype(np.float32) / 255.0
663
+
664
+ sat_boost = self.glb_saturation_slider.value
665
+ if sat_boost != 1.0:
666
+ gray = colors_float.mean(axis=1, keepdims=True)
667
+ colors_float = gray + sat_boost * (colors_float - gray)
668
+
669
+ bri_boost = self.glb_brightness_slider.value
670
+ if bri_boost != 1.0:
671
+ colors_float = colors_float * bri_boost
672
+
673
+ colors_float = np.clip(colors_float, 0.0, 1.0)
674
+
675
+ # --- Opacity ---
676
+ # Simulate opacity by blending colors toward white (works in all viewers).
677
+ # For Spheres mode, also set true alpha for viewers that support it.
678
+ alpha = self.glb_opacity_slider.value
679
+ if alpha < 1.0:
680
+ bg = np.ones_like(colors_float) # white background
681
+ colors_float = colors_float * alpha + bg * (1.0 - alpha)
682
+ colors_float = np.clip(colors_float, 0.0, 1.0)
683
+
684
+ colors_u8 = (colors_float * 255).astype(np.uint8)
685
+ colors_rgba = np.concatenate([
686
+ colors_u8,
687
+ np.full((len(colors_u8), 1), int(alpha * 255), dtype=np.uint8),
688
+ ], axis=1) # (N, 4)
689
+
690
+ # Compute scene scale for camera sizing
691
+ lo = np.percentile(vertices, 5, axis=0)
692
+ hi = np.percentile(vertices, 95, axis=0)
693
+ scene_scale = max(np.linalg.norm(hi - lo), 0.1)
694
+
695
+ scene_3d = trimesh.Scene()
696
+
697
+ # --- Export mode ---
698
+ export_mode = self.glb_mode_dropdown.value
699
+ if export_mode == "Spheres":
700
+ self.glb_status.value = "Building spheres..."
701
+ max_pts = int(self.glb_max_sphere_pts_slider.value)
702
+ radius = self.glb_sphere_radius_slider.value
703
+
704
+ # Subsample if too many points
705
+ if len(vertices) > max_pts:
706
+ idx = np.random.choice(len(vertices), max_pts, replace=False)
707
+ idx.sort()
708
+ vertices = vertices[idx]
709
+ colors_rgba = colors_rgba[idx]
710
+
711
+ sphere_template = trimesh.creation.icosphere(subdivisions=1, radius=radius)
712
+ n_verts_per = len(sphere_template.vertices)
713
+ n_faces_per = len(sphere_template.faces)
714
+
715
+ all_verts = np.empty((len(vertices) * n_verts_per, 3), dtype=np.float32)
716
+ all_faces = np.empty((len(vertices) * n_faces_per, 3), dtype=np.int64)
717
+ all_face_colors = np.empty((len(vertices) * n_faces_per, 4), dtype=np.uint8)
718
+
719
+ for i, (pt, rgba) in enumerate(zip(vertices, colors_rgba)):
720
+ v_off = i * n_verts_per
721
+ f_off = i * n_faces_per
722
+ all_verts[v_off:v_off + n_verts_per] = sphere_template.vertices + pt
723
+ all_faces[f_off:f_off + n_faces_per] = sphere_template.faces + v_off
724
+ all_face_colors[f_off:f_off + n_faces_per] = rgba
725
+
726
+ mesh = trimesh.Trimesh(vertices=all_verts, faces=all_faces)
727
+ mesh.visual.face_colors = all_face_colors
728
+ # Enable alpha blending in glTF material for true transparency
729
+ if alpha < 1.0:
730
+ mesh.visual.material.alphaMode = 'BLEND'
731
+ scene_3d.add_geometry(mesh)
732
+ print(f"Spheres mode: {len(vertices):,} spheres, {len(all_faces):,} faces")
733
+ else:
734
+ # Points mode (GLB viewers ignore alpha on points, so use blended RGB)
735
+ scene_3d.add_geometry(trimesh.PointCloud(vertices=vertices, colors=colors_u8))
736
+
737
+ # Add cameras and trajectory
738
+ if self.glb_show_cam_checkbox.value and self.cam_dict is not None:
739
+ from lingbot_map.vis.glb_export import integrate_camera_into_scene
740
+ import matplotlib
741
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
742
+ num_cameras = len(self.all_steps)
743
+ cam_positions = []
744
+
745
+ frustum_thickness = self.glb_frustum_thickness_slider.value
746
+ effective_cam_scale = scene_scale * self.glb_cam_scale_slider.value
747
+
748
+ for i, step in enumerate(self.all_steps):
749
+ R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3)
750
+ t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3)
751
+
752
+ c2w = np.eye(4)
753
+ c2w[:3, :3] = R
754
+ c2w[:3, 3] = t
755
+ cam_positions.append(np.array(t, dtype=np.float64))
756
+
757
+ rgba_c = colormap(i / max(num_cameras - 1, 1))
758
+ cam_color = tuple(int(255 * x) for x in rgba_c[:3])
759
+ integrate_camera_into_scene(
760
+ scene_3d, c2w, cam_color,
761
+ effective_cam_scale,
762
+ frustum_thickness=frustum_thickness,
763
+ )
764
+
765
+ # Add trajectory line as a tube connecting camera positions
766
+ if self.glb_trajectory_checkbox.value and len(cam_positions) >= 2:
767
+ traj_pts = np.array(cam_positions)
768
+ traj_radius = self.glb_trajectory_radius_slider.value * self.glb_cam_scale_slider.value
769
+ traj_mesh = self._build_trajectory_tube(
770
+ traj_pts, traj_radius, colormap, num_cameras
771
+ )
772
+ if traj_mesh is not None:
773
+ scene_3d.add_geometry(traj_mesh)
774
+
775
+ # Align scene using first camera extrinsic
776
+ if self.cam_dict is not None and len(self.all_steps) > 0:
777
+ from lingbot_map.vis.glb_export import apply_scene_alignment
778
+ step0 = self.all_steps[0]
779
+ R0 = self.cam_dict["R"][step0] if "R" in self.cam_dict else np.eye(3)
780
+ t0 = self.cam_dict["t"][step0] if "t" in self.cam_dict else np.zeros(3)
781
+ c2w_0 = np.eye(4)
782
+ c2w_0[:3, :3] = R0
783
+ c2w_0[:3, 3] = t0
784
+ w2c_0 = np.linalg.inv(c2w_0)
785
+ extrinsics = np.expand_dims(w2c_0, 0)
786
+ scene_3d = apply_scene_alignment(scene_3d, extrinsics)
787
+
788
+ output_path = self.glb_output_path.value
789
+ scene_3d.export(output_path)
790
+
791
+ n_pts = len(vertices)
792
+ mode_str = f"spheres r={self.glb_sphere_radius_slider.value}" if export_mode == "Spheres" else "points"
793
+ self.glb_status.value = f"Saved: {output_path} ({n_pts:,} {mode_str})"
794
+ print(f"GLB exported to {output_path} ({n_pts:,} {mode_str})")
795
+
796
+ @staticmethod
797
+ def _build_trajectory_tube(positions, radius, colormap, num_cameras):
798
+ """Build a tube mesh following camera trajectory with per-segment color.
799
+
800
+ Args:
801
+ positions: (N, 3) camera positions.
802
+ radius: Tube radius.
803
+ colormap: Matplotlib colormap for gradient coloring.
804
+ num_cameras: Total number of cameras (for color normalization).
805
+
806
+ Returns:
807
+ trimesh.Trimesh or None.
808
+ """
809
+ import trimesh
810
+
811
+ segments = []
812
+ for i in range(len(positions) - 1):
813
+ p0, p1 = positions[i], positions[i + 1]
814
+ seg_len = np.linalg.norm(p1 - p0)
815
+ if seg_len < 1e-8:
816
+ continue
817
+
818
+ # Create cylinder along Z, then transform
819
+ cyl = trimesh.creation.cylinder(radius=radius, height=seg_len, sections=8)
820
+
821
+ # Direction vector
822
+ direction = (p1 - p0) / seg_len
823
+ mid = (p0 + p1) / 2.0
824
+
825
+ # Build rotation: default cylinder is along Z
826
+ z_axis = np.array([0.0, 0.0, 1.0])
827
+ v = np.cross(z_axis, direction)
828
+ c = np.dot(z_axis, direction)
829
+
830
+ if np.linalg.norm(v) < 1e-8:
831
+ rot = np.eye(3) if c > 0 else np.diag([1, -1, -1])
832
+ else:
833
+ vx = np.array([[0, -v[2], v[1]],
834
+ [v[2], 0, -v[0]],
835
+ [-v[1], v[0], 0]])
836
+ rot = np.eye(3) + vx + vx @ vx / (1.0 + c)
837
+
838
+ transform = np.eye(4)
839
+ transform[:3, :3] = rot
840
+ transform[:3, 3] = mid
841
+ cyl.apply_transform(transform)
842
+
843
+ # Color: midpoint index
844
+ t_color = (i + 0.5) / max(num_cameras - 1, 1)
845
+ rgba = colormap(t_color)
846
+ color_rgb = tuple(int(255 * x) for x in rgba[:3])
847
+ cyl.visual.face_colors[:, :3] = color_rgb
848
+ segments.append(cyl)
849
+
850
+ if not segments:
851
+ return None
852
+ return trimesh.util.concatenate(segments)
853
+
854
+ def update_frame_visibility(self):
855
+ """Show all frames up to the current timestep (or only the current one in 4D mode)."""
856
+ if not hasattr(self, 'frame_nodes') or not hasattr(self, 'gui_timestep'):
857
+ return
858
+
859
+ current_timestep = self.gui_timestep.value
860
+ for i, frame_node in enumerate(self.frame_nodes):
861
+ frame_node.visible = (
862
+ i <= current_timestep if not self.fourd else i == current_timestep
863
+ )
864
+
865
+ def _move_to_camera(self, frame_idx: int, smooth: bool = True):
866
+ """Move viewer camera to match reconstructed camera at given frame."""
867
+ if self.cam_dict is None:
868
+ return
869
+
870
+ step = self.all_steps[frame_idx] if frame_idx < len(self.all_steps) else self.all_steps[-1]
871
+
872
+ R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3)
873
+ t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3)
874
+ focal = self.cam_dict["focal"][step] if "focal" in self.cam_dict else 1.0
875
+ pp = self.cam_dict["pp"][step] if "pp" in self.cam_dict else (1.0, 1.0)
876
+
877
+ offset = 0.5
878
+ viewing_dir = R[:, 2] # camera Z axis in world frame
879
+ position = t - viewing_dir * offset
880
+ look_at = t + viewing_dir * 0.5 # look slightly ahead of camera
881
+
882
+ fov = 2 * np.arctan(pp[0] / focal)
883
+ up = -R[:, 1] # camera -Y axis in world frame
884
+
885
+ for client in self.server.get_clients().values():
886
+ if smooth:
887
+ self._smooth_camera_transition(
888
+ client,
889
+ target_position=position,
890
+ target_look_at=look_at,
891
+ target_up=up,
892
+ target_fov=fov,
893
+ duration=0.3,
894
+ )
895
+ else:
896
+ client.camera.up_direction = tuple(up)
897
+ client.camera.position = tuple(position)
898
+ client.camera.look_at = tuple(look_at)
899
+ if fov is not None:
900
+ client.camera.fov = fov
901
+
902
+ def _smooth_camera_transition(
903
+ self,
904
+ client,
905
+ target_position,
906
+ target_look_at=None,
907
+ target_up=None,
908
+ target_fov=None,
909
+ duration=0.3,
910
+ ):
911
+ """Smoothly transition camera to target pose using look_at based control.
912
+
913
+ Args:
914
+ client: Viser client handle.
915
+ target_position: Target camera position (3,).
916
+ target_look_at: Target look-at point (3,). If None, keeps current.
917
+ target_up: Target up direction (3,). If None, keeps current.
918
+ target_fov: Target FOV. If None, keeps current.
919
+ duration: Transition duration in seconds.
920
+ """
921
+ def interpolate():
922
+ num_steps = 15
923
+ dt = duration / num_steps
924
+
925
+ start_position = np.array(client.camera.position, dtype=np.float64)
926
+ start_look_at = np.array(client.camera.look_at, dtype=np.float64)
927
+ start_fov = client.camera.fov
928
+
929
+ end_position = np.asarray(target_position, dtype=np.float64)
930
+ end_look_at = np.asarray(target_look_at, dtype=np.float64) if target_look_at is not None else start_look_at
931
+
932
+ # Set up direction once at the start (not interpolated to avoid flicker)
933
+ if target_up is not None:
934
+ client.camera.up_direction = tuple(np.asarray(target_up, dtype=np.float64))
935
+
936
+ for i in range(num_steps + 1):
937
+ alpha = i / num_steps
938
+ # Smooth ease-in-out
939
+ alpha_smooth = alpha * alpha * (3 - 2 * alpha)
940
+
941
+ interp_pos = start_position + (end_position - start_position) * alpha_smooth
942
+ interp_look = start_look_at + (end_look_at - start_look_at) * alpha_smooth
943
+
944
+ # Set position first (this auto-moves look_at), then override look_at
945
+ client.camera.position = tuple(interp_pos)
946
+ client.camera.look_at = tuple(interp_look)
947
+
948
+ if target_fov is not None:
949
+ interp_fov = start_fov + (target_fov - start_fov) * alpha_smooth
950
+ client.camera.fov = interp_fov
951
+
952
+ time.sleep(dt)
953
+
954
+ thread = threading.Thread(target=interpolate, daemon=True)
955
+ thread.start()
956
+
957
+ def _slerp(self, q1, q2, t):
958
+ """Spherical linear interpolation between quaternions."""
959
+ dot = np.dot(q1, q2)
960
+
961
+ if abs(dot) > 0.9995:
962
+ result = q1 + t * (q2 - q1)
963
+ return result / np.linalg.norm(result)
964
+
965
+ dot = np.clip(dot, -1.0, 1.0)
966
+ theta_0 = np.arccos(dot)
967
+ theta = theta_0 * t
968
+
969
+ q2_orthogonal = q2 - q1 * dot
970
+ q2_orthogonal = q2_orthogonal / np.linalg.norm(q2_orthogonal)
971
+
972
+ return q1 * np.cos(theta) + q2_orthogonal * np.sin(theta)
973
+
974
+ def get_camera_state(self, client: viser.ClientHandle) -> CameraState:
975
+ """Get current camera state from client."""
976
+ camera = client.camera
977
+ c2w = np.concatenate([
978
+ np.concatenate([tf.SO3(camera.wxyz).as_matrix(), camera.position[:, None]], 1),
979
+ [[0, 0, 0, 1]],
980
+ ], 0)
981
+ return CameraState(fov=camera.fov, aspect=camera.aspect, c2w=c2w)
982
+
983
+ @staticmethod
984
+ def generate_pseudo_intrinsics(h: int, w: int) -> np.ndarray:
985
+ """Generate pseudo intrinsics from image size."""
986
+ focal = (h**2 + w**2) ** 0.5
987
+ return np.array([[focal, 0, w // 2], [0, focal, h // 2], [0, 0, 1]]).astype(np.float32)
988
+
989
+ def _connect_client(self, client: viser.ClientHandle):
990
+ """Setup client connection callbacks."""
991
+ wxyz_panel = client.gui.add_text("wxyz:", f"{client.camera.wxyz}")
992
+ position_panel = client.gui.add_text("position:", f"{client.camera.position}")
993
+ fov_panel = client.gui.add_text(
994
+ "fov:", f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}"
995
+ )
996
+ aspect_panel = client.gui.add_text("aspect:", "1.0")
997
+
998
+ @client.camera.on_update
999
+ def _(_: viser.CameraHandle):
1000
+ with self.server.atomic():
1001
+ wxyz_panel.value = f"{client.camera.wxyz}"
1002
+ position_panel.value = f"{client.camera.position}"
1003
+ fov_panel.value = f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}"
1004
+ aspect_panel.value = "1.0"
1005
+
1006
+ @staticmethod
1007
+ def set_color_border(image, border_width=5, color=[1, 0, 0]):
1008
+ """Add colored border to image."""
1009
+ image[:border_width, :, 0] = color[0]
1010
+ image[:border_width, :, 1] = color[1]
1011
+ image[:border_width, :, 2] = color[2]
1012
+ image[-border_width:, :, 0] = color[0]
1013
+ image[-border_width:, :, 1] = color[1]
1014
+ image[-border_width:, :, 2] = color[2]
1015
+ image[:, :border_width, 0] = color[0]
1016
+ image[:, :border_width, 1] = color[1]
1017
+ image[:, :border_width, 2] = color[2]
1018
+ image[:, -border_width:, 0] = color[0]
1019
+ image[:, -border_width:, 1] = color[1]
1020
+ image[:, -border_width:, 2] = color[2]
1021
+ return image
1022
+
1023
+ def read_data(self, pc_list, color_list, conf_list, edge_color_list=None):
1024
+ """Read and organize point cloud data."""
1025
+ pcs = {}
1026
+ step_list = []
1027
+ for i, pc in enumerate(pc_list):
1028
+ step = i
1029
+ pcs.update({
1030
+ step: {
1031
+ "pc": pc,
1032
+ "color": color_list[i],
1033
+ "conf": conf_list[i],
1034
+ "edge_color": (
1035
+ None if edge_color_list is None or edge_color_list[i] is None
1036
+ else edge_color_list[i]
1037
+ ),
1038
+ }
1039
+ })
1040
+ step_list.append(step)
1041
+
1042
+ # Generate camera gradient colors
1043
+ num_cameras = len(pc_list)
1044
+ if num_cameras > 1:
1045
+ normalized_indices = np.array(list(range(num_cameras))) / (num_cameras - 1)
1046
+ else:
1047
+ normalized_indices = np.array([0.0])
1048
+ cmap = cm.get_cmap('viridis')
1049
+ self.camera_colors = cmap(normalized_indices)
1050
+ return pcs, step_list
1051
+
1052
+ def parse_pc_data(
1053
+ self,
1054
+ pc,
1055
+ color,
1056
+ conf=None,
1057
+ edge_color=[0.251, 0.702, 0.902],
1058
+ set_border_color=False,
1059
+ downsample_factor=1,
1060
+ ):
1061
+ """Parse and filter point cloud data."""
1062
+ pred_pts = pc.reshape(-1, 3)
1063
+
1064
+ if set_border_color and edge_color is not None:
1065
+ color = self.set_color_border(color[0], color=edge_color)
1066
+ if np.isnan(color).any():
1067
+ color = np.zeros((pred_pts.shape[0], 3))
1068
+ color[:, 2] = 1
1069
+ else:
1070
+ color = color.reshape(-1, 3)
1071
+
1072
+ # Remove NaN / Inf points
1073
+ valid = np.isfinite(pred_pts).all(axis=1)
1074
+ if not valid.all():
1075
+ pred_pts = pred_pts[valid]
1076
+ color = color[valid]
1077
+ if conf is not None:
1078
+ conf = conf.reshape(-1)[valid]
1079
+
1080
+ # Confidence threshold filter
1081
+ if conf is not None:
1082
+ conf_flat = conf.reshape(-1) if conf.ndim > 1 else conf
1083
+ mask = conf_flat > self.vis_threshold
1084
+ pred_pts = pred_pts[mask]
1085
+ color = color[mask]
1086
+
1087
+ if len(pred_pts) == 0:
1088
+ return pred_pts, color
1089
+
1090
+ # Downsample
1091
+ if downsample_factor > 1 and len(pred_pts) > 0:
1092
+ indices = np.arange(0, len(pred_pts), downsample_factor)
1093
+ pred_pts = pred_pts[indices]
1094
+ color = color[indices]
1095
+
1096
+ return pred_pts, color
1097
+
1098
+ def add_pc(self, step):
1099
+ """Add point cloud for a frame."""
1100
+ pc = self.pcs[step]["pc"]
1101
+ color = self.pcs[step]["color"]
1102
+ conf = self.pcs[step]["conf"]
1103
+ edge_color = self.pcs[step].get("edge_color", None)
1104
+
1105
+ pred_pts, color = self.parse_pc_data(
1106
+ pc, color, conf, edge_color, set_border_color=True,
1107
+ downsample_factor=self.downsample_slider.value
1108
+ )
1109
+
1110
+ self.vis_pts_list.append(pred_pts)
1111
+ self.pc_handles.append(
1112
+ self.server.scene.add_point_cloud(
1113
+ name=f"/frames/{step}/pred_pts",
1114
+ points=pred_pts,
1115
+ colors=color,
1116
+ point_size=self.psize_slider.value,
1117
+ )
1118
+ )
1119
+
1120
+ def add_camera(self, step):
1121
+ """Add camera visualization for a frame."""
1122
+ cam = self.cam_dict
1123
+ focal = cam["focal"][step] if cam and "focal" in cam else 1.0
1124
+ pp = cam["pp"][step] if cam and "pp" in cam else (1.0, 1.0)
1125
+ R = cam["R"][step] if cam and "R" in cam else np.eye(3)
1126
+ t = cam["t"][step] if cam and "t" in cam else np.zeros(3)
1127
+
1128
+ q = tf.SO3.from_matrix(R).wxyz
1129
+ fov = 2 * np.arctan(pp[0] / focal)
1130
+ aspect = pp[0] / pp[1]
1131
+ self.traj_list.append((q, t))
1132
+
1133
+ step_index = self.all_steps.index(step) if step in self.all_steps else 0
1134
+ camera_color = self.camera_colors[step_index]
1135
+ camera_color_rgb = tuple((camera_color[:3] * 255).astype(int))
1136
+
1137
+ self.server.scene.add_frame(
1138
+ f"/frames/{step}/camera_frame",
1139
+ wxyz=q,
1140
+ position=t,
1141
+ axes_length=0.05,
1142
+ axes_radius=0.002,
1143
+ origin_radius=0.002,
1144
+ )
1145
+
1146
+ frustum_handle = self.server.scene.add_camera_frustum(
1147
+ name=f"/frames/{step}/camera",
1148
+ fov=fov,
1149
+ aspect=aspect,
1150
+ wxyz=q,
1151
+ position=t,
1152
+ scale=0.03,
1153
+ color=camera_color_rgb,
1154
+ )
1155
+
1156
+ @frustum_handle.on_click
1157
+ def _(event) -> None:
1158
+ look_at_pt = t + R[:, 2] * 0.5 # look ahead along camera Z
1159
+ up_dir = -R[:, 1]
1160
+ for client in self.server.get_clients().values():
1161
+ client.camera.up_direction = tuple(up_dir)
1162
+ client.camera.position = tuple(t)
1163
+ client.camera.look_at = tuple(look_at_pt)
1164
+
1165
+ self.cam_handles.append(frustum_handle)
1166
+
1167
+ def animate(self):
1168
+ """Setup and run animation controls."""
1169
+ with self.server.gui.add_folder("Playback"):
1170
+ self.gui_timestep = self.server.gui.add_slider(
1171
+ "Train Step", min=0, max=self.num_frames - 1, step=1, initial_value=0, disabled=False
1172
+ )
1173
+ gui_next_frame = self.server.gui.add_button("Next Step", disabled=False)
1174
+ gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False)
1175
+ gui_playing = self.server.gui.add_checkbox("Playing", True)
1176
+ gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=20)
1177
+ gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60"))
1178
+
1179
+ @gui_next_frame.on_click
1180
+ def _(_) -> None:
1181
+ self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames
1182
+
1183
+ @gui_prev_frame.on_click
1184
+ def _(_) -> None:
1185
+ self.gui_timestep.value = (self.gui_timestep.value - 1) % self.num_frames
1186
+
1187
+ @gui_playing.on_update
1188
+ def _(_) -> None:
1189
+ self.gui_timestep.disabled = gui_playing.value
1190
+ gui_next_frame.disabled = gui_playing.value
1191
+ gui_prev_frame.disabled = gui_playing.value
1192
+
1193
+ @gui_framerate_options.on_click
1194
+ def _(_) -> None:
1195
+ gui_framerate.value = int(gui_framerate_options.value)
1196
+
1197
+ prev_timestep = self.gui_timestep.value
1198
+
1199
+ @self.gui_timestep.on_update
1200
+ def _(_) -> None:
1201
+ nonlocal prev_timestep
1202
+ current_timestep = self.gui_timestep.value
1203
+
1204
+ if self.current_frame_image is not None and hasattr(self, 'original_images'):
1205
+ if current_timestep < len(self.original_images):
1206
+ self.current_frame_image.image = self.original_images[current_timestep]
1207
+
1208
+ with self.server.atomic():
1209
+ self.frame_nodes[current_timestep].visible = True
1210
+ self.frame_nodes[prev_timestep].visible = False
1211
+ self.server.flush()
1212
+
1213
+ prev_timestep = current_timestep
1214
+
1215
+ self.server.scene.add_frame("/frames", show_axes=False)
1216
+ self.frame_nodes = []
1217
+ for i in range(self.num_frames):
1218
+ step = self.all_steps[i]
1219
+ self.frame_nodes.append(
1220
+ self.server.scene.add_frame(f"/frames/{step}", show_axes=False)
1221
+ )
1222
+ self.add_pc(step)
1223
+ if self.show_camera:
1224
+ downsample_factor = int(self.camera_downsample_slider.value)
1225
+ if i % downsample_factor == 0:
1226
+ self.add_camera(step)
1227
+
1228
+ prev_timestep = self.gui_timestep.value
1229
+ while True:
1230
+ if self.on_replay:
1231
+ pass
1232
+ else:
1233
+ if gui_playing.value:
1234
+ self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames
1235
+ self.update_frame_visibility()
1236
+
1237
+ time.sleep(1.0 / gui_framerate.value)
1238
+
1239
+ def _take_screenshot(self, client: Optional[Any] = None):
1240
+ """Capture a screenshot from the current view and save to file.
1241
+
1242
+ Args:
1243
+ client: The viser client that triggered the action. If None,
1244
+ uses the first connected client.
1245
+ """
1246
+ output_path = self.screenshot_path.value
1247
+ res_str = self.screenshot_resolution.value
1248
+
1249
+ # Resolve client
1250
+ if client is None:
1251
+ clients = list(self.server.get_clients().values())
1252
+ if not clients:
1253
+ self.screenshot_status.value = "Error: no client connected"
1254
+ return
1255
+ client = clients[0]
1256
+
1257
+ try:
1258
+ self.screenshot_status.value = "Capturing..."
1259
+
1260
+ if res_str == "Current":
1261
+ # Use default render size
1262
+ width, height = 1920, 1080
1263
+ else:
1264
+ width, height = map(int, res_str.split("x"))
1265
+
1266
+ render = client.camera.get_render(height=height, width=width)
1267
+
1268
+ if render is not None:
1269
+ frame = np.array(render)
1270
+ if frame.shape[2] == 4:
1271
+ frame = frame[:, :, :3]
1272
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
1273
+ cv2.imwrite(output_path, frame_bgr)
1274
+ self.screenshot_status.value = f"Saved: {output_path}"
1275
+ print(f"Screenshot saved to {output_path} ({width}x{height})")
1276
+ else:
1277
+ self.screenshot_status.value = "Error: render returned None"
1278
+ print("Screenshot failed: render returned None")
1279
+
1280
+ except Exception as e:
1281
+ self.screenshot_status.value = f"Error: {e}"
1282
+ print(f"Screenshot error: {e}")
1283
+
1284
+ def save_video(
1285
+ self,
1286
+ output_path: str = "output_pointcloud.mp4",
1287
+ fps: int = 30,
1288
+ resolution: str = "1920x1080",
1289
+ save_original_video: bool = True
1290
+ ):
1291
+ """Save point cloud animation as video."""
1292
+ try:
1293
+ if hasattr(self, 'video_status'):
1294
+ self.video_status.value = "Saving video..."
1295
+ print(f"Saving video to {output_path}...")
1296
+
1297
+ width, height = map(int, resolution.split('x'))
1298
+ temp_dir = tempfile.mkdtemp(prefix="viser_video_")
1299
+ print(f"Temporary directory: {temp_dir}")
1300
+
1301
+ print("Waiting for client connection...")
1302
+ timeout = 10
1303
+ start_time = time.time()
1304
+ while len(self.server.get_clients()) == 0:
1305
+ time.sleep(0.1)
1306
+ if time.time() - start_time > timeout:
1307
+ raise RuntimeError("No client connected. Please open the visualization in a browser first.")
1308
+
1309
+ print("Client connected. Starting to render frames...")
1310
+ clients = list(self.server.get_clients().values())
1311
+ client = clients[0]
1312
+
1313
+ if not hasattr(self, 'gui_timestep'):
1314
+ raise RuntimeError("Animation not initialized. Please ensure animate() is called before save_video().")
1315
+
1316
+ for i in tqdm(range(self.num_frames), desc="Rendering frames"):
1317
+ self.gui_timestep.value = i
1318
+ time.sleep(0.1)
1319
+
1320
+ try:
1321
+ screenshot = client.camera.get_render(height=height, width=width)
1322
+ if screenshot is not None:
1323
+ frame = np.array(screenshot)
1324
+ if frame.shape[2] == 4:
1325
+ frame = frame[:, :, :3]
1326
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
1327
+ frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
1328
+ cv2.imwrite(frame_path, frame)
1329
+ else:
1330
+ frame = self._render_frame_fallback(i, width, height)
1331
+ frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
1332
+ cv2.imwrite(frame_path, frame)
1333
+ except Exception as e:
1334
+ print(f"Warning: Error capturing frame {i}: {e}, using fallback")
1335
+ frame = self._render_frame_fallback(i, width, height)
1336
+ frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
1337
+ cv2.imwrite(frame_path, frame)
1338
+
1339
+ print("Encoding video with ffmpeg...")
1340
+ ffmpeg_cmd = [
1341
+ 'ffmpeg', '-y', '-framerate', str(fps),
1342
+ '-i', os.path.join(temp_dir, 'frame_%06d.png'),
1343
+ '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18',
1344
+ output_path
1345
+ ]
1346
+
1347
+ result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
1348
+
1349
+ if result.returncode == 0:
1350
+ print(f"Point cloud video saved successfully to {output_path}")
1351
+ if hasattr(self, 'video_status'):
1352
+ self.video_status.value = f"Saved to {output_path}"
1353
+ else:
1354
+ print(f"FFmpeg error: {result.stderr}")
1355
+ if hasattr(self, 'video_status'):
1356
+ self.video_status.value = "Error: FFmpeg failed"
1357
+
1358
+ if save_original_video and hasattr(self, 'original_images') and len(self.original_images) > 0:
1359
+ self._save_original_video(output_path, fps, width, height)
1360
+
1361
+ shutil.rmtree(temp_dir)
1362
+ print("Temporary files cleaned up")
1363
+
1364
+ except Exception as e:
1365
+ print(f"Error saving video: {e}")
1366
+ import traceback
1367
+ traceback.print_exc()
1368
+ if hasattr(self, 'video_status'):
1369
+ self.video_status.value = f"Error: {str(e)}"
1370
+
1371
+ def _save_original_video(self, pointcloud_video_path: str, fps: int, width: int, height: int):
1372
+ """Save original images as video."""
1373
+ base_path = os.path.splitext(pointcloud_video_path)[0]
1374
+ original_video_path = f"{base_path}_original.mp4"
1375
+
1376
+ print(f"Saving original images video to {original_video_path}...")
1377
+
1378
+ try:
1379
+ temp_dir = tempfile.mkdtemp(prefix="original_video_")
1380
+
1381
+ for i, img in enumerate(tqdm(self.original_images, desc="Saving original frames")):
1382
+ frame = cv2.resize(img, (width, height))
1383
+ if len(frame.shape) == 3 and frame.shape[2] == 3:
1384
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
1385
+ frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
1386
+ cv2.imwrite(frame_path, frame)
1387
+
1388
+ print("Encoding original video with ffmpeg...")
1389
+ ffmpeg_cmd = [
1390
+ 'ffmpeg', '-y', '-framerate', str(fps),
1391
+ '-i', os.path.join(temp_dir, 'frame_%06d.png'),
1392
+ '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18',
1393
+ original_video_path
1394
+ ]
1395
+
1396
+ result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
1397
+
1398
+ if result.returncode == 0:
1399
+ print(f"Original video saved successfully to {original_video_path}")
1400
+ else:
1401
+ print(f"FFmpeg error for original video: {result.stderr}")
1402
+
1403
+ shutil.rmtree(temp_dir)
1404
+
1405
+ except Exception as e:
1406
+ print(f"Error saving original video: {e}")
1407
+ import traceback
1408
+ traceback.print_exc()
1409
+
1410
+ def _render_frame_fallback(self, frame_idx: int, width: int, height: int) -> np.ndarray:
1411
+ """Fallback rendering when screenshot capture fails."""
1412
+ if hasattr(self, 'original_images') and frame_idx < len(self.original_images):
1413
+ frame = self.original_images[frame_idx].copy()
1414
+ frame = cv2.resize(frame, (width, height))
1415
+ cv2.putText(frame, f"Frame {frame_idx}", (10, 30),
1416
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
1417
+ return frame
1418
+ else:
1419
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
1420
+ cv2.putText(frame, f"Frame {frame_idx} - No render available",
1421
+ (width//4, height//2),
1422
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
1423
+ return frame
1424
+
1425
+ def run(self, background_mode: bool = False):
1426
+ """Run the viewer."""
1427
+ self.animate()
1428
+ if background_mode:
1429
+ def server_loop():
1430
+ while True:
1431
+ time.sleep(0.001)
1432
+
1433
+ thread = threading.Thread(target=server_loop, daemon=True)
1434
+ thread.start()
1435
+ else:
1436
+ while True:
1437
+ time.sleep(10.0)
lingbot_map/vis/sky_segmentation.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Sky segmentation utilities for filtering sky points from point clouds.
9
+ """
10
+
11
+ import glob
12
+ import os
13
+ from typing import Optional, Tuple
14
+
15
+ import numpy as np
16
+ import cv2
17
+ from tqdm.auto import tqdm
18
+
19
+ try:
20
+ import onnxruntime
21
+ except ImportError:
22
+ onnxruntime = None
23
+ print("onnxruntime not found. Sky segmentation may not work.")
24
+
25
+
26
+ _SKYSEG_INPUT_SIZE = (320, 320)
27
+ _SKYSEG_SOFT_THRESHOLD = 0.1
28
+ _SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
29
+
30
+
31
+ def _get_cache_version_path(sky_mask_dir: str) -> str:
32
+ return os.path.join(sky_mask_dir, ".skyseg_cache_version")
33
+
34
+
35
+ def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> None:
36
+ """Ensure the sky mask cache directory exists and write the version stamp."""
37
+ if sky_mask_dir is None:
38
+ return
39
+ os.makedirs(sky_mask_dir, exist_ok=True)
40
+ version_path = _get_cache_version_path(sky_mask_dir)
41
+ if not os.path.exists(version_path):
42
+ with open(version_path, "w", encoding="utf-8") as f:
43
+ f.write(_SKYSEG_CACHE_VERSION)
44
+
45
+
46
+ def run_skyseg(
47
+ onnx_session,
48
+ input_size: Tuple[int, int],
49
+ image: np.ndarray,
50
+ ) -> np.ndarray:
51
+ """
52
+ Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
53
+ """
54
+ resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
55
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
56
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
57
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
58
+ x = (x / 255.0 - mean) / std
59
+ x = x.transpose(2, 0, 1)
60
+ x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
61
+
62
+ input_name = onnx_session.get_inputs()[0].name
63
+ output_name = onnx_session.get_outputs()[0].name
64
+ onnx_result = onnx_session.run([output_name], {input_name: x})
65
+
66
+ onnx_result = np.array(onnx_result).squeeze()
67
+ min_value = np.min(onnx_result)
68
+ max_value = np.max(onnx_result)
69
+ denom = max(max_value - min_value, 1e-8)
70
+ onnx_result = (onnx_result - min_value) / denom
71
+ onnx_result *= 255.0
72
+ return onnx_result.astype(np.uint8)
73
+
74
+
75
+ def _mask_to_float(mask: np.ndarray) -> np.ndarray:
76
+ mask = mask.astype(np.float32)
77
+ if mask.size == 0:
78
+ return mask
79
+ return np.clip(mask, 0.0, 1.0)
80
+
81
+
82
+ def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
83
+ mask = np.asarray(mask)
84
+ if mask.dtype == np.uint8:
85
+ return mask
86
+ mask = mask.astype(np.float32)
87
+ if mask.size > 0 and mask.max() <= 1.0:
88
+ mask = mask * 255.0
89
+ return np.clip(mask, 0.0, 255.0).astype(np.uint8)
90
+
91
+
92
+ def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
93
+ # The raw skyseg map is higher on sky and lower on non-sky.
94
+ return 1.0 - _mask_to_float(result_map)
95
+
96
+
97
+ def segment_sky_from_array(
98
+ image: np.ndarray,
99
+ skyseg_session,
100
+ target_h: int,
101
+ target_w: int
102
+ ) -> np.ndarray:
103
+ """
104
+ Segment sky from an image array using ONNX model.
105
+
106
+ Args:
107
+ image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
108
+ skyseg_session: ONNX runtime inference session
109
+ target_h: Target output height
110
+ target_w: Target output width
111
+
112
+ Returns:
113
+ Continuous non-sky confidence map in [0, 1].
114
+ """
115
+ image_rgb = _image_to_rgb_uint8(image)
116
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
117
+ result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
118
+ result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
119
+ return _result_map_to_non_sky_conf(result_map)
120
+
121
+
122
+ def segment_sky(
123
+ image_path: str,
124
+ skyseg_session,
125
+ output_path: Optional[str] = None
126
+ ) -> np.ndarray:
127
+ """
128
+ Segment sky from an image using ONNX model.
129
+
130
+ Args:
131
+ image_path: Path to the input image
132
+ skyseg_session: ONNX runtime inference session
133
+ output_path: Optional path to save the mask
134
+
135
+ Returns:
136
+ Continuous non-sky confidence map in [0, 1].
137
+ """
138
+ image = cv2.imread(image_path)
139
+ if image is None:
140
+ raise ValueError(f"Failed to read image: {image_path}")
141
+
142
+ result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
143
+ result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
144
+ mask = _result_map_to_non_sky_conf(result_map)
145
+
146
+ if output_path is not None:
147
+ output_dir = os.path.dirname(output_path)
148
+ if output_dir:
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ cv2.imwrite(output_path, _mask_to_uint8(mask))
151
+
152
+ return mask
153
+
154
+
155
+ def _list_image_files(image_folder: str) -> list[str]:
156
+ image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
157
+ image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
158
+ return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
159
+
160
+
161
+ def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
162
+ if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
163
+ image = image.transpose(1, 2, 0)
164
+
165
+ if image.ndim != 3 or image.shape[2] != 3:
166
+ raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
167
+
168
+ if image.dtype != np.uint8:
169
+ image = image.astype(np.float32)
170
+ if image.max() <= 1.0:
171
+ image = image * 255.0
172
+ image = np.clip(image, 0.0, 255.0).astype(np.uint8)
173
+
174
+ return image
175
+
176
+
177
+ def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
178
+ if image_paths is not None and index < len(image_paths):
179
+ return os.path.basename(image_paths[index])
180
+ return f"frame_{index:06d}.png"
181
+
182
+
183
+ def _save_sky_mask_visualization(
184
+ image: np.ndarray,
185
+ sky_mask: np.ndarray,
186
+ output_path: str,
187
+ ) -> None:
188
+ image_rgb = _image_to_rgb_uint8(image)
189
+ if sky_mask.shape[:2] != image_rgb.shape[:2]:
190
+ sky_mask = cv2.resize(
191
+ sky_mask,
192
+ (image_rgb.shape[1], image_rgb.shape[0]),
193
+ interpolation=cv2.INTER_NEAREST,
194
+ )
195
+
196
+ mask_uint8 = _mask_to_uint8(sky_mask)
197
+ mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
198
+ overlay = image_rgb.astype(np.float32).copy()
199
+ sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
200
+ overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
201
+ overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
202
+
203
+ panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
204
+ output_dir = os.path.dirname(output_path)
205
+ if output_dir:
206
+ os.makedirs(output_dir, exist_ok=True)
207
+ cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
208
+
209
+
210
+ def load_or_create_sky_masks(
211
+ image_folder: Optional[str] = None,
212
+ image_paths: Optional[list[str]] = None,
213
+ images: Optional[np.ndarray] = None,
214
+ skyseg_model_path: str = "skyseg.onnx",
215
+ sky_mask_dir: Optional[str] = None,
216
+ sky_mask_visualization_dir: Optional[str] = None,
217
+ target_shape: Optional[Tuple[int, int]] = None,
218
+ num_frames: Optional[int] = None,
219
+ ) -> Optional[np.ndarray]:
220
+ """
221
+ Load cached sky masks or generate them with the ONNX model.
222
+
223
+ Args:
224
+ image_folder: Folder containing input images.
225
+ image_paths: Optional explicit image file list, in the exact order to process.
226
+ images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
227
+ skyseg_model_path: Path to the sky segmentation ONNX model.
228
+ sky_mask_dir: Optional directory for cached raw masks.
229
+ sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
230
+ target_shape: Optional output mask shape (H, W) after resizing.
231
+ num_frames: Optional maximum number of frames to process.
232
+
233
+ Returns:
234
+ Sky masks with shape (S, H, W), or None if sky segmentation could not run.
235
+ """
236
+ if onnxruntime is None:
237
+ print("Warning: onnxruntime not available, skipping sky segmentation")
238
+ return None
239
+
240
+ if image_folder is None and image_paths is None and images is None:
241
+ print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
242
+ return None
243
+
244
+ if not os.path.exists(skyseg_model_path):
245
+ print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
246
+ try:
247
+ download_skyseg_model(skyseg_model_path)
248
+ except Exception as e:
249
+ print(f"Warning: Failed to download sky segmentation model: {e}")
250
+ return None
251
+
252
+ skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
253
+ sky_masks = []
254
+
255
+ if sky_mask_visualization_dir is not None:
256
+ os.makedirs(sky_mask_visualization_dir, exist_ok=True)
257
+ print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
258
+
259
+ if images is not None:
260
+ if image_paths is None and image_folder is not None:
261
+ image_paths = _list_image_files(image_folder)
262
+
263
+ num_images = images.shape[0]
264
+ if num_frames is not None:
265
+ num_images = min(num_images, num_frames)
266
+ if image_paths is not None:
267
+ image_paths = image_paths[:num_images]
268
+
269
+ if sky_mask_dir is None and image_folder is not None:
270
+ sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
271
+ _prepare_sky_mask_cache(sky_mask_dir)
272
+
273
+ print("Generating sky masks from image array...")
274
+ for i in tqdm(range(num_images)):
275
+ image_rgb = _image_to_rgb_uint8(images[i])
276
+ image_h, image_w = image_rgb.shape[:2]
277
+ image_name = _get_mask_filename(image_paths, i)
278
+ mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
279
+
280
+ if mask_filepath is not None and os.path.exists(mask_filepath):
281
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
282
+ if sky_mask is not None and sky_mask.shape[:2] == (image_h, image_w):
283
+ # Reuse cached mask
284
+ pass
285
+ else:
286
+ sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
287
+ cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
288
+ else:
289
+ sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
290
+ if mask_filepath is not None:
291
+ cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
292
+
293
+ if sky_mask_visualization_dir is not None:
294
+ _save_sky_mask_visualization(
295
+ image_rgb,
296
+ sky_mask,
297
+ os.path.join(sky_mask_visualization_dir, image_name),
298
+ )
299
+
300
+ if target_shape is not None and sky_mask.shape[:2] != target_shape:
301
+ sky_mask = cv2.resize(
302
+ sky_mask,
303
+ (target_shape[1], target_shape[0]),
304
+ interpolation=cv2.INTER_LINEAR,
305
+ )
306
+
307
+ sky_masks.append(_mask_to_float(sky_mask))
308
+
309
+ else:
310
+ if image_paths is None and image_folder is not None:
311
+ image_paths = _list_image_files(image_folder)
312
+
313
+ if images is None and image_paths is not None:
314
+ if len(image_paths) == 0:
315
+ print("Warning: No image files provided, skipping sky segmentation")
316
+ return None
317
+
318
+ if num_frames is not None:
319
+ image_paths = image_paths[:num_frames]
320
+
321
+ if sky_mask_dir is None:
322
+ if image_folder is None:
323
+ image_folder = os.path.dirname(image_paths[0])
324
+ sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
325
+ _prepare_sky_mask_cache(sky_mask_dir)
326
+
327
+ print("Generating sky masks from image files...")
328
+ for image_path in tqdm(image_paths):
329
+ image_name = os.path.basename(image_path)
330
+ mask_filepath = os.path.join(sky_mask_dir, image_name)
331
+
332
+ if os.path.exists(mask_filepath):
333
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
334
+ if sky_mask is None:
335
+ print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
336
+ sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
337
+ else:
338
+ sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
339
+
340
+ if sky_mask is None:
341
+ print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
342
+ continue
343
+
344
+ if sky_mask_visualization_dir is not None:
345
+ image_bgr = cv2.imread(image_path)
346
+ if image_bgr is not None:
347
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
348
+ _save_sky_mask_visualization(
349
+ image_rgb,
350
+ sky_mask,
351
+ os.path.join(sky_mask_visualization_dir, image_name),
352
+ )
353
+
354
+ if target_shape is not None and sky_mask.shape[:2] != target_shape:
355
+ sky_mask = cv2.resize(
356
+ sky_mask,
357
+ (target_shape[1], target_shape[0]),
358
+ interpolation=cv2.INTER_LINEAR,
359
+ )
360
+
361
+ sky_masks.append(_mask_to_float(sky_mask))
362
+
363
+ if len(sky_masks) == 0:
364
+ print("Warning: No sky masks generated, skipping sky segmentation")
365
+ return None
366
+
367
+ try:
368
+ return np.stack(sky_masks, axis=0)
369
+ except ValueError:
370
+ return np.array(sky_masks, dtype=object)
371
+
372
+
373
+ def apply_sky_segmentation(
374
+ conf: np.ndarray,
375
+ image_folder: Optional[str] = None,
376
+ image_paths: Optional[list[str]] = None,
377
+ images: Optional[np.ndarray] = None,
378
+ skyseg_model_path: str = "skyseg.onnx",
379
+ sky_mask_dir: Optional[str] = None,
380
+ sky_mask_visualization_dir: Optional[str] = None,
381
+ ) -> np.ndarray:
382
+ """
383
+ Apply sky segmentation to confidence scores.
384
+
385
+ Args:
386
+ conf: Confidence scores with shape (S, H, W)
387
+ image_folder: Path to the folder containing input images (optional if images provided)
388
+ image_paths: Optional explicit image file list in processing order
389
+ images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
390
+ skyseg_model_path: Path to the sky segmentation ONNX model
391
+ sky_mask_dir: Optional directory for cached raw masks
392
+ sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
393
+
394
+ Returns:
395
+ Updated confidence scores with sky regions masked out
396
+ """
397
+ S, H, W = conf.shape
398
+
399
+ sky_mask_array = load_or_create_sky_masks(
400
+ image_folder=image_folder,
401
+ image_paths=image_paths,
402
+ images=images,
403
+ skyseg_model_path=skyseg_model_path,
404
+ sky_mask_dir=sky_mask_dir,
405
+ sky_mask_visualization_dir=sky_mask_visualization_dir,
406
+ target_shape=(H, W),
407
+ num_frames=S,
408
+ )
409
+ if sky_mask_array is None:
410
+ return conf
411
+
412
+ if sky_mask_array.shape[0] < S:
413
+ print(
414
+ f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
415
+ "leaving the remaining frames unmasked"
416
+ )
417
+ padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
418
+ padded[: sky_mask_array.shape[0]] = sky_mask_array
419
+ sky_mask_array = padded
420
+ elif sky_mask_array.shape[0] > S:
421
+ sky_mask_array = sky_mask_array[:S]
422
+
423
+ sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
424
+ conf = conf * sky_mask_binary
425
+
426
+ print("Sky segmentation applied successfully")
427
+ return conf
428
+
429
+
430
+ def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
431
+ """
432
+ Download sky segmentation model from HuggingFace.
433
+
434
+ Args:
435
+ output_path: Path to save the model
436
+
437
+ Returns:
438
+ Path to the downloaded model
439
+ """
440
+ import requests
441
+
442
+ url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
443
+
444
+ print(f"Downloading sky segmentation model from {url}...")
445
+ response = requests.get(url, stream=True)
446
+ response.raise_for_status()
447
+
448
+ total_size = int(response.headers.get('content-length', 0))
449
+
450
+ with open(output_path, 'wb') as f:
451
+ with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
452
+ for chunk in response.iter_content(chunk_size=8192):
453
+ f.write(chunk)
454
+ pbar.update(len(chunk))
455
+
456
+ print(f"Model saved to {output_path}")
457
+ return output_path
lingbot_map/vis/utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Visualization utility functions for colorization and color bars.
9
+ """
10
+
11
+ import dataclasses
12
+ from typing import Optional, Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ import cv2
17
+ import matplotlib.cm as cm
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class CameraState:
22
+ """Camera state for rendering."""
23
+ fov: float
24
+ aspect: float
25
+ c2w: np.ndarray
26
+
27
+ def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray:
28
+ """Get camera intrinsic matrix from FOV and image size."""
29
+ W, H = img_wh
30
+ focal_length = H / 2.0 / np.tan(self.fov / 2.0)
31
+ K = np.array([
32
+ [focal_length, 0.0, W / 2.0],
33
+ [0.0, focal_length, H / 2.0],
34
+ [0.0, 0.0, 1.0],
35
+ ])
36
+ return K
37
+
38
+
39
+ def get_vertical_colorbar(
40
+ h: int,
41
+ vmin: float,
42
+ vmax: float,
43
+ cmap_name: str = "jet",
44
+ label: Optional[str] = None,
45
+ cbar_precision: int = 2
46
+ ) -> np.ndarray:
47
+ """
48
+ Create a vertical colorbar image.
49
+
50
+ Args:
51
+ h: Height in pixels
52
+ vmin: Minimum value
53
+ vmax: Maximum value
54
+ cmap_name: Colormap name
55
+ label: Optional label for the colorbar
56
+ cbar_precision: Decimal precision for tick labels
57
+
58
+ Returns:
59
+ Colorbar image as numpy array (H, W, 3)
60
+ """
61
+ from matplotlib.figure import Figure
62
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
63
+ import matplotlib as mpl
64
+
65
+ fig = Figure(figsize=(2, 8), dpi=100)
66
+ fig.subplots_adjust(right=1.5)
67
+ canvas = FigureCanvasAgg(fig)
68
+
69
+ ax = fig.add_subplot(111)
70
+ cmap = cm.get_cmap(cmap_name)
71
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
72
+
73
+ tick_cnt = 6
74
+ tick_loc = np.linspace(vmin, vmax, tick_cnt)
75
+ cb1 = mpl.colorbar.ColorbarBase(
76
+ ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
77
+ )
78
+
79
+ tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
80
+ if cbar_precision == 0:
81
+ tick_label = [x[:-2] for x in tick_label]
82
+
83
+ cb1.set_ticklabels(tick_label)
84
+ cb1.ax.tick_params(labelsize=18, rotation=0)
85
+ if label is not None:
86
+ cb1.set_label(label)
87
+
88
+ canvas.draw()
89
+ s, (width, height) = canvas.print_to_buffer()
90
+
91
+ im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
92
+ im = im[:, :, :3].astype(np.float32) / 255.0
93
+
94
+ if h != im.shape[0]:
95
+ w = int(im.shape[1] / im.shape[0] * h)
96
+ im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
97
+
98
+ return im
99
+
100
+
101
+ def colorize_np(
102
+ x: np.ndarray,
103
+ cmap_name: str = "jet",
104
+ mask: Optional[np.ndarray] = None,
105
+ range: Optional[Tuple[float, float]] = None,
106
+ append_cbar: bool = False,
107
+ cbar_in_image: bool = False,
108
+ cbar_precision: int = 2,
109
+ ) -> np.ndarray:
110
+ """
111
+ Turn a grayscale image into a color image.
112
+
113
+ Args:
114
+ x: Input grayscale image [H, W]
115
+ cmap_name: Colormap name
116
+ mask: Optional mask image [H, W]
117
+ range: Value range for scaling [min, max], automatic if None
118
+ append_cbar: Whether to append colorbar
119
+ cbar_in_image: Put colorbar inside image
120
+ cbar_precision: Colorbar tick precision
121
+
122
+ Returns:
123
+ Colorized image [H, W, 3]
124
+ """
125
+ if range is not None:
126
+ vmin, vmax = range
127
+ elif mask is not None:
128
+ vmin = np.min(x[mask][np.nonzero(x[mask])])
129
+ vmax = np.max(x[mask])
130
+ x[np.logical_not(mask)] = vmin
131
+ else:
132
+ vmin, vmax = np.percentile(x, (1, 100))
133
+ vmax += 1e-6
134
+
135
+ x = np.clip(x, vmin, vmax)
136
+ x = (x - vmin) / (vmax - vmin)
137
+
138
+ cmap = cm.get_cmap(cmap_name)
139
+ x_new = cmap(x)[:, :, :3]
140
+
141
+ if mask is not None:
142
+ mask = np.float32(mask[:, :, np.newaxis])
143
+ x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
144
+
145
+ cbar = get_vertical_colorbar(
146
+ h=x.shape[0],
147
+ vmin=vmin,
148
+ vmax=vmax,
149
+ cmap_name=cmap_name,
150
+ cbar_precision=cbar_precision,
151
+ )
152
+
153
+ if append_cbar:
154
+ if cbar_in_image:
155
+ x_new[:, -cbar.shape[1]:, :] = cbar
156
+ else:
157
+ x_new = np.concatenate(
158
+ (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
159
+ )
160
+ return x_new
161
+ else:
162
+ return x_new
163
+
164
+
165
+ def colorize(
166
+ x: torch.Tensor,
167
+ cmap_name: str = "jet",
168
+ mask: Optional[torch.Tensor] = None,
169
+ range: Optional[Tuple[float, float]] = None,
170
+ append_cbar: bool = False,
171
+ cbar_in_image: bool = False
172
+ ) -> torch.Tensor:
173
+ """
174
+ Turn a grayscale image into a color image (PyTorch tensor version).
175
+
176
+ Args:
177
+ x: Grayscale image tensor [H, W] or [B, H, W]
178
+ cmap_name: Colormap name
179
+ mask: Optional mask tensor [H, W] or [B, H, W]
180
+ range: Value range for scaling
181
+ append_cbar: Whether to append colorbar
182
+ cbar_in_image: Put colorbar inside image
183
+
184
+ Returns:
185
+ Colorized tensor
186
+ """
187
+ device = x.device
188
+ x = x.cpu().numpy()
189
+ if mask is not None:
190
+ mask = mask.cpu().numpy() > 0.99
191
+ kernel = np.ones((3, 3), np.uint8)
192
+
193
+ if x.ndim == 2:
194
+ x = x[None]
195
+ if mask is not None:
196
+ mask = mask[None]
197
+
198
+ out = []
199
+ for x_ in x:
200
+ if mask is not None:
201
+ mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
202
+
203
+ x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
204
+ out.append(torch.from_numpy(x_).to(device).float())
205
+ out = torch.stack(out).squeeze(0)
206
+ return out
lingbot_map/vis/viser_wrapper.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Quick visualization wrapper for GCT predictions using Viser.
9
+ """
10
+
11
+ import time
12
+ import threading
13
+ from typing import List, Optional
14
+
15
+ import numpy as np
16
+ import viser
17
+ import viser.transforms as tf
18
+ from tqdm.auto import tqdm
19
+
20
+ from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
21
+ from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
22
+
23
+
24
+ def viser_wrapper(
25
+ pred_dict: dict,
26
+ port: int = 8080,
27
+ init_conf_threshold: float = 50.0,
28
+ use_point_map: bool = False,
29
+ background_mode: bool = False,
30
+ mask_sky: bool = False,
31
+ image_folder: Optional[str] = None,
32
+ ):
33
+ """
34
+ Visualize predicted 3D points and camera poses with viser.
35
+
36
+ This is a simplified wrapper for quick visualization without the full
37
+ PointCloudViewer controls.
38
+
39
+ Args:
40
+ pred_dict: Dictionary containing predictions with keys:
41
+ - images: (S, 3, H, W) - Input images
42
+ - world_points: (S, H, W, 3)
43
+ - world_points_conf: (S, H, W)
44
+ - depth: (S, H, W, 1)
45
+ - depth_conf: (S, H, W)
46
+ - extrinsic: (S, 3, 4)
47
+ - intrinsic: (S, 3, 3)
48
+ port: Port number for the viser server
49
+ init_conf_threshold: Initial percentage of low-confidence points to filter out
50
+ use_point_map: Whether to visualize world_points or use depth-based points
51
+ background_mode: Whether to run the server in background thread
52
+ mask_sky: Whether to apply sky segmentation to filter out sky points
53
+ image_folder: Path to the folder containing input images (for sky segmentation)
54
+
55
+ Returns:
56
+ viser.ViserServer: The viser server instance
57
+ """
58
+ print(f"Starting viser server on port {port}")
59
+
60
+ server = viser.ViserServer(host="0.0.0.0", port=port)
61
+ server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
62
+
63
+ # Unpack prediction dict
64
+ images = pred_dict["images"] # (S, 3, H, W)
65
+ world_points_map = pred_dict["world_points"] # (S, H, W, 3)
66
+ conf_map = pred_dict["world_points_conf"] # (S, H, W)
67
+
68
+ depth_map = pred_dict["depth"] # (S, H, W, 1)
69
+ depth_conf = pred_dict["depth_conf"] # (S, H, W)
70
+
71
+ extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
72
+ intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
73
+
74
+ # Compute world points from depth if not using the precomputed point map
75
+ if not use_point_map:
76
+ world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
77
+ conf = depth_conf
78
+ else:
79
+ world_points = world_points_map
80
+ conf = conf_map
81
+
82
+ # Apply sky segmentation if enabled
83
+ if mask_sky and image_folder is not None:
84
+ conf = apply_sky_segmentation(conf, image_folder)
85
+
86
+ # Convert images from (S, 3, H, W) to (S, H, W, 3)
87
+ colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
88
+ shape = world_points.shape
89
+ S: int = shape[0]
90
+ H: int = shape[1]
91
+ W: int = shape[2]
92
+
93
+ # Flatten
94
+ points = world_points.reshape(-1, 3)
95
+ colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
96
+ conf_flat = conf.reshape(-1)
97
+
98
+ # Random sample points if too many
99
+ indices = None
100
+ if points.shape[0] > 6000000:
101
+ print(f"Too many points ({points.shape[0]}), randomly sampling 6M points")
102
+ indices = np.random.choice(points.shape[0], size=6000000, replace=False)
103
+ points = points[indices]
104
+ colors_flat = colors_flat[indices]
105
+ conf_flat = conf_flat[indices]
106
+
107
+ cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
108
+ cam_to_world = cam_to_world_mat[:, :3, :]
109
+
110
+ # Compute scene center and recenter
111
+ scene_center = np.mean(points, axis=0)
112
+ points_centered = points - scene_center
113
+ cam_to_world[..., -1] -= scene_center
114
+
115
+ # Store frame indices for filtering
116
+ frame_indices = (
117
+ np.repeat(np.arange(S), H * W)[indices]
118
+ if indices is not None
119
+ else np.repeat(np.arange(S), H * W)
120
+ )
121
+
122
+ # Build the viser GUI
123
+ gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
124
+ gui_points_conf = server.gui.add_slider(
125
+ "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
126
+ )
127
+ gui_frame_selector = server.gui.add_dropdown(
128
+ "Show Points from Frames",
129
+ options=["All"] + [str(i) for i in range(S)],
130
+ initial_value="All"
131
+ )
132
+
133
+ # Create the main point cloud
134
+ init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
135
+ init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
136
+ point_cloud = server.scene.add_point_cloud(
137
+ name="viser_pcd",
138
+ points=points_centered[init_conf_mask],
139
+ colors=colors_flat[init_conf_mask],
140
+ point_size=0.0005,
141
+ point_shape="circle",
142
+ )
143
+
144
+ frames: List[viser.FrameHandle] = []
145
+ frustums: List[viser.CameraFrustumHandle] = []
146
+
147
+ def visualize_frames(extrinsics, images_: np.ndarray) -> None:
148
+ """Add camera frames and frustums to the scene."""
149
+ for f in frames:
150
+ f.remove()
151
+ frames.clear()
152
+ for fr in frustums:
153
+ fr.remove()
154
+ frustums.clear()
155
+
156
+ def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
157
+ @frustum.on_click
158
+ def _(_) -> None:
159
+ for client in server.get_clients().values():
160
+ client.camera.wxyz = frame.wxyz
161
+ client.camera.position = frame.position
162
+
163
+ for img_id in tqdm(range(S)):
164
+ cam2world_3x4 = extrinsics[img_id]
165
+ T_world_camera = tf.SE3.from_matrix(cam2world_3x4)
166
+
167
+ frame_axis = server.scene.add_frame(
168
+ f"frame_{img_id}",
169
+ wxyz=T_world_camera.rotation().wxyz,
170
+ position=T_world_camera.translation(),
171
+ axes_length=0.05,
172
+ axes_radius=0.002,
173
+ origin_radius=0.002,
174
+ )
175
+ frames.append(frame_axis)
176
+
177
+ img = images_[img_id]
178
+ img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
179
+ h, w = img.shape[:2]
180
+
181
+ fy = 1.1 * h
182
+ fov = 2 * np.arctan2(h / 2, fy)
183
+
184
+ frustum_cam = server.scene.add_camera_frustum(
185
+ f"frame_{img_id}/frustum",
186
+ fov=fov,
187
+ aspect=w / h,
188
+ scale=0.05,
189
+ image=img,
190
+ line_width=1.0
191
+ )
192
+ frustums.append(frustum_cam)
193
+ attach_callback(frustum_cam, frame_axis)
194
+
195
+ def update_point_cloud() -> None:
196
+ """Update point cloud based on current GUI selections."""
197
+ current_percentage = gui_points_conf.value
198
+ threshold_val = np.percentile(conf_flat, current_percentage)
199
+ print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
200
+
201
+ conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
202
+
203
+ if gui_frame_selector.value == "All":
204
+ frame_mask = np.ones_like(conf_mask, dtype=bool)
205
+ else:
206
+ selected_idx = int(gui_frame_selector.value)
207
+ frame_mask = frame_indices == selected_idx
208
+
209
+ combined_mask = conf_mask & frame_mask
210
+ point_cloud.points = points_centered[combined_mask]
211
+ point_cloud.colors = colors_flat[combined_mask]
212
+
213
+ @gui_points_conf.on_update
214
+ def _(_) -> None:
215
+ update_point_cloud()
216
+
217
+ @gui_frame_selector.on_update
218
+ def _(_) -> None:
219
+ update_point_cloud()
220
+
221
+ @gui_show_frames.on_update
222
+ def _(_) -> None:
223
+ for f in frames:
224
+ f.visible = gui_show_frames.value
225
+ for fr in frustums:
226
+ fr.visible = gui_show_frames.value
227
+
228
+ # Add camera frames
229
+ import torch
230
+ if torch.is_tensor(cam_to_world):
231
+ cam_to_world_np = cam_to_world.cpu().numpy()
232
+ else:
233
+ cam_to_world_np = cam_to_world
234
+ visualize_frames(cam_to_world_np, images)
235
+
236
+ print("Starting viser server...")
237
+ if background_mode:
238
+ def server_loop():
239
+ while True:
240
+ time.sleep(0.001)
241
+
242
+ thread = threading.Thread(target=server_loop, daemon=True)
243
+ thread.start()
244
+ else:
245
+ while True:
246
+ time.sleep(0.01)
247
+
248
+ return server
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+ torch==2.9.1
3
+ torchvision==0.24.1
4
+ gradio>=5.0,<6
5
+ spaces>=0.34.0
6
+ huggingface_hub>=0.30.0
7
+ einops>=0.8.0
8
+ safetensors>=0.5.0
9
+ opencv-python-headless>=4.10.0
10
+ tqdm>=4.66.0
11
+ scipy>=1.13.0
12
+ trimesh>=4.4.0
13
+ matplotlib>=3.8.0
14
+ Pillow>=10.0.0