dxm21 commited on
Commit
4d798a0
·
verified ·
1 Parent(s): 035398f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +88 -0
  2. README.md +7 -14
  3. dpm/aggregator.py +24 -11
  4. dpm/decoder.py +24 -2
  5. dpm/model.py +110 -11
  6. gradio_demo.py +144 -70
  7. test.ipynb +101 -0
  8. vggt-low-vram/.gitattributes +2 -0
  9. vggt-low-vram/.gitignore +155 -0
  10. vggt-low-vram/CODE_OF_CONDUCT.md +80 -0
  11. vggt-low-vram/CONTRIBUTING.md +31 -0
  12. vggt-low-vram/LICENSE.txt +115 -0
  13. vggt-low-vram/README.md +398 -0
  14. vggt-low-vram/benchmark/benchmark.py +64 -0
  15. vggt-low-vram/benchmark/benchmark_baseline.py +65 -0
  16. vggt-low-vram/benchmark/plot_recon.py +247 -0
  17. vggt-low-vram/benchmark/run_benchmark.bash +28 -0
  18. vggt-low-vram/demo_colmap.py +337 -0
  19. vggt-low-vram/demo_gradio.py +684 -0
  20. vggt-low-vram/demo_viser.py +400 -0
  21. vggt-low-vram/docs/package.md +45 -0
  22. vggt-low-vram/examples/kitchen/images/00.png +3 -0
  23. vggt-low-vram/examples/kitchen/images/01.png +3 -0
  24. vggt-low-vram/examples/kitchen/images/02.png +3 -0
  25. vggt-low-vram/examples/kitchen/images/03.png +3 -0
  26. vggt-low-vram/examples/kitchen/images/04.png +3 -0
  27. vggt-low-vram/examples/kitchen/images/05.png +3 -0
  28. vggt-low-vram/examples/kitchen/images/06.png +3 -0
  29. vggt-low-vram/examples/kitchen/images/07.png +3 -0
  30. vggt-low-vram/examples/kitchen/images/08.png +3 -0
  31. vggt-low-vram/examples/kitchen/images/09.png +3 -0
  32. vggt-low-vram/examples/kitchen/images/10.png +3 -0
  33. vggt-low-vram/examples/kitchen/images/11.png +3 -0
  34. vggt-low-vram/examples/kitchen/images/12.png +3 -0
  35. vggt-low-vram/examples/kitchen/images/13.png +3 -0
  36. vggt-low-vram/examples/kitchen/images/14.png +3 -0
  37. vggt-low-vram/examples/kitchen/images/15.png +3 -0
  38. vggt-low-vram/examples/kitchen/images/16.png +3 -0
  39. vggt-low-vram/examples/kitchen/images/17.png +3 -0
  40. vggt-low-vram/examples/kitchen/images/18.png +3 -0
  41. vggt-low-vram/examples/kitchen/images/19.png +3 -0
  42. vggt-low-vram/examples/kitchen/images/20.png +3 -0
  43. vggt-low-vram/examples/kitchen/images/21.png +3 -0
  44. vggt-low-vram/examples/kitchen/images/22.png +3 -0
  45. vggt-low-vram/examples/kitchen/images/23.png +3 -0
  46. vggt-low-vram/examples/kitchen/images/24.png +3 -0
  47. vggt-low-vram/examples/llff_fern/images/000.png +3 -0
  48. vggt-low-vram/examples/llff_fern/images/001.png +3 -0
  49. vggt-low-vram/examples/llff_fern/images/002.png +3 -0
  50. vggt-low-vram/examples/llff_fern/images/003.png +3 -0
.gitattributes CHANGED
@@ -127,3 +127,91 @@ input_images_20260127_235812_906867/images/000001.png filter=lfs diff=lfs merge=
127
  input_images_20260127_235812_906867/images/000002.png filter=lfs diff=lfs merge=lfs -text
128
  input_images_20260127_235812_906867/images/000003.png filter=lfs diff=lfs merge=lfs -text
129
  4dgs-dpm/bin/ruff filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  input_images_20260127_235812_906867/images/000002.png filter=lfs diff=lfs merge=lfs -text
128
  input_images_20260127_235812_906867/images/000003.png filter=lfs diff=lfs merge=lfs -text
129
  4dgs-dpm/bin/ruff filter=lfs diff=lfs merge=lfs -text
130
+ vggt-low-vram/examples/kitchen/images/00.png filter=lfs diff=lfs merge=lfs -text
131
+ vggt-low-vram/examples/kitchen/images/01.png filter=lfs diff=lfs merge=lfs -text
132
+ vggt-low-vram/examples/kitchen/images/02.png filter=lfs diff=lfs merge=lfs -text
133
+ vggt-low-vram/examples/kitchen/images/03.png filter=lfs diff=lfs merge=lfs -text
134
+ vggt-low-vram/examples/kitchen/images/04.png filter=lfs diff=lfs merge=lfs -text
135
+ vggt-low-vram/examples/kitchen/images/05.png filter=lfs diff=lfs merge=lfs -text
136
+ vggt-low-vram/examples/kitchen/images/06.png filter=lfs diff=lfs merge=lfs -text
137
+ vggt-low-vram/examples/kitchen/images/07.png filter=lfs diff=lfs merge=lfs -text
138
+ vggt-low-vram/examples/kitchen/images/08.png filter=lfs diff=lfs merge=lfs -text
139
+ vggt-low-vram/examples/kitchen/images/09.png filter=lfs diff=lfs merge=lfs -text
140
+ vggt-low-vram/examples/kitchen/images/10.png filter=lfs diff=lfs merge=lfs -text
141
+ vggt-low-vram/examples/kitchen/images/11.png filter=lfs diff=lfs merge=lfs -text
142
+ vggt-low-vram/examples/kitchen/images/12.png filter=lfs diff=lfs merge=lfs -text
143
+ vggt-low-vram/examples/kitchen/images/13.png filter=lfs diff=lfs merge=lfs -text
144
+ vggt-low-vram/examples/kitchen/images/14.png filter=lfs diff=lfs merge=lfs -text
145
+ vggt-low-vram/examples/kitchen/images/15.png filter=lfs diff=lfs merge=lfs -text
146
+ vggt-low-vram/examples/kitchen/images/16.png filter=lfs diff=lfs merge=lfs -text
147
+ vggt-low-vram/examples/kitchen/images/17.png filter=lfs diff=lfs merge=lfs -text
148
+ vggt-low-vram/examples/kitchen/images/18.png filter=lfs diff=lfs merge=lfs -text
149
+ vggt-low-vram/examples/kitchen/images/19.png filter=lfs diff=lfs merge=lfs -text
150
+ vggt-low-vram/examples/kitchen/images/20.png filter=lfs diff=lfs merge=lfs -text
151
+ vggt-low-vram/examples/kitchen/images/21.png filter=lfs diff=lfs merge=lfs -text
152
+ vggt-low-vram/examples/kitchen/images/22.png filter=lfs diff=lfs merge=lfs -text
153
+ vggt-low-vram/examples/kitchen/images/23.png filter=lfs diff=lfs merge=lfs -text
154
+ vggt-low-vram/examples/kitchen/images/24.png filter=lfs diff=lfs merge=lfs -text
155
+ vggt-low-vram/examples/llff_fern/images/000.png filter=lfs diff=lfs merge=lfs -text
156
+ vggt-low-vram/examples/llff_fern/images/001.png filter=lfs diff=lfs merge=lfs -text
157
+ vggt-low-vram/examples/llff_fern/images/002.png filter=lfs diff=lfs merge=lfs -text
158
+ vggt-low-vram/examples/llff_fern/images/003.png filter=lfs diff=lfs merge=lfs -text
159
+ vggt-low-vram/examples/llff_fern/images/004.png filter=lfs diff=lfs merge=lfs -text
160
+ vggt-low-vram/examples/llff_fern/images/005.png filter=lfs diff=lfs merge=lfs -text
161
+ vggt-low-vram/examples/llff_fern/images/006.png filter=lfs diff=lfs merge=lfs -text
162
+ vggt-low-vram/examples/llff_fern/images/007.png filter=lfs diff=lfs merge=lfs -text
163
+ vggt-low-vram/examples/llff_fern/images/008.png filter=lfs diff=lfs merge=lfs -text
164
+ vggt-low-vram/examples/llff_fern/images/009.png filter=lfs diff=lfs merge=lfs -text
165
+ vggt-low-vram/examples/llff_fern/images/010.png filter=lfs diff=lfs merge=lfs -text
166
+ vggt-low-vram/examples/llff_fern/images/011.png filter=lfs diff=lfs merge=lfs -text
167
+ vggt-low-vram/examples/llff_fern/images/012.png filter=lfs diff=lfs merge=lfs -text
168
+ vggt-low-vram/examples/llff_fern/images/013.png filter=lfs diff=lfs merge=lfs -text
169
+ vggt-low-vram/examples/llff_fern/images/014.png filter=lfs diff=lfs merge=lfs -text
170
+ vggt-low-vram/examples/llff_fern/images/015.png filter=lfs diff=lfs merge=lfs -text
171
+ vggt-low-vram/examples/llff_fern/images/016.png filter=lfs diff=lfs merge=lfs -text
172
+ vggt-low-vram/examples/llff_fern/images/017.png filter=lfs diff=lfs merge=lfs -text
173
+ vggt-low-vram/examples/llff_fern/images/018.png filter=lfs diff=lfs merge=lfs -text
174
+ vggt-low-vram/examples/llff_fern/images/019.png filter=lfs diff=lfs merge=lfs -text
175
+ vggt-low-vram/examples/llff_flower/images/000.png filter=lfs diff=lfs merge=lfs -text
176
+ vggt-low-vram/examples/llff_flower/images/001.png filter=lfs diff=lfs merge=lfs -text
177
+ vggt-low-vram/examples/llff_flower/images/002.png filter=lfs diff=lfs merge=lfs -text
178
+ vggt-low-vram/examples/llff_flower/images/003.png filter=lfs diff=lfs merge=lfs -text
179
+ vggt-low-vram/examples/llff_flower/images/004.png filter=lfs diff=lfs merge=lfs -text
180
+ vggt-low-vram/examples/llff_flower/images/005.png filter=lfs diff=lfs merge=lfs -text
181
+ vggt-low-vram/examples/llff_flower/images/006.png filter=lfs diff=lfs merge=lfs -text
182
+ vggt-low-vram/examples/llff_flower/images/007.png filter=lfs diff=lfs merge=lfs -text
183
+ vggt-low-vram/examples/llff_flower/images/008.png filter=lfs diff=lfs merge=lfs -text
184
+ vggt-low-vram/examples/llff_flower/images/009.png filter=lfs diff=lfs merge=lfs -text
185
+ vggt-low-vram/examples/llff_flower/images/010.png filter=lfs diff=lfs merge=lfs -text
186
+ vggt-low-vram/examples/llff_flower/images/011.png filter=lfs diff=lfs merge=lfs -text
187
+ vggt-low-vram/examples/llff_flower/images/012.png filter=lfs diff=lfs merge=lfs -text
188
+ vggt-low-vram/examples/llff_flower/images/013.png filter=lfs diff=lfs merge=lfs -text
189
+ vggt-low-vram/examples/llff_flower/images/014.png filter=lfs diff=lfs merge=lfs -text
190
+ vggt-low-vram/examples/llff_flower/images/015.png filter=lfs diff=lfs merge=lfs -text
191
+ vggt-low-vram/examples/llff_flower/images/016.png filter=lfs diff=lfs merge=lfs -text
192
+ vggt-low-vram/examples/llff_flower/images/017.png filter=lfs diff=lfs merge=lfs -text
193
+ vggt-low-vram/examples/llff_flower/images/018.png filter=lfs diff=lfs merge=lfs -text
194
+ vggt-low-vram/examples/llff_flower/images/019.png filter=lfs diff=lfs merge=lfs -text
195
+ vggt-low-vram/examples/llff_flower/images/020.png filter=lfs diff=lfs merge=lfs -text
196
+ vggt-low-vram/examples/llff_flower/images/021.png filter=lfs diff=lfs merge=lfs -text
197
+ vggt-low-vram/examples/llff_flower/images/022.png filter=lfs diff=lfs merge=lfs -text
198
+ vggt-low-vram/examples/llff_flower/images/023.png filter=lfs diff=lfs merge=lfs -text
199
+ vggt-low-vram/examples/llff_flower/images/024.png filter=lfs diff=lfs merge=lfs -text
200
+ vggt-low-vram/examples/room/images/no_overlap_1.png filter=lfs diff=lfs merge=lfs -text
201
+ vggt-low-vram/examples/room/images/no_overlap_2.jpg filter=lfs diff=lfs merge=lfs -text
202
+ vggt-low-vram/examples/room/images/no_overlap_3.jpg filter=lfs diff=lfs merge=lfs -text
203
+ vggt-low-vram/examples/room/images/no_overlap_4.jpg filter=lfs diff=lfs merge=lfs -text
204
+ vggt-low-vram/examples/room/images/no_overlap_5.jpg filter=lfs diff=lfs merge=lfs -text
205
+ vggt-low-vram/examples/room/images/no_overlap_6.jpg filter=lfs diff=lfs merge=lfs -text
206
+ vggt-low-vram/examples/room/images/no_overlap_7.jpg filter=lfs diff=lfs merge=lfs -text
207
+ vggt-low-vram/examples/room/images/no_overlap_8.jpg filter=lfs diff=lfs merge=lfs -text
208
+ vggt-low-vram/examples/single_cartoon/images/model_was_never_trained_on_single_image_or_cartoon.jpg filter=lfs diff=lfs merge=lfs -text
209
+ vggt-low-vram/examples/single_oil_painting/images/model_was_never_trained_on_single_image_or_oil_painting.png filter=lfs diff=lfs merge=lfs -text
210
+ vggt-low-vram/examples/videos/Colosseum.mp4 filter=lfs diff=lfs merge=lfs -text
211
+ vggt-low-vram/examples/videos/fern.mp4 filter=lfs diff=lfs merge=lfs -text
212
+ vggt-low-vram/examples/videos/great_wall.mp4 filter=lfs diff=lfs merge=lfs -text
213
+ vggt-low-vram/examples/videos/kitchen.mp4 filter=lfs diff=lfs merge=lfs -text
214
+ vggt-low-vram/examples/videos/pyramid.mp4 filter=lfs diff=lfs merge=lfs -text
215
+ vggt-low-vram/examples/videos/room.mp4 filter=lfs diff=lfs merge=lfs -text
216
+ vggt-low-vram/examples/videos/single_cartoon.mp4 filter=lfs diff=lfs merge=lfs -text
217
+ vggt-low-vram/examples/videos/single_oil_painting.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -4,21 +4,19 @@ app_file: gradio_demo.py
4
  sdk: gradio
5
  sdk_version: 5.17.1
6
  ---
7
- <div align="center">
8
- <h1>V-DPM: 4D Video Reconstruction with Dynamic Point Maps</h1>
9
 
10
- <a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
11
- <a href="https://huggingface.co/spaces/edgarsucar/vdpm"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
12
 
13
- **[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**
14
 
 
15
 
16
- [Edgar Sucar](https://edgarsucar.github.io/)\*, [Eldar Insafutdinov](https://eldar.insafutdinov.com/)\*, [Zihang Lai](https://scholar.google.com/citations?user=31eXgMYAAAAJ), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/)
17
- </div>
18
 
19
- ## Setup
 
 
 
20
 
21
- First, clone the repository and setup a virtual environment with [uv](https://github.com/astral-sh/uv):
22
 
23
  ```bash
24
  git clone git@github.com:eldar/vdpm.git
@@ -33,11 +31,6 @@ uv pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pyt
33
  uv pip install -r requirements.txt
34
  ```
35
 
36
- ## Viser demo
37
- ```bash
38
- python visualise.py ++vis.input_video=examples/videos/camel.mp4
39
- ```
40
-
41
  ## Gradio demo
42
  ```bash
43
  python gradio_demo.py
 
4
  sdk: gradio
5
  sdk_version: 5.17.1
6
  ---
 
 
7
 
8
+ ## V-DPM Low-VRAM Multi-View Inference
 
9
 
10
+ Multi-view modification of V-DPM with low-vram modification to improve inference speed on consumer graphics cards and allow longer sequences to finish. Tested on 3070Ti, H200, PRO 6000,
11
 
12
+ First, clone the repository and setup a virtual environment with [uv](https://github.com/astral-sh/uv):
13
 
 
 
14
 
15
+ Original model's aggregator stores intermediate outputs of all 24 attention blocks, but only 4 of them is used by prediction heads. Made it return only that 4.
16
+ del unused intermediate tensors to free memory for subsequent code
17
+ @torch.compile some functions (e.g. MLP with GELU, LayerNorm)
18
+ torch.cuda.empty_cache() when helpful
19
 
 
20
 
21
  ```bash
22
  git clone git@github.com:eldar/vdpm.git
 
31
  uv pip install -r requirements.txt
32
  ```
33
 
 
 
 
 
 
34
  ## Gradio demo
35
  ```bash
36
  python gradio_demo.py
dpm/aggregator.py CHANGED
@@ -245,9 +245,12 @@ class Aggregator(nn.Module):
245
 
246
  frame_idx = 0
247
  global_idx = 0
248
- output_list = []
 
 
249
 
250
- for _ in range(self.aa_block_num):
 
251
  for attn_type in self.aa_order:
252
  if attn_type == "frame":
253
  tokens, frame_idx, frame_intermediates = self._process_frame_attention(
@@ -260,15 +263,25 @@ class Aggregator(nn.Module):
260
  else:
261
  raise ValueError(f"Unknown attention type: {attn_type}")
262
 
263
- for i in range(len(frame_intermediates)):
264
- # concat frame and global intermediates, [B x S x P x 2C]
265
- concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
266
- output_list.append(concat_inter)
267
-
268
- del concat_inter
269
- del frame_intermediates
270
- del global_intermediates
271
- return output_list, self.patch_start_idx
 
 
 
 
 
 
 
 
 
 
272
 
273
  def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
274
  """
 
245
 
246
  frame_idx = 0
247
  global_idx = 0
248
+ # Only store intermediates for layers actually used by prediction heads
249
+ used_layer_idx = {4, 11, 17, 23}#{4, 11, 17, 23}
250
+ output_dict = {}
251
 
252
+ max_used_layer = max(used_layer_idx)
253
+ for block_iter in range(self.aa_block_num):
254
  for attn_type in self.aa_order:
255
  if attn_type == "frame":
256
  tokens, frame_idx, frame_intermediates = self._process_frame_attention(
 
263
  else:
264
  raise ValueError(f"Unknown attention type: {attn_type}")
265
 
266
+ if block_iter in used_layer_idx:
267
+ for i in range(len(frame_intermediates)):
268
+ # concat frame and global intermediates, [B x S x P x 2C]
269
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
270
+ # Store merged intermediates in FP16 to reduce GPU memory usage.
271
+ output_dict[block_iter] = concat_inter.half()
272
+
273
+ # Stop after the highest used layer to skip unnecessary blocks
274
+ if block_iter >= max_used_layer:
275
+ break
276
+
277
+ # Clean up
278
+ if 'frame_intermediates' in locals():
279
+ del frame_intermediates
280
+ if 'global_intermediates' in locals():
281
+ del global_intermediates
282
+ # Alias for CameraHead which indexes with [-1]
283
+ output_dict[-1] = output_dict[max_used_layer]
284
+ return output_dict, self.patch_start_idx
285
 
286
  def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
287
  """
dpm/decoder.py CHANGED
@@ -167,7 +167,7 @@ class Decoder(nn.Module):
167
  self,
168
  cfg,
169
  dim_in: int,
170
- intermediate_layer_idx: List[int] = [4, 11, 17, 23],
171
  patch_size=14,
172
  embed_dim=1024,
173
  depth=2,
@@ -279,13 +279,35 @@ class Decoder(nn.Module):
279
 
280
  self.use_reentrant = False # hardcoded to False
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  def get_condition_tokens(
283
  self,
284
  aggregated_tokens_list: List[torch.Tensor],
285
  cond_view_idxs: torch.Tensor
286
  ):
287
  # Use tokens from the last block for conditioning
288
- tokens_last = aggregated_tokens_list[-1] # [B S N_tok D]
289
  # Extract the camera tokens
290
  cond_token_idx = 1
291
  camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
 
167
  self,
168
  cfg,
169
  dim_in: int,
170
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
171
  patch_size=14,
172
  embed_dim=1024,
173
  depth=2,
 
279
 
280
  self.use_reentrant = False # hardcoded to False
281
 
282
+ # Track whether we have compiled the attention blocks for faster inference.
283
+ # We postpone compilation until after weights are loaded to keep state_dict names stable.
284
+ self._compiled = False
285
+
286
+
287
+ def compile_blocks(self):
288
+ """Compile decoder attention blocks for faster inference (Linux/macOS only)."""
289
+ import sys
290
+ if self._compiled or self.old_decoder:
291
+ return
292
+ if sys.platform.startswith("win"):
293
+ print("[vdpm] torch.compile is not supported on Windows; running in eager mode.")
294
+ self._compiled = True
295
+ return
296
+ for module_list in self.frame_blocks:
297
+ for i in range(len(module_list)):
298
+ module_list[i] = torch.compile(module_list[i])
299
+ for module_list in self.global_blocks:
300
+ for i in range(len(module_list)):
301
+ module_list[i] = torch.compile(module_list[i])
302
+ self._compiled = True
303
+
304
  def get_condition_tokens(
305
  self,
306
  aggregated_tokens_list: List[torch.Tensor],
307
  cond_view_idxs: torch.Tensor
308
  ):
309
  # Use tokens from the last block for conditioning
310
+ tokens_last = aggregated_tokens_list[self.intermediate_layer_idx[-1]] # [B S N_tok D]
311
  # Extract the camera tokens
312
  cond_token_idx = 1
313
  camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
dpm/model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
 
@@ -37,6 +38,7 @@ class VDPM(nn.Module):
37
  self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
38
 
39
  self.camera_head = CameraHead(dim_in=2 * embed_dim)
 
40
  self.set_freeze()
41
 
42
  def set_freeze(self):
@@ -64,9 +66,10 @@ class VDPM(nn.Module):
64
  aggregated_tokens_list, images, patch_start_idx
65
  )
66
 
67
- padded_decoded_tokens = [None] * len(aggregated_tokens_list)
68
- for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
69
- padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
 
70
  pts3d_dyn, pts3d_dyn_conf = self.point_head(
71
  padded_decoded_tokens, images, patch_start_idx
72
  )
@@ -91,19 +94,37 @@ class VDPM(nn.Module):
91
  images=None,
92
  num_timesteps=None
93
  ):
94
- autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
95
 
96
  if images is None:
97
  images = torch.stack([view["img"] for view in views], dim=1)
98
 
 
 
 
 
 
99
  with autocast_amp:
100
  aggregated_tokens_list, patch_start_idx = self.aggregator(images)
 
 
 
 
 
101
  S = images.shape[1]
102
 
103
  # Determine number of timesteps to query
104
  if num_timesteps is None:
105
- # Default to S if not specified (legacy behavior)
106
- # But if views has indices, try to infer max time
107
  if views is not None and "view_idxs" in views[0]:
108
  try:
109
  all_idxs = torch.cat([v["view_idxs"][:, 1] for v in views])
@@ -116,32 +137,110 @@ class VDPM(nn.Module):
116
  predictions = dict()
117
  pointmaps = []
118
  ones = torch.ones(1, S, dtype=torch.int64)
 
119
  for time_ in range(num_timesteps):
120
  cond_view_idxs = ones * time_
121
 
 
 
 
122
  with autocast_amp:
123
  decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
124
- padded_decoded_tokens = [None] * len(aggregated_tokens_list)
125
- for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
126
- padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
127
 
 
 
 
 
 
 
 
 
 
 
128
  pts3d, pts3d_conf = self.point_head(
129
  padded_decoded_tokens, images, patch_start_idx
130
  )
131
 
 
 
 
132
  pointmaps.append(dict(
133
  pts3d=pts3d,
134
  conf=pts3d_conf
135
  ))
136
 
 
 
 
137
  pose_enc_list = self.camera_head(aggregated_tokens_list)
138
- predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  predictions["pose_enc_list"] = pose_enc_list
140
  predictions["pointmaps"] = pointmaps
 
 
 
 
 
 
 
141
  return predictions
142
 
143
  def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
144
  # don't load these VGGT heads as not needed
145
  exclude = ["depth_head", "track_head"]
146
  ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
147
- return super().load_state_dict(ckpt, **kw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
  import torch
3
  import torch.nn as nn
4
 
 
38
  self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
39
 
40
  self.camera_head = CameraHead(dim_in=2 * embed_dim)
41
+ self.profile = False
42
  self.set_freeze()
43
 
44
  def set_freeze(self):
 
66
  aggregated_tokens_list, images, patch_start_idx
67
  )
68
 
69
+ padded_decoded_tokens = {
70
+ layer_idx: decoded_tokens[idx]
71
+ for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx)
72
+ }
73
  pts3d_dyn, pts3d_dyn_conf = self.point_head(
74
  padded_decoded_tokens, images, patch_start_idx
75
  )
 
94
  images=None,
95
  num_timesteps=None
96
  ):
97
+ profile = self.profile and torch.cuda.is_available()
98
+ if profile:
99
+ ev = lambda: torch.cuda.Event(enable_timing=True)
100
+ e_start, e_agg, e_cam_end = ev(), ev(), ev()
101
+ e_dec_starts, e_dec_ends = [], []
102
+ e_head_starts, e_head_ends = [], []
103
+ e_cam_start = ev()
104
+ mem_before = torch.cuda.memory_allocated() / 1024**3
105
+ e_start.record()
106
+
107
+ autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.float16)
108
 
109
  if images is None:
110
  images = torch.stack([view["img"] for view in views], dim=1)
111
 
112
+ # If not profiling per-stage, measure a single total inference time (minimal overhead)
113
+ if not profile:
114
+ torch.cuda.synchronize()
115
+ _t_start = time.time()
116
+
117
  with autocast_amp:
118
  aggregated_tokens_list, patch_start_idx = self.aggregator(images)
119
+
120
+ if profile:
121
+ e_agg.record()
122
+ mem_after_agg = torch.cuda.memory_allocated() / 1024**3
123
+
124
  S = images.shape[1]
125
 
126
  # Determine number of timesteps to query
127
  if num_timesteps is None:
 
 
128
  if views is not None and "view_idxs" in views[0]:
129
  try:
130
  all_idxs = torch.cat([v["view_idxs"][:, 1] for v in views])
 
137
  predictions = dict()
138
  pointmaps = []
139
  ones = torch.ones(1, S, dtype=torch.int64)
140
+
141
  for time_ in range(num_timesteps):
142
  cond_view_idxs = ones * time_
143
 
144
+ if profile:
145
+ e_ds = ev(); e_ds.record(); e_dec_starts.append(e_ds)
146
+
147
  with autocast_amp:
148
  decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
 
 
 
149
 
150
+ if profile:
151
+ e_de = ev(); e_de.record(); e_dec_ends.append(e_de)
152
+
153
+ padded_decoded_tokens = {
154
+ layer_idx: decoded_tokens[idx]
155
+ for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx)
156
+ }
157
+
158
+ if profile:
159
+ e_hs = ev(); e_hs.record(); e_head_starts.append(e_hs)
160
  pts3d, pts3d_conf = self.point_head(
161
  padded_decoded_tokens, images, patch_start_idx
162
  )
163
 
164
+ if profile:
165
+ e_he = ev(); e_he.record(); e_head_ends.append(e_he)
166
+
167
  pointmaps.append(dict(
168
  pts3d=pts3d,
169
  conf=pts3d_conf
170
  ))
171
 
172
+ if profile:
173
+ e_cam_start.record()
174
+
175
  pose_enc_list = self.camera_head(aggregated_tokens_list)
176
+
177
+ if profile:
178
+ e_cam_end.record()
179
+ torch.cuda.synchronize() # single sync at the very end
180
+ mem_peak = torch.cuda.max_memory_allocated() / 1024**3
181
+
182
+ t_agg = e_start.elapsed_time(e_agg) / 1000
183
+ t_dec = sum(s.elapsed_time(e) / 1000 for s, e in zip(e_dec_starts, e_dec_ends))
184
+ t_head = sum(s.elapsed_time(e) / 1000 for s, e in zip(e_head_starts, e_head_ends))
185
+ t_cam = e_cam_start.elapsed_time(e_cam_end) / 1000
186
+ t_total = e_start.elapsed_time(e_cam_end) / 1000
187
+
188
+ print(f" [PROFILE] Aggregator: {t_agg:.3f}s | VRAM: {mem_before:.2f} -> {mem_after_agg:.2f} GB (+{mem_after_agg - mem_before:.2f})")
189
+ print(f" [PROFILE] Stored layers: {sorted(k for k in aggregated_tokens_list if k >= 0)}")
190
+ print(f" [PROFILE] Decoder: {t_dec:.3f}s ({num_timesteps} timesteps, {t_dec/max(num_timesteps,1)*1000:.0f}ms each)")
191
+ print(f" [PROFILE] Point Head: {t_head:.3f}s ({num_timesteps} timesteps, {t_head/max(num_timesteps,1)*1000:.0f}ms each)")
192
+ print(f" [PROFILE] Camera Head:{t_cam:.3f}s")
193
+ print(f" [PROFILE] Total: {t_total:.3f}s | Peak VRAM: {mem_peak:.2f} GB")
194
+ print(f" [PROFILE] Breakdown: Agg {t_agg/t_total*100:.0f}% | Dec {t_dec/t_total*100:.0f}% | PtHead {t_head/t_total*100:.0f}% | CamHead {t_cam/t_total*100:.0f}%")
195
+
196
+ predictions["pose_enc"] = pose_enc_list[-1]
197
  predictions["pose_enc_list"] = pose_enc_list
198
  predictions["pointmaps"] = pointmaps
199
+
200
+ if not profile:
201
+ # single final sync and lightweight wall-clock timing
202
+ torch.cuda.synchronize()
203
+ t_total = time.time() - _t_start
204
+ print(f" [PROFILE] Total inference time: {t_total:.3f}s")
205
+
206
  return predictions
207
 
208
  def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
209
  # don't load these VGGT heads as not needed
210
  exclude = ["depth_head", "track_head"]
211
  ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
212
+
213
+ res = super().load_state_dict(ckpt, **kw)
214
+
215
+ # Compile decoder blocks after weights are loaded so state_dict keys match the checkpoint.
216
+ if hasattr(self, "decoder") and hasattr(self.decoder, "compile_blocks"):
217
+ self.decoder.compile_blocks()
218
+
219
+ return res
220
+
221
+ def to_fp16(self, keep_norm_fp32: bool = False):
222
+ """Convert model parameters and buffers to FP16 for inference.
223
+
224
+ Args:
225
+ keep_norm_fp32 (bool): If True, keep normalization layers (LayerNorm/BatchNorm)
226
+ in FP32 for numerical stability. If False, convert everything to FP16.
227
+ """
228
+ # Convert whole model to half first
229
+ self.half()
230
+
231
+ if keep_norm_fp32:
232
+ for m in self.modules():
233
+ if isinstance(m, (torch.nn.LayerNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.SyncBatchNorm)):
234
+ m.float()
235
+
236
+ # Ensure any stored dtype-sensitive tensors are converted appropriately
237
+ try:
238
+ # camera/register/time tokens are Parameters and are handled by self.half(),
239
+ # but ensure any other buffers are also cast
240
+ for name, buf in list(self._buffers.items()):
241
+ if isinstance(buf, torch.Tensor):
242
+ self.register_buffer(name, buf.half(), persistent=(getattr(buf, 'persistent', False)))
243
+ except Exception:
244
+ pass
245
+
246
+ return self
gradio_demo.py CHANGED
@@ -27,7 +27,11 @@ from dpm.model import VDPM
27
  from vggt.utils.load_fn import load_and_preprocess_images
28
  from util.depth import write_depth_to_png
29
 
30
-
 
 
 
 
31
  # ============================================================================
32
  # MEMORY OPTIMIZATION SETTINGS FOR 8GB GPUs (RTX 3070 Ti, 3060 Ti, etc.)
33
  # ============================================================================
@@ -65,17 +69,17 @@ if device == "cuda":
65
 
66
  print(f"\u2713 GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
67
 
68
- if vram_gb >= 22: # A10G (24GB), A100 (40/80GB), RTX 3090/4090 (24GB)
69
  MAX_FRAMES = 80
70
  print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
71
- elif vram_gb >= 14: # T4 (16GB), 4080 (16GB)
72
- MAX_FRAMES = 16
73
  print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
74
  elif vram_gb >= 7.5: # RTX 3070 Ti, 2080, etc (8GB)
75
- MAX_FRAMES = 8
76
  print(f" -> 8GB VRAM detected. Set MAX_FRAMES to {MAX_FRAMES}")
77
  else:
78
- MAX_FRAMES = 5
79
  print(f" -> Low VRAM (<8GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
80
  print(f"\u2713 TF32 enabled for faster matrix operations")
81
 
@@ -84,7 +88,12 @@ if device == "cuda":
84
  def load_cfg_from_cli() -> "omegaconf.DictConfig":
85
  if GlobalHydra.instance().is_initialized():
86
  GlobalHydra.instance().clear()
87
- overrides = sys.argv[1:]
 
 
 
 
 
88
  with initialize(config_path="configs"):
89
  return compose(config_name="visualise", overrides=overrides)
90
 
@@ -121,15 +130,23 @@ def load_model(cfg) -> VDPM:
121
 
122
  # Option 1: Use FP16/BF16 for all model weights (simple, ~2x memory/speed boost)
123
  if USE_HALF_PRECISION and not USE_QUANTIZATION:
124
- # Use BF16 on high-end GPUs (A100/H100/H200) for better stability
125
- if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
126
- print("Converting model to BF16 precision (more stable than FP16)...")
 
127
  model = model.to(torch.bfloat16)
128
- print("✓ Model converted to BF16: ~2x speed boost, no CUBLAS errors")
129
  else:
130
  print("Converting model to FP16 precision...")
131
- model = model.half()
132
- print("✓ Model converted to FP16: ~2x memory reduction (3GB -> 1.5GB)")
 
 
 
 
 
 
 
133
 
134
  # Option 2: Apply INT8 dynamic quantization (more aggressive, ~3-4x reduction)
135
  if USE_QUANTIZATION:
@@ -150,14 +167,8 @@ def load_model(cfg) -> VDPM:
150
  model = model.to(device)
151
 
152
  # Enable torch.compile for faster inference (PyTorch 2.0+)
153
- # Note: Disable compile if using quantization as they may conflict
154
- if not USE_QUANTIZATION:
155
- try:
156
- print("Compiling model with torch.compile for faster inference...")
157
- model = torch.compile(model, mode="reduce-overhead")
158
- print("✓ Model compilation successful")
159
- except Exception as e:
160
- print(f"Warning: torch.compile not available or failed: {e}")
161
 
162
  return model
163
 
@@ -241,7 +252,7 @@ def compute_normals_from_pointmap(point_map: np.ndarray) -> tuple[np.ndarray, np
241
  tangent_x_np = tangent_x_np.reshape(T, V, H, W, 3)
242
  tangent_y_np = tangent_y_np.reshape(T, V, H, W, 3)
243
 
244
- return normals_np.astype(np.float16), tangent_x_np.astype(np.float16), tangent_y_np.astype(np.float16)
245
 
246
 
247
  def compute_smooth_normals(normals: np.ndarray, kernel_size: int = 5) -> np.ndarray:
@@ -291,7 +302,7 @@ def compute_smooth_normals(normals: np.ndarray, kernel_size: int = 5) -> np.ndar
291
  if normals.ndim == 5:
292
  result = result.reshape(T, V, H, W, 3)
293
 
294
- return result.astype(np.float16)
295
 
296
 
297
  def compute_optical_flow(world_points: np.ndarray, extrinsics: np.ndarray = None, intrinsics: np.ndarray = None, num_views: int = 1) -> np.ndarray:
@@ -405,6 +416,18 @@ def write_optical_flow_to_png(outpath: str, flow: np.ndarray, max_flow: float =
405
  PIL.Image.fromarray(rgb).save(outpath)
406
 
407
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  def create_output_zip(target_dir: str) -> str:
409
  """Create a zip file containing all outputs for download.
410
 
@@ -422,12 +445,14 @@ def create_output_zip(target_dir: str) -> str:
422
  "tracks.npz",
423
  "poses.npz",
424
  "depths.npz",
425
- "normals.npz",
426
- "optical_flow.npz",
427
- "depths", # directory
428
- "normals", # directory
429
- "flow", # directory
430
- "images", # directory
 
 
431
  "meta.json",
432
  ]
433
 
@@ -1142,7 +1167,7 @@ def run_model(target_dir: str, model: VDPM, frame_id_arg=0, use_temporal_trackin
1142
  extrinsics, intrinsics = None, None
1143
 
1144
  # ============================================================================
1145
- # COMPUTE AND SAVE NORMALS
1146
  # ============================================================================
1147
  print("Computing surface normals from point maps...")
1148
  normals, tangent_x, tangent_y = compute_normals_from_pointmap(world_points_full)
@@ -1151,17 +1176,15 @@ def run_model(target_dir: str, model: VDPM, frame_id_arg=0, use_temporal_trackin
1151
  print("Computing smooth normals...")
1152
  smooth_normals = compute_smooth_normals(normals, kernel_size=7)
1153
 
1154
- # Save normals as npz
1155
- normals_npz_path = os.path.join(target_dir, "normals.npz")
1156
- print(f"Saving normals to {normals_npz_path}")
1157
  np.savez_compressed(
1158
- normals_npz_path,
1159
- normals=normals,
1160
- smooth_normals=smooth_normals,
1161
- tangent_x=tangent_x,
1162
- tangent_y=tangent_y,
1163
- num_views=num_views,
1164
- num_timesteps=num_timesteps
1165
  )
1166
 
1167
  # Save individual normal images as PNGs
@@ -1185,39 +1208,77 @@ def run_model(target_dir: str, model: VDPM, frame_id_arg=0, use_temporal_trackin
1185
  print(f"✓ Saved {T_norm * V_norm * 2} normal images (raw + smooth)")
1186
 
1187
  # ============================================================================
1188
- # COMPUTE AND SAVE OPTICAL FLOW
1189
  # ============================================================================
1190
- print("Computing optical flow from point trajectories...")
1191
- optical_flow = compute_optical_flow(world_points_full, extrinsics, intrinsics, num_views)
1192
-
1193
- # Save optical flow as npz
1194
- flow_npz_path = os.path.join(target_dir, "optical_flow.npz")
1195
- print(f"Saving optical flow to {flow_npz_path}")
1196
- np.savez_compressed(
1197
- flow_npz_path,
1198
- optical_flow=optical_flow,
1199
- num_views=num_views,
1200
- num_timesteps=num_timesteps
1201
- )
1202
 
1203
- # Save individual flow images as PNGs
1204
- flow_dir = os.path.join(target_dir, "flow")
1205
- os.makedirs(flow_dir, exist_ok=True)
1206
- print(f"Saving flow images to {flow_dir}/")
1207
 
1208
- # Compute global max flow magnitude for consistent visualization
1209
- max_flow_magnitude = np.sqrt((optical_flow**2).sum(axis=-1)).max()
1210
- if max_flow_magnitude == 0:
1211
- max_flow_magnitude = 1.0
1212
 
1213
- T_flow, V_flow = optical_flow.shape[:2]
1214
- for t in range(T_flow):
1215
- for v in range(V_flow):
1216
- flow_map = optical_flow[t, v]
1217
- png_path = os.path.join(flow_dir, f"flow_t{t:04d}_v{v:02d}.png")
1218
- write_optical_flow_to_png(png_path, flow_map, max_flow=max_flow_magnitude)
1219
 
1220
- print(f"✓ Saved {T_flow * V_flow} optical flow images")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1221
 
1222
  # (Moved saving logic to the end of function to capture all viz variables)
1223
 
@@ -1336,13 +1397,26 @@ def run_model(target_dir: str, model: VDPM, frame_id_arg=0, use_temporal_trackin
1336
 
1337
 
1338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1339
  # Save Results for Download (Final Format)
1340
  output_path = os.path.join(target_dir, "output_4d.npz")
1341
  save_dict = {
1342
  "world_points": world_points_s,
1343
  "world_points_conf": world_points_conf_s,
1344
- "world_points_tracks": world_points_tracks, # (S, T, H, W, 3) for tracking
1345
- "world_points_conf_tracks": world_points_conf_tracks, # (S, T, H, W)
1346
  "images": img_np_viz,
1347
  "images_raw": img_np[:, :, ::2, ::2], # Original images subsampled
1348
  "num_views": num_views,
 
27
  from vggt.utils.load_fn import load_and_preprocess_images
28
  from util.depth import write_depth_to_png
29
 
30
+ import torch
31
+ import torch._dynamo
32
+ torch._dynamo.config.suppress_errors = True
33
+ # This disables the specific fused GELU that's crashing
34
+ torch.set_float32_matmul_precision('high')
35
  # ============================================================================
36
  # MEMORY OPTIMIZATION SETTINGS FOR 8GB GPUs (RTX 3070 Ti, 3060 Ti, etc.)
37
  # ============================================================================
 
69
 
70
  print(f"\u2713 GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
71
 
72
+ if vram_gb >= 30: # A10G (24GB), A100 (40/80GB), RTX 3090/4090 (24GB)
73
  MAX_FRAMES = 80
74
  print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
75
+ elif vram_gb >= 15: # T4 (16GB), 4080 (16GB)
76
+ MAX_FRAMES = 32
77
  print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
78
  elif vram_gb >= 7.5: # RTX 3070 Ti, 2080, etc (8GB)
79
+ MAX_FRAMES = 24
80
  print(f" -> 8GB VRAM detected. Set MAX_FRAMES to {MAX_FRAMES}")
81
  else:
82
+ MAX_FRAMES = 8
83
  print(f" -> Low VRAM (<8GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
84
  print(f"\u2713 TF32 enabled for faster matrix operations")
85
 
 
88
  def load_cfg_from_cli() -> "omegaconf.DictConfig":
89
  if GlobalHydra.instance().is_initialized():
90
  GlobalHydra.instance().clear()
91
+ # In notebooks or some interactive shells, sys.argv contains kernel launch flags
92
+ # like `--f=...` which Hydra cannot parse as overrides. Filter those out.
93
+ if GlobalHydra.instance().is_initialized():
94
+ GlobalHydra.instance().clear()
95
+ raw_overrides = sys.argv[1:]
96
+ overrides = [o for o in raw_overrides if not (o.startswith("--f=") or "kernel" in o)]
97
  with initialize(config_path="configs"):
98
  return compose(config_name="visualise", overrides=overrides)
99
 
 
130
 
131
  # Option 1: Use FP16/BF16 for all model weights (simple, ~2x memory/speed boost)
132
  if USE_HALF_PRECISION and not USE_QUANTIZATION:
133
+ # Use BF16 only on Hopper+ (compute >= 9) where BF16 throughput matches FP16
134
+ # On Ampere (compute 8.x, e.g. 3070Ti), FP16 tensor cores are ~2x faster than BF16
135
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
136
+ print("Converting model to BF16 precision (Hopper+ GPU detected)...")
137
  model = model.to(torch.bfloat16)
138
+ print("✓ Model converted to BF16")
139
  else:
140
  print("Converting model to FP16 precision...")
141
+ # Convert model to full FP16 (with helper that can preserve norm layers if desired)
142
+ try:
143
+ model = model.to_fp16(keep_norm_fp32=False)
144
+ print("✓ Model converted to FP16 via to_fp16()")
145
+ except Exception:
146
+ # Fallback to half() if helper not available
147
+ model = model.half()
148
+ print("✓ Model converted to FP16 via half() (fallback)")
149
+
150
 
151
  # Option 2: Apply INT8 dynamic quantization (more aggressive, ~3-4x reduction)
152
  if USE_QUANTIZATION:
 
167
  model = model.to(device)
168
 
169
  # Enable torch.compile for faster inference (PyTorch 2.0+)
170
+ # torch.compile is handled per-block in decoder.compile_blocks()
171
+ # called during load_state_dict above
 
 
 
 
 
 
172
 
173
  return model
174
 
 
252
  tangent_x_np = tangent_x_np.reshape(T, V, H, W, 3)
253
  tangent_y_np = tangent_y_np.reshape(T, V, H, W, 3)
254
 
255
+ return normals_np.astype(np.float32), tangent_x_np.astype(np.float32), tangent_y_np.astype(np.float32)
256
 
257
 
258
  def compute_smooth_normals(normals: np.ndarray, kernel_size: int = 5) -> np.ndarray:
 
302
  if normals.ndim == 5:
303
  result = result.reshape(T, V, H, W, 3)
304
 
305
+ return result.astype(np.float32)
306
 
307
 
308
  def compute_optical_flow(world_points: np.ndarray, extrinsics: np.ndarray = None, intrinsics: np.ndarray = None, num_views: int = 1) -> np.ndarray:
 
416
  PIL.Image.fromarray(rgb).save(outpath)
417
 
418
 
419
+ def _flow_to_image_3d(flow_hw3: np.ndarray) -> np.ndarray:
420
+ """Convert a 3D flow map (H, W, 3) to an RGB image.
421
+
422
+ Maps each XYZ component to R, G, B scaled by the global max magnitude.
423
+ """
424
+ mag = np.linalg.norm(flow_hw3, axis=-1, keepdims=True)
425
+ max_mag = float(mag.max()) + 1e-8
426
+ norm = flow_hw3 / max_mag
427
+ rgb = np.clip((norm + 1.0) * 0.5, 0.0, 1.0)
428
+ return (rgb * 255).astype(np.uint8)
429
+
430
+
431
  def create_output_zip(target_dir: str) -> str:
432
  """Create a zip file containing all outputs for download.
433
 
 
445
  "tracks.npz",
446
  "poses.npz",
447
  "depths.npz",
448
+ "depth_normals.npz",
449
+ "scene_flow.npz",
450
+ "angular_flow.npz",
451
+ "depths", # directory
452
+ "normals", # directory
453
+ "scene_flow", # directory
454
+ "angular_flow", # directory
455
+ "images", # directory
456
  "meta.json",
457
  ]
458
 
 
1167
  extrinsics, intrinsics = None, None
1168
 
1169
  # ============================================================================
1170
+ # COMPUTE AND SAVE NORMALS (depth_normals.npz — matches trainer format)
1171
  # ============================================================================
1172
  print("Computing surface normals from point maps...")
1173
  normals, tangent_x, tangent_y = compute_normals_from_pointmap(world_points_full)
 
1176
  print("Computing smooth normals...")
1177
  smooth_normals = compute_smooth_normals(normals, kernel_size=7)
1178
 
1179
+ # Save as depth_normals.npz (the name the trainer/viewer expect)
1180
+ depth_normals_path = os.path.join(target_dir, "depth_normals.npz")
1181
+ print(f"Saving normals to {depth_normals_path}")
1182
  np.savez_compressed(
1183
+ depth_normals_path,
1184
+ depth_normals=normals.astype(np.float16), # key must be 'depth_normals'
1185
+ smooth_normals=smooth_normals.astype(np.float16),
1186
+ tangent_x=tangent_x.astype(np.float16),
1187
+ tangent_y=tangent_y.astype(np.float16),
 
 
1188
  )
1189
 
1190
  # Save individual normal images as PNGs
 
1208
  print(f"✓ Saved {T_norm * V_norm * 2} normal images (raw + smooth)")
1209
 
1210
  # ============================================================================
1211
+ # COMPUTE AND SAVE SCENE FLOW + ANGULAR FLOW (matches trainer format)
1212
  # ============================================================================
1213
+ # The trainer expects:
1214
+ # scene_flow.npz — keys "tXXXX_vYY" → (H, W, 3) float32
1215
+ # angular_flow.npz — keys "tXXXX_vYY" → (H, W, 9) float32
1216
+ # Computed from world_points_raw which has the full pairwise DPM output:
1217
+ # world_points_raw[t, s, h, w] = P_s(t, π₀)
1218
+ # tracks[frame_idx, ref_idx] style indexing
1219
+ # ============================================================================
1220
+ print("Computing scene flow and angular flow from pairwise DPM output...")
 
 
 
 
1221
 
1222
+ sf_npz_dict = {}
1223
+ af_npz_dict = {}
 
 
1224
 
1225
+ sf_dir = os.path.join(target_dir, "scene_flow")
1226
+ af_dir = os.path.join(target_dir, "angular_flow")
1227
+ os.makedirs(sf_dir, exist_ok=True)
1228
+ os.makedirs(af_dir, exist_ok=True)
1229
 
1230
+ # world_points_raw: (T_query, S_source, H, W, 3)
1231
+ # For multi-view: S = num_views * num_timesteps, interleaved as
1232
+ # [v0_t0, v1_t0, ..., v0_t1, v1_t1, ...]
1233
+ # frame_idx = t * num_views + v (same as the trainer's convention)
1234
+ H_raw, W_raw = world_points_raw.shape[2:4]
 
1235
 
1236
+ sf_count = 0
1237
+ for t in range(num_timesteps - 1):
1238
+ for v in range(num_views):
1239
+ frame_idx = t * num_views + v
1240
+ next_frame_idx = (t + 1) * num_views + v
1241
+
1242
+ # P(t): points at current time (frame queries itself)
1243
+ P_t = world_points_raw[frame_idx, frame_idx].astype(np.float32) # (H, W, 3)
1244
+ # P(t+1): where frame t's points are at time t+1
1245
+ P_t1 = world_points_raw[frame_idx, next_frame_idx].astype(np.float32) # (H, W, 3)
1246
+
1247
+ scene_flow = np.nan_to_num(P_t1 - P_t, nan=0.0, posinf=0.0, neginf=0.0)
1248
+
1249
+ key = f"t{t:04d}_v{v:02d}"
1250
+ sf_npz_dict[key] = scene_flow.astype(np.float32)
1251
+
1252
+ # Angular flow: normal difference + tangent frame difference
1253
+ n_t, tx_t, ty_t = compute_normals_from_pointmap(P_t[np.newaxis])
1254
+ n_t1, tx_t1, ty_t1 = compute_normals_from_pointmap(P_t1[np.newaxis])
1255
+
1256
+ delta_n = np.nan_to_num(n_t1[0] - n_t[0], nan=0.0)
1257
+ delta_tx = np.nan_to_num(tx_t1[0] - tx_t[0], nan=0.0)
1258
+ delta_ty = np.nan_to_num(ty_t1[0] - ty_t[0], nan=0.0)
1259
+ angular_flow = np.concatenate([delta_n, delta_tx, delta_ty], axis=-1) # (H, W, 9)
1260
+ af_npz_dict[key] = angular_flow.astype(np.float32)
1261
+
1262
+ # Save debug images
1263
+ try:
1264
+ sf_img = _flow_to_image_3d(scene_flow)
1265
+ PIL.Image.fromarray(sf_img).save(os.path.join(sf_dir, f"{key}.png"))
1266
+ af_img = _flow_to_image_3d(delta_n)
1267
+ PIL.Image.fromarray(af_img).save(os.path.join(af_dir, f"{key}.png"))
1268
+ except Exception:
1269
+ pass
1270
+
1271
+ sf_count += 1
1272
+ if sf_count <= 4:
1273
+ mag = np.linalg.norm(scene_flow, axis=-1)
1274
+ print(f" [{key}] scene flow: mean={mag.mean():.6f}, max={mag.max():.6f}")
1275
+
1276
+ sf_npz_path = os.path.join(target_dir, "scene_flow.npz")
1277
+ af_npz_path = os.path.join(target_dir, "angular_flow.npz")
1278
+ np.savez_compressed(sf_npz_path, **sf_npz_dict)
1279
+ np.savez_compressed(af_npz_path, **af_npz_dict)
1280
+ print(f"✓ Saved {sf_count} scene flow entries to {sf_npz_path}")
1281
+ print(f"✓ Saved {sf_count} angular flow entries (9ch) to {af_npz_path}")
1282
 
1283
  # (Moved saving logic to the end of function to capture all viz variables)
1284
 
 
1397
 
1398
 
1399
 
1400
+ # ================================================================
1401
+ # BUILD world_points_tracks in (TV, TV, H, W, 3) pairwise format
1402
+ # ================================================================
1403
+ # The trainer's _ensure_flow_targets expects:
1404
+ # tracks[frame_idx, ref_idx] where frame_idx = t * V + v
1405
+ # world_points_raw is (T_query, S_source, H, W, 3) — this IS the
1406
+ # pairwise DPM output with T_query = S_source = T * V.
1407
+ # We just need to save it at full resolution (no subsampling).
1408
+ TV = num_timesteps * num_views
1409
+ assert world_points_raw.shape[0] == TV and world_points_raw.shape[1] == TV, \
1410
+ f"Expected ({TV}, {TV}, H, W, 3) but got {world_points_raw.shape}"
1411
+ world_points_tracks_full = world_points_raw # (TV, TV, H, W, 3)
1412
+
1413
  # Save Results for Download (Final Format)
1414
  output_path = os.path.join(target_dir, "output_4d.npz")
1415
  save_dict = {
1416
  "world_points": world_points_s,
1417
  "world_points_conf": world_points_conf_s,
1418
+ "world_points_tracks": world_points_tracks_full, # (TV, TV, H, W, 3) pairwise
1419
+ "world_points_conf_tracks": world_points_conf_tracks, # (S, T, H, W) for full_sample
1420
  "images": img_np_viz,
1421
  "images_raw": img_np[:, :, ::2, ::2], # Original images subsampled
1422
  "num_views": num_views,
test.ipynb ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "bd12eb72",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "allocated MB: 0.0\n",
14
+ "reserved MB: 0.0\n"
15
+ ]
16
+ },
17
+ {
18
+ "ename": "AttributeError",
19
+ "evalue": "module 'dpm.model' has no attribute 'parameters'",
20
+ "output_type": "error",
21
+ "traceback": [
22
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
23
+ "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
24
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 12\u001b[39m\n\u001b[32m 9\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mallocated MB:\u001b[39m\u001b[33m\"\u001b[39m, torch.cuda.memory_allocated()/\u001b[32m1024\u001b[39m**\u001b[32m2\u001b[39m)\n\u001b[32m 10\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mreserved MB:\u001b[39m\u001b[33m\"\u001b[39m, torch.cuda.memory_reserved()/\u001b[32m1024\u001b[39m**\u001b[32m2\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m12\u001b[39m params_bytes = \u001b[38;5;28msum\u001b[39m(p.numel()*p.element_size() \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparameters\u001b[49m())\n\u001b[32m 13\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mparams MB:\u001b[39m\u001b[33m\"\u001b[39m, params_bytes/\u001b[32m1024\u001b[39m**\u001b[32m2\u001b[39m)\n\u001b[32m 15\u001b[39m dtype_bytes = {}\n",
25
+ "\u001b[31mAttributeError\u001b[39m: module 'dpm.model' has no attribute 'parameters'"
26
+ ]
27
+ }
28
+ ],
29
+ "source": [
30
+ "import torch\n",
31
+ "# Load model using the project's loader (may download if missing)\n",
32
+ "from gradio_demo import load_cfg_from_cli, load_model\n",
33
+ "cfg = load_cfg_from_cli()\n",
34
+ "vdpm_model = load_model(cfg)\n",
35
+ "vdpm_model.eval()\n",
36
+ "\n",
37
+ "# Lightweight diagnostics: VRAM and dtype breakdown\n",
38
+ "import gc\n",
39
+ "from collections import Counter\n",
40
+ "gc.collect()\n",
41
+ "torch.cuda.synchronize()\n",
42
+ "torch.cuda.empty_cache()\n",
43
+ "torch.cuda.synchronize()\n",
44
+ "print('allocated MB:', torch.cuda.memory_allocated()/1024**2)\n",
45
+ "print('reserved MB:', torch.cuda.memory_reserved()/1024**2)\n",
46
+ "# Ensure we have an nn.Module instance (some imports may shadow the name 'model')\n",
47
+ "import types\n",
48
+ "if isinstance(vdpm_model, types.ModuleType):\n",
49
+ " print('Warning: loader returned a module, not an instance. Trying to instantiate via load_model again.')\n",
50
+ " vdpm_model = load_model(cfg)\n",
51
+ "assert hasattr(vdpm_model, 'parameters'), f'Loaded object has no parameters: {type(vdpm_model)}'\n",
52
+ "params_bytes = sum(p.numel()*p.element_size() for p in vdpm_model.parameters())\n",
53
+ "print('params MB:', params_bytes/1024**2)\n",
54
+ "dtype_bytes = {}\n",
55
+ "for p in vdpm_model.parameters():\n",
56
+ " k = str(p.dtype)\n",
57
+ " dtype_bytes[k] = dtype_bytes.get(k,0) + p.numel()*p.element_size()\n",
58
+ "print('param bytes by dtype (MB):', {k:v/1024**2 for k,v in dtype_bytes.items()})\n",
59
+ "print('buffer dtype counts:', Counter(getattr(b,'dtype',None) for b in vdpm_model.buffers()))\n",
60
+ "\n",
61
+ "# Small aggregator inspect (one small dummy batch)\n",
62
+ "B, S, H, W = 1, min(4, 1 if not torch.cuda.is_available() else 4), 518, 518\n",
63
+ "dummy = torch.rand(B, S, 3, H, W, device='cuda' if torch.cuda.is_available() else 'cpu')\n",
64
+ "agg, patch_start = vdpm_model.aggregator(dummy)\n",
65
+ "for k, t in sorted(agg.items()):\n",
66
+ " print(k, t.device, t.dtype, tuple(t.shape), f\"{t.numel()*t.element_size()/1024**2:.2f} MB\")\n",
67
+ "\n",
68
+ "# Warmup + one timed inference\n",
69
+ "import time\n",
70
+ "for _ in range(2):\n",
71
+ " _ = vdpm_model.inference([{'img': torch.rand(1,3,518,518, device='cuda' if torch.cuda.is_available() else 'cpu')}])\n",
72
+ "torch.cuda.synchronize()\n",
73
+ "t0 = time.time()\n",
74
+ "_ = vdpm_model.inference([{'img': torch.rand(1,3,518,518, device='cuda' if torch.cuda.is_available() else 'cpu')}])\n",
75
+ "torch.cuda.synchronize()\n",
76
+ "print('inference time (s):', time.time() - t0)"
77
+ ]
78
+ }
79
+ ],
80
+ "metadata": {
81
+ "kernelspec": {
82
+ "display_name": "4dgs-dpm",
83
+ "language": "python",
84
+ "name": "python3"
85
+ },
86
+ "language_info": {
87
+ "codemirror_mode": {
88
+ "name": "ipython",
89
+ "version": 3
90
+ },
91
+ "file_extension": ".py",
92
+ "mimetype": "text/x-python",
93
+ "name": "python",
94
+ "nbconvert_exporter": "python",
95
+ "pygments_lexer": "ipython3",
96
+ "version": "3.12.12"
97
+ }
98
+ },
99
+ "nbformat": 4,
100
+ "nbformat_minor": 5
101
+ }
vggt-low-vram/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # SCM syntax highlighting & preventing 3-way merges
2
+ pixi.lock merge=binary linguist-language=YAML linguist-generated=true
vggt-low-vram/.gitignore ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .hydra/
2
+ output/
3
+ ckpt/
4
+ .gradio/
5
+ input_images_*
6
+ examples/*/sparse/
7
+ examples/*/outputs
8
+ examples/*/transforms.json
9
+ examples/*/sparse_pc.ply
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ **/__pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ pip-wheel-metadata/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ # pytype static type analyzer
143
+ .pytype/
144
+
145
+ # Profiling data
146
+ .prof
147
+
148
+ # Folder specific to your needs
149
+ **/tmp/
150
+ **/outputs/skyseg.onnx
151
+ skyseg.onnx
152
+
153
+ # pixi environments
154
+ .pixi
155
+ *.egg-info
vggt-low-vram/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
vggt-low-vram/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to vggt
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to vggt, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
vggt-low-vram/LICENSE.txt ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGT License
2
+
3
+ v1 Last Updated: July 29, 2025
4
+
5
+ “Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
6
+
7
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
8
+
9
+
10
+ “Documentation” means the specifications, manuals and documentation accompanying
11
+ Research Materials distributed by Meta.
12
+
13
+
14
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
18
+
19
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
20
+
21
+
22
+ 1. License Rights and Redistribution.
23
+
24
+
25
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
26
+
27
+ b. Redistribution and Use.
28
+
29
+
30
+ i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
31
+
32
+
33
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
34
+
35
+
36
+ iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
37
+ 2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
38
+
39
+
40
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
41
+
42
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
43
+
44
+ 5. Intellectual Property.
45
+
46
+
47
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
48
+
49
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
50
+
51
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
52
+
53
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
54
+
55
+
56
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
57
+
58
+
59
+ Acceptable Use Policy
60
+
61
+ Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
62
+
63
+ As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.
64
+
65
+ Prohibited Uses
66
+
67
+ You agree you will not use, or allow others to use, Research Materials to:
68
+
69
+ Violate the law or others’ rights, including to:
70
+ Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
71
+ Violence or terrorism
72
+ Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
73
+ Human trafficking, exploitation, and sexual violence
74
+ The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
75
+ Sexual solicitation
76
+ Any other criminal activity
77
+
78
+ Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
79
+
80
+ Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
81
+
82
+ Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
83
+
84
+ Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
85
+
86
+ Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials
87
+
88
+ Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
89
+
90
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
91
+
92
+ Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
93
+
94
+ Guns and illegal weapons (including weapon development)
95
+
96
+ Illegal drugs and regulated/controlled substances
97
+ Operation of critical infrastructure, transportation technologies, or heavy machinery
98
+
99
+ Self-harm or harm to others, including suicide, cutting, and eating disorders
100
+ Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
101
+
102
+ 3. Intentionally deceive or mislead others, including use of Research Materials related to the following:
103
+
104
+ Generating, promoting, or furthering fraud or the creation or promotion of disinformation
105
+ Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
106
+
107
+ Generating, promoting, or further distributing spam
108
+
109
+ Impersonating another individual without consent, authorization, or legal right
110
+
111
+ Representing that outputs of research materials or outputs from technology using Research Materials are human-generated
112
+
113
+ Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
114
+
115
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
vggt-low-vram/README.md ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Low-VRAM VGGT Inference
2
+
3
+ End-to-end 3D reconstruction model [VGGT](https://github.com/facebookresearch/vggt) optimized for VRAM usage. This fork is for **inference only**.
4
+
5
+ This optimized version uses **6 times less VRAM** than original without significant slow down. We have successfully run reconstruction for **150 images** with 8GB VRAM, or **1100 images** with 32GB VRAM.
6
+
7
+ ## Main optimizations
8
+ - Original model's aggregator stores intermediate outputs of all 24 attention blocks, but only 4 of them is used by prediction heads. Made it return only that 4.
9
+ - `torch.cuda.amp.autocast` doesn't seem to be smart. A lot of large tensors are still stored in FP32. Do mixed precision manually instead.
10
+ - For a lot of modules, `self.training` is True even with `torch.no_grad()`. Removed all branches that involve `self.training` that potentially lead to overhead.
11
+ - `del` unused intermediate tensors to free memory for subsequent code
12
+ - `@torch.compile` some functions (e.g. MLP with GELU, LayerNorm)
13
+ - `torch.cuda.empty_cache()` when helpful
14
+
15
+ ## Benchmark
16
+
17
+ Each benchmark runs a full forward through embedding, aggregator, and camera/depth/point heads. See `benchmark/` for code used for benchmark.
18
+
19
+ ### Ours results ([commit](https://github.com/harry7557558/vggt-low-vram/commit/100f7b5813c35561a425b1dd32f9d8bef10063fb)):
20
+
21
+ | | example room (8) | example kitchen (25) | Mip-NeRF 360 stump (125) | Mip-NeRF 360 room (311) | Zip-NeRF nyc (990) | IMC-PT bdbg-gate (1363) | Zip-NeRF london (1874) |
22
+ | :------- | :------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
23
+ | VRAM | 3.27&nbsp;GB | 3.50&nbsp;GB | 5.83&nbsp;GB | 10.97&nbsp;GB | 25.37&nbsp;GB | 58.03&nbsp;GB | 45.92&nbsp;GB |
24
+ | RTX&nbsp;5070<br>laptop&nbsp;(8GB) | 1.97&nbsp;s | 4.80&nbsp;s | 48.47&nbsp;s | - | - | - | - |
25
+ | RTX&nbsp;4090<br>(24GB) | 2.65&nbsp;s | 4.31&nbsp;s | 16.49&nbsp;s | 66.42&nbsp;s | - | - | - |
26
+ | RTX&nbsp;5090<br>(32GB) | 0.97&nbsp;s | 1.61&nbsp;s | 9.97&nbsp;s | 44.06&nbsp;s | 275.91&nbsp;s | - | - |
27
+ | RTX&nbsp;A6000<br>(48GB) | 1.40&nbsp;s | 3.47&nbsp;s | 21.71&nbsp;s | 103.45&nbsp;s | 687.31&nbsp;s | - | - |
28
+ | A100&nbsp;SXM4<br>(80GB) | 2.88&nbsp;s | 4.10&nbsp;s | 15.36&nbsp;s | 62.86&nbsp;s | 376.65&nbsp;s | 2163.30&nbsp;s | 1326.58&nbsp;s |
29
+ | H100&nbsp;NVL<br>(94GB) | 1.10&nbsp;s | 1.67&nbsp;s | 8.52&nbsp;s | 42.41&nbsp;s | 288.55&nbsp;s | 1733.15&nbsp;s | 1052.18&nbsp;s |
30
+
31
+ ### Baseline results ([commit](https://github.com/facebookresearch/vggt/commit/8492456ce358ee9a4fe3274e36d73106b640fb5c)):
32
+
33
+ | | example room (8) | example kitchen (25) | Mip-NeRF 360 stump (125) | Mip-NeRF 360 room (311) | Zip-NeRF nyc (990) | IMC-PT bdbg-gate (1363) | Zip-NeRF london (1874) |
34
+ | :------- | :------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
35
+ | VRAM | 9.72&nbsp;GB | 12.34&nbsp;GB | 31.52&nbsp;GB | 68.95&nbsp;GB | - | - | - |
36
+ | RTX&nbsp;5070<br>laptop&nbsp;(8GB) | - | - | - | - | - | - | - |
37
+ | RTX&nbsp;4090<br>(24GB) | 0.86&nbsp;s | 1.80&nbsp;s | - | - | - | - | - |
38
+ | RTX&nbsp;5090<br>(32GB) | 0.57&nbsp;s | 1.22&nbsp;s | - | - | - | - | - |
39
+ | RTX&nbsp;A6000<br>(48GB) | 0.92&nbsp;s | 2.22&nbsp;s | 19.77&nbsp;s | - | - | - | - |
40
+ | A100&nbsp;SXM4<br>(80GB) | 1.09&nbsp;s | 1.74&nbsp;s | 11.39&nbsp;s | 54.54&nbsp;s | - | - | - |
41
+ | H100&nbsp;NVL<br>(94GB) | 0.53&nbsp;s | 1.00&nbsp;s | 7.99&nbsp;s | 41.15&nbsp;s | - | - | - |
42
+
43
+
44
+ ### Baseline vs. Ours<br/>
45
+ <!-- (time is ratio of total time of runs that are successful with both methods) -->
46
+
47
+ Ours use multiple times less VRAM. Ours is consistently slower, but time difference is less significant for larger datasets.
48
+
49
+ | # images | 8 | 25 | 125 | 311 |
50
+ | :------- | :------: | :-------: | :-------: | :-------: |
51
+ | VRAM | 3.0&nbsp;x | 3.5&nbsp;x | 5.4&nbsp;x | 6.3&nbsp;x |
52
+ | Time | 0.44&nbsp;x | 0.53&nbsp;x | 0.86&nbsp;x | 0.91&nbsp;x |
53
+
54
+ ### Additional details
55
+
56
+ Hardware and platform:
57
+ - RTX 5070 is my local device (CUDA 12.8, PyTorch 2.7.1, Python 3.12, Ubuntu 24.04).
58
+ - Rest of GPUs are provided by https://cloud.vast.ai/ with "PyTorch (Vast)" image without further modification (CUDA 12.4/12.8, PyTorch 2.5.1/2.7.1, Python 3.10/3.12).
59
+
60
+ ## To Use
61
+
62
+ This fork can be installed in the same way as the original VGGT (i.e. follow the original instruction but change the git clone link). It runs the same checkpoints.
63
+
64
+ `demo_gradio.py` and `demo_viser.py` should work as before. `demo_colmap.py` may also work, although not throughly tested.
65
+
66
+ See `benchmark/benchmark.py` for a full example. For basic usage:
67
+
68
+ ```py
69
+ import torch
70
+ from vggt.models.vggt import VGGT
71
+ from vggt.utils.load_fn import load_and_preprocess_images
72
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
73
+
74
+ dtype = torch.bfloat16 # or torch.float16 or torch.float32
75
+ model = VGGT.from_pretrained("facebook/VGGT-1B").cuda().to(dtype)
76
+ images = load_and_preprocess_images(["...list of image paths"]).cuda().to(dtype)
77
+ with torch.no_grad():
78
+ predictions = model(images, verbose=True) # will compute in (mostly) dtype and output in FP32
79
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions['pose_enc'], images.shape[-2:])
80
+ ```
81
+
82
+ Some notes:
83
+
84
+ - Model weights and input images must have the same dtype. Model inference will run computation and store intermediate results largely in `images.dtype`.
85
+
86
+ - Do not use `autocast`. This fork does mixed precision manually in the FP16/BF16 case, doing so can lead to some intermediate tensors being stored in FP32 and increase VRAM usage.
87
+
88
+
89
+ ## Breaking Changes
90
+ - Precision loss caused by FP32 -> FP16/BF16
91
+ - You can always make it use FP32 and still enjoy significant VRAM reduction
92
+ - Notes in "To Use" section above
93
+ - Occasional compatibility issues introduced by `@torch.compile`
94
+ - Does not support training
95
+
96
+ ### Not (yet) tested, feel free to issue/PR if you are experiencing problems:
97
+ - Point tracking
98
+ - Different OS/Python/PyTorch/CUDA versions from above Benchmark section
99
+ - Multi GPU, or running on CPU
100
+
101
+ <div><br/></div>
102
+
103
+ ----
104
+ # ==== Original README below ====
105
+
106
+ <div><br/></div>
107
+
108
+
109
+ <div align="center">
110
+ <h1>VGGT: Visual Geometry Grounded Transformer</h1>
111
+
112
+ <a href="https://jytime.github.io/data/VGGT_CVPR25.pdf" target="_blank" rel="noopener noreferrer">
113
+ <img src="https://img.shields.io/badge/Paper-VGGT" alt="Paper PDF">
114
+ </a>
115
+ <a href="https://arxiv.org/abs/2503.11651"><img src="https://img.shields.io/badge/arXiv-2503.11651-b31b1b" alt="arXiv"></a>
116
+ <a href="https://vgg-t.github.io/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
117
+ <a href='https://huggingface.co/spaces/facebook/vggt'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
118
+
119
+
120
+ **[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**; **[Meta AI](https://ai.facebook.com/research/)**
121
+
122
+
123
+ [Jianyuan Wang](https://jytime.github.io/), [Minghao Chen](https://silent-chen.github.io/), [Nikita Karaev](https://nikitakaraevv.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/), [David Novotny](https://d-novotny.github.io/)
124
+ </div>
125
+
126
+ ```bibtex
127
+ @inproceedings{wang2025vggt,
128
+ title={VGGT: Visual Geometry Grounded Transformer},
129
+ author={Wang, Jianyuan and Chen, Minghao and Karaev, Nikita and Vedaldi, Andrea and Rupprecht, Christian and Novotny, David},
130
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
131
+ year={2025}
132
+ }
133
+ ```
134
+
135
+ ## Updates
136
+
137
+ - [July 29, 2025] We've updated the license for VGGT to permit **commercial use** (excluding military applications). All code in this repository is now under a commercial-use-friendly license. However, only the newly released checkpoint [**VGGT-1B-Commercial**](https://huggingface.co/facebook/VGGT-1B-Commercial) is licensed for commercial usage — the original checkpoint remains non-commercial. Full license details are available [here](https://github.com/facebookresearch/vggt/blob/main/LICENSE.txt). Access to the checkpoint requires completing an application form, which is processed by a system similar to LLaMA's approval workflow, automatically. The new checkpoint delivers similar performance to the original model. Please submit an issue if you notice a significant performance discrepancy.
138
+
139
+
140
+
141
+ - [July 6, 2025] Training code is now available in the `training` folder, including an example to finetune VGGT on a custom dataset.
142
+
143
+
144
+ - [June 13, 2025] Honored to receive the Best Paper Award at CVPR 2025! Apologies if I’m slow to respond to queries or GitHub issues these days. If you’re interested, our oral presentation is available [here](https://docs.google.com/presentation/d/1JVuPnuZx6RgAy-U5Ezobg73XpBi7FrOh/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true). Another long presentation can be found [here](https://docs.google.com/presentation/d/1aSv0e5PmH1mnwn2MowlJIajFUYZkjqgw/edit?usp=sharing&ouid=107115712143490405606&rtpof=true&sd=true) (Note: it’s shared in .pptx format with animations — quite large, but feel free to use it as a template if helpful.)
145
+
146
+
147
+ - [June 2, 2025] Added a script to run VGGT and save predictions in COLMAP format, with bundle adjustment support optional. The saved COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) or other NeRF/Gaussian splatting libraries.
148
+
149
+
150
+ - [May 3, 2025] Evaluation code for reproducing our camera pose estimation results on Co3D is now available in the [evaluation](https://github.com/facebookresearch/vggt/tree/evaluation) branch.
151
+
152
+
153
+ ## Overview
154
+
155
+ Visual Geometry Grounded Transformer (VGGT, CVPR 2025) is a feed-forward neural network that directly infers all key 3D attributes of a scene, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks, **from one, a few, or hundreds of its views, within seconds**.
156
+
157
+
158
+ ## Quick Start
159
+
160
+ First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub).
161
+
162
+ ```bash
163
+ git clone git@github.com:facebookresearch/vggt.git
164
+ cd vggt
165
+ pip install -r requirements.txt
166
+ ```
167
+
168
+ Alternatively, you can install VGGT as a package (<a href="docs/package.md">click here</a> for details).
169
+
170
+
171
+ Now, try the model with just a few lines of code:
172
+
173
+ ```python
174
+ import torch
175
+ from vggt.models.vggt import VGGT
176
+ from vggt.utils.load_fn import load_and_preprocess_images
177
+
178
+ device = "cuda" if torch.cuda.is_available() else "cpu"
179
+ # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
180
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
181
+
182
+ # Initialize the model and load the pretrained weights.
183
+ # This will automatically download the model weights the first time it's run, which may take a while.
184
+ model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
185
+
186
+ # Load and preprocess example images (replace with your own image paths)
187
+ image_names = ["path/to/imageA.png", "path/to/imageB.png", "path/to/imageC.png"]
188
+ images = load_and_preprocess_images(image_names).to(device)
189
+
190
+ with torch.no_grad():
191
+ with torch.cuda.amp.autocast(dtype=dtype):
192
+ # Predict attributes including cameras, depth maps, and point maps.
193
+ predictions = model(images)
194
+ ```
195
+
196
+ The model weights will be automatically downloaded from Hugging Face. If you encounter issues such as slow loading, you can manually download them [here](https://huggingface.co/facebook/VGGT-1B/blob/main/model.pt) and load, or:
197
+
198
+ ```python
199
+ model = VGGT()
200
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
201
+ model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
202
+ ```
203
+
204
+ ## Detailed Usage
205
+
206
+ <details>
207
+ <summary>Click to expand</summary>
208
+
209
+ You can also optionally choose which attributes (branches) to predict, as shown below. This achieves the same result as the example above. This example uses a batch size of 1 (processing a single scene), but it naturally works for multiple scenes.
210
+
211
+ ```python
212
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
213
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
214
+
215
+ with torch.no_grad():
216
+ with torch.cuda.amp.autocast(dtype=dtype):
217
+ images = images[None] # add batch dimension
218
+ aggregated_tokens_list, ps_idx = model.aggregator(images)
219
+
220
+ # Predict Cameras
221
+ pose_enc = model.camera_head(aggregated_tokens_list)[-1]
222
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
223
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
224
+
225
+ # Predict Depth Maps
226
+ depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
227
+
228
+ # Predict Point Maps
229
+ point_map, point_conf = model.point_head(aggregated_tokens_list, images, ps_idx)
230
+
231
+ # Construct 3D Points from Depth Maps and Cameras
232
+ # which usually leads to more accurate 3D points than point map branch
233
+ point_map_by_unprojection = unproject_depth_map_to_point_map(depth_map.squeeze(0),
234
+ extrinsic.squeeze(0),
235
+ intrinsic.squeeze(0))
236
+
237
+ # Predict Tracks
238
+ # choose your own points to track, with shape (N, 2) for one scene
239
+ query_points = torch.FloatTensor([[100.0, 200.0],
240
+ [60.72, 259.94]]).to(device)
241
+ track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])
242
+ ```
243
+
244
+
245
+ Furthermore, if certain pixels in the input frames are unwanted (e.g., reflective surfaces, sky, or water), you can simply mask them by setting the corresponding pixel values to 0 or 1. Precise segmentation masks aren't necessary - simple bounding box masks work effectively (check this [issue](https://github.com/facebookresearch/vggt/issues/47) for an example).
246
+
247
+ </details>
248
+
249
+
250
+ ## Interactive Demo
251
+
252
+ We provide multiple ways to visualize your 3D reconstructions. Before using these visualization tools, install the required dependencies:
253
+
254
+ ```bash
255
+ pip install -r requirements_demo.txt
256
+ ```
257
+
258
+ ### Interactive 3D Visualization
259
+
260
+ **Please note:** VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, independent of VGGT's processing time. The visualization is slow especially when the number of images is large.
261
+
262
+
263
+ #### Gradio Web Interface
264
+
265
+ Our Gradio-based interface allows you to upload images/videos, run reconstruction, and interactively explore the 3D scene in your browser. You can launch this in your local machine or try it on [Hugging Face](https://huggingface.co/spaces/facebook/vggt).
266
+
267
+
268
+ ```bash
269
+ python demo_gradio.py
270
+ ```
271
+
272
+ <details>
273
+ <summary>Click to preview the Gradio interactive interface</summary>
274
+
275
+ ![Gradio Web Interface Preview](https://jytime.github.io/data/vggt_hf_demo_screen.png)
276
+ </details>
277
+
278
+
279
+ #### Viser 3D Viewer
280
+
281
+ Run the following command to run reconstruction and visualize the point clouds in viser. Note this script requires a path to a folder containing images. It assumes only image files under the folder. You can set `--use_point_map` to use the point cloud from the point map branch, instead of the depth-based point cloud.
282
+
283
+ ```bash
284
+ python demo_viser.py --image_folder path/to/your/images/folder
285
+ ```
286
+
287
+ ## Exporting to COLMAP Format
288
+
289
+ We also support exporting VGGT's predictions directly to COLMAP format, by:
290
+
291
+ ```bash
292
+ # Feedforward prediction only
293
+ python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/
294
+
295
+ # With bundle adjustment
296
+ python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba
297
+
298
+ # Run with bundle adjustment using reduced parameters for faster processing
299
+ # Reduces max_query_pts from 4096 (default) to 2048 and query_frame_num from 8 (default) to 5
300
+ # Trade-off: Faster execution but potentially less robust reconstruction in complex scenes (you may consider setting query_frame_num equal to your total number of images)
301
+ # See demo_colmap.py for additional bundle adjustment configuration options
302
+ python demo_colmap.py --scene_dir=/YOUR/SCENE_DIR/ --use_ba --max_query_pts=2048 --query_frame_num=5
303
+ ```
304
+
305
+ Please ensure that the images are stored in `/YOUR/SCENE_DIR/images/`. This folder should contain only the images. Check the examples folder for the desired data structure.
306
+
307
+ The reconstruction result (camera parameters and 3D points) will be automatically saved under `/YOUR/SCENE_DIR/sparse/` in the COLMAP format, such as:
308
+
309
+ ```
310
+ SCENE_DIR/
311
+ ├── images/
312
+ └── sparse/
313
+ ├── cameras.bin
314
+ ├── images.bin
315
+ └── points3D.bin
316
+ ```
317
+
318
+ ## Integration with Gaussian Splatting
319
+
320
+
321
+ The exported COLMAP files can be directly used with [gsplat](https://github.com/nerfstudio-project/gsplat) for Gaussian Splatting training. Install `gsplat` following their official instructions (we recommend `gsplat==1.3.0`):
322
+
323
+ An example command to train the model is:
324
+ ```
325
+ cd gsplat
326
+ python examples/simple_trainer.py default --data_factor 1 --data_dir /YOUR/SCENE_DIR/ --result_dir /YOUR/RESULT_DIR/
327
+ ```
328
+
329
+
330
+
331
+ ## Zero-shot Single-view Reconstruction
332
+
333
+ Our model shows surprisingly good performance on single-view reconstruction, although it was never trained for this task. The model does not need to duplicate the single-view image to a pair, instead, it can directly infer the 3D structure from the tokens of the single view image. Feel free to try it with our demos above, which naturally works for single-view reconstruction.
334
+
335
+
336
+ We did not quantitatively test monocular depth estimation performance ourselves, but [@kabouzeid](https://github.com/kabouzeid) generously provided a comparison of VGGT to recent methods [here](https://github.com/facebookresearch/vggt/issues/36). VGGT shows competitive or better results compared to state-of-the-art monocular approaches such as DepthAnything v2 or MoGe, despite never being explicitly trained for single-view tasks.
337
+
338
+
339
+
340
+ ## Runtime and GPU Memory
341
+
342
+ We benchmark the runtime and GPU memory usage of VGGT's aggregator on a single NVIDIA H100 GPU across various input sizes.
343
+
344
+ | **Input Frames** | 1 | 2 | 4 | 8 | 10 | 20 | 50 | 100 | 200 |
345
+ |:----------------:|:-:|:-:|:-:|:-:|:--:|:--:|:--:|:---:|:---:|
346
+ | **Time (s)** | 0.04 | 0.05 | 0.07 | 0.11 | 0.14 | 0.31 | 1.04 | 3.12 | 8.75 |
347
+ | **Memory (GB)** | 1.88 | 2.07 | 2.45 | 3.23 | 3.63 | 5.58 | 11.41 | 21.15 | 40.63 |
348
+
349
+ Note that these results were obtained using Flash Attention 3, which is faster than the default Flash Attention 2 implementation while maintaining almost the same memory usage. Feel free to compile Flash Attention 3 from source to get better performance.
350
+
351
+
352
+ ## Research Progression
353
+
354
+ Our work builds upon a series of previous research projects. If you're interested in understanding how our research evolved, check out our previous works:
355
+
356
+
357
+ <table border="0" cellspacing="0" cellpadding="0">
358
+ <tr>
359
+ <td align="left">
360
+ <a href="https://github.com/jytime/Deep-SfM-Revisited">Deep SfM Revisited</a>
361
+ </td>
362
+ <td style="white-space: pre;">──┐</td>
363
+ <td></td>
364
+ </tr>
365
+ <tr>
366
+ <td align="left">
367
+ <a href="https://github.com/facebookresearch/PoseDiffusion">PoseDiffusion</a>
368
+ </td>
369
+ <td style="white-space: pre;">─────►</td>
370
+ <td>
371
+ <a href="https://github.com/facebookresearch/vggsfm">VGGSfM</a> ──►
372
+ <a href="https://github.com/facebookresearch/vggt">VGGT</a>
373
+ </td>
374
+ </tr>
375
+ <tr>
376
+ <td align="left">
377
+ <a href="https://github.com/facebookresearch/co-tracker">CoTracker</a>
378
+ </td>
379
+ <td style="white-space: pre;">──┘</td>
380
+ <td></td>
381
+ </tr>
382
+ </table>
383
+
384
+
385
+ ## Acknowledgements
386
+
387
+ Thanks to these great repositories: [PoseDiffusion](https://github.com/facebookresearch/PoseDiffusion), [VGGSfM](https://github.com/facebookresearch/vggsfm), [CoTracker](https://github.com/facebookresearch/co-tracker), [DINOv2](https://github.com/facebookresearch/dinov2), [Dust3r](https://github.com/naver/dust3r), [Moge](https://github.com/microsoft/moge), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [Sky Segmentation](https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing), [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), [Metric3D](https://github.com/YvanYin/Metric3D) and many other inspiring works in the community.
388
+
389
+ ## Checklist
390
+
391
+ - [x] Release the training code
392
+ - [ ] Release VGGT-500M and VGGT-200M
393
+
394
+
395
+ ## License
396
+ See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
397
+
398
+ Please note that only this [model checkpoint](https://huggingface.co/facebook/VGGT-1B-Commercial) allows commercial usage. This new checkpoint achieves the same performance level (might be slightly better) as the original one, e.g., AUC@30: 90.37 vs. 89.98 on the Co3D dataset.
vggt-low-vram/benchmark/benchmark.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from vggt.models.vggt import VGGT
3
+ from vggt.utils.load_fn import load_and_preprocess_images
4
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
5
+ from time import perf_counter
6
+ import os
7
+ from typing import List
8
+
9
+
10
+ def main(image_list: List[str], plot: bool):
11
+ device = "cuda"
12
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
13
+ # dtype = torch.float32
14
+
15
+ print("Loading model")
16
+ model = VGGT.from_pretrained("facebook/VGGT-1B").to(device).to(dtype)
17
+
18
+ print(f"Loading {len(image_list)} images")
19
+ images = load_and_preprocess_images(image_list).to(device).to(dtype)
20
+
21
+
22
+ torch.cuda.synchronize()
23
+ mem = torch.cuda.memory_allocated() / (1024**3)
24
+ print(f"Current VRAM usage (model weights + images): {mem:.2f} GiB")
25
+
26
+ torch.cuda.reset_peak_memory_stats()
27
+ time0 = perf_counter()
28
+
29
+ with torch.no_grad():
30
+ predictions = model(images, verbose=True)
31
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions['pose_enc'], images.shape[-2:])
32
+
33
+ torch.cuda.synchronize()
34
+ mem = torch.cuda.max_memory_allocated() / (1024**3)
35
+ print(f"Peak inference VRAM (including model/images): {mem:.2f} GiB")
36
+ dt = perf_counter() - time0
37
+ print(f"Inference time: {dt:.2f} s")
38
+
39
+
40
+ if not plot:
41
+ return
42
+
43
+ from plot_recon import plot_recon
44
+
45
+ plot_recon(
46
+ extrinsic.float().cpu().numpy()[0],
47
+ intrinsic.float().cpu().numpy()[0],
48
+ predictions["world_points"].float().cpu().numpy()[0],
49
+ images.float().cpu().numpy(),
50
+ frustum_size=0.05,
51
+ point_subsample=5000
52
+ )
53
+
54
+
55
+ if __name__ == "__main__":
56
+ import argparse
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--image_dir", type=str, required=True)
59
+ parser.add_argument("--plot", action="store_true")
60
+ args = parser.parse_args()
61
+
62
+ images = [os.path.join(args.image_dir, f) for f in sorted(os.listdir(args.image_dir))]
63
+ main(images, args.plot)
64
+
vggt-low-vram/benchmark/benchmark_baseline.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from vggt.models.vggt import VGGT
3
+ from vggt.utils.load_fn import load_and_preprocess_images
4
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
5
+ from time import perf_counter
6
+ import os
7
+ from typing import List
8
+
9
+
10
+ def main(image_list: List[str], plot: bool):
11
+ device = "cuda"
12
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
13
+ # dtype = torch.float32
14
+
15
+ print("Loading model")
16
+ model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
17
+
18
+ print(f"Loading {len(image_list)} images")
19
+ images = load_and_preprocess_images(image_list).to(device)
20
+
21
+
22
+ torch.cuda.synchronize()
23
+ mem = torch.cuda.memory_allocated() / (1024**3)
24
+ print(f"Current VRAM usage (model weights + images): {mem:.2f} GiB")
25
+
26
+ torch.cuda.reset_peak_memory_stats()
27
+ time0 = perf_counter()
28
+
29
+ with torch.no_grad():
30
+ with torch.cuda.amp.autocast(dtype=dtype):
31
+ predictions = model(images)
32
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions['pose_enc'], images.shape[-2:])
33
+
34
+ torch.cuda.synchronize()
35
+ mem = torch.cuda.max_memory_allocated() / (1024**3)
36
+ print(f"Peak inference VRAM (including model/images): {mem:.2f} GiB")
37
+ dt = perf_counter() - time0
38
+ print(f"Inference time: {dt:.2f} s")
39
+
40
+
41
+ if not plot:
42
+ return
43
+
44
+ from plot_recon import plot_recon
45
+
46
+ plot_recon(
47
+ extrinsic.float().cpu().numpy()[0],
48
+ intrinsic.float().cpu().numpy()[0],
49
+ predictions["world_points"].float().cpu().numpy()[0],
50
+ images.float().cpu().numpy(),
51
+ frustum_size=0.05,
52
+ point_subsample=5000
53
+ )
54
+
55
+
56
+ if __name__ == "__main__":
57
+ import argparse
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--image_dir", type=str, required=True)
60
+ parser.add_argument("--plot", action="store_true")
61
+ args = parser.parse_args()
62
+
63
+ images = [os.path.join(args.image_dir, f) for f in sorted(os.listdir(args.image_dir))]
64
+ main(images, args.plot)
65
+
vggt-low-vram/benchmark/plot_recon.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # thanks Claude
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from mpl_toolkits.mplot3d import Axes3D
6
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
7
+
8
+
9
+ def plot_recon(camera_poses, camera_matrices, world_points, rgb_images,
10
+ frustum_size=1.0, point_subsample=1000, camera_colors=None,
11
+ figsize=(12, 10), show_plot=True):
12
+ """
13
+ Plot 3D reconstruction results with camera frustums and point cloud.
14
+
15
+ Parameters:
16
+ -----------
17
+ camera_poses : np.ndarray, shape (N, 3, 4)
18
+ Camera poses in world coordinates [R|t] format
19
+ camera_matrices : np.ndarray, shape (N, 3, 3)
20
+ Camera intrinsic matrices K in OpenCV convention
21
+ world_points : np.ndarray, shape (N, H, W, 3)
22
+ Dense world-space points for each camera view
23
+ rgb_images : np.ndarray, shape (N, 3, H, W) or (N, H, W, 3)
24
+ Input RGB images corresponding to each camera
25
+ frustum_size : float, default=1.0
26
+ Size scaling factor for camera frustums
27
+ point_subsample : int, default=1000
28
+ Number of points to subsample from point cloud for visualization
29
+ camera_colors : list or None
30
+ Colors for each camera frustum. If None, uses default colormap
31
+ figsize : tuple, default=(12, 10)
32
+ Figure size for the plot
33
+ show_plot : bool, default=True
34
+ Whether to display the plot
35
+
36
+ Returns:
37
+ --------
38
+ fig : matplotlib.figure.Figure
39
+ The created figure object
40
+ ax : matplotlib.axes._subplots.Axes3DSubplot
41
+ The 3D axes object
42
+ """
43
+
44
+ # Handle different RGB image formats
45
+ if rgb_images.shape[1] == 3: # (N, 3, H, W) format
46
+ rgb_images = rgb_images.transpose(0, 2, 3, 1) # Convert to (N, H, W, 3)
47
+
48
+ N, H, W = world_points.shape[:3]
49
+
50
+ # Create figure and 3D axis
51
+ fig = plt.figure(figsize=figsize)
52
+ ax = fig.add_subplot(111, projection='3d')
53
+
54
+ # Extract and plot point cloud
55
+ valid_points = []
56
+ colors = []
57
+
58
+ for i in range(N):
59
+ # Get valid points (assuming invalid points have all zeros or very large values)
60
+ points_3d = world_points[i].reshape(-1, 3)
61
+ rgb_vals = rgb_images[i].reshape(-1, 3)
62
+
63
+ # Filter out invalid points (you may need to adjust this condition)
64
+ valid_mask = np.all(np.abs(points_3d) < 1000, axis=1) & np.any(points_3d != 0, axis=1)
65
+
66
+ if np.sum(valid_mask) > 0:
67
+ valid_points.append(points_3d[valid_mask])
68
+ colors.append(rgb_vals[valid_mask])
69
+
70
+ if valid_points:
71
+ all_points = np.vstack(valid_points)
72
+ all_colors = np.vstack(colors)
73
+
74
+ # Subsample points for visualization
75
+ if len(all_points) > point_subsample:
76
+ indices = np.random.choice(len(all_points), point_subsample, replace=False)
77
+ all_points = all_points[indices]
78
+ all_colors = all_colors[indices]
79
+
80
+ # Normalize colors to [0, 1] if needed
81
+ if all_colors.max() > 1.0:
82
+ all_colors = all_colors / 255.0
83
+
84
+ # Plot point cloud
85
+ ax.scatter(all_points[:, 0], all_points[:, 1], all_points[:, 2],
86
+ c=all_colors, s=1, alpha=0.6)
87
+
88
+ # Set up camera colors
89
+ if camera_colors is None:
90
+ cmap = plt.cm.tab10
91
+ camera_colors = [cmap(i % 10) for i in range(N)]
92
+
93
+ # Plot camera frustums
94
+ for i in range(N):
95
+ pose = camera_poses[i] # (3, 4) matrix [R|t]
96
+ K = camera_matrices[i] # (3, 3) intrinsic matrix
97
+
98
+ # Extract rotation and translation
99
+ R = pose[:, :3] # (3, 3) rotation matrix
100
+ t = pose[:, 3] # (3,) translation vector
101
+
102
+ # Camera center in world coordinates
103
+ camera_center = -R.T @ t
104
+
105
+ # Define image corners in pixel coordinates
106
+ corners_2d = np.array([
107
+ [0, 0, 1],
108
+ [W-1, 0, 1],
109
+ [W-1, H-1, 1],
110
+ [0, H-1, 1]
111
+ ]).T # (3, 4)
112
+
113
+ # Backproject to normalized camera coordinates
114
+ K_inv = np.linalg.inv(K)
115
+ corners_normalized = K_inv @ corners_2d # (3, 4)
116
+
117
+ # Scale by frustum size and transform to world coordinates
118
+ corners_cam = corners_normalized * frustum_size
119
+ corners_world = R.T @ corners_cam + camera_center[:, np.newaxis]
120
+
121
+ # Create frustum vertices
122
+ frustum_vertices = np.concatenate([
123
+ camera_center.reshape(1, 3),
124
+ corners_world.T
125
+ ], 0) # (5, 3) - camera center + 4 corners
126
+
127
+ # Define frustum faces (triangular faces forming the pyramid)
128
+ faces = [
129
+ [0, 1, 2], # Camera center to corner 0-1
130
+ [0, 2, 3], # Camera center to corner 1-2
131
+ [0, 3, 4], # Camera center to corner 2-3
132
+ [0, 4, 1], # Camera center to corner 3-0
133
+ [1, 2, 3, 4] # Far plane (rectangle)
134
+ ]
135
+
136
+ # Create and add frustum faces
137
+ frustum_collection = []
138
+ for face in faces[:-1]: # Triangular faces
139
+ triangle = frustum_vertices[face]
140
+ frustum_collection.append(triangle)
141
+
142
+ # Add rectangular far plane
143
+ rectangle = frustum_vertices[faces[-1]]
144
+ frustum_collection.append(rectangle)
145
+
146
+ # Add frustum to plot
147
+ poly3d = Poly3DCollection(frustum_collection,
148
+ facecolors=camera_colors[i],
149
+ alpha=0.3,
150
+ edgecolors='black',
151
+ linewidths=0.5)
152
+ ax.add_collection3d(poly3d)
153
+
154
+ # Plot camera center
155
+ # ax.scatter(camera_center[0], camera_center[1], camera_center[2],
156
+ # c=[camera_colors[i]], s=50, marker='o')
157
+
158
+ # Add camera label
159
+ ax.text(camera_center[0], camera_center[1], camera_center[2],
160
+ f' {i}', fontsize=8)
161
+
162
+ # Set axis properties
163
+ ax.set_xlabel('X')
164
+ ax.set_ylabel('Y')
165
+ ax.set_zlabel('Z')
166
+ ax.set_title(f'3D Reconstruction')
167
+
168
+ # Set equal aspect ratio
169
+ if valid_points:
170
+ max_range = np.max(all_points.max(axis=0) - all_points.min(axis=0)) / 2.0
171
+ mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
172
+ mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
173
+ mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
174
+
175
+ ax.set_xlim(mid_x - max_range, mid_x + max_range)
176
+ ax.set_ylim(mid_y - max_range, mid_y + max_range)
177
+ ax.set_zlim(mid_z - max_range, mid_z + max_range)
178
+
179
+ # Show plot if requested
180
+ if show_plot:
181
+ plt.show()
182
+
183
+ return fig, ax
184
+
185
+
186
+ # Example usage function
187
+ def example_usage():
188
+ """
189
+ Example of how to use the plot_3d_reconstruction function.
190
+ """
191
+ # Create synthetic data for demonstration
192
+ N = 5 # Number of cameras
193
+ H, W = 480, 640 # Image dimensions
194
+
195
+ # Synthetic camera poses (circular arrangement)
196
+ camera_poses = []
197
+ for i in range(N):
198
+ angle = 2 * np.pi * i / N
199
+
200
+ # Rotation matrix (looking towards center)
201
+ R = np.array([
202
+ [np.cos(angle), 0, np.sin(angle)],
203
+ [0, 1, 0],
204
+ [-np.sin(angle), 0, np.cos(angle)]
205
+ ])
206
+
207
+ # Translation (positioned in circle)
208
+ t = np.array([3 * np.cos(angle), 0, 3 * np.sin(angle)])
209
+
210
+ pose = np.column_stack([R, t])
211
+ camera_poses.append(pose)
212
+
213
+ camera_poses = np.array(camera_poses)
214
+
215
+ # Synthetic camera matrices
216
+ focal_length = 500
217
+ cx, cy = W // 2, H // 2
218
+ K = np.array([
219
+ [focal_length, 0, cx],
220
+ [0, focal_length, cy],
221
+ [0, 0, 1]
222
+ ])
223
+ camera_matrices = np.tile(K[np.newaxis], (N, 1, 1))
224
+
225
+ # Synthetic world points (random point cloud)
226
+ world_points = np.random.randn(N, H, W, 3) * 2
227
+
228
+ # Synthetic RGB images
229
+ rgb_images = np.random.rand(N, H, W, 3)
230
+
231
+ # Plot the reconstruction
232
+ fig, ax = plot_3d_reconstruction(
233
+ camera_poses=camera_poses,
234
+ camera_matrices=camera_matrices,
235
+ world_points=world_points,
236
+ rgb_images=rgb_images,
237
+ frustum_size=0.5,
238
+ point_subsample=5000
239
+ )
240
+
241
+ return fig, ax
242
+
243
+
244
+ if __name__ == "__main__":
245
+ # Run example
246
+ example_usage()
247
+
vggt-low-vram/benchmark/run_benchmark.bash ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ datasets=(
4
+ # warmup runs
5
+ "./examples/single_cartoon/images" # warmup run
6
+ "./examples/room/images" # warmup run
7
+ "./examples/kitchen/images" # warmup run
8
+
9
+ # examples that comes with original repository
10
+ "./examples/single_cartoon/images" # 1 image
11
+ "./examples/room/images" # 8 images
12
+ "./examples/kitchen/images" # 25 images
13
+
14
+ # larger public benchmark datasets
15
+ "../vggt_low_vram_benchmark/360_v2_stump_images_4" # 125 images
16
+ "../vggt_low_vram_benchmark/tnt_family" # 152 images
17
+ "../vggt_low_vram_benchmark/360_v2_room_images_4" # 311 images
18
+ "../vggt_low_vram_benchmark/zipnerf_nyc_undistorted_images_2" # 990 images
19
+ "../vggt_low_vram_benchmark/imc_pt_brandenburg_gate" # 1363 images
20
+ "../vggt_low_vram_benchmark/zipnerf_london_undistorted_images_2" # 1874 images
21
+ )
22
+
23
+ for dataset in "${datasets[@]}"; do
24
+ echo "Running $dataset"
25
+ python benchmark/benchmark.py --image_dir $dataset #--plot
26
+ # python benchmark/benchmark_baseline.py --image_dir $dataset #--plot
27
+ echo ""
28
+ done
vggt-low-vram/demo_colmap.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ import numpy as np
9
+ import glob
10
+ import os
11
+ import copy
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ # Configure CUDA settings
16
+ torch.backends.cudnn.enabled = True
17
+ torch.backends.cudnn.benchmark = True
18
+ torch.backends.cudnn.deterministic = False
19
+
20
+ import argparse
21
+ from pathlib import Path
22
+ import trimesh
23
+ import pycolmap
24
+
25
+
26
+ from vggt.models.vggt import VGGT
27
+ from vggt.utils.load_fn import load_and_preprocess_images_square
28
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
29
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
30
+ from vggt.utils.helper import create_pixel_coordinate_grid, randomly_limit_trues
31
+ from vggt.dependency.track_predict import predict_tracks
32
+ from vggt.dependency.np_to_pycolmap import batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track
33
+
34
+
35
+ # TODO: add support for masks
36
+ # TODO: add iterative BA
37
+ # TODO: add support for radial distortion, which needs extra_params
38
+ # TODO: test with more cases
39
+ # TODO: test different camera types
40
+
41
+
42
+ def parse_args():
43
+ parser = argparse.ArgumentParser(description="VGGT Demo")
44
+ parser.add_argument("--scene_dir", type=str, required=True, help="Directory containing the scene images")
45
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
46
+ parser.add_argument("--use_ba", action="store_true", default=False, help="Use BA for reconstruction")
47
+ ######### BA parameters #########
48
+ parser.add_argument(
49
+ "--max_reproj_error", type=float, default=8.0, help="Maximum reprojection error for reconstruction"
50
+ )
51
+ parser.add_argument("--shared_camera", action="store_true", default=False, help="Use shared camera for all images")
52
+ parser.add_argument("--camera_type", type=str, default="SIMPLE_PINHOLE", help="Camera type for reconstruction")
53
+ parser.add_argument("--vis_thresh", type=float, default=0.2, help="Visibility threshold for tracks")
54
+ parser.add_argument("--query_frame_num", type=int, default=8, help="Number of frames to query")
55
+ parser.add_argument("--max_query_pts", type=int, default=4096, help="Maximum number of query points")
56
+ parser.add_argument(
57
+ "--fine_tracking", action="store_true", default=True, help="Use fine tracking (slower but more accurate)"
58
+ )
59
+ parser.add_argument(
60
+ "--conf_thres_value", type=float, default=5.0, help="Confidence threshold value for depth filtering (wo BA)"
61
+ )
62
+ return parser.parse_args()
63
+
64
+
65
+ def run_VGGT(model, images, device, dtype, resolution=518):
66
+ # images: [B, 3, H, W]
67
+
68
+ assert len(images.shape) == 4
69
+ assert images.shape[1] == 3
70
+
71
+ # hard-coded to use 518 for VGGT
72
+ images = F.interpolate(images, size=(resolution, resolution), mode="bilinear", align_corners=False)
73
+ images = images.to(device, dtype)
74
+
75
+ with torch.no_grad():
76
+ images = images[None] # add batch dimension
77
+ aggregated_tokens_list, ps_idx = model.aggregator(images, verbose=True)
78
+
79
+ # Predict Cameras
80
+ pose_enc = model.camera_head(aggregated_tokens_list)[-1]
81
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
82
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
83
+ # Predict Depth Maps
84
+ depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)
85
+
86
+ extrinsic = extrinsic.squeeze(0).cpu().numpy()
87
+ intrinsic = intrinsic.squeeze(0).cpu().numpy()
88
+ depth_map = depth_map.squeeze(0).cpu().numpy()
89
+ depth_conf = depth_conf.squeeze(0).cpu().numpy()
90
+ return extrinsic, intrinsic, depth_map, depth_conf
91
+
92
+
93
+ def demo_fn(args):
94
+ # Print configuration
95
+ print("Arguments:", vars(args))
96
+
97
+ # Set seed for reproducibility
98
+ np.random.seed(args.seed)
99
+ torch.manual_seed(args.seed)
100
+ random.seed(args.seed)
101
+ if torch.cuda.is_available():
102
+ torch.cuda.manual_seed(args.seed)
103
+ torch.cuda.manual_seed_all(args.seed) # for multi-GPU
104
+ print(f"Setting seed as: {args.seed}")
105
+
106
+ # Set device and dtype
107
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
108
+ device = "cuda" if torch.cuda.is_available() else "cpu"
109
+ print(f"Using device: {device}")
110
+ print(f"Using dtype: {dtype}")
111
+
112
+ # Run VGGT for camera and depth estimation
113
+ model = VGGT()
114
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
115
+ model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
116
+ model.eval()
117
+ model = model.to(dtype=dtype, device=device)
118
+ print(f"Model loaded")
119
+
120
+ # Get image paths and preprocess them
121
+ image_dir = os.path.join(args.scene_dir, "images")
122
+ image_path_list = glob.glob(os.path.join(image_dir, "*"))
123
+ if len(image_path_list) == 0:
124
+ raise ValueError(f"No images found in {image_dir}")
125
+ base_image_path_list = [os.path.basename(path) for path in image_path_list]
126
+
127
+ # Load images and original coordinates
128
+ # Load Image in 1024, while running VGGT with 518
129
+ vggt_fixed_resolution = 518
130
+ img_load_resolution = 1024
131
+
132
+ images, original_coords = load_and_preprocess_images_square(image_path_list, img_load_resolution)
133
+ print(f"Loaded {len(images)} images from {image_dir}")
134
+
135
+ # Run VGGT to estimate camera and depth
136
+ # Run with 518x518 images
137
+ extrinsic, intrinsic, depth_map, depth_conf = run_VGGT(model, images, device, dtype, vggt_fixed_resolution)
138
+ points_3d = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic)
139
+ # images = images.float()
140
+
141
+ del model # free memory
142
+ torch.cuda.empty_cache()
143
+
144
+ images = images.to(device, dtype)
145
+ original_coords = original_coords.to(device)
146
+
147
+ if args.use_ba:
148
+ image_size = np.array(images.shape[-2:])
149
+ scale = img_load_resolution / vggt_fixed_resolution
150
+ shared_camera = args.shared_camera
151
+
152
+ # TODO: use VGGT tracker
153
+ with torch.inference_mode():
154
+ # Predicting Tracks
155
+ # Using VGGSfM tracker instead of VGGT tracker for efficiency
156
+ # VGGT tracker requires multiple backbone runs to query different frames (this is a problem caused by the training process)
157
+ # Will be fixed in VGGT v2
158
+
159
+ # You can also change the pred_tracks to tracks from any other methods
160
+ # e.g., from COLMAP, from CoTracker, or by chaining 2D matches from Lightglue/LoFTR.
161
+ pred_tracks, pred_vis_scores, pred_confs, points_3d, points_rgb = predict_tracks(
162
+ images,
163
+ conf=depth_conf,
164
+ points_3d=points_3d,
165
+ masks=None,
166
+ max_query_pts=args.max_query_pts,
167
+ query_frame_num=args.query_frame_num,
168
+ keypoint_extractor="aliked+sp",
169
+ fine_tracking=args.fine_tracking,
170
+ )
171
+
172
+ torch.cuda.empty_cache()
173
+
174
+ # rescale the intrinsic matrix from 518 to 1024
175
+ intrinsic[:, :2, :] *= scale
176
+ track_mask = pred_vis_scores > args.vis_thresh
177
+
178
+ # TODO: radial distortion, iterative BA, masks
179
+ reconstruction, valid_track_mask = batch_np_matrix_to_pycolmap(
180
+ points_3d,
181
+ extrinsic,
182
+ intrinsic,
183
+ pred_tracks,
184
+ image_size,
185
+ masks=track_mask,
186
+ max_reproj_error=args.max_reproj_error,
187
+ shared_camera=shared_camera,
188
+ camera_type=args.camera_type,
189
+ points_rgb=points_rgb,
190
+ )
191
+
192
+ if reconstruction is None:
193
+ raise ValueError("No reconstruction can be built with BA")
194
+
195
+ # Bundle Adjustment
196
+ ba_options = pycolmap.BundleAdjustmentOptions()
197
+ pycolmap.bundle_adjustment(reconstruction, ba_options)
198
+
199
+ reconstruction_resolution = img_load_resolution
200
+ else:
201
+ conf_thres_value = args.conf_thres_value
202
+ max_points_for_colmap = 100000 # randomly sample 3D points
203
+ shared_camera = False # in the feedforward manner, we do not support shared camera
204
+ camera_type = "PINHOLE" # in the feedforward manner, we only support PINHOLE camera
205
+
206
+ image_size = np.array([vggt_fixed_resolution, vggt_fixed_resolution])
207
+ num_frames, height, width, _ = points_3d.shape
208
+ images = images.float()
209
+
210
+ points_rgb = F.interpolate(
211
+ images, size=(vggt_fixed_resolution, vggt_fixed_resolution), mode="bilinear", align_corners=False
212
+ )
213
+ points_rgb = (points_rgb.cpu().numpy() * 255).astype(np.uint8)
214
+ points_rgb = points_rgb.transpose(0, 2, 3, 1)
215
+
216
+ # (S, H, W, 3), with x, y coordinates and frame indices
217
+ points_xyf = create_pixel_coordinate_grid(num_frames, height, width)
218
+
219
+ conf_mask = depth_conf >= conf_thres_value
220
+ # at most writing 100000 3d points to colmap reconstruction object
221
+ conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap)
222
+
223
+ points_3d = points_3d[conf_mask]
224
+ points_xyf = points_xyf[conf_mask]
225
+ points_rgb = points_rgb[conf_mask]
226
+
227
+ print("Converting to COLMAP format")
228
+ reconstruction = batch_np_matrix_to_pycolmap_wo_track(
229
+ points_3d,
230
+ points_xyf,
231
+ points_rgb,
232
+ extrinsic,
233
+ intrinsic,
234
+ image_size,
235
+ shared_camera=shared_camera,
236
+ camera_type=camera_type,
237
+ )
238
+
239
+ reconstruction_resolution = vggt_fixed_resolution
240
+
241
+ reconstruction = rename_colmap_recons_and_rescale_camera(
242
+ reconstruction,
243
+ base_image_path_list,
244
+ original_coords.cpu().numpy(),
245
+ img_size=reconstruction_resolution,
246
+ shift_point2d_to_original_res=True,
247
+ shared_camera=shared_camera,
248
+ )
249
+
250
+ print(f"Saving reconstruction to {args.scene_dir}/sparse")
251
+ sparse_reconstruction_dir = os.path.join(args.scene_dir, "sparse")
252
+ os.makedirs(sparse_reconstruction_dir, exist_ok=True)
253
+ reconstruction.write(sparse_reconstruction_dir)
254
+
255
+ # Save point cloud for fast visualization
256
+ trimesh.PointCloud(points_3d, colors=points_rgb).export(os.path.join(args.scene_dir, "sparse/points.ply"))
257
+
258
+ return True
259
+
260
+
261
+ def rename_colmap_recons_and_rescale_camera(
262
+ reconstruction, image_paths, original_coords, img_size, shift_point2d_to_original_res=False, shared_camera=False
263
+ ):
264
+ rescale_camera = True
265
+
266
+ for pyimageid in reconstruction.images:
267
+ # Reshaped the padded&resized image to the original size
268
+ # Rename the images to the original names
269
+ pyimage = reconstruction.images[pyimageid]
270
+ pycamera = reconstruction.cameras[pyimage.camera_id]
271
+ pyimage.name = image_paths[pyimageid - 1]
272
+
273
+ if rescale_camera:
274
+ # Rescale the camera parameters
275
+ pred_params = copy.deepcopy(pycamera.params)
276
+
277
+ real_image_size = original_coords[pyimageid - 1, -2:]
278
+ resize_ratio = max(real_image_size) / img_size
279
+ pred_params = pred_params * resize_ratio
280
+ real_pp = real_image_size / 2
281
+ pred_params[-2:] = real_pp # center of the image
282
+
283
+ pycamera.params = pred_params
284
+ pycamera.width = real_image_size[0]
285
+ pycamera.height = real_image_size[1]
286
+
287
+ if shift_point2d_to_original_res:
288
+ # Also shift the point2D to original resolution
289
+ top_left = original_coords[pyimageid - 1, :2]
290
+
291
+ for point2D in pyimage.points2D:
292
+ point2D.xy = (point2D.xy - top_left) * resize_ratio
293
+
294
+ if shared_camera:
295
+ # If shared_camera, all images share the same camera
296
+ # no need to rescale any more
297
+ rescale_camera = False
298
+
299
+ return reconstruction
300
+
301
+
302
+ if __name__ == "__main__":
303
+ args = parse_args()
304
+ with torch.no_grad():
305
+ demo_fn(args)
306
+
307
+
308
+ # Work in Progress (WIP)
309
+
310
+ """
311
+ VGGT Runner Script
312
+ =================
313
+
314
+ A script to run the VGGT model for 3D reconstruction from image sequences.
315
+
316
+ Directory Structure
317
+ ------------------
318
+ Input:
319
+ input_folder/
320
+ └── images/ # Source images for reconstruction
321
+
322
+ Output:
323
+ output_folder/
324
+ ├── images/
325
+ ├── sparse/ # Reconstruction results
326
+ │ ├── cameras.bin # Camera parameters (COLMAP format)
327
+ │ ├── images.bin # Pose for each image (COLMAP format)
328
+ │ ├── points3D.bin # 3D points (COLMAP format)
329
+ │ └── points.ply # Point cloud visualization file
330
+ └── visuals/ # Visualization outputs TODO
331
+
332
+ Key Features
333
+ -----------
334
+ • Dual-mode Support: Run reconstructions using either VGGT or VGGT+BA
335
+ • Resolution Preservation: Maintains original image resolution in camera parameters and tracks
336
+ • COLMAP Compatibility: Exports results in standard COLMAP sparse reconstruction format
337
+ """
vggt-low-vram/demo_gradio.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ import sys
13
+ import shutil
14
+ from datetime import datetime
15
+ import glob
16
+ import gc
17
+ import time
18
+
19
+ sys.path.append("vggt/")
20
+
21
+ from visual_util import predictions_to_glb
22
+ from vggt.models.vggt import VGGT
23
+ from vggt.utils.load_fn import load_and_preprocess_images
24
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
25
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
26
+
27
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ print("Initializing and loading VGGT model...")
31
+ # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
32
+
33
+ model = VGGT()
34
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
35
+ model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
36
+ model.eval()
37
+ model = model.to(dtype=dtype, device=device)
38
+
39
+
40
+ # -------------------------------------------------------------------------
41
+ # 1) Core model inference
42
+ # -------------------------------------------------------------------------
43
+ def run_model(target_dir, model) -> dict:
44
+ """
45
+ Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
46
+ """
47
+ print(f"Processing images from {target_dir}")
48
+
49
+ # Device check
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ if not torch.cuda.is_available():
52
+ raise ValueError("CUDA is not available. Check your environment.")
53
+
54
+ # Load and preprocess images
55
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
56
+ image_names = sorted(image_names)
57
+ print(f"Found {len(image_names)} images")
58
+ if len(image_names) == 0:
59
+ raise ValueError("No images found. Check your upload.")
60
+
61
+ images = load_and_preprocess_images(image_names).to(dtype=dtype, device=device)
62
+ print(f"Preprocessed images shape: {images.shape}")
63
+
64
+ # Run inference
65
+ print("Running inference...")
66
+
67
+ with torch.no_grad():
68
+ predictions = model(images, verbose=True)
69
+
70
+ # Convert pose encoding to extrinsic and intrinsic matrices
71
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
72
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
73
+ predictions["extrinsic"] = extrinsic
74
+ predictions["intrinsic"] = intrinsic
75
+
76
+ # Convert tensors to numpy
77
+ for key in predictions.keys():
78
+ if isinstance(predictions[key], torch.Tensor):
79
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
80
+ predictions['pose_enc_list'] = None # remove pose_enc_list
81
+
82
+ # Generate world points from depth map
83
+ print("Computing world points from depth map...")
84
+ depth_map = predictions["depth"] # (S, H, W, 1)
85
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
86
+ predictions["world_points_from_depth"] = world_points
87
+
88
+ # Clean up
89
+ torch.cuda.empty_cache()
90
+ return predictions
91
+
92
+
93
+ # -------------------------------------------------------------------------
94
+ # 2) Handle uploaded video/images --> produce target_dir + images
95
+ # -------------------------------------------------------------------------
96
+ def handle_uploads(input_video, input_images):
97
+ """
98
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
99
+ images or extracted frames from video into it. Return (target_dir, image_paths).
100
+ """
101
+ start_time = time.time()
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+
105
+ # Create a unique folder name
106
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
107
+ target_dir = f"input_images_{timestamp}"
108
+ target_dir_images = os.path.join(target_dir, "images")
109
+
110
+ # Clean up if somehow that folder already exists
111
+ if os.path.exists(target_dir):
112
+ shutil.rmtree(target_dir)
113
+ os.makedirs(target_dir)
114
+ os.makedirs(target_dir_images)
115
+
116
+ image_paths = []
117
+
118
+ # --- Handle images ---
119
+ if input_images is not None:
120
+ for file_data in input_images:
121
+ if isinstance(file_data, dict) and "name" in file_data:
122
+ file_path = file_data["name"]
123
+ else:
124
+ file_path = file_data
125
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
126
+ shutil.copy(file_path, dst_path)
127
+ image_paths.append(dst_path)
128
+
129
+ # --- Handle video ---
130
+ if input_video is not None:
131
+ if isinstance(input_video, dict) and "name" in input_video:
132
+ video_path = input_video["name"]
133
+ else:
134
+ video_path = input_video
135
+
136
+ vs = cv2.VideoCapture(video_path)
137
+ fps = vs.get(cv2.CAP_PROP_FPS)
138
+ frame_interval = int(fps * 1) # 1 frame/sec
139
+
140
+ count = 0
141
+ video_frame_num = 0
142
+ while True:
143
+ gotit, frame = vs.read()
144
+ if not gotit:
145
+ break
146
+ count += 1
147
+ if count % frame_interval == 0:
148
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
149
+ cv2.imwrite(image_path, frame)
150
+ image_paths.append(image_path)
151
+ video_frame_num += 1
152
+
153
+ # Sort final images for gallery
154
+ image_paths = sorted(image_paths)
155
+
156
+ end_time = time.time()
157
+ print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
158
+ return target_dir, image_paths
159
+
160
+
161
+ # -------------------------------------------------------------------------
162
+ # 3) Update gallery on upload
163
+ # -------------------------------------------------------------------------
164
+ def update_gallery_on_upload(input_video, input_images):
165
+ """
166
+ Whenever user uploads or changes files, immediately handle them
167
+ and show in the gallery. Return (target_dir, image_paths).
168
+ If nothing is uploaded, returns "None" and empty list.
169
+ """
170
+ if not input_video and not input_images:
171
+ return None, None, None, None
172
+ target_dir, image_paths = handle_uploads(input_video, input_images)
173
+ return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
174
+
175
+
176
+ # -------------------------------------------------------------------------
177
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
178
+ # -------------------------------------------------------------------------
179
+ def gradio_demo(
180
+ target_dir,
181
+ conf_thres=3.0,
182
+ frame_filter="All",
183
+ mask_black_bg=False,
184
+ mask_white_bg=False,
185
+ show_cam=True,
186
+ mask_sky=False,
187
+ prediction_mode="Pointmap Regression",
188
+ ):
189
+ """
190
+ Perform reconstruction using the already-created target_dir/images.
191
+ """
192
+ if not os.path.isdir(target_dir) or target_dir == "None":
193
+ return None, "No valid target directory found. Please upload first.", None, None
194
+
195
+ start_time = time.time()
196
+ gc.collect()
197
+ torch.cuda.empty_cache()
198
+
199
+ # Prepare frame_filter dropdown
200
+ target_dir_images = os.path.join(target_dir, "images")
201
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
202
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
203
+ frame_filter_choices = ["All"] + all_files
204
+
205
+ print("Running run_model...")
206
+ with torch.no_grad():
207
+ predictions = run_model(target_dir, model)
208
+
209
+ # Save predictions
210
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
211
+ np.savez(prediction_save_path, **predictions)
212
+
213
+ # Handle None frame_filter
214
+ if frame_filter is None:
215
+ frame_filter = "All"
216
+
217
+ # Build a GLB file name
218
+ glbfile = os.path.join(
219
+ target_dir,
220
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
221
+ )
222
+
223
+ # Convert predictions to GLB
224
+ glbscene = predictions_to_glb(
225
+ predictions,
226
+ conf_thres=conf_thres,
227
+ filter_by_frames=frame_filter,
228
+ mask_black_bg=mask_black_bg,
229
+ mask_white_bg=mask_white_bg,
230
+ show_cam=show_cam,
231
+ mask_sky=mask_sky,
232
+ target_dir=target_dir,
233
+ prediction_mode=prediction_mode,
234
+ )
235
+ glbscene.export(file_obj=glbfile)
236
+
237
+ # Cleanup
238
+ del predictions
239
+ gc.collect()
240
+ torch.cuda.empty_cache()
241
+
242
+ end_time = time.time()
243
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
244
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
245
+
246
+ return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
247
+
248
+
249
+ # -------------------------------------------------------------------------
250
+ # 5) Helper functions for UI resets + re-visualization
251
+ # -------------------------------------------------------------------------
252
+ def clear_fields():
253
+ """
254
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
255
+ """
256
+ return None
257
+
258
+
259
+ def update_log():
260
+ """
261
+ Display a quick log message while waiting.
262
+ """
263
+ return "Loading and Reconstructing..."
264
+
265
+
266
+ def update_visualization(
267
+ target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
268
+ ):
269
+ """
270
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
271
+ and return it for the 3D viewer. If is_example == "True", skip.
272
+ """
273
+
274
+ # If it's an example click, skip as requested
275
+ if is_example == "True":
276
+ return None, "No reconstruction available. Please click the Reconstruct button first."
277
+
278
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
279
+ return None, "No reconstruction available. Please click the Reconstruct button first."
280
+
281
+ predictions_path = os.path.join(target_dir, "predictions.npz")
282
+ if not os.path.exists(predictions_path):
283
+ return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
284
+
285
+ key_list = [
286
+ "pose_enc",
287
+ "depth",
288
+ "depth_conf",
289
+ "world_points",
290
+ "world_points_conf",
291
+ "images",
292
+ "extrinsic",
293
+ "intrinsic",
294
+ "world_points_from_depth",
295
+ ]
296
+
297
+ loaded = np.load(predictions_path)
298
+ predictions = {key: np.array(loaded[key]) for key in key_list}
299
+
300
+ glbfile = os.path.join(
301
+ target_dir,
302
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
303
+ )
304
+
305
+ if not os.path.exists(glbfile):
306
+ glbscene = predictions_to_glb(
307
+ predictions,
308
+ conf_thres=conf_thres,
309
+ filter_by_frames=frame_filter,
310
+ mask_black_bg=mask_black_bg,
311
+ mask_white_bg=mask_white_bg,
312
+ show_cam=show_cam,
313
+ mask_sky=mask_sky,
314
+ target_dir=target_dir,
315
+ prediction_mode=prediction_mode,
316
+ )
317
+ glbscene.export(file_obj=glbfile)
318
+
319
+ return glbfile, "Updating Visualization"
320
+
321
+
322
+ # -------------------------------------------------------------------------
323
+ # Example images
324
+ # -------------------------------------------------------------------------
325
+
326
+ great_wall_video = "examples/videos/great_wall.mp4"
327
+ colosseum_video = "examples/videos/Colosseum.mp4"
328
+ room_video = "examples/videos/room.mp4"
329
+ kitchen_video = "examples/videos/kitchen.mp4"
330
+ fern_video = "examples/videos/fern.mp4"
331
+ single_cartoon_video = "examples/videos/single_cartoon.mp4"
332
+ single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
333
+ pyramid_video = "examples/videos/pyramid.mp4"
334
+
335
+
336
+ # -------------------------------------------------------------------------
337
+ # 6) Build Gradio UI
338
+ # -------------------------------------------------------------------------
339
+ theme = gr.themes.Ocean()
340
+ theme.set(
341
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
342
+ checkbox_label_text_color_selected="*button_primary_text_color",
343
+ )
344
+
345
+ with gr.Blocks(
346
+ theme=theme,
347
+ css="""
348
+ .custom-log * {
349
+ font-style: italic;
350
+ font-size: 22px !important;
351
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
352
+ -webkit-background-clip: text;
353
+ background-clip: text;
354
+ font-weight: bold !important;
355
+ color: transparent !important;
356
+ text-align: center !important;
357
+ }
358
+
359
+ .example-log * {
360
+ font-style: italic;
361
+ font-size: 16px !important;
362
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
363
+ -webkit-background-clip: text;
364
+ background-clip: text;
365
+ color: transparent !important;
366
+ }
367
+
368
+ #my_radio .wrap {
369
+ display: flex;
370
+ flex-wrap: nowrap;
371
+ justify-content: center;
372
+ align-items: center;
373
+ }
374
+
375
+ #my_radio .wrap label {
376
+ display: flex;
377
+ width: 50%;
378
+ justify-content: center;
379
+ align-items: center;
380
+ margin: 0;
381
+ padding: 10px 0;
382
+ box-sizing: border-box;
383
+ }
384
+ """,
385
+ ) as demo:
386
+ # Instead of gr.State, we use a hidden Textbox:
387
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
388
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
389
+
390
+ gr.HTML(
391
+ """
392
+ <h1>🏛️ VGGT: Visual Geometry Grounded Transformer</h1>
393
+ <p>
394
+ <a href="https://github.com/facebookresearch/vggt">🐙 GitHub Repository</a> |
395
+ <a href="#">Project Page</a>
396
+ </p>
397
+
398
+ <div style="font-size: 16px; line-height: 1.5;">
399
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
400
+
401
+ <h3>Getting Started:</h3>
402
+ <ol>
403
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
404
+ <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
405
+ <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
406
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
407
+ <li>
408
+ <strong>Adjust Visualization (Optional):</strong>
409
+ After reconstruction, you can fine-tune the visualization using the options below
410
+ <details style="display:inline;">
411
+ <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
412
+ <ul>
413
+ <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
414
+ <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
415
+ <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
416
+ <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
417
+ <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
418
+ </ul>
419
+ </details>
420
+ </li>
421
+ </ol>
422
+ <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. </span></p>
423
+ </div>
424
+ """
425
+ )
426
+
427
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
428
+
429
+ with gr.Row():
430
+ with gr.Column(scale=2):
431
+ input_video = gr.Video(label="Upload Video", interactive=True)
432
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
433
+
434
+ image_gallery = gr.Gallery(
435
+ label="Preview",
436
+ columns=4,
437
+ height="300px",
438
+ show_download_button=True,
439
+ object_fit="contain",
440
+ preview=True,
441
+ )
442
+
443
+ with gr.Column(scale=4):
444
+ with gr.Column():
445
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
446
+ log_output = gr.Markdown(
447
+ "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
448
+ )
449
+ reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
450
+
451
+ with gr.Row():
452
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
453
+ clear_btn = gr.ClearButton(
454
+ [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
455
+ scale=1,
456
+ )
457
+
458
+ with gr.Row():
459
+ prediction_mode = gr.Radio(
460
+ ["Depthmap and Camera Branch", "Pointmap Branch"],
461
+ label="Select a Prediction Mode",
462
+ value="Depthmap and Camera Branch",
463
+ scale=1,
464
+ elem_id="my_radio",
465
+ )
466
+
467
+ with gr.Row():
468
+ conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
469
+ frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
470
+ with gr.Column():
471
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
472
+ mask_sky = gr.Checkbox(label="Filter Sky", value=False)
473
+ mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
474
+ mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
475
+
476
+ # ---------------------- Examples section ----------------------
477
+ examples = [
478
+ [colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
479
+ [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
480
+ [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
481
+ [single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"],
482
+ [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
483
+ [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
484
+ [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
485
+ ]
486
+
487
+ def example_pipeline(
488
+ input_video,
489
+ num_images_str,
490
+ input_images,
491
+ conf_thres,
492
+ mask_black_bg,
493
+ mask_white_bg,
494
+ show_cam,
495
+ mask_sky,
496
+ prediction_mode,
497
+ is_example_str,
498
+ ):
499
+ """
500
+ 1) Copy example images to new target_dir
501
+ 2) Reconstruct
502
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
503
+ We do NOT return is_example. It's just an input.
504
+ """
505
+ target_dir, image_paths = handle_uploads(input_video, input_images)
506
+ # Always use "All" for frame_filter in examples
507
+ frame_filter = "All"
508
+ glbfile, log_msg, dropdown = gradio_demo(
509
+ target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
510
+ )
511
+ return glbfile, log_msg, target_dir, dropdown, image_paths
512
+
513
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
514
+
515
+ gr.Examples(
516
+ examples=examples,
517
+ inputs=[
518
+ input_video,
519
+ num_images,
520
+ input_images,
521
+ conf_thres,
522
+ mask_black_bg,
523
+ mask_white_bg,
524
+ show_cam,
525
+ mask_sky,
526
+ prediction_mode,
527
+ is_example,
528
+ ],
529
+ outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
530
+ fn=example_pipeline,
531
+ cache_examples=False,
532
+ examples_per_page=50,
533
+ )
534
+
535
+ # -------------------------------------------------------------------------
536
+ # "Reconstruct" button logic:
537
+ # - Clear fields
538
+ # - Update log
539
+ # - gradio_demo(...) with the existing target_dir
540
+ # - Then set is_example = "False"
541
+ # -------------------------------------------------------------------------
542
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
543
+ fn=update_log, inputs=[], outputs=[log_output]
544
+ ).then(
545
+ fn=gradio_demo,
546
+ inputs=[
547
+ target_dir_output,
548
+ conf_thres,
549
+ frame_filter,
550
+ mask_black_bg,
551
+ mask_white_bg,
552
+ show_cam,
553
+ mask_sky,
554
+ prediction_mode,
555
+ ],
556
+ outputs=[reconstruction_output, log_output, frame_filter],
557
+ ).then(
558
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
559
+ )
560
+
561
+ # -------------------------------------------------------------------------
562
+ # Real-time Visualization Updates
563
+ # -------------------------------------------------------------------------
564
+ conf_thres.change(
565
+ update_visualization,
566
+ [
567
+ target_dir_output,
568
+ conf_thres,
569
+ frame_filter,
570
+ mask_black_bg,
571
+ mask_white_bg,
572
+ show_cam,
573
+ mask_sky,
574
+ prediction_mode,
575
+ is_example,
576
+ ],
577
+ [reconstruction_output, log_output],
578
+ )
579
+ frame_filter.change(
580
+ update_visualization,
581
+ [
582
+ target_dir_output,
583
+ conf_thres,
584
+ frame_filter,
585
+ mask_black_bg,
586
+ mask_white_bg,
587
+ show_cam,
588
+ mask_sky,
589
+ prediction_mode,
590
+ is_example,
591
+ ],
592
+ [reconstruction_output, log_output],
593
+ )
594
+ mask_black_bg.change(
595
+ update_visualization,
596
+ [
597
+ target_dir_output,
598
+ conf_thres,
599
+ frame_filter,
600
+ mask_black_bg,
601
+ mask_white_bg,
602
+ show_cam,
603
+ mask_sky,
604
+ prediction_mode,
605
+ is_example,
606
+ ],
607
+ [reconstruction_output, log_output],
608
+ )
609
+ mask_white_bg.change(
610
+ update_visualization,
611
+ [
612
+ target_dir_output,
613
+ conf_thres,
614
+ frame_filter,
615
+ mask_black_bg,
616
+ mask_white_bg,
617
+ show_cam,
618
+ mask_sky,
619
+ prediction_mode,
620
+ is_example,
621
+ ],
622
+ [reconstruction_output, log_output],
623
+ )
624
+ show_cam.change(
625
+ update_visualization,
626
+ [
627
+ target_dir_output,
628
+ conf_thres,
629
+ frame_filter,
630
+ mask_black_bg,
631
+ mask_white_bg,
632
+ show_cam,
633
+ mask_sky,
634
+ prediction_mode,
635
+ is_example,
636
+ ],
637
+ [reconstruction_output, log_output],
638
+ )
639
+ mask_sky.change(
640
+ update_visualization,
641
+ [
642
+ target_dir_output,
643
+ conf_thres,
644
+ frame_filter,
645
+ mask_black_bg,
646
+ mask_white_bg,
647
+ show_cam,
648
+ mask_sky,
649
+ prediction_mode,
650
+ is_example,
651
+ ],
652
+ [reconstruction_output, log_output],
653
+ )
654
+ prediction_mode.change(
655
+ update_visualization,
656
+ [
657
+ target_dir_output,
658
+ conf_thres,
659
+ frame_filter,
660
+ mask_black_bg,
661
+ mask_white_bg,
662
+ show_cam,
663
+ mask_sky,
664
+ prediction_mode,
665
+ is_example,
666
+ ],
667
+ [reconstruction_output, log_output],
668
+ )
669
+
670
+ # -------------------------------------------------------------------------
671
+ # Auto-update gallery whenever user uploads or changes their files
672
+ # -------------------------------------------------------------------------
673
+ input_video.change(
674
+ fn=update_gallery_on_upload,
675
+ inputs=[input_video, input_images],
676
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
677
+ )
678
+ input_images.change(
679
+ fn=update_gallery_on_upload,
680
+ inputs=[input_video, input_images],
681
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
682
+ )
683
+
684
+ demo.queue(max_size=20).launch(show_error=True, share=True)
vggt-low-vram/demo_viser.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import glob
9
+ import time
10
+ import threading
11
+ import argparse
12
+ from typing import List, Optional
13
+
14
+ import numpy as np
15
+ import torch
16
+ from tqdm.auto import tqdm
17
+ import viser
18
+ import viser.transforms as viser_tf
19
+ import cv2
20
+
21
+
22
+ try:
23
+ import onnxruntime
24
+ except ImportError:
25
+ print("onnxruntime not found. Sky segmentation may not work.")
26
+
27
+ from visual_util import segment_sky, download_file_from_url
28
+ from vggt.models.vggt import VGGT
29
+ from vggt.utils.load_fn import load_and_preprocess_images
30
+ from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
31
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
32
+
33
+
34
+ def viser_wrapper(
35
+ pred_dict: dict,
36
+ port: int = 8080,
37
+ init_conf_threshold: float = 50.0, # represents percentage (e.g., 50 means filter lowest 50%)
38
+ use_point_map: bool = False,
39
+ background_mode: bool = False,
40
+ mask_sky: bool = False,
41
+ image_folder: str = None,
42
+ ):
43
+ """
44
+ Visualize predicted 3D points and camera poses with viser.
45
+
46
+ Args:
47
+ pred_dict (dict):
48
+ {
49
+ "images": (S, 3, H, W) - Input images,
50
+ "world_points": (S, H, W, 3),
51
+ "world_points_conf": (S, H, W),
52
+ "depth": (S, H, W, 1),
53
+ "depth_conf": (S, H, W),
54
+ "extrinsic": (S, 3, 4),
55
+ "intrinsic": (S, 3, 3),
56
+ }
57
+ port (int): Port number for the viser server.
58
+ init_conf_threshold (float): Initial percentage of low-confidence points to filter out.
59
+ use_point_map (bool): Whether to visualize world_points or use depth-based points.
60
+ background_mode (bool): Whether to run the server in background thread.
61
+ mask_sky (bool): Whether to apply sky segmentation to filter out sky points.
62
+ image_folder (str): Path to the folder containing input images.
63
+ """
64
+ print(f"Starting viser server on port {port}")
65
+
66
+ server = viser.ViserServer(host="0.0.0.0", port=port)
67
+ server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
68
+
69
+ # Unpack prediction dict
70
+ images = pred_dict["images"] # (S, 3, H, W)
71
+ world_points_map = pred_dict["world_points"] # (S, H, W, 3)
72
+ conf_map = pred_dict["world_points_conf"] # (S, H, W)
73
+
74
+ depth_map = pred_dict["depth"] # (S, H, W, 1)
75
+ depth_conf = pred_dict["depth_conf"] # (S, H, W)
76
+
77
+ extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
78
+ intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
79
+
80
+ # Compute world points from depth if not using the precomputed point map
81
+ if not use_point_map:
82
+ world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
83
+ conf = depth_conf
84
+ else:
85
+ world_points = world_points_map
86
+ conf = conf_map
87
+
88
+ # Apply sky segmentation if enabled
89
+ if mask_sky and image_folder is not None:
90
+ conf = apply_sky_segmentation(conf, image_folder)
91
+
92
+ # Convert images from (S, 3, H, W) to (S, H, W, 3)
93
+ # Then flatten everything for the point cloud
94
+ colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
95
+ S, H, W, _ = world_points.shape
96
+
97
+ # Flatten
98
+ points = world_points.reshape(-1, 3)
99
+ colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
100
+ conf_flat = conf.reshape(-1)
101
+
102
+ cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
103
+ # For convenience, we store only (3,4) portion
104
+ cam_to_world = cam_to_world_mat[:, :3, :]
105
+
106
+ # Compute scene center and recenter
107
+ scene_center = np.mean(points, axis=0)
108
+ points_centered = points - scene_center
109
+ cam_to_world[..., -1] -= scene_center
110
+
111
+ # Store frame indices so we can filter by frame
112
+ frame_indices = np.repeat(np.arange(S), H * W)
113
+
114
+ # Build the viser GUI
115
+ gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
116
+
117
+ # Now the slider represents percentage of points to filter out
118
+ gui_points_conf = server.gui.add_slider(
119
+ "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
120
+ )
121
+
122
+ gui_frame_selector = server.gui.add_dropdown(
123
+ "Show Points from Frames", options=["All"] + [str(i) for i in range(S)], initial_value="All"
124
+ )
125
+
126
+ # Create the main point cloud handle
127
+ # Compute the threshold value as the given percentile
128
+ init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
129
+ init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
130
+ point_cloud = server.scene.add_point_cloud(
131
+ name="viser_pcd",
132
+ points=points_centered[init_conf_mask],
133
+ colors=colors_flat[init_conf_mask],
134
+ point_size=0.001,
135
+ point_shape="circle",
136
+ )
137
+
138
+ # We will store references to frames & frustums so we can toggle visibility
139
+ frames: List[viser.FrameHandle] = []
140
+ frustums: List[viser.CameraFrustumHandle] = []
141
+
142
+ def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None:
143
+ """
144
+ Add camera frames and frustums to the scene.
145
+ extrinsics: (S, 3, 4)
146
+ images_: (S, 3, H, W)
147
+ """
148
+ # Clear any existing frames or frustums
149
+ for f in frames:
150
+ f.remove()
151
+ frames.clear()
152
+ for fr in frustums:
153
+ fr.remove()
154
+ frustums.clear()
155
+
156
+ # Optionally attach a callback that sets the viewpoint to the chosen camera
157
+ def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
158
+ @frustum.on_click
159
+ def _(_) -> None:
160
+ for client in server.get_clients().values():
161
+ client.camera.wxyz = frame.wxyz
162
+ client.camera.position = frame.position
163
+
164
+ img_ids = range(S)
165
+ for img_id in tqdm(img_ids):
166
+ cam2world_3x4 = extrinsics[img_id]
167
+ T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4)
168
+
169
+ # Add a small frame axis
170
+ frame_axis = server.scene.add_frame(
171
+ f"frame_{img_id}",
172
+ wxyz=T_world_camera.rotation().wxyz,
173
+ position=T_world_camera.translation(),
174
+ axes_length=0.05,
175
+ axes_radius=0.002,
176
+ origin_radius=0.002,
177
+ )
178
+ frames.append(frame_axis)
179
+
180
+ # Convert the image for the frustum
181
+ img = images_[img_id] # shape (3, H, W)
182
+ img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
183
+ h, w = img.shape[:2]
184
+
185
+ # If you want correct FOV from intrinsics, do something like:
186
+ # fx = intrinsics_cam[img_id, 0, 0]
187
+ # fov = 2 * np.arctan2(h/2, fx)
188
+ # For demonstration, we pick a simple approximate FOV:
189
+ fy = 1.1 * h
190
+ fov = 2 * np.arctan2(h / 2, fy)
191
+
192
+ # Add the frustum
193
+ frustum_cam = server.scene.add_camera_frustum(
194
+ f"frame_{img_id}/frustum", fov=fov, aspect=w / h, scale=0.05, image=img, line_width=1.0
195
+ )
196
+ frustums.append(frustum_cam)
197
+ attach_callback(frustum_cam, frame_axis)
198
+
199
+ def update_point_cloud() -> None:
200
+ """Update the point cloud based on current GUI selections."""
201
+ # Here we compute the threshold value based on the current percentage
202
+ current_percentage = gui_points_conf.value
203
+ threshold_val = np.percentile(conf_flat, current_percentage)
204
+
205
+ print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
206
+
207
+ conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
208
+
209
+ if gui_frame_selector.value == "All":
210
+ frame_mask = np.ones_like(conf_mask, dtype=bool)
211
+ else:
212
+ selected_idx = int(gui_frame_selector.value)
213
+ frame_mask = frame_indices == selected_idx
214
+
215
+ combined_mask = conf_mask & frame_mask
216
+ point_cloud.points = points_centered[combined_mask]
217
+ point_cloud.colors = colors_flat[combined_mask]
218
+
219
+ @gui_points_conf.on_update
220
+ def _(_) -> None:
221
+ update_point_cloud()
222
+
223
+ @gui_frame_selector.on_update
224
+ def _(_) -> None:
225
+ update_point_cloud()
226
+
227
+ @gui_show_frames.on_update
228
+ def _(_) -> None:
229
+ """Toggle visibility of camera frames and frustums."""
230
+ for f in frames:
231
+ f.visible = gui_show_frames.value
232
+ for fr in frustums:
233
+ fr.visible = gui_show_frames.value
234
+
235
+ # Add the camera frames to the scene
236
+ visualize_frames(cam_to_world, images)
237
+
238
+ print("Starting viser server...")
239
+ # If background_mode is True, spawn a daemon thread so the main thread can continue.
240
+ if background_mode:
241
+
242
+ def server_loop():
243
+ while True:
244
+ time.sleep(0.001)
245
+
246
+ thread = threading.Thread(target=server_loop, daemon=True)
247
+ thread.start()
248
+ else:
249
+ while True:
250
+ time.sleep(0.01)
251
+
252
+ return server
253
+
254
+
255
+ # Helper functions for sky segmentation
256
+
257
+
258
+ def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
259
+ """
260
+ Apply sky segmentation to confidence scores.
261
+
262
+ Args:
263
+ conf (np.ndarray): Confidence scores with shape (S, H, W)
264
+ image_folder (str): Path to the folder containing input images
265
+
266
+ Returns:
267
+ np.ndarray: Updated confidence scores with sky regions masked out
268
+ """
269
+ S, H, W = conf.shape
270
+ sky_masks_dir = image_folder.rstrip("/") + "_sky_masks"
271
+ os.makedirs(sky_masks_dir, exist_ok=True)
272
+
273
+ # Download skyseg.onnx if it doesn't exist
274
+ if not os.path.exists("skyseg.onnx"):
275
+ print("Downloading skyseg.onnx...")
276
+ download_file_from_url("https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx")
277
+
278
+ skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
279
+ image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
280
+ sky_mask_list = []
281
+
282
+ print("Generating sky masks...")
283
+ for i, image_path in enumerate(tqdm(image_files[:S])): # Limit to the number of images in the batch
284
+ image_name = os.path.basename(image_path)
285
+ mask_filepath = os.path.join(sky_masks_dir, image_name)
286
+
287
+ if os.path.exists(mask_filepath):
288
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
289
+ else:
290
+ sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
291
+
292
+ # Resize mask to match H×W if needed
293
+ if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
294
+ sky_mask = cv2.resize(sky_mask, (W, H))
295
+
296
+ sky_mask_list.append(sky_mask)
297
+
298
+ # Convert list to numpy array with shape S×H×W
299
+ sky_mask_array = np.array(sky_mask_list)
300
+ # Apply sky mask to confidence scores
301
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
302
+ conf = conf * sky_mask_binary
303
+
304
+ print("Sky segmentation applied successfully")
305
+ return conf
306
+
307
+
308
+ parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization")
309
+ parser.add_argument(
310
+ "--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images"
311
+ )
312
+ parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points")
313
+ parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode")
314
+ parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server")
315
+ parser.add_argument(
316
+ "--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out"
317
+ )
318
+ parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
319
+
320
+
321
+ def main():
322
+ """
323
+ Main function for the VGGT demo with viser for 3D visualization.
324
+
325
+ This function:
326
+ 1. Loads the VGGT model
327
+ 2. Processes input images from the specified folder
328
+ 3. Runs inference to generate 3D points and camera poses
329
+ 4. Optionally applies sky segmentation to filter out sky points
330
+ 5. Visualizes the results using viser
331
+
332
+ Command-line arguments:
333
+ --image_folder: Path to folder containing input images
334
+ --use_point_map: Use point map instead of depth-based points
335
+ --background_mode: Run the viser server in background mode
336
+ --port: Port number for the viser server
337
+ --conf_threshold: Initial percentage of low-confidence points to filter out
338
+ --mask_sky: Apply sky segmentation to filter out sky points
339
+ """
340
+ args = parser.parse_args()
341
+ device = "cuda" if torch.cuda.is_available() else "cpu"
342
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
343
+ print(f"Using device: {device}")
344
+
345
+ print("Initializing and loading VGGT model...")
346
+ # model = VGGT.from_pretrained("facebook/VGGT-1B")
347
+
348
+ model = VGGT()
349
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
350
+ model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
351
+
352
+ model.eval()
353
+ model = model.to(dtype=dtype, device=device)
354
+
355
+ # Use the provided image folder path
356
+ print(f"Loading images from {args.image_folder}...")
357
+ image_names = glob.glob(os.path.join(args.image_folder, "*"))
358
+ print(f"Found {len(image_names)} images")
359
+
360
+ images = load_and_preprocess_images(image_names).to(dtype=dtype, device=device)
361
+ print(f"Preprocessed images shape: {images.shape}")
362
+
363
+ print("Running inference...")
364
+ with torch.no_grad():
365
+ predictions = model(images, verbose=True)
366
+
367
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
368
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
369
+ predictions["extrinsic"] = extrinsic
370
+ predictions["intrinsic"] = intrinsic
371
+
372
+ print("Processing model outputs...")
373
+ for key in predictions.keys():
374
+ if isinstance(predictions[key], torch.Tensor):
375
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy
376
+
377
+ if args.use_point_map:
378
+ print("Visualizing 3D points from point map")
379
+ else:
380
+ print("Visualizing 3D points by unprojecting depth map by cameras")
381
+
382
+ if args.mask_sky:
383
+ print("Sky segmentation enabled - will filter out sky points")
384
+
385
+ print("Starting viser visualization...")
386
+
387
+ viser_server = viser_wrapper(
388
+ predictions,
389
+ port=args.port,
390
+ init_conf_threshold=args.conf_threshold,
391
+ use_point_map=args.use_point_map,
392
+ background_mode=args.background_mode,
393
+ mask_sky=args.mask_sky,
394
+ image_folder=args.image_folder,
395
+ )
396
+ print("Visualization complete")
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main()
vggt-low-vram/docs/package.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Alternative Installation Methods
2
+
3
+ This document explains how to install VGGT as a package using different package managers.
4
+
5
+ ## Prerequisites
6
+
7
+ Before installing VGGT as a package, you need to install PyTorch and torchvision. We don't list these as dependencies to avoid CUDA version mismatches. Install them first, with an example as:
8
+
9
+ ```bash
10
+ # install pytorch 2.3.1 with cuda 12.1
11
+ pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
12
+ ```
13
+
14
+ ## Installation Options
15
+
16
+ ### Install with pip
17
+
18
+ The simplest way to install VGGT is using pip:
19
+
20
+ ```bash
21
+ pip install -e .
22
+ ```
23
+
24
+ ### Install and run with pixi
25
+
26
+ [Pixi](https://pixi.sh) is a package management tool for creating reproducible environments.
27
+
28
+ 1. First, [download and install pixi](https://pixi.sh/latest/get_started/)
29
+ 2. Then run:
30
+
31
+ ```bash
32
+ pixi run -e python demo_gradio.py
33
+ ```
34
+
35
+ ### Install and run with uv
36
+
37
+ [uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver.
38
+
39
+ 1. First, [install uv](https://docs.astral.sh/uv/getting-started/installation/)
40
+ 2. Then run:
41
+
42
+ ```bash
43
+ uv run --extra demo demo_gradio.py
44
+ ```
45
+
vggt-low-vram/examples/kitchen/images/00.png ADDED

Git LFS Details

  • SHA256: 54527a575988094058cdc1975b421c48e0f446726473d0ac21ea55ecb24e96a7
  • Pointer size: 131 Bytes
  • Size of remote file: 691 kB
vggt-low-vram/examples/kitchen/images/01.png ADDED

Git LFS Details

  • SHA256: 0ad4c6d74c16661ed427f8100124aaf53e7fd0577b32c362f13559dfad7027a7
  • Pointer size: 131 Bytes
  • Size of remote file: 726 kB
vggt-low-vram/examples/kitchen/images/02.png ADDED

Git LFS Details

  • SHA256: 596bd54d26f889fc80cedee81d95dda709fa134d86ac199b6509337e413246d5
  • Pointer size: 131 Bytes
  • Size of remote file: 789 kB
vggt-low-vram/examples/kitchen/images/03.png ADDED

Git LFS Details

  • SHA256: 78193756310d9abaf81fa28902cf0b284260a0a916b085a7c08a4723eead1dd6
  • Pointer size: 131 Bytes
  • Size of remote file: 828 kB
vggt-low-vram/examples/kitchen/images/04.png ADDED

Git LFS Details

  • SHA256: ca551254002a318228e19e46982813f3e489828796e98547ff632043f3002f9d
  • Pointer size: 131 Bytes
  • Size of remote file: 724 kB
vggt-low-vram/examples/kitchen/images/05.png ADDED

Git LFS Details

  • SHA256: a8dcd116d782d32b404d7e4aa69f462abbd048a0d8727440ec37f18cc4548ee4
  • Pointer size: 131 Bytes
  • Size of remote file: 759 kB
vggt-low-vram/examples/kitchen/images/06.png ADDED

Git LFS Details

  • SHA256: 2fcc2b871c6fef6f3a3e0f06a3ffc1f0eee3e40afa2461f7c7c665057decb3e6
  • Pointer size: 131 Bytes
  • Size of remote file: 674 kB
vggt-low-vram/examples/kitchen/images/07.png ADDED

Git LFS Details

  • SHA256: 28d21898de0e6370790839a40f7f45d84fbb3e6ff5809f0a0e14bd01bdef730e
  • Pointer size: 131 Bytes
  • Size of remote file: 856 kB
vggt-low-vram/examples/kitchen/images/08.png ADDED

Git LFS Details

  • SHA256: 0137a2bb3eb3e691d8d8b1f8884a9c8f99748888b1db770091d7acdf35fe8efa
  • Pointer size: 131 Bytes
  • Size of remote file: 677 kB
vggt-low-vram/examples/kitchen/images/09.png ADDED

Git LFS Details

  • SHA256: 1ab59c1ef85d8169b404463f01b7ae4d287da12677126b68a3dce407ca2b9077
  • Pointer size: 131 Bytes
  • Size of remote file: 797 kB
vggt-low-vram/examples/kitchen/images/10.png ADDED

Git LFS Details

  • SHA256: f180cbf110bc65b89ad616328ad7d076dc3901a18def4b1337a134cdf65233a0
  • Pointer size: 131 Bytes
  • Size of remote file: 730 kB
vggt-low-vram/examples/kitchen/images/11.png ADDED

Git LFS Details

  • SHA256: 781196eadae8d907928e877e073289c0998e2b9e513d4f7580e147d15d1ae571
  • Pointer size: 131 Bytes
  • Size of remote file: 799 kB
vggt-low-vram/examples/kitchen/images/12.png ADDED

Git LFS Details

  • SHA256: dd59b24dc8962ba0fc7fbb37b53a6d76fec9730c74e7e3235a06902b250e7d44
  • Pointer size: 131 Bytes
  • Size of remote file: 707 kB
vggt-low-vram/examples/kitchen/images/13.png ADDED

Git LFS Details

  • SHA256: b4cd39f22c766477bad741ff37a1ee5f71aecde8bb6762d869b4c9dca1ceacfb
  • Pointer size: 131 Bytes
  • Size of remote file: 755 kB
vggt-low-vram/examples/kitchen/images/14.png ADDED

Git LFS Details

  • SHA256: 5df1f398efc144271e342d7b65447e022a100b93b3850a755fbc66aff5fca0f2
  • Pointer size: 131 Bytes
  • Size of remote file: 642 kB
vggt-low-vram/examples/kitchen/images/15.png ADDED

Git LFS Details

  • SHA256: 325262829ddb11d1c7df1a8f1fef79a297332dad51870ab0d40a73f1dd6869b1
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
vggt-low-vram/examples/kitchen/images/16.png ADDED

Git LFS Details

  • SHA256: 9779a78d72fc25f2118a270f060afeacbcef149a4f012119ff041effa8727cbf
  • Pointer size: 131 Bytes
  • Size of remote file: 754 kB
vggt-low-vram/examples/kitchen/images/17.png ADDED

Git LFS Details

  • SHA256: 2549f4f505ea021eebe0bf579b969b6c162d2dee18b0c8e9d7a3c043d200e45b
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
vggt-low-vram/examples/kitchen/images/18.png ADDED

Git LFS Details

  • SHA256: e1c21131c4732756d5774dd732af86c1d39dea96fd2d613afd570633b3a76ef6
  • Pointer size: 131 Bytes
  • Size of remote file: 829 kB
vggt-low-vram/examples/kitchen/images/19.png ADDED

Git LFS Details

  • SHA256: d17680e77c6cb326eb4604e29f9e532db34769ca20b938e944ab53e8bd3798e2
  • Pointer size: 131 Bytes
  • Size of remote file: 678 kB
vggt-low-vram/examples/kitchen/images/20.png ADDED

Git LFS Details

  • SHA256: 5e9c835a0e0c1bc162a8bff6b93677c58cb53afaadca260b0ca2a388565b4cc2
  • Pointer size: 131 Bytes
  • Size of remote file: 718 kB
vggt-low-vram/examples/kitchen/images/21.png ADDED

Git LFS Details

  • SHA256: 0747b2d1b44ef538a9aa40a067881ef9d3ed5cacbf954c926a2bdf5f29c114e6
  • Pointer size: 131 Bytes
  • Size of remote file: 787 kB
vggt-low-vram/examples/kitchen/images/22.png ADDED

Git LFS Details

  • SHA256: 77a0014d7c7d5802ce23cda4e102759274fd8f4c150271a3b61cbb2fe33b69b6
  • Pointer size: 131 Bytes
  • Size of remote file: 675 kB
vggt-low-vram/examples/kitchen/images/23.png ADDED

Git LFS Details

  • SHA256: 1a9415e9b8f08ff298829ffac779bb1e8dedccb3bf36060d59a7da2a35c4f790
  • Pointer size: 131 Bytes
  • Size of remote file: 652 kB
vggt-low-vram/examples/kitchen/images/24.png ADDED

Git LFS Details

  • SHA256: 5199003307466bf4706a0898f139bf3590946f255d08c6b11d5aa9eede54c83a
  • Pointer size: 131 Bytes
  • Size of remote file: 800 kB
vggt-low-vram/examples/llff_fern/images/000.png ADDED

Git LFS Details

  • SHA256: 47f447d31a84d53494045087cbb8a40b877a68a76f549af14f6bb6f490a5b05d
  • Pointer size: 131 Bytes
  • Size of remote file: 671 kB
vggt-low-vram/examples/llff_fern/images/001.png ADDED

Git LFS Details

  • SHA256: 05402df1d7247e794768461571c188737dcae5fcb34400990f5751244a3e41c0
  • Pointer size: 131 Bytes
  • Size of remote file: 666 kB
vggt-low-vram/examples/llff_fern/images/002.png ADDED

Git LFS Details

  • SHA256: e17135aa9b506fac24a9529ee56c37ef5a52c55498998d3de64cf3e46210dccc
  • Pointer size: 131 Bytes
  • Size of remote file: 652 kB
vggt-low-vram/examples/llff_fern/images/003.png ADDED

Git LFS Details

  • SHA256: 3285c7cc6b4b75703a68f510072c5eca81cff9b983044426cbe2ca27d4e526c5
  • Pointer size: 131 Bytes
  • Size of remote file: 653 kB