Diffusers
Safetensors
zeyuren2002 commited on
Commit
ea3c0ad
·
verified ·
1 Parent(s): 87a49e9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Depth-Anything-3/da3_streaming/loop_utils/__init__.py +15 -0
  2. Depth-Anything-3/da3_streaming/loop_utils/alignment_torch.py +395 -0
  3. Depth-Anything-3/da3_streaming/loop_utils/alignment_triton.py +543 -0
  4. Depth-Anything-3/da3_streaming/loop_utils/config_utils.py +66 -0
  5. Depth-Anything-3/da3_streaming/loop_utils/logging_utils.py +32 -0
  6. Depth-Anything-3/da3_streaming/loop_utils/loop_detector.py +391 -0
  7. Depth-Anything-3/da3_streaming/loop_utils/loop_refinement.py +268 -0
  8. Depth-Anything-3/da3_streaming/loop_utils/sim3loop.py +399 -0
  9. Depth-Anything-3/da3_streaming/loop_utils/sim3utils.py +1261 -0
  10. Depth-Anything-3/da3_streaming/scripts/download_weights.sh +20 -0
  11. Depth-Anything-3/docs/API.md +465 -0
  12. Depth-Anything-3/docs/BENCHMARK.md +484 -0
  13. Depth-Anything-3/docs/CLI.md +654 -0
  14. Depth-Anything-3/docs/funcs/ref_view_strategy.md +183 -0
  15. Depth-Anything-3/notebooks/da3.ipynb +0 -0
  16. Depth-Anything-3/src/depth_anything_3/api.py +446 -0
  17. Depth-Anything-3/src/depth_anything_3/app/css_and_html.py +594 -0
  18. Depth-Anything-3/src/depth_anything_3/app/gradio_app.py +724 -0
  19. Depth-Anything-3/src/depth_anything_3/app/modules/__init__.py +43 -0
  20. Depth-Anything-3/src/depth_anything_3/app/modules/event_handlers.py +619 -0
  21. Depth-Anything-3/src/depth_anything_3/app/modules/file_handlers.py +304 -0
  22. Depth-Anything-3/src/depth_anything_3/app/modules/model_inference.py +260 -0
  23. Depth-Anything-3/src/depth_anything_3/app/modules/ui_components.py +477 -0
  24. Depth-Anything-3/src/depth_anything_3/app/modules/utils.py +207 -0
  25. Depth-Anything-3/src/depth_anything_3/app/modules/visualization.py +434 -0
  26. Depth-Anything-3/src/depth_anything_3/bench/__init__.py +45 -0
  27. Depth-Anything-3/src/depth_anything_3/bench/configs/eval_bench.yaml +98 -0
  28. Depth-Anything-3/src/depth_anything_3/bench/dataset.py +136 -0
  29. Depth-Anything-3/src/depth_anything_3/bench/datasets/__init__.py +21 -0
  30. Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu.py +681 -0
  31. Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu64.py +182 -0
  32. Depth-Anything-3/src/depth_anything_3/bench/datasets/eth3d.py +594 -0
  33. Depth-Anything-3/src/depth_anything_3/bench/datasets/hiroom.py +440 -0
  34. Depth-Anything-3/src/depth_anything_3/bench/datasets/scannetpp.py +591 -0
  35. Depth-Anything-3/src/depth_anything_3/bench/datasets/sevenscenes.py +449 -0
  36. Depth-Anything-3/src/depth_anything_3/bench/evaluator.py +752 -0
  37. Depth-Anything-3/src/depth_anything_3/bench/print_metrics.py +618 -0
  38. Depth-Anything-3/src/depth_anything_3/bench/registries.py +85 -0
  39. Depth-Anything-3/src/depth_anything_3/bench/utils.py +525 -0
  40. Depth-Anything-3/src/depth_anything_3/cfg.py +144 -0
  41. Depth-Anything-3/src/depth_anything_3/cli.py +803 -0
  42. Depth-Anything-3/src/depth_anything_3/configs/da3-base.yaml +45 -0
  43. Depth-Anything-3/src/depth_anything_3/configs/da3-giant.yaml +71 -0
  44. Depth-Anything-3/src/depth_anything_3/configs/da3-large.yaml +45 -0
  45. Depth-Anything-3/src/depth_anything_3/configs/da3-small.yaml +45 -0
  46. Depth-Anything-3/src/depth_anything_3/configs/da3metric-large.yaml +28 -0
  47. Depth-Anything-3/src/depth_anything_3/configs/da3mono-large.yaml +28 -0
  48. Depth-Anything-3/src/depth_anything_3/configs/da3nested-giant-large.yaml +10 -0
  49. Depth-Anything-3/src/depth_anything_3/model/__init__.py +20 -0
  50. Depth-Anything-3/src/depth_anything_3/model/cam_dec.py +45 -0
Depth-Anything-3/da3_streaming/loop_utils/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
Depth-Anything-3/da3_streaming/loop_utils/alignment_torch.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ def weighted_estimate_se3_torch(source_points, target_points, weights):
22
+ source_points = torch.from_numpy(source_points).cuda().float()
23
+ target_points = torch.from_numpy(target_points).cuda().float()
24
+ weights = torch.from_numpy(weights).cuda().float()
25
+
26
+ total_weight = torch.sum(weights)
27
+ if total_weight < 1e-6:
28
+ return (
29
+ 1.0,
30
+ np.zeros(3, dtype=np.float32),
31
+ np.zeros(3, dtype=np.float32),
32
+ np.zeros((3, 3), dtype=np.float32),
33
+ )
34
+
35
+ normalized_weights = weights / total_weight
36
+
37
+ mu_src = torch.sum(normalized_weights[:, None] * source_points, dim=0)
38
+ mu_tgt = torch.sum(normalized_weights[:, None] * target_points, dim=0)
39
+
40
+ src_centered = source_points - mu_src
41
+ tgt_centered = target_points - mu_tgt
42
+
43
+ weighted_src = src_centered * torch.sqrt(normalized_weights)[:, None]
44
+ weighted_tgt = tgt_centered * torch.sqrt(normalized_weights)[:, None]
45
+
46
+ H = weighted_src.T @ weighted_tgt
47
+
48
+ return 1.0, mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
49
+
50
+
51
+ def weighted_estimate_sim3_torch(source_points, target_points, weights):
52
+
53
+ source_points = torch.from_numpy(source_points).cuda().float()
54
+ target_points = torch.from_numpy(target_points).cuda().float()
55
+ weights = torch.from_numpy(weights).cuda().float()
56
+
57
+ total_weight = torch.sum(weights)
58
+ if total_weight < 1e-6:
59
+ return (
60
+ -1.0,
61
+ np.zeros(3, dtype=np.float32),
62
+ np.zeros(3, dtype=np.float32),
63
+ np.zeros((3, 3), dtype=np.float32),
64
+ )
65
+
66
+ normalized_weights = weights / total_weight
67
+
68
+ mu_src = torch.sum(normalized_weights[:, None] * source_points, dim=0)
69
+ mu_tgt = torch.sum(normalized_weights[:, None] * target_points, dim=0)
70
+
71
+ src_centered = source_points - mu_src
72
+ tgt_centered = target_points - mu_tgt
73
+
74
+ scale_src = torch.sqrt(torch.sum(normalized_weights * torch.sum(src_centered**2, dim=1)))
75
+ scale_tgt = torch.sqrt(torch.sum(normalized_weights * torch.sum(tgt_centered**2, dim=1)))
76
+ s = scale_tgt / scale_src
77
+
78
+ weighted_src = (s * src_centered) * torch.sqrt(normalized_weights)[:, None]
79
+ weighted_tgt = tgt_centered * torch.sqrt(normalized_weights)[:, None]
80
+
81
+ H = weighted_src.T @ weighted_tgt
82
+
83
+ return s.cpu().numpy(), mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
84
+
85
+
86
+ def weighted_estimate_sim3_numba_torch(source_points, target_points, weights, align_method="sim3"):
87
+
88
+ if align_method == "sim3":
89
+ s, mu_src, mu_tgt, H = weighted_estimate_sim3_torch(source_points, target_points, weights)
90
+ elif align_method == "se3" or align_method == "scale+se3":
91
+ s, mu_src, mu_tgt, H = weighted_estimate_se3_torch(source_points, target_points, weights)
92
+
93
+ if s < 0:
94
+ raise ValueError("Total weight too small for meaningful estimation")
95
+
96
+ H_torch = torch.from_numpy(H).cuda().float()
97
+ U, _, Vt = torch.linalg.svd(H_torch)
98
+
99
+ U = U.cpu().numpy()
100
+ Vt = Vt.cpu().numpy()
101
+
102
+ R = Vt.T @ U.T
103
+ if np.linalg.det(R) < 0:
104
+ Vt[2, :] *= -1
105
+ R = Vt.T @ U.T
106
+
107
+ mu_src = mu_src.astype(np.float32)
108
+ mu_tgt = mu_tgt.astype(np.float32)
109
+ R = R.astype(np.float32)
110
+
111
+ if align_method == "se3" or align_method == "scale+se3":
112
+ t = mu_tgt - R @ mu_src
113
+ else:
114
+ t = mu_tgt - s * R @ mu_src
115
+
116
+ return s, R, t.astype(np.float32)
117
+
118
+
119
+ def huber_loss_torch(r, delta):
120
+
121
+ r_torch = torch.from_numpy(r).cuda().float()
122
+ delta_torch = torch.tensor(delta, device="cuda", dtype=torch.float32)
123
+
124
+ abs_r = torch.abs(r_torch)
125
+ result = torch.where(
126
+ abs_r <= delta_torch, 0.5 * r_torch**2, delta_torch * (abs_r - 0.5 * delta_torch)
127
+ )
128
+
129
+ return result.cpu().numpy()
130
+
131
+
132
+ def compute_residuals_torch(tgt, transformed):
133
+
134
+ tgt_torch = torch.from_numpy(tgt).cuda().float()
135
+ transformed_torch = torch.from_numpy(transformed).cuda().float()
136
+
137
+ residuals = torch.sqrt(torch.sum((tgt_torch - transformed_torch) ** 2, dim=1))
138
+ return residuals.cpu().numpy()
139
+
140
+
141
+ def compute_huber_weights_torch(residuals, delta):
142
+
143
+ residuals_torch = torch.from_numpy(residuals).cuda().float()
144
+ delta_torch = torch.tensor(delta, device="cuda", dtype=torch.float32)
145
+
146
+ weights = torch.ones_like(residuals_torch)
147
+ mask = residuals_torch > delta_torch
148
+ weights[mask] = delta_torch / residuals_torch[mask]
149
+
150
+ return weights.cpu().numpy()
151
+
152
+
153
+ def apply_transformation_torch(src, s, R, t):
154
+
155
+ src_torch = torch.from_numpy(src).cuda().float()
156
+ R_torch = torch.from_numpy(R).cuda().float()
157
+ t_torch = torch.from_numpy(t).cuda().float()
158
+ s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
159
+
160
+ transformed = s_torch * (src_torch @ R_torch.T) + t_torch
161
+ return transformed.cpu().numpy()
162
+
163
+
164
+ def robust_weighted_estimate_sim3_torch(
165
+ src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
166
+ ):
167
+
168
+ src = src.astype(np.float32)
169
+ tgt = tgt.astype(np.float32)
170
+ init_weights = init_weights.astype(np.float32)
171
+
172
+ s, R, t = weighted_estimate_sim3_numba_torch(src, tgt, init_weights, align_method=align_method)
173
+
174
+ prev_error = float("inf")
175
+
176
+ for iter in range(max_iters):
177
+ transformed = apply_transformation_torch(src, s, R, t)
178
+ residuals = compute_residuals_torch(tgt, transformed)
179
+
180
+ print(f"Iter {iter}: Mean residual = {np.mean(residuals):.6f}")
181
+
182
+ huber_weights = compute_huber_weights_torch(residuals, delta)
183
+ combined_weights = init_weights * huber_weights
184
+ combined_weights /= np.sum(combined_weights) + 1e-12
185
+
186
+ s_new, R_new, t_new = weighted_estimate_sim3_numba_torch(
187
+ src, tgt, combined_weights, align_method=align_method
188
+ )
189
+
190
+ param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
191
+ rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
192
+
193
+ current_error = np.sum(huber_loss_torch(residuals, delta) * init_weights)
194
+
195
+ if (param_change < tol and rot_angle < np.radians(0.1)) or (
196
+ abs(prev_error - current_error) < tol * prev_error
197
+ ):
198
+ print(f"Converged at iteration {iter}")
199
+ break
200
+
201
+ s, R, t = s_new, R_new, t_new
202
+ prev_error = current_error
203
+
204
+ return s, R, t
205
+
206
+
207
+ def apply_sim3_direct_torch(point_maps, s, R, t, device=None):
208
+ """
209
+ PyTorch SIM3
210
+ point_maps: (b, h, w, 3) numpy array
211
+ s: scalar or (b,) array
212
+ R: (3, 3) or (b, 3, 3) numpy array
213
+ t: (3,) or (b, 3) numpy array
214
+ """
215
+ if isinstance(point_maps, np.ndarray):
216
+ point_maps_torch = torch.from_numpy(point_maps).float()
217
+ R_torch = torch.from_numpy(R).float()
218
+ t_torch = torch.from_numpy(t).float()
219
+ s_torch = torch.tensor(s).float() if np.isscalar(s) else torch.from_numpy(s).float()
220
+ else:
221
+ point_maps_torch = point_maps
222
+ R_torch = R
223
+ t_torch = t
224
+ s_torch = s
225
+
226
+ if device is not None:
227
+ point_maps_torch = point_maps_torch.to(device)
228
+ R_torch = R_torch.to(device)
229
+ t_torch = t_torch.to(device)
230
+ s_torch = s_torch.to(device)
231
+
232
+ b, h, w, c = point_maps_torch.shape
233
+
234
+ points_flat = point_maps_torch.reshape(b, -1, 3) # (b, h*w, 3)
235
+
236
+ if R_torch.dim() == 2:
237
+ R_torch = R_torch.unsqueeze(0).expand(b, 3, 3) # (b, 3, 3)
238
+
239
+ if t_torch.dim() == 1:
240
+ t_torch = t_torch.unsqueeze(0).expand(b, 3) # (b, 3)
241
+
242
+ if s_torch.dim() == 0:
243
+ s_torch = s_torch.unsqueeze(0).expand(b) # (b,)
244
+
245
+ rotated_flat = torch.bmm(points_flat, R_torch.transpose(1, 2)) # (b, h*w, 3)
246
+
247
+ transformed_flat = s_torch[:, None, None] * rotated_flat + t_torch[:, None, :]
248
+
249
+ transformed = transformed_flat.reshape(b, h, w, 3)
250
+
251
+ if isinstance(point_maps, np.ndarray):
252
+ return transformed.cpu().numpy()
253
+ return transformed
254
+
255
+
256
+ def depth_to_point_cloud_optimized_torch(depth, intrinsics, extrinsics, device=None):
257
+
258
+ input_is_numpy = isinstance(depth, np.ndarray)
259
+
260
+ if input_is_numpy:
261
+ depth_tensor = torch.from_numpy(depth).float()
262
+ intrinsics_tensor = torch.from_numpy(intrinsics).float()
263
+ extrinsics_tensor = torch.from_numpy(extrinsics).float()
264
+ else:
265
+ depth_tensor = depth
266
+ intrinsics_tensor = intrinsics
267
+ extrinsics_tensor = extrinsics
268
+
269
+ if device is not None:
270
+ depth_tensor = depth_tensor.to(device)
271
+ intrinsics_tensor = intrinsics_tensor.to(device)
272
+ extrinsics_tensor = extrinsics_tensor.to(device)
273
+
274
+ N, H, W = depth_tensor.shape
275
+ device = depth_tensor.device
276
+
277
+ u = torch.arange(W, device=device, dtype=torch.float32).view(1, 1, W)
278
+ v = torch.arange(H, device=device, dtype=torch.float32).view(1, H, 1)
279
+
280
+ u_expanded = u.expand(N, H, W)
281
+ v_expanded = v.expand(N, H, W)
282
+
283
+ ones = torch.ones((N, H, W), device=device)
284
+ pixel_coords = torch.stack([u_expanded, v_expanded, ones], dim=-1) # [N, H, W, 3]
285
+
286
+ intrinsics_inv = torch.inverse(intrinsics_tensor) # [N, 3, 3]
287
+
288
+ camera_coords = torch.einsum("nij,nhwj->nhwi", intrinsics_inv, pixel_coords)
289
+
290
+ camera_coords = camera_coords * depth_tensor.unsqueeze(-1) # [N, H, W, 3]
291
+
292
+ camera_coords_homo = torch.cat(
293
+ [camera_coords, torch.ones((N, H, W, 1), device=device)], dim=-1
294
+ )
295
+
296
+ extrinsics_4x4 = torch.zeros(N, 4, 4, device=device)
297
+ extrinsics_4x4[:, :3, :4] = extrinsics_tensor
298
+ extrinsics_4x4[:, 3, 3] = 1.0
299
+
300
+ c2w = torch.inverse(extrinsics_4x4) # [N, 4, 4]
301
+
302
+ world_coords_homo = torch.einsum("nij,nhwj->nhwi", c2w, camera_coords_homo)
303
+ point_cloud_world = world_coords_homo[..., :3] # [N, H, W, 3]
304
+
305
+ if input_is_numpy:
306
+ return point_cloud_world.cpu().numpy()
307
+ return point_cloud_world
308
+
309
+
310
+ def warmup_torch():
311
+
312
+ print("\nWarming up PyTorch alignment...")
313
+
314
+ src = np.random.randn(100000, 3).astype(np.float32)
315
+ tgt = np.random.randn(100000, 3).astype(np.float32)
316
+ weights = np.ones(100000, dtype=np.float32)
317
+ residuals = np.abs(np.random.randn(100000).astype(np.float32))
318
+ R = np.eye(3, dtype=np.float32)
319
+ t = np.zeros(3, dtype=np.float32)
320
+ s = np.float32(1.0)
321
+ delta = np.float32(1.0)
322
+
323
+ try:
324
+ _ = weighted_estimate_sim3_torch(src, tgt, weights)
325
+ print(" - weighted_estimate_sim3_torch warmed up.")
326
+ except Exception as e:
327
+ print(" ! Failed to warm up weighted_estimate_sim3_torch:", e)
328
+
329
+ try:
330
+ _ = weighted_estimate_se3_torch(src, tgt, weights)
331
+ print(" - weighted_estimate_se3_torch warmed up.")
332
+ except Exception as e:
333
+ print(" ! Failed to warm up weighted_estimate_se3_torch:", e)
334
+
335
+ try:
336
+ _ = huber_loss_torch(residuals, delta)
337
+ print(" - huber_loss_torch warmed up.")
338
+ except Exception as e:
339
+ print(" ! Failed to warm up huber_loss_torch:", e)
340
+
341
+ try:
342
+ _ = compute_huber_weights_torch(residuals, delta)
343
+ print(" - compute_huber_weights_torch warmed up.")
344
+ except Exception as e:
345
+ print(" ! Failed to warm up compute_huber_weights_torch:", e)
346
+
347
+ try:
348
+ _ = compute_residuals_torch(tgt, src)
349
+ print(" - compute_residuals_torch warmed up.")
350
+ except Exception as e:
351
+ print(" ! Failed to warm up compute_residuals_torch:", e)
352
+
353
+ try:
354
+ _ = apply_transformation_torch(src, s, R, t)
355
+ print(" - apply_transformation_torch warmed up.")
356
+ except Exception as e:
357
+ print(" ! Failed to warm up apply_transformation_torch:", e)
358
+
359
+ print("PyTorch warm-up complete.\n")
360
+
361
+
362
+ def print_gpu_memory():
363
+ if torch.cuda.is_available():
364
+ allocated = torch.cuda.memory_allocated() / 1024**3 # GB
365
+ cached = torch.cuda.memory_reserved() / 1024**3 # GB
366
+ print(f"GPU Memory Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
367
+
368
+
369
+ if __name__ == "__main__":
370
+
371
+ warmup_torch()
372
+
373
+ n_points = 7_500_000
374
+ src = np.random.randn(n_points, 3).astype(np.float32)
375
+
376
+ true_R = np.array([[0.866, -0.5, 0], [0.5, 0.866, 0], [0, 0, 1]], dtype=np.float32)
377
+ true_t = np.array([1.0, 2.0, 0.5], dtype=np.float32)
378
+ true_s = 1.2
379
+
380
+ tgt = true_s * (src @ true_R.T) + true_t
381
+ tgt += 0.01 * np.random.randn(*tgt.shape).astype(np.float32)
382
+
383
+ weights = np.ones(n_points, dtype=np.float32)
384
+
385
+ print_gpu_memory()
386
+
387
+ s, R, t = robust_weighted_estimate_sim3_torch(
388
+ src, tgt, weights, delta=0.1, max_iters=5, align_method="sim3"
389
+ )
390
+
391
+ print(f"\nEstimated scale: {s:.6f}")
392
+ print(f"Estimated rotation:\n{R}")
393
+ print(f"Estimated translation: {t}")
394
+
395
+ print_gpu_memory()
Depth-Anything-3/da3_streaming/loop_utils/alignment_triton.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import numpy as np
18
+ import torch
19
+ import triton
20
+ import triton.language as tl
21
+
22
+
23
+ @triton.jit
24
+ def apply_transformation_residual_kernel(
25
+ src_ptr, # [n, 3]
26
+ tgt_ptr, # [n, 3]
27
+ transformed_ptr, # [n, 3]
28
+ residuals_ptr, # [n]
29
+ s,
30
+ R00,
31
+ R01,
32
+ R02,
33
+ R10,
34
+ R11,
35
+ R12,
36
+ R20,
37
+ R21,
38
+ R22,
39
+ t0,
40
+ t1,
41
+ t2,
42
+ n_points,
43
+ BLOCK_SIZE: tl.constexpr,
44
+ ):
45
+ pid = tl.program_id(0)
46
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_points
48
+
49
+ src_x = tl.load(src_ptr + offsets * 3 + 0, mask=mask)
50
+ src_y = tl.load(src_ptr + offsets * 3 + 1, mask=mask)
51
+ src_z = tl.load(src_ptr + offsets * 3 + 2, mask=mask)
52
+
53
+ tgt_x = tl.load(tgt_ptr + offsets * 3 + 0, mask=mask)
54
+ tgt_y = tl.load(tgt_ptr + offsets * 3 + 1, mask=mask)
55
+ tgt_z = tl.load(tgt_ptr + offsets * 3 + 2, mask=mask)
56
+
57
+ # transformed = s * (R @ p) + t
58
+ transformed_x = s * (R00 * src_x + R01 * src_y + R02 * src_z) + t0
59
+ transformed_y = s * (R10 * src_x + R11 * src_y + R12 * src_z) + t1
60
+ transformed_z = s * (R20 * src_x + R21 * src_y + R22 * src_z) + t2
61
+
62
+ tl.store(transformed_ptr + offsets * 3 + 0, transformed_x, mask=mask)
63
+ tl.store(transformed_ptr + offsets * 3 + 1, transformed_y, mask=mask)
64
+ tl.store(transformed_ptr + offsets * 3 + 2, transformed_z, mask=mask)
65
+
66
+ dx = tgt_x - transformed_x
67
+ dy = tgt_y - transformed_y
68
+ dz = tgt_z - transformed_z
69
+ residual = tl.sqrt(dx * dx + dy * dy + dz * dz)
70
+ tl.store(residuals_ptr + offsets, residual, mask=mask)
71
+
72
+
73
+ @triton.jit
74
+ def weighted_covariance_kernel(
75
+ src_ptr, # [n, 3]
76
+ tgt_ptr, # [n, 3]
77
+ weights_ptr, # [n]
78
+ mu_src0,
79
+ mu_src1,
80
+ mu_src2,
81
+ mu_tgt0,
82
+ mu_tgt1,
83
+ mu_tgt2,
84
+ H_ptr, # [3, 3]
85
+ n_points,
86
+ BLOCK_SIZE: tl.constexpr,
87
+ ):
88
+ pid = tl.program_id(0)
89
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90
+ mask = offsets < n_points
91
+
92
+ w = tl.load(weights_ptr + offsets, mask=mask)
93
+ src_x = tl.load(src_ptr + offsets * 3 + 0, mask=mask)
94
+ src_y = tl.load(src_ptr + offsets * 3 + 1, mask=mask)
95
+ src_z = tl.load(src_ptr + offsets * 3 + 2, mask=mask)
96
+ tgt_x = tl.load(tgt_ptr + offsets * 3 + 0, mask=mask)
97
+ tgt_y = tl.load(tgt_ptr + offsets * 3 + 1, mask=mask)
98
+ tgt_z = tl.load(tgt_ptr + offsets * 3 + 2, mask=mask)
99
+
100
+ src_centered_x = src_x - mu_src0
101
+ src_centered_y = src_y - mu_src1
102
+ src_centered_z = src_z - mu_src2
103
+
104
+ tgt_centered_x = tgt_x - mu_tgt0
105
+ tgt_centered_y = tgt_y - mu_tgt1
106
+ tgt_centered_z = tgt_z - mu_tgt2
107
+
108
+ sqrt_w = tl.sqrt(w)
109
+ weighted_src_x = src_centered_x * sqrt_w
110
+ weighted_src_y = src_centered_y * sqrt_w
111
+ weighted_src_z = src_centered_z * sqrt_w
112
+
113
+ weighted_tgt_x = tgt_centered_x * sqrt_w
114
+ weighted_tgt_y = tgt_centered_y * sqrt_w
115
+ weighted_tgt_z = tgt_centered_z * sqrt_w
116
+
117
+ h00 = weighted_src_x * weighted_tgt_x
118
+ h01 = weighted_src_x * weighted_tgt_y
119
+ h02 = weighted_src_x * weighted_tgt_z
120
+
121
+ h10 = weighted_src_y * weighted_tgt_x
122
+ h11 = weighted_src_y * weighted_tgt_y
123
+ h12 = weighted_src_y * weighted_tgt_z
124
+
125
+ h20 = weighted_src_z * weighted_tgt_x
126
+ h21 = weighted_src_z * weighted_tgt_y
127
+ h22 = weighted_src_z * weighted_tgt_z
128
+
129
+ tl.atomic_add(H_ptr + 0, tl.sum(h00, axis=0))
130
+ tl.atomic_add(H_ptr + 1, tl.sum(h01, axis=0))
131
+ tl.atomic_add(H_ptr + 2, tl.sum(h02, axis=0))
132
+
133
+ tl.atomic_add(H_ptr + 3, tl.sum(h10, axis=0))
134
+ tl.atomic_add(H_ptr + 4, tl.sum(h11, axis=0))
135
+ tl.atomic_add(H_ptr + 5, tl.sum(h12, axis=0))
136
+
137
+ tl.atomic_add(H_ptr + 6, tl.sum(h20, axis=0))
138
+ tl.atomic_add(H_ptr + 7, tl.sum(h21, axis=0))
139
+ tl.atomic_add(H_ptr + 8, tl.sum(h22, axis=0))
140
+
141
+
142
+ @triton.jit
143
+ def compute_huber_weights_kernel(
144
+ residuals_ptr,
145
+ weights_ptr,
146
+ delta,
147
+ n_points,
148
+ BLOCK_SIZE: tl.constexpr,
149
+ ):
150
+ pid = tl.program_id(0)
151
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
152
+ mask = offsets < n_points
153
+
154
+ r = tl.load(residuals_ptr + offsets, mask=mask)
155
+
156
+ weight = tl.where(r > delta, delta / r, 1.0)
157
+
158
+ tl.store(weights_ptr + offsets, weight, mask=mask)
159
+
160
+
161
+ @triton.jit
162
+ def weighted_mean_kernel(
163
+ points_ptr, # [n, 3]
164
+ weights_ptr, # [n]
165
+ mean_ptr, # [sum(w*x), sum(w*y), sum(w*z), sum(w)]
166
+ n_points,
167
+ BLOCK_SIZE: tl.constexpr,
168
+ ):
169
+ pid = tl.program_id(0)
170
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
171
+ mask = offsets < n_points
172
+
173
+ w = tl.load(weights_ptr + offsets, mask=mask)
174
+ x = tl.load(points_ptr + offsets * 3 + 0, mask=mask)
175
+ y = tl.load(points_ptr + offsets * 3 + 1, mask=mask)
176
+ z = tl.load(points_ptr + offsets * 3 + 2, mask=mask)
177
+
178
+ wx = w * x
179
+ wy = w * y
180
+ wz = w * z
181
+
182
+ tl.atomic_add(mean_ptr + 0, tl.sum(wx, axis=0))
183
+ tl.atomic_add(mean_ptr + 1, tl.sum(wy, axis=0))
184
+ tl.atomic_add(mean_ptr + 2, tl.sum(wz, axis=0))
185
+ tl.atomic_add(mean_ptr + 3, tl.sum(w, axis=0))
186
+
187
+
188
+ def apply_transformation_residual_triton(src, tgt, s, R, t):
189
+ n_points = src.shape[0]
190
+
191
+ transformed = torch.empty_like(src)
192
+ residuals = torch.empty(n_points, device=src.device, dtype=src.dtype)
193
+
194
+ BLOCK_SIZE = 256
195
+ grid = (triton.cdiv(n_points, BLOCK_SIZE),)
196
+
197
+ R_flat = R.contiguous().view(-1)
198
+ t_flat = t.contiguous().view(-1)
199
+
200
+ apply_transformation_residual_kernel[grid](
201
+ src,
202
+ tgt,
203
+ transformed,
204
+ residuals,
205
+ float(s),
206
+ float(R_flat[0]),
207
+ float(R_flat[1]),
208
+ float(R_flat[2]),
209
+ float(R_flat[3]),
210
+ float(R_flat[4]),
211
+ float(R_flat[5]),
212
+ float(R_flat[6]),
213
+ float(R_flat[7]),
214
+ float(R_flat[8]),
215
+ float(t_flat[0]),
216
+ float(t_flat[1]),
217
+ float(t_flat[2]),
218
+ n_points,
219
+ BLOCK_SIZE=BLOCK_SIZE,
220
+ )
221
+
222
+ return transformed, residuals
223
+
224
+
225
+ def compute_weighted_mean_triton(points, weights):
226
+ n_points = points.shape[0]
227
+
228
+ # [sum(w*x), sum(w*y), sum(w*z), sum(w)]
229
+ mean_buffer = torch.zeros(4, device=points.device, dtype=points.dtype)
230
+
231
+ BLOCK_SIZE = 256
232
+ grid = (triton.cdiv(n_points, BLOCK_SIZE),)
233
+
234
+ weighted_mean_kernel[grid](points, weights, mean_buffer, n_points, BLOCK_SIZE=BLOCK_SIZE)
235
+
236
+ total_weight = mean_buffer[3]
237
+ if total_weight > 1e-12:
238
+ mean = mean_buffer[:3] / total_weight
239
+ else:
240
+ mean = torch.zeros(3, device=points.device, dtype=points.dtype)
241
+
242
+ return mean, total_weight
243
+
244
+
245
+ def compute_weighted_covariance_triton(src, tgt, weights, mu_src, mu_tgt):
246
+ n_points = src.shape[0]
247
+
248
+ H = torch.zeros(9, device=src.device, dtype=src.dtype)
249
+
250
+ BLOCK_SIZE = 256
251
+ grid = (triton.cdiv(n_points, BLOCK_SIZE),)
252
+
253
+ mu_src_flat = mu_src.contiguous().view(-1)
254
+ mu_tgt_flat = mu_tgt.contiguous().view(-1)
255
+
256
+ weighted_covariance_kernel[grid](
257
+ src,
258
+ tgt,
259
+ weights,
260
+ float(mu_src_flat[0]),
261
+ float(mu_src_flat[1]),
262
+ float(mu_src_flat[2]),
263
+ float(mu_tgt_flat[0]),
264
+ float(mu_tgt_flat[1]),
265
+ float(mu_tgt_flat[2]),
266
+ H,
267
+ n_points,
268
+ BLOCK_SIZE=BLOCK_SIZE,
269
+ )
270
+
271
+ return H.reshape(3, 3)
272
+
273
+
274
+ def compute_huber_weights_triton(residuals, delta):
275
+ n_points = residuals.shape[0]
276
+ weights = torch.empty_like(residuals)
277
+
278
+ BLOCK_SIZE = 256
279
+ grid = (triton.cdiv(n_points, BLOCK_SIZE),)
280
+
281
+ compute_huber_weights_kernel[grid](
282
+ residuals, weights, float(delta), n_points, BLOCK_SIZE=BLOCK_SIZE
283
+ )
284
+
285
+ return weights
286
+
287
+
288
+ def weighted_estimate_se3_triton(source_points, target_points, weights):
289
+
290
+ source_points = torch.from_numpy(source_points).cuda().float()
291
+ target_points = torch.from_numpy(target_points).cuda().float()
292
+ weights = torch.from_numpy(weights).cuda().float()
293
+
294
+ total_weight = torch.sum(weights)
295
+ if total_weight < 1e-6:
296
+ return (
297
+ 1.0,
298
+ np.zeros(3, dtype=np.float32),
299
+ np.zeros(3, dtype=np.float32),
300
+ np.zeros((3, 3), dtype=np.float32),
301
+ )
302
+
303
+ normalized_weights = weights / total_weight
304
+
305
+ mu_src, _ = compute_weighted_mean_triton(source_points, normalized_weights)
306
+ mu_tgt, _ = compute_weighted_mean_triton(target_points, normalized_weights)
307
+
308
+ H = compute_weighted_covariance_triton(
309
+ source_points, target_points, normalized_weights, mu_src, mu_tgt
310
+ )
311
+
312
+ return 1.0, mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
313
+
314
+
315
+ def weighted_estimate_sim3_triton(source_points, target_points, weights):
316
+
317
+ source_points = torch.from_numpy(source_points).cuda().float()
318
+ target_points = torch.from_numpy(target_points).cuda().float()
319
+ weights = torch.from_numpy(weights).cuda().float()
320
+
321
+ total_weight = torch.sum(weights)
322
+ if total_weight < 1e-6:
323
+ return (
324
+ -1.0,
325
+ np.zeros(3, dtype=np.float32),
326
+ np.zeros(3, dtype=np.float32),
327
+ np.zeros((3, 3), dtype=np.float32),
328
+ )
329
+
330
+ normalized_weights = weights / total_weight
331
+
332
+ mu_src, _ = compute_weighted_mean_triton(source_points, normalized_weights)
333
+ mu_tgt, _ = compute_weighted_mean_triton(target_points, normalized_weights)
334
+
335
+ src_centered = source_points - mu_src
336
+ tgt_centered = target_points - mu_tgt
337
+
338
+ scale_src = torch.sqrt(torch.sum(normalized_weights * torch.sum(src_centered**2, dim=1)))
339
+ scale_tgt = torch.sqrt(torch.sum(normalized_weights * torch.sum(tgt_centered**2, dim=1)))
340
+ s = scale_tgt / scale_src
341
+
342
+ weighted_src = s * src_centered
343
+ H = compute_weighted_covariance_triton(
344
+ weighted_src,
345
+ tgt_centered,
346
+ normalized_weights,
347
+ torch.zeros_like(mu_src),
348
+ torch.zeros_like(mu_tgt),
349
+ )
350
+
351
+ return s.cpu().numpy(), mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
352
+
353
+
354
+ def weighted_estimate_sim3_numba_triton(
355
+ source_points, target_points, weights, align_method="sim3"
356
+ ):
357
+
358
+ if align_method == "sim3":
359
+ s, mu_src, mu_tgt, H = weighted_estimate_sim3_triton(source_points, target_points, weights)
360
+ elif align_method == "se3" or align_method == "scale+se3":
361
+ s, mu_src, mu_tgt, H = weighted_estimate_se3_triton(source_points, target_points, weights)
362
+
363
+ if s < 0:
364
+ raise ValueError("Total weight too small for meaningful estimation")
365
+
366
+ H_torch = torch.from_numpy(H).cuda().float()
367
+ U, _, Vt = torch.linalg.svd(H_torch)
368
+
369
+ U = U.cpu().numpy()
370
+ Vt = Vt.cpu().numpy()
371
+
372
+ R = Vt.T @ U.T
373
+ if np.linalg.det(R) < 0:
374
+ Vt[2, :] *= -1
375
+ R = Vt.T @ U.T
376
+
377
+ mu_src = mu_src.astype(np.float32)
378
+ mu_tgt = mu_tgt.astype(np.float32)
379
+ R = R.astype(np.float32)
380
+
381
+ if align_method == "se3" or align_method == "scale+se3":
382
+ t = mu_tgt - R @ mu_src
383
+ else:
384
+ t = mu_tgt - s * R @ mu_src
385
+
386
+ return s, R, t.astype(np.float32)
387
+
388
+
389
+ def robust_weighted_estimate_sim3_triton(
390
+ src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
391
+ ):
392
+
393
+ src = src.astype(np.float32)
394
+ tgt = tgt.astype(np.float32)
395
+ init_weights = init_weights.astype(np.float32)
396
+
397
+ src_torch = torch.from_numpy(src).cuda().float()
398
+ tgt_torch = torch.from_numpy(tgt).cuda().float()
399
+ init_weights_torch = torch.from_numpy(init_weights).cuda().float()
400
+
401
+ s, R, t = weighted_estimate_sim3_numba_triton(
402
+ src, tgt, init_weights, align_method=align_method
403
+ )
404
+
405
+ R_torch = torch.from_numpy(R).cuda().float()
406
+ t_torch = torch.from_numpy(t).cuda().float()
407
+ s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
408
+
409
+ prev_error = float("inf")
410
+
411
+ for iter in range(max_iters):
412
+ transformed, residuals = apply_transformation_residual_triton(
413
+ src_torch, tgt_torch, s_torch, R_torch, t_torch
414
+ )
415
+
416
+ mean_residual = torch.mean(residuals).cpu().numpy()
417
+ print(f"Iter {iter}: Mean residual = {mean_residual:.6f}")
418
+
419
+ huber_weights = compute_huber_weights_triton(residuals, delta)
420
+
421
+ combined_weights = init_weights_torch * huber_weights
422
+ combined_weights_sum = torch.sum(combined_weights)
423
+ if combined_weights_sum > 1e-12:
424
+ combined_weights /= combined_weights_sum
425
+ else:
426
+ combined_weights = init_weights_torch / torch.sum(init_weights_torch)
427
+
428
+ combined_weights_np = combined_weights.cpu().numpy()
429
+ s_new, R_new, t_new = weighted_estimate_sim3_numba_triton(
430
+ src, tgt, combined_weights_np, align_method=align_method
431
+ )
432
+
433
+ param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
434
+ rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
435
+
436
+ residuals_np = residuals.cpu().numpy()
437
+ huber_loss_values = np.where(
438
+ residuals_np <= delta, 0.5 * residuals_np**2, delta * (residuals_np - 0.5 * delta)
439
+ )
440
+ current_error = np.sum(huber_loss_values * init_weights)
441
+
442
+ if (param_change < tol and rot_angle < np.radians(0.1)) or (
443
+ abs(prev_error - current_error) < tol * prev_error
444
+ ):
445
+ print(f"Converged at iteration {iter}")
446
+ break
447
+
448
+ s, R, t = s_new, R_new, t_new
449
+ s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
450
+ R_torch = torch.from_numpy(R).cuda().float()
451
+ t_torch = torch.from_numpy(t).cuda().float()
452
+ prev_error = current_error
453
+
454
+ return s, R, t
455
+
456
+
457
+ def warmup_triton():
458
+ print("\nWarming up Triton functions...")
459
+
460
+ n_points = 10000
461
+ src = np.random.randn(n_points, 3).astype(np.float32)
462
+ tgt = np.random.randn(n_points, 3).astype(np.float32)
463
+ weights = np.ones(n_points, dtype=np.float32)
464
+
465
+ src_torch = torch.from_numpy(src).cuda().float()
466
+ tgt_torch = torch.from_numpy(tgt).cuda().float()
467
+ weights_torch = torch.from_numpy(weights).cuda().float()
468
+
469
+ R = np.eye(3, dtype=np.float32)
470
+ t = np.zeros(3, dtype=np.float32)
471
+ s = np.float32(1.0)
472
+ delta = np.float32(0.1)
473
+
474
+ R_torch = torch.from_numpy(R).cuda().float()
475
+ t_torch = torch.from_numpy(t).cuda().float()
476
+ s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
477
+
478
+ try:
479
+ _, _ = apply_transformation_residual_triton(
480
+ src_torch, tgt_torch, s_torch, R_torch, t_torch
481
+ )
482
+ print(" - apply_transformation_residual_triton warmed up.")
483
+ except Exception as e:
484
+ print(f" ! Failed to warm up apply_transformation_residual_triton: {e}")
485
+
486
+ try:
487
+ _, _ = compute_weighted_mean_triton(src_torch, weights_torch)
488
+ print(" - compute_weighted_mean_triton warmed up.")
489
+ except Exception as e:
490
+ print(f" ! Failed to warm up compute_weighted_mean_triton: {e}")
491
+
492
+ try:
493
+ mu_src, _ = compute_weighted_mean_triton(src_torch, weights_torch)
494
+ mu_tgt, _ = compute_weighted_mean_triton(tgt_torch, weights_torch)
495
+ _ = compute_weighted_covariance_triton(src_torch, tgt_torch, weights_torch, mu_src, mu_tgt)
496
+ print(" - compute_weighted_covariance_triton warmed up.")
497
+ except Exception as e:
498
+ print(f" ! Failed to warm up compute_weighted_covariance_triton: {e}")
499
+
500
+ try:
501
+ residuals = torch.abs(torch.randn(n_points, device="cuda", dtype=torch.float32))
502
+ _ = compute_huber_weights_triton(residuals, delta)
503
+ print(" - compute_huber_weights_triton warmed up.")
504
+ except Exception as e:
505
+ print(f" ! Failed to warm up compute_huber_weights_triton: {e}")
506
+
507
+ print("Triton warm-up complete.\n")
508
+
509
+
510
+ def print_gpu_memory():
511
+ if torch.cuda.is_available():
512
+ allocated = torch.cuda.memory_allocated() / 1024**3 # GB
513
+ cached = torch.cuda.memory_reserved() / 1024**3 # GB
514
+ print(f"GPU Memory Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
515
+
516
+
517
+ if __name__ == "__main__":
518
+
519
+ warmup_triton()
520
+
521
+ n_points = 7_500_000
522
+ src = np.random.randn(n_points, 3).astype(np.float32)
523
+
524
+ true_R = np.array([[0.866, -0.5, 0], [0.5, 0.866, 0], [0, 0, 1]], dtype=np.float32)
525
+ true_t = np.array([1.0, 2.0, 0.5], dtype=np.float32)
526
+ true_s = 1.2
527
+
528
+ tgt = true_s * (src @ true_R.T) + true_t
529
+ tgt += 0.01 * np.random.randn(*tgt.shape).astype(np.float32)
530
+
531
+ weights = np.ones(n_points, dtype=np.float32)
532
+
533
+ print_gpu_memory()
534
+
535
+ s, R, t = robust_weighted_estimate_sim3_triton(
536
+ src, tgt, weights, delta=0.1, max_iters=5, align_method="sim3"
537
+ )
538
+
539
+ print(f"\nEstimated scale: {s:.6f}")
540
+ print(f"Estimated rotation:\n{R}")
541
+ print(f"Estimated translation: {t}")
542
+
543
+ print_gpu_memory()
Depth-Anything-3/da3_streaming/loop_utils/config_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import yaml
18
+
19
+
20
+ def load_config(path, default_path=None):
21
+ """
22
+ Loads config file.
23
+
24
+ Args:
25
+ path (str): path to config file.
26
+ default_path (str, optional): whether to use default path. Defaults to None.
27
+
28
+ Returns:
29
+ cfg (dict): config dict.
30
+
31
+ """
32
+ # load configuration from per scene/dataset cfg.
33
+ with open(path) as f:
34
+ cfg_special = yaml.full_load(f)
35
+
36
+ inherit_from = cfg_special.get("inherit_from")
37
+
38
+ if inherit_from is not None:
39
+ cfg = load_config(inherit_from, default_path)
40
+ elif default_path is not None:
41
+ with open(default_path) as f:
42
+ cfg = yaml.full_load(f)
43
+ else:
44
+ cfg = dict()
45
+
46
+ # merge per dataset cfg. and main cfg.
47
+ update_recursive(cfg, cfg_special)
48
+
49
+ return cfg
50
+
51
+
52
+ def update_recursive(dict1, dict2):
53
+ """
54
+ Update two config dictionaries recursively. dict1 get masked by dict2, and we retuen dict1.
55
+
56
+ Args:
57
+ dict1 (dict): first dictionary to be updated.
58
+ dict2 (dict): second dictionary which entries should be used.
59
+ """
60
+ for k, v in dict2.items():
61
+ if k not in dict1:
62
+ dict1[k] = dict()
63
+ if isinstance(v, dict):
64
+ update_recursive(dict1[k], v)
65
+ else:
66
+ dict1[k] = v
Depth-Anything-3/da3_streaming/loop_utils/logging_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import rich
18
+
19
+ _log_styles = {
20
+ "DA3-Streaming": "bold green",
21
+ }
22
+
23
+
24
+ def get_style(tag):
25
+ if tag in _log_styles.keys():
26
+ return _log_styles[tag]
27
+ return "bold blue"
28
+
29
+
30
+ def Log(*args, tag="DA3-Streaming"):
31
+ style = get_style(tag)
32
+ rich.print(f"[{style}]{tag}:[/{style}]", *args)
Depth-Anything-3/da3_streaming/loop_utils/loop_detector.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import argparse
18
+ import os
19
+ import sys
20
+ from pathlib import Path
21
+ import faiss
22
+ import torch
23
+ import torchvision.transforms as T
24
+ from PIL import Image
25
+ from torch import nn
26
+ from tqdm import tqdm
27
+
28
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
29
+ SALAD_ROOT = os.path.join(CURRENT_DIR, "salad")
30
+ if SALAD_ROOT not in sys.path:
31
+ sys.path.insert(0, SALAD_ROOT)
32
+ from loop_utils.salad.models import helper
33
+
34
+
35
+ class VPRModel(nn.Module):
36
+ """This is the main model for Visual Place Recognition
37
+ we use Pytorch Lightning for modularity purposes.
38
+
39
+ Args:
40
+ pl (_type_): _description_
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ # ---- Backbone
46
+ backbone_arch="resnet50",
47
+ backbone_config={},
48
+ # ---- Aggregator
49
+ agg_arch="ConvAP",
50
+ agg_config={},
51
+ ):
52
+ super().__init__()
53
+
54
+ # Backbone
55
+ self.encoder_arch = backbone_arch
56
+ self.backbone_config = backbone_config
57
+
58
+ # Aggregator
59
+ self.agg_arch = agg_arch
60
+ self.agg_config = agg_config
61
+
62
+ # ----------------------------------
63
+ # get the backbone and the aggregator
64
+ self.backbone = helper.get_backbone(backbone_arch, backbone_config)
65
+ self.aggregator = helper.get_aggregator(agg_arch, agg_config)
66
+
67
+ # the forward pass of the lightning model
68
+ def forward(self, x):
69
+ x = self.backbone(x)
70
+ x = self.aggregator(x)
71
+ return x
72
+
73
+
74
+ class LoopDetector:
75
+ """Loop detector class for detecting loop closures in image sequences"""
76
+
77
+ def __init__(self, image_dir, output="loop_closures.txt", config=None):
78
+ """Initialize the loop detector
79
+
80
+ Args:
81
+ image_dir: Directory path containing images
82
+ ckpt_path: Model checkpoint path
83
+ image_size: Image resize dimensions [height width]
84
+ batch_size: Batch size for processing
85
+ similarity_threshold: Similarity threshold for loop closure
86
+ top_k: Number of nearest neighbors to check for each image
87
+ use_nms: Whether to use Non-Maximum Suppression (NMS) filtering
88
+ nms_threshold: NMS threshold for minimum frame difference between loop pairs
89
+ output: Output file path
90
+ """
91
+ self.config = config
92
+ self.image_dir = image_dir
93
+ self.ckpt_path = self.config["Weights"]["SALAD"]
94
+ self.image_size = self.config["Loop"]["SALAD"]["image_size"]
95
+ self.batch_size = self.config["Loop"]["SALAD"]["batch_size"]
96
+ self.similarity_threshold = self.config["Loop"]["SALAD"]["similarity_threshold"]
97
+ self.top_k = self.config["Loop"]["SALAD"]["top_k"]
98
+ self.use_nms = self.config["Loop"]["SALAD"]["use_nms"]
99
+ self.nms_threshold = self.config["Loop"]["SALAD"]["nms_threshold"]
100
+ self.output = output
101
+
102
+ self.model = None
103
+ self.device = None
104
+ self.image_paths = None
105
+ self.descriptors = None
106
+ self.loop_closures = None
107
+
108
+ def _input_transform(self, image_size=None):
109
+ """Create image transformation function"""
110
+ MEAN = [0.485, 0.456, 0.406]
111
+ STD = [0.229, 0.224, 0.225]
112
+ if image_size:
113
+ return T.Compose(
114
+ [
115
+ T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
116
+ T.ToTensor(),
117
+ T.Normalize(mean=MEAN, std=STD),
118
+ ]
119
+ )
120
+ else:
121
+ return T.Compose([T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])
122
+
123
+ def load_model(self):
124
+ """Load model"""
125
+ model = VPRModel(
126
+ backbone_arch="dinov2_vitb14",
127
+ backbone_config={
128
+ "num_trainable_blocks": 4,
129
+ "return_token": True,
130
+ "norm_layer": True,
131
+ },
132
+ agg_arch="SALAD",
133
+ agg_config={
134
+ "num_channels": 768,
135
+ "num_clusters": 64,
136
+ "cluster_dim": 128,
137
+ "token_dim": 256,
138
+ },
139
+ )
140
+
141
+ model.load_state_dict(torch.load(self.ckpt_path))
142
+ model = model.eval()
143
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
+ model = model.to(device)
145
+ print(f"Model loaded: {self.ckpt_path}")
146
+
147
+ self.model = model
148
+ self.device = device
149
+ return model, device
150
+
151
+ def get_image_paths(self):
152
+ """Get paths of all image files in directory"""
153
+ image_extensions = [".jpg", ".jpeg", ".png"]
154
+ image_paths = []
155
+
156
+ for ext in image_extensions:
157
+ image_paths.extend(list(Path(self.image_dir).glob(f"*{ext}")))
158
+ image_paths.extend(list(Path(self.image_dir).glob(f"*{ext.upper()}")))
159
+
160
+ image_paths = sorted(image_paths)
161
+ self.image_paths = image_paths
162
+ return image_paths
163
+
164
+ def extract_descriptors(self):
165
+ """Extract image feature descriptors"""
166
+ if self.model is None or self.device is None:
167
+ self.load_model()
168
+
169
+ if self.image_paths is None:
170
+ self.get_image_paths()
171
+
172
+ transform = self._input_transform(self.image_size)
173
+ descriptors = []
174
+
175
+ for i in tqdm(
176
+ range(0, len(self.image_paths), self.batch_size), desc="Extracting features"
177
+ ):
178
+ batch_paths = self.image_paths[i : i + self.batch_size]
179
+ batch_imgs = []
180
+
181
+ for path in batch_paths:
182
+ try:
183
+ img = Image.open(path).convert("RGB")
184
+ img = transform(img)
185
+ batch_imgs.append(img)
186
+ except Exception as e:
187
+ print(f"Error processing image {path}: {e}")
188
+ img = (
189
+ torch.zeros(3, 224, 224)
190
+ if self.image_size is None
191
+ else torch.zeros(3, self.image_size[0], self.image_size[1])
192
+ )
193
+ batch_imgs.append(img)
194
+
195
+ batch_tensor = torch.stack(batch_imgs).to(self.device)
196
+
197
+ with torch.no_grad():
198
+ with torch.autocast(
199
+ device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16
200
+ ):
201
+ batch_descriptors = self.model(batch_tensor).cpu()
202
+
203
+ descriptors.append(batch_descriptors)
204
+
205
+ self.descriptors = torch.cat(descriptors)
206
+ return self.descriptors
207
+
208
+ def _apply_nms_filter(self, loop_closures, nms_threshold):
209
+ """Apply Non-Maximum Suppression (NMS) filtering to loop pairs"""
210
+ if not loop_closures or nms_threshold <= 0:
211
+ return loop_closures
212
+
213
+ sorted_loops = sorted(loop_closures, key=lambda x: x[2], reverse=True)
214
+ filtered_loops = []
215
+ suppressed = set()
216
+
217
+ max_frame = max(max(idx1, idx2) for idx1, idx2, _ in loop_closures)
218
+
219
+ for idx1, idx2, sim in sorted_loops:
220
+ if idx1 in suppressed or idx2 in suppressed:
221
+ continue
222
+
223
+ filtered_loops.append((idx1, idx2, sim))
224
+
225
+ suppress_range = set()
226
+
227
+ start1 = max(0, idx1 - nms_threshold)
228
+ end1 = min(idx1 + nms_threshold + 1, idx2)
229
+ suppress_range.update(range(start1, end1))
230
+
231
+ start2 = max(idx1 + 1, idx2 - nms_threshold)
232
+ end2 = min(idx2 + nms_threshold + 1, max_frame + 1)
233
+ suppress_range.update(range(start2, end2))
234
+
235
+ suppressed.update(suppress_range)
236
+
237
+ return filtered_loops
238
+
239
+ def _ensure_decending_order(self, tuples_list):
240
+ return [(max(a, b), min(a, b), score) for a, b, score in tuples_list]
241
+
242
+ def find_loop_closures(self):
243
+ """Find loop closures"""
244
+ if self.descriptors is None:
245
+ self.extract_descriptors()
246
+
247
+ embed_size = self.descriptors.shape[1]
248
+ faiss_index = faiss.IndexFlatIP(embed_size)
249
+
250
+ normalized_descriptors = self.descriptors.numpy()
251
+ faiss_index.add(normalized_descriptors)
252
+
253
+ similarities, indices = faiss_index.search(
254
+ normalized_descriptors, self.top_k + 1
255
+ ) # +1 because self is most similar
256
+
257
+ loop_closures = []
258
+ for i in range(len(self.descriptors)):
259
+ # Skip first result (self)
260
+ for j in range(1, self.top_k + 1):
261
+ neighbor_idx = indices[i, j]
262
+ similarity = similarities[i, j]
263
+
264
+ if similarity > self.similarity_threshold and abs(i - neighbor_idx) > 10:
265
+ if i < neighbor_idx:
266
+ loop_closures.append((i, neighbor_idx, similarity))
267
+ else:
268
+ loop_closures.append((neighbor_idx, i, similarity))
269
+
270
+ loop_closures = list(set(loop_closures))
271
+ loop_closures.sort(key=lambda x: x[2], reverse=True)
272
+
273
+ if self.use_nms and self.nms_threshold > 0:
274
+ loop_closures = self._apply_nms_filter(loop_closures, self.nms_threshold)
275
+
276
+ self.loop_closures = self._ensure_decending_order(loop_closures)
277
+ return self.loop_closures
278
+
279
+ def save_results(self):
280
+ """Save loop detection results to file"""
281
+ if self.loop_closures is None:
282
+ self.find_loop_closures()
283
+
284
+ with open(self.output, "w") as f:
285
+ f.write("# Loop Detection Results (index1, index2, similarity)\n")
286
+ if self.use_nms:
287
+ f.write(f"# NMS filtering applied, threshold: {self.nms_threshold}\n")
288
+ f.write("\n# Loop pairs:\n")
289
+ for i, j, sim in self.loop_closures:
290
+ f.write(f"{i}, {j}, {sim:.4f}\n")
291
+ f.write("\n# Image path list:\n")
292
+ for i, path in enumerate(self.image_paths):
293
+ f.write(f"# {i}: {path}\n")
294
+
295
+ print(f"Found {len(self.loop_closures)} loop pairs, results saved to {self.output}")
296
+ if self.use_nms:
297
+ print(f"NMS filtering applied, threshold: {self.nms_threshold}")
298
+
299
+ if self.loop_closures:
300
+ print("\nTop 10 loop pairs:")
301
+ for i, (idx1, idx2, sim) in enumerate(self.loop_closures[:10]):
302
+ print(f"{idx1}, {idx2}, similarity: {sim:.4f}")
303
+ if i >= 9:
304
+ break
305
+
306
+ def get_loop_list(self):
307
+ return [(idx1, idx2) for idx1, idx2, _ in self.loop_closures]
308
+
309
+ def run(self):
310
+ """Run complete loop detection pipeline"""
311
+ print("Loading model...")
312
+ if self.model is None:
313
+ self.load_model()
314
+
315
+ self.get_image_paths()
316
+ if not self.image_paths:
317
+ print(f"No image files found in {self.image_dir}")
318
+ return
319
+
320
+ print(f"Found {len(self.image_paths)} image files")
321
+
322
+ self.extract_descriptors()
323
+
324
+ self.find_loop_closures()
325
+
326
+ self.save_results()
327
+
328
+ return self.loop_closures
329
+
330
+
331
+ def main():
332
+ parser = argparse.ArgumentParser(description="Loop detection using SALAD model")
333
+ parser.add_argument(
334
+ "--image_dir",
335
+ type=str,
336
+ default="/media/deng/Data/KITTIdataset/data_odometry_color/dataset/sequences/00/image_2",
337
+ help="Directory path containing images",
338
+ )
339
+ parser.add_argument(
340
+ "--ckpt_path", type=str, default="./weights/dino_salad.ckpt", help="Model checkpoint path"
341
+ )
342
+ parser.add_argument(
343
+ "--image_size",
344
+ nargs=2,
345
+ type=int,
346
+ default=[336, 336],
347
+ help="Image resize dimensions [height width]",
348
+ )
349
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing")
350
+ parser.add_argument(
351
+ "--similarity_threshold",
352
+ type=float,
353
+ default=0.7,
354
+ help="Similarity threshold for loop closure",
355
+ )
356
+ parser.add_argument(
357
+ "--top_k", type=int, default=5, help="Number of nearest neighbors to check for each image"
358
+ )
359
+ parser.add_argument("--output", type=str, default="loop_closures.txt", help="Output file path")
360
+ parser.add_argument(
361
+ "--use_nms",
362
+ action="store_true",
363
+ default=True,
364
+ help="Whether to use Non-Maximum Suppression (NMS) filtering",
365
+ )
366
+ parser.add_argument(
367
+ "--nms_threshold",
368
+ type=int,
369
+ default=25,
370
+ help="NMS threshold for minimum frame difference between loop pairs",
371
+ )
372
+
373
+ args = parser.parse_args()
374
+
375
+ detector = LoopDetector(
376
+ image_dir=args.image_dir,
377
+ ckpt_path=args.ckpt_path,
378
+ image_size=args.image_size,
379
+ batch_size=args.batch_size,
380
+ similarity_threshold=args.similarity_threshold,
381
+ top_k=args.top_k,
382
+ use_nms=args.use_nms,
383
+ nms_threshold=args.nms_threshold,
384
+ output=args.output,
385
+ )
386
+
387
+ detector.run()
388
+
389
+
390
+ if __name__ == "__main__":
391
+ main()
Depth-Anything-3/da3_streaming/loop_utils/loop_refinement.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import numba as nb
18
+ import numpy as np
19
+ import pypose as pp
20
+ import sim3solve
21
+ import torch
22
+ from einops import parse_shape, rearrange
23
+ from scipy.spatial.transform import Rotation as R
24
+
25
+
26
+ def make_pypose_Sim3(rot, t, s):
27
+ q = R.from_matrix(rot).as_quat()
28
+ data = np.concatenate([t, q, np.array(s).reshape((1,))])
29
+ return pp.Sim3(data)
30
+
31
+
32
+ def SE3_to_Sim3(x: pp.SE3):
33
+ out = torch.cat((x.data, torch.ones_like(x.data[..., :1])), dim=-1)
34
+ return pp.Sim3(out)
35
+
36
+
37
+ @nb.njit(cache=True)
38
+ def _format(es):
39
+ return np.asarray(es, dtype=np.int64).reshape((-1, 2))[1:]
40
+
41
+
42
+ @nb.njit(cache=True)
43
+ def reduce_edges(flow_mag, ii, jj, max_num_edges, nms):
44
+ es = [(-1, -1)]
45
+
46
+ if ii.size == 0:
47
+ return _format(es)
48
+
49
+ Ni, Nj = (ii.max() + 1), (jj.max() + 1)
50
+ ignore_lookup = np.zeros((Ni, Nj), dtype=nb.bool_)
51
+
52
+ idxs = np.argsort(flow_mag)
53
+ for idx in idxs: # edge index
54
+
55
+ if len(es) > max_num_edges:
56
+ break
57
+
58
+ i = ii[idx]
59
+ j = jj[idx]
60
+ mag = flow_mag[idx]
61
+
62
+ if (j - i) < 30:
63
+ continue
64
+
65
+ if mag >= 1000: # i.e., inf
66
+ continue
67
+
68
+ if ignore_lookup[i, j]:
69
+ continue
70
+
71
+ es.append((i, j))
72
+
73
+ for di in range(-nms, nms + 1):
74
+ i1 = i + di
75
+
76
+ if 0 <= i1 < Ni:
77
+ ignore_lookup[i1, j] = True
78
+
79
+ return _format(es)
80
+
81
+
82
+ @nb.njit(cache=True)
83
+ def umeyama_alignment(x: np.ndarray, y: np.ndarray):
84
+ """
85
+ The following function was copied from:
86
+ https://github.com/MichaelGrupp/evo/blob/3067541b350528fe46375423e5bc3a7c42c06c63/evo/core/geometry.py#L35
87
+
88
+ Computes the least squares solution parameters of an Sim(m) matrix
89
+ that minimizes the distance between a set of registered points.
90
+ Umeyama, Shinji: Least-squares estimation of transformation parameters
91
+ between two point patterns. IEEE PAMI, 1991
92
+ :param x: mxn matrix of points, m = dimension, n = nr. of data points
93
+ :param y: mxn matrix of points, m = dimension, n = nr. of data points
94
+ :param with_scale: set to True to align also the scale (default: 1.0 scale)
95
+ :return: r, t, c - rotation matrix, translation vector and scale factor
96
+ """
97
+
98
+ # m = dimension, n = nr. of data points
99
+ m, n = x.shape
100
+
101
+ # means, eq. 34 and 35
102
+ mean_x = x.sum(axis=1) / n
103
+ mean_y = y.sum(axis=1) / n
104
+
105
+ # variance, eq. 36
106
+ # "transpose" for column subtraction
107
+ sigma_x = 1.0 / n * (np.linalg.norm(x - mean_x[:, np.newaxis]) ** 2)
108
+
109
+ # covariance matrix, eq. 38
110
+ outer_sum = np.zeros((m, m))
111
+ for i in range(n):
112
+ outer_sum += np.outer((y[:, i] - mean_y), (x[:, i] - mean_x))
113
+ cov_xy = np.multiply(1.0 / n, outer_sum)
114
+
115
+ # SVD (text betw. eq. 38 and 39)
116
+ u, d, v = np.linalg.svd(cov_xy)
117
+ if np.count_nonzero(d > np.finfo(d.dtype).eps) < m - 1:
118
+ return None, None, None # Degenerate covariance rank, Umeyama alignment is not possible
119
+
120
+ # S matrix, eq. 43
121
+ s = np.eye(m)
122
+ if np.linalg.det(u) * np.linalg.det(v) < 0.0:
123
+ # Ensure a RHS coordinate system (Kabsch algorithm).
124
+ s[m - 1, m - 1] = -1
125
+
126
+ # rotation, eq. 40
127
+ r = u.dot(s).dot(v)
128
+
129
+ # scale & translation, eq. 42 and 41
130
+ c = 1 / sigma_x * np.trace(np.diag(d).dot(s))
131
+ t = mean_y - np.multiply(c, r.dot(mean_x))
132
+
133
+ return r, t, c
134
+
135
+
136
+ @nb.njit(cache=True)
137
+ def ransac_umeyama(src_points, dst_points, iterations=1, threshold=0.1):
138
+ best_inliers = 0
139
+ best_R = None
140
+ best_t = None
141
+ best_s = None
142
+ for _ in range(iterations):
143
+ # Randomly select three points
144
+ indices = np.random.choice(src_points.shape[0], 3, replace=False)
145
+ src_sample = src_points[indices]
146
+ dst_sample = dst_points[indices]
147
+
148
+ # Estimate transformation
149
+ R, t, s = umeyama_alignment(src_sample.T, dst_sample.T)
150
+ if t is None:
151
+ continue
152
+
153
+ # Apply transformation
154
+ transformed = (src_points @ (R * s).T) + t
155
+
156
+ # Count inliers (not ideal because depends on scene scale)
157
+ distances = np.sum((transformed - dst_points) ** 2, axis=1) ** 0.5
158
+ inlier_mask = distances < threshold
159
+ inliers = np.sum(inlier_mask)
160
+
161
+ # Update best transformation
162
+ if inliers > best_inliers:
163
+ best_inliers = inliers
164
+ best_R, best_t, best_s = umeyama_alignment(
165
+ src_points[inlier_mask].T, dst_points[inlier_mask].T
166
+ )
167
+
168
+ return best_R, best_t, best_s, best_inliers
169
+
170
+
171
+ def batch_jacobian(func, x):
172
+ def _func_sum(*x):
173
+ return func(*x).sum(dim=0)
174
+
175
+ _, b, c = torch.autograd.functional.jacobian(_func_sum, x, vectorize=True)
176
+ return rearrange(torch.stack((b, c)), "N O B I -> N B O I", N=2)
177
+
178
+
179
+ def _residual(C, Gi, Gj):
180
+ assert parse_shape(C, "N _") == parse_shape(Gi, "N _") == parse_shape(Gj, "N _")
181
+ out = C @ pp.Exp(Gi) @ pp.Exp(Gj).Inv()
182
+ return out.Log().tensor()
183
+
184
+
185
+ def residual(Ginv, input_poses, dSloop, ii, jj, jacobian=False):
186
+
187
+ # prep
188
+ device = Ginv.device
189
+ assert parse_shape(input_poses, "_ d") == dict(d=7)
190
+ pred_inv_poses = SE3_to_Sim3(input_poses).Inv()
191
+
192
+ # free variables
193
+ n, _ = pred_inv_poses.shape
194
+ kk = torch.arange(1, n, device=device)
195
+ ll = kk - 1
196
+
197
+ # constants
198
+ Ti = pred_inv_poses[kk]
199
+ Tj = pred_inv_poses[ll]
200
+ dSij = Tj @ Ti.Inv()
201
+
202
+ constants = torch.cat((dSij, dSloop), dim=0)
203
+ iii = torch.cat((kk, ii))
204
+ jjj = torch.cat((ll, jj))
205
+ resid = _residual(constants, Ginv[iii], Ginv[jjj])
206
+
207
+ if not jacobian:
208
+ return resid
209
+
210
+ J_Ginv_i, J_Ginv_j = batch_jacobian(_residual, (constants, Ginv[iii], Ginv[jjj]))
211
+ return resid, (J_Ginv_i, J_Ginv_j, iii, jjj)
212
+
213
+
214
+ def perform_updates(
215
+ input_poses, dSloop, ii_loop, jj_loop, iters=30, ep=0.0, lmbda=1e-6, fix_opt_window=False
216
+ ):
217
+ """Run the Levenberg Marquardt algorithm"""
218
+
219
+ input_poses = input_poses.clone()
220
+
221
+ if fix_opt_window:
222
+ freen = torch.cat((ii_loop, jj_loop)).max().item() + 1
223
+ else:
224
+ freen = -1
225
+
226
+ Ginv = SE3_to_Sim3(input_poses).Inv().Log()
227
+
228
+ residual_history = []
229
+
230
+ for itr in range(iters):
231
+ resid, (J_Ginv_i, J_Ginv_j, iii, jjj) = residual(
232
+ Ginv, input_poses, dSloop, ii_loop, jj_loop, jacobian=True
233
+ )
234
+ residual_history.append(resid.square().mean().item())
235
+ print(f"resid: {resid.square().mean().item()}")
236
+ (delta_pose,) = sim3solve.solve_system(
237
+ J_Ginv_i, J_Ginv_j, iii, jjj, resid, ep, lmbda, freen
238
+ )
239
+ assert Ginv.shape == delta_pose.shape
240
+ Ginv_tmp = Ginv + delta_pose
241
+
242
+ new_resid = residual(Ginv_tmp, input_poses, dSloop, ii_loop, jj_loop)
243
+ if new_resid.square().mean() < residual_history[-1]:
244
+ Ginv = Ginv_tmp
245
+ lmbda /= 2
246
+ else:
247
+ lmbda *= 2
248
+
249
+ if (
250
+ (residual_history[-1] < 1e-5)
251
+ and (itr >= 4)
252
+ and ((residual_history[-5] / residual_history[-1]) < 1.5)
253
+ ):
254
+ break
255
+
256
+ return pp.Exp(Ginv).Inv()
257
+
258
+
259
+ def pose_refinement(pred_poses, loop_poses, loop_ii, loop_jj):
260
+
261
+ final_est = perform_updates(pred_poses, loop_poses, loop_ii, loop_jj, iters=30)
262
+
263
+ safe_i = loop_ii.max().item() + 1
264
+ aa = SE3_to_Sim3(pred_poses.cpu())
265
+ final_est = (aa[[safe_i]] * final_est[[safe_i]].Inv()) * final_est
266
+ output = final_est[:safe_i]
267
+
268
+ return output
Depth-Anything-3/da3_streaming/loop_utils/sim3loop.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import time
18
+ from typing import List, Tuple
19
+ import numpy as np
20
+ import pypose as pp
21
+ import torch
22
+ from fastloop.solve_python import solve_system_py
23
+ from scipy.spatial.transform import Rotation as R
24
+
25
+ cpp_version = False
26
+ try:
27
+ import sim3solve
28
+
29
+ cpp_version = True
30
+ except Exception:
31
+ print("Sim3solve of C++ Version failed, Will using Python Version.")
32
+
33
+
34
+ class Sim3LoopOptimizer:
35
+ """
36
+ Loop closure optimizer for sequences of Sim3 transformations
37
+
38
+ Input:
39
+ - sequential_transforms: List[Tuple[float, np.ndarray, np.ndarray]]
40
+ Each element is (s, R, t), where s is scalar scale, R is [3,3] rotation matrix,
41
+ t is [3,] translation vector
42
+ - loop_constraints: List[Tuple[int, int, Tuple[float, np.ndarray, np.ndarray]]]
43
+ Each element is (i, j, (s, R, t)), representing a loop closure constraint
44
+ from frame i to frame j
45
+
46
+ Output:
47
+ - Optimized sequential_transforms
48
+ """
49
+
50
+ def __init__(self, config, device="cpu"):
51
+ self.device = device
52
+ self.config = config
53
+ self.solve_system_version = self.config["Loop"]["SIM3_Optimizer"][
54
+ "lang_version"
55
+ ] # choose between 'python' and 'cpp'
56
+
57
+ if not cpp_version:
58
+ self.solve_system_version = "python"
59
+
60
+ def numpy_to_pypose_sim3(self, s: float, R_mat: np.ndarray, t_vec: np.ndarray) -> pp.Sim3:
61
+ """Convert numpy s,R,t to pypose Sim3"""
62
+ q = R.from_matrix(R_mat).as_quat() # [x,y,z,w]
63
+ # pypose requires [t, q, s] format
64
+ data = np.concatenate([t_vec, q, np.array([s])])
65
+ return pp.Sim3(torch.from_numpy(data).float().to(self.device))
66
+
67
+ def pypose_sim3_to_numpy(self, sim3: pp.Sim3) -> Tuple[float, np.ndarray, np.ndarray]:
68
+ """Convert pypose Sim3 to numpy s,R,t"""
69
+ data = sim3.data.cpu().numpy()
70
+ t = data[:3]
71
+ q = data[3:7] # [x,y,z,w]
72
+ s = data[7]
73
+ R_mat = R.from_quat(q).as_matrix()
74
+ return s, R_mat, t
75
+
76
+ def sequential_to_absolute_poses(
77
+ self, sequential_transforms: List[Tuple[float, np.ndarray, np.ndarray]]
78
+ ) -> torch.Tensor:
79
+ """
80
+ Convert sequential relative transforms to absolute pose sequence
81
+ S_01, S_12, S_23, ... -> T_0, T_1, T_2, T_3, ...
82
+ Where T_i is the transform from world coordinate to frame i
83
+ """
84
+ len(sequential_transforms) + 1
85
+ poses = []
86
+
87
+ identity = pp.Sim3(
88
+ torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], device=self.device)
89
+ )
90
+ poses.append(identity)
91
+
92
+ current_pose = identity
93
+ for s, R_mat, t_vec in sequential_transforms:
94
+ rel_transform = self.numpy_to_pypose_sim3(s, R_mat, t_vec)
95
+ current_pose = current_pose @ rel_transform
96
+ poses.append(current_pose)
97
+
98
+ return torch.stack(poses)
99
+
100
+ def absolute_to_sequential_transforms(
101
+ self, absolute_poses: pp.Sim3
102
+ ) -> List[Tuple[float, np.ndarray, np.ndarray]]:
103
+ """
104
+ Convert absolute pose sequence back to sequential relative transforms
105
+ T_0, T_1, T_2, ... -> S_01, S_12, S_23, ...
106
+ """
107
+ sequential_transforms = []
108
+ n = absolute_poses.shape[0]
109
+
110
+ for i in range(n - 1):
111
+ rel_transform = absolute_poses[i].Inv() @ absolute_poses[i + 1]
112
+ s, R_mat, t_vec = self.pypose_sim3_to_numpy(rel_transform)
113
+ sequential_transforms.append((s, R_mat, t_vec))
114
+
115
+ return sequential_transforms
116
+
117
+ def SE3_to_Sim3(self, x: torch.Tensor) -> pp.Sim3:
118
+ """Convert SE3 to Sim3 (add unit scale)"""
119
+ ones = torch.ones_like(x[..., :1])
120
+ out = torch.cat((x, ones), dim=-1)
121
+ return pp.Sim3(out)
122
+
123
+ def build_loop_constraints(
124
+ self, loop_constraints: List[Tuple[int, int, Tuple[float, np.ndarray, np.ndarray]]]
125
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ """Build loop closure constraints"""
127
+ if not loop_constraints:
128
+ return (
129
+ torch.empty(0, 8, device=self.device),
130
+ torch.empty(0, dtype=torch.long),
131
+ torch.empty(0, dtype=torch.long),
132
+ )
133
+
134
+ loop_transforms = []
135
+ ii_loop = []
136
+ jj_loop = []
137
+
138
+ for i, j, (s, R_mat, t_vec) in loop_constraints:
139
+ loop_sim3 = self.numpy_to_pypose_sim3(s, R_mat, t_vec)
140
+ loop_transforms.append(loop_sim3.data)
141
+ ii_loop.append(i)
142
+ jj_loop.append(j)
143
+
144
+ dSloop = pp.Sim3(torch.stack(loop_transforms))
145
+ ii_loop = torch.tensor(ii_loop, dtype=torch.long, device=self.device)
146
+ jj_loop = torch.tensor(jj_loop, dtype=torch.long, device=self.device)
147
+
148
+ return dSloop, ii_loop, jj_loop
149
+
150
+ def residual(self, Ginv, input_poses, dSloop, ii, jj, jacobian=False):
151
+ """Compute residuals (modified from original code)"""
152
+
153
+ def _residual(C, Gi, Gj):
154
+ out = C @ pp.Exp(Gi) @ pp.Exp(Gj).Inv()
155
+ return out.Log().tensor()
156
+
157
+ pred_inv_poses = pp.Sim3(input_poses).Inv()
158
+
159
+ n, _ = pred_inv_poses.shape
160
+ if n > 1:
161
+ kk = torch.arange(1, n, device=self.device)
162
+ ll = kk - 1
163
+ Ti = pred_inv_poses[kk]
164
+ Tj = pred_inv_poses[ll]
165
+ dSij = Tj @ Ti.Inv()
166
+ else:
167
+ kk = torch.empty(0, dtype=torch.long, device=self.device)
168
+ ll = torch.empty(0, dtype=torch.long, device=self.device)
169
+ dSij = pp.Sim3(torch.empty(0, 8, device=self.device))
170
+
171
+ constants = (
172
+ torch.cat((dSij.data, dSloop.data), dim=0) if dSloop.shape[0] > 0 else dSij.data
173
+ )
174
+ if constants.shape[0] > 0:
175
+ constants = pp.Sim3(constants)
176
+ iii = torch.cat((kk, ii))
177
+ jjj = torch.cat((ll, jj))
178
+ resid = _residual(constants, Ginv[iii], Ginv[jjj])
179
+ else:
180
+ iii = torch.empty(0, dtype=torch.long, device=self.device)
181
+ jjj = torch.empty(0, dtype=torch.long, device=self.device)
182
+ resid = torch.empty(0, device=self.device)
183
+
184
+ if not jacobian:
185
+ return resid
186
+
187
+ if constants.shape[0] > 0:
188
+
189
+ def batch_jacobian(func, x):
190
+ def _func_sum(*x):
191
+ return func(*x).sum(dim=0)
192
+
193
+ _, b, c = torch.autograd.functional.jacobian(_func_sum, x, vectorize=True)
194
+ from einops import rearrange
195
+
196
+ return rearrange(torch.stack((b, c)), "N O B I -> N B O I", N=2)
197
+
198
+ J_Ginv_i, J_Ginv_j = batch_jacobian(_residual, (constants, Ginv[iii], Ginv[jjj]))
199
+ else:
200
+ J_Ginv_i = torch.empty(0, device=self.device)
201
+ J_Ginv_j = torch.empty(0, device=self.device)
202
+
203
+ return resid, (J_Ginv_i, J_Ginv_j, iii, jjj)
204
+
205
+ def optimize(
206
+ self,
207
+ sequential_transforms: List[Tuple[float, np.ndarray, np.ndarray]],
208
+ loop_constraints: List[Tuple[int, int, Tuple[float, np.ndarray, np.ndarray]]],
209
+ max_iterations: int = None,
210
+ lambda_init: float = None,
211
+ ) -> List[Tuple[float, np.ndarray, np.ndarray]]:
212
+ """
213
+ Main optimization function
214
+
215
+ Args:
216
+ sequential_transforms: Input sequence of transforms
217
+ loop_constraints: List of loop closure constraints
218
+ max_iterations: Maximum iterations
219
+ lambda_init: Initial lambda for L-M algorithm
220
+
221
+ Returns:
222
+ Optimized sequence of transforms
223
+ """
224
+ if max_iterations is None:
225
+ max_iterations = self.config["Loop"]["SIM3_Optimizer"]["max_iterations"]
226
+ if lambda_init is None:
227
+ lambda_init = eval(self.config["Loop"]["SIM3_Optimizer"]["lambda_init"])
228
+
229
+ input_poses = self.sequential_to_absolute_poses(sequential_transforms)
230
+
231
+ dSloop, ii_loop, jj_loop = self.build_loop_constraints(loop_constraints)
232
+
233
+ if len(loop_constraints) == 0:
234
+ print("Warning: No loop constraints provided, returning original transforms")
235
+ return sequential_transforms
236
+
237
+ Ginv = pp.Sim3(input_poses).Inv().Log()
238
+ lmbda = lambda_init
239
+ residual_history = []
240
+
241
+ print(
242
+ f"Starting optimization with {len(sequential_transforms)} poses \
243
+ and {len(loop_constraints)} loop constraints"
244
+ )
245
+
246
+ # L-M loop
247
+ for itr in range(max_iterations):
248
+ resid, (J_Ginv_i, J_Ginv_j, iii, jjj) = self.residual(
249
+ Ginv, input_poses, dSloop, ii_loop, jj_loop, jacobian=True
250
+ )
251
+
252
+ if resid.numel() == 0:
253
+ print("No residuals to optimize")
254
+ break
255
+
256
+ current_cost = resid.square().mean().item()
257
+ residual_history.append(current_cost)
258
+
259
+ try: # Solve linear system
260
+ begin_time = time.time()
261
+ if self.solve_system_version == "cpp":
262
+ (delta_pose,) = sim3solve.solve_system(
263
+ J_Ginv_i, J_Ginv_j, iii, jjj, resid, 0.0, lmbda, -1
264
+ )
265
+ elif self.solve_system_version == "python":
266
+ delta_pose = solve_system_py(
267
+ J_Ginv_i, J_Ginv_j, iii, jjj, resid, 0.0, lmbda, -1
268
+ )
269
+ else:
270
+ print("Solver version has not been chosen! ('python' or 'cpp')")
271
+ end_time = time.time()
272
+ except Exception as e:
273
+ print(f"Solver failed at iteration {itr}: {e}")
274
+ break
275
+
276
+ Ginv_tmp = Ginv + delta_pose
277
+
278
+ new_resid = self.residual(Ginv_tmp, input_poses, dSloop, ii_loop, jj_loop)
279
+ new_cost = new_resid.square().mean().item() if new_resid.numel() > 0 else float("inf")
280
+
281
+ # L-M
282
+ if new_cost < current_cost:
283
+ Ginv = Ginv_tmp
284
+ lmbda /= 2
285
+ print(
286
+ f"Iteration {itr}: cost {current_cost:.14f} -> {new_cost:.14f} (accepted)",
287
+ end=" | ",
288
+ )
289
+ else:
290
+ lmbda *= 2
291
+ print(
292
+ f"Iteration {itr}: cost {current_cost:.14f} -> {new_cost:.14f} (rej) ",
293
+ end=" | ",
294
+ ) # more readible to accepted
295
+
296
+ print(
297
+ f"Time of solver ({self.solve_system_version}): \
298
+ {(end_time - begin_time)*1000:.4f} ms"
299
+ )
300
+
301
+ if (current_cost < 1e-5) and (itr >= 4):
302
+ if len(residual_history) >= 5:
303
+ improvement_ratio = residual_history[-5] / residual_history[-1]
304
+ if improvement_ratio < 1.5:
305
+ print(f"Converged at iteration {itr}")
306
+ break
307
+
308
+ optimized_absolute_poses = pp.Exp(Ginv).Inv()
309
+
310
+ optimized_sequential = self.absolute_to_sequential_transforms(optimized_absolute_poses)
311
+
312
+ print(
313
+ f"Optimization completed. Final cost: \
314
+ {residual_history[-1] if residual_history else 'N/A'}"
315
+ )
316
+
317
+ return optimized_sequential
318
+
319
+
320
+ # ======== TEST CODE ========
321
+
322
+
323
+ def create_ring_transforms(num_poses=6, radius=5.0, rot_noise_deg=2.0):
324
+ """Generate a ring of Sim3 transforms with rotation, adding slight rotational noise"""
325
+ transforms = []
326
+ angle_step = 2 * np.pi / num_poses
327
+
328
+ for i in range(num_poses):
329
+ angle = angle_step
330
+
331
+ # Main rotation (around Z-axis)
332
+ R_z = R.from_euler("z", angle, degrees=False)
333
+
334
+ # Add slight rotational noise (Gaussian noise in degrees)
335
+ noise_angles_deg = np.random.normal(loc=0.0, scale=rot_noise_deg, size=3)
336
+ R_noise = R.from_euler("xyz", noise_angles_deg, degrees=True)
337
+
338
+ # Combine rotations
339
+ R_mat = (R_noise * R_z).as_matrix()
340
+
341
+ # Translation: simulate a circular trajectory
342
+ t = np.array([radius * np.sin(angle), radius * (1 - np.cos(angle)), 0.0])
343
+
344
+ s = np.random.uniform(0.8, 1.2)
345
+
346
+ transforms.append((s, R_mat, t))
347
+
348
+ return transforms
349
+
350
+
351
+ def example_usage():
352
+ optimizer = Sim3LoopOptimizer(solve_system_version="cpp")
353
+
354
+ # Build rotating ring
355
+ sequential_transforms = create_ring_transforms(num_poses=20, radius=3.0)
356
+
357
+ # Add loop closure constraint: from frame 5 back to frame 0
358
+ loop_constraints = [
359
+ (20, 0, (1.0, np.eye(3), np.zeros(3))) # Temporary unit loop for simulation
360
+ ]
361
+
362
+ # Trajectory before/after optimization
363
+ input_abs_poses = optimizer.sequential_to_absolute_poses(sequential_transforms)
364
+ optimized_transforms = optimizer.optimize(sequential_transforms, loop_constraints)
365
+ optimized_abs_poses = optimizer.sequential_to_absolute_poses(optimized_transforms)
366
+
367
+ def extract_xyz(pose_tensor):
368
+ poses = pose_tensor.cpu().numpy()
369
+ return poses[:, 0], poses[:, 1], poses[:, 2]
370
+
371
+ x0, y0, z0 = extract_xyz(input_abs_poses)
372
+ x1, y1, z1 = extract_xyz(optimized_abs_poses)
373
+
374
+ # Visualize trajectory
375
+ import matplotlib
376
+ import matplotlib.pyplot as plt
377
+
378
+ matplotlib.use("Agg")
379
+
380
+ plt.figure(figsize=(8, 6))
381
+ plt.plot(x0, y0, "o--", label="Before Optimization")
382
+ plt.plot(x1, y1, "o-", label="After Optimization")
383
+ for i, j, _ in loop_constraints:
384
+ plt.plot([x0[i], x0[j]], [y0[i], y0[j]], "r--", label="Loop (Before)" if i == 5 else "")
385
+ plt.plot([x1[i], x1[j]], [y1[i], y1[j]], "g-", label="Loop (After)" if i == 5 else "")
386
+ plt.gca().set_aspect("equal")
387
+ plt.title("Sim3 Loop Closure Optimization (Rotating Ring)")
388
+ plt.xlabel("x")
389
+ plt.ylabel("y")
390
+ plt.legend()
391
+ plt.grid(True)
392
+ plt.axis("equal")
393
+ plt.show()
394
+
395
+ return optimized_transforms
396
+
397
+
398
+ if __name__ == "__main__":
399
+ example_usage()
Depth-Anything-3/da3_streaming/loop_utils/sim3utils.py ADDED
@@ -0,0 +1,1261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
16
+
17
+ import bisect
18
+ import glob
19
+ import os
20
+ import numpy as np
21
+ import trimesh
22
+ from loop_utils.alignment_torch import robust_weighted_estimate_sim3_torch
23
+ from loop_utils.alignment_triton import robust_weighted_estimate_sim3_triton
24
+ from numba import njit
25
+ from sklearn.linear_model import LinearRegression, RANSACRegressor
26
+
27
+
28
+ def accumulate_sim3_transforms(transforms):
29
+ """
30
+ Accumulate adjacent SIM(3) transforms into transforms
31
+ from the initial frame to each subsequent frame.
32
+
33
+ Args:
34
+ transforms: list, each element is a tuple (R, s, t)
35
+ R: 3x3 rotation matrix (np.array)
36
+ s: scale factor (scalar)
37
+ t: 3x1 translation vector (np.array)
38
+
39
+ Returns:
40
+ Cumulative transforms list, each element is (R_cum, s_cum, t_cum)
41
+ representing the transform from frame 0 to frame k
42
+ """
43
+ if not transforms:
44
+ return []
45
+
46
+ cumulative_transforms = [transforms[0]]
47
+
48
+ for i in range(1, len(transforms)):
49
+ s_cum_prev, R_cum_prev, t_cum_prev = cumulative_transforms[i - 1]
50
+ s_next, R_next, t_next = transforms[i]
51
+ R_cum_new = R_cum_prev @ R_next
52
+ s_cum_new = s_cum_prev * s_next
53
+ t_cum_new = s_cum_prev * (R_cum_prev @ t_next) + t_cum_prev
54
+ cumulative_transforms.append((s_cum_new, R_cum_new, t_cum_new))
55
+
56
+ return cumulative_transforms
57
+
58
+
59
+ def estimate_sim3(source_points, target_points):
60
+ mu_src = np.mean(source_points, axis=0)
61
+ mu_tgt = np.mean(target_points, axis=0)
62
+
63
+ src_centered = source_points - mu_src
64
+ tgt_centered = target_points - mu_tgt
65
+
66
+ scale_src = np.sqrt((src_centered**2).sum(axis=1).mean())
67
+ scale_tgt = np.sqrt((tgt_centered**2).sum(axis=1).mean())
68
+ s = scale_tgt / scale_src
69
+
70
+ src_scaled = src_centered * s
71
+
72
+ H = src_scaled.T @ tgt_centered
73
+ U, _, Vt = np.linalg.svd(H)
74
+ R = Vt.T @ U.T
75
+ if np.linalg.det(R) < 0:
76
+ Vt[2, :] *= -1
77
+ R = Vt.T @ U.T
78
+
79
+ t = mu_tgt - s * R @ mu_src
80
+ return s, R, t
81
+
82
+
83
+ def align_point_maps(point_map1, conf1, point_map2, conf2, conf_threshold):
84
+ """point_map2 -> point_map1"""
85
+ b1, _, _, _ = point_map1.shape
86
+ b2, _, _, _ = point_map2.shape
87
+ b = min(b1, b2)
88
+
89
+ aligned_points1 = []
90
+ aligned_points2 = []
91
+
92
+ for i in range(b):
93
+ mask1 = conf1[i] > conf_threshold
94
+ mask2 = conf2[i] > conf_threshold
95
+ valid_mask = mask1 & mask2
96
+
97
+ idx = np.where(valid_mask)
98
+ if len(idx[0]) == 0:
99
+ continue
100
+
101
+ pts1 = point_map1[i][idx]
102
+ pts2 = point_map2[i][idx]
103
+
104
+ aligned_points1.append(pts1)
105
+ aligned_points2.append(pts2)
106
+
107
+ if len(aligned_points1) == 0:
108
+ raise ValueError("No matching point pairs were found!")
109
+
110
+ all_pts1 = np.concatenate(aligned_points1, axis=0)
111
+ all_pts2 = np.concatenate(aligned_points2, axis=0)
112
+
113
+ print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
114
+ s, R, t = estimate_sim3(all_pts2, all_pts1)
115
+
116
+ mean_error = compute_alignment_error(
117
+ point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
118
+ )
119
+ print(f"Mean error: {mean_error}")
120
+
121
+ return s, R, t
122
+
123
+
124
+ def apply_sim3(points, s, R, t):
125
+ return (s * (R @ points.T)).T + t
126
+
127
+
128
+ def apply_sim3_direct(point_maps, s, R, t):
129
+ # point_maps: (b, h, w, 3) -> (b, h, w, 3, 1)
130
+ point_maps_expanded = point_maps[..., np.newaxis] # (b, h, w, 3, 1)
131
+
132
+ # R: (3, 3) -> (b, h, w, 3, 1) = (3, 3) @ (3, 1)
133
+ rotated = np.matmul(R, point_maps_expanded) # (b, h, w, 3, 1)
134
+ rotated = rotated.squeeze(-1) # (b, h, w, 3)
135
+ transformed = s * rotated + t # (b, h, w, 3)
136
+
137
+ return transformed
138
+
139
+
140
+ def compute_alignment_error(point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t):
141
+ """
142
+ Compute the average point alignment error (using only original inputs)
143
+
144
+ Args:
145
+ point_map1: target point map (b, h, w, 3)
146
+ conf1: target confidence map (b, h, w)
147
+ point_map2: source point map (b, h, w, 3)
148
+ conf2: source confidence map (b, h, w)
149
+ conf_threshold: confidence threshold
150
+ s, R, t: transformation parameters
151
+ """
152
+ b1, h1, w1, _ = point_map1.shape
153
+ b2, h2, w2, _ = point_map2.shape
154
+ b = min(b1, b2)
155
+ h = min(h1, h2)
156
+ w = min(w1, w2)
157
+
158
+ target_points = []
159
+ source_points = []
160
+
161
+ for i in range(b):
162
+ mask1 = conf1[i, :h, :w] > conf_threshold
163
+ mask2 = conf2[i, :h, :w] > conf_threshold
164
+ valid_mask = mask1 & mask2
165
+
166
+ idx = np.where(valid_mask)
167
+ if len(idx[0]) == 0:
168
+ continue
169
+
170
+ t_pts = point_map1[i, :h, :w][idx]
171
+ s_pts = point_map2[i, :h, :w][idx]
172
+
173
+ target_points.append(t_pts)
174
+ source_points.append(s_pts)
175
+
176
+ if len(target_points) == 0:
177
+ print("Warning: No matching point pairs found for error calculation")
178
+ return np.nan
179
+
180
+ all_target = np.concatenate(target_points, axis=0)
181
+ all_source = np.concatenate(source_points, axis=0)
182
+
183
+ transformed = (s * (R @ all_source.T)).T + t
184
+
185
+ errors = np.linalg.norm(transformed - all_target, axis=1)
186
+
187
+ mean_error = np.mean(errors)
188
+ std_error = np.std(errors)
189
+ median_error = np.median(errors)
190
+ max_error = np.max(errors)
191
+
192
+ print(
193
+ f"Alignment error statistics [using {len(errors)} points]: "
194
+ f"mean={mean_error:.4f}, std={std_error:.4f}, "
195
+ f"median={median_error:.4f}, max={max_error:.4f}"
196
+ )
197
+
198
+ return mean_error
199
+
200
+
201
+ def save_confident_pointcloud(
202
+ points, colors, confs, output_path, conf_threshold, sample_ratio=1.0
203
+ ):
204
+ """
205
+ Filter points based on confidence threshold
206
+ and save as PLY file, with optional random sampling ratio.
207
+
208
+ Args:
209
+ - points: np.ndarray, shape (H, W, 3) or (N, 3)
210
+ - colors: np.ndarray, shape (H, W, 3) or (N, 3)
211
+ - confs: np.ndarray, shape (H, W) or (N,)
212
+ - output_path: str, output PLY file path
213
+ - conf_threshold: float, confidence threshold for point filtering
214
+ - sample_ratio: float, sampling ratio (0 < sample_ratio <= 1.0)
215
+ """
216
+ points = points.reshape(-1, 3).astype(np.float32, copy=False)
217
+ colors = colors.reshape(-1, 3).astype(np.uint8, copy=False)
218
+ confs = confs.reshape(-1).astype(np.float32, copy=False)
219
+
220
+ conf_mask = (confs >= conf_threshold) & (confs > 1e-5)
221
+ points = points[conf_mask]
222
+ colors = colors[conf_mask]
223
+
224
+ if 0 < sample_ratio < 1.0 and len(points) > 0:
225
+ num_samples = int(len(points) * sample_ratio)
226
+ indices = np.random.choice(len(points), num_samples, replace=False)
227
+ points = points[indices]
228
+ colors = colors[indices]
229
+
230
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
231
+
232
+ print(f"shape of sampled point: {points.shape}")
233
+ trimesh.PointCloud(points, colors=colors).export(output_path)
234
+ print(f"Saved point cloud with {len(points)} points to {output_path}")
235
+
236
+
237
+ def save_confident_pointcloud_batch(
238
+ points, colors, confs, output_path, conf_threshold, sample_ratio=1.0, batch_size=1000000
239
+ ):
240
+ """
241
+ - points: np.ndarray, (b, H, W, 3) / (N, 3)
242
+ - colors: np.ndarray, (b, H, W, 3) / (N, 3)
243
+ - confs: np.ndarray, (b, H, W) / (N,)
244
+ - output_path: str
245
+ - conf_threshold: float,
246
+ - sample_ratio: float (0 < sample_ratio <= 1.0)
247
+ - batch_size: int
248
+ """
249
+ if points.ndim == 2:
250
+ b = 1
251
+ points = points[np.newaxis, ...]
252
+ colors = colors[np.newaxis, ...]
253
+ confs = confs[np.newaxis, ...]
254
+ elif points.ndim == 4:
255
+ b = points.shape[0]
256
+ else:
257
+ raise ValueError("Unsupported points dimension. Must be 2 (N,3) or 4 (b,H,W,3)")
258
+
259
+ total_valid = 0
260
+ for i in range(b):
261
+ cfs = confs[i].reshape(-1)
262
+ total_valid += np.count_nonzero((cfs >= conf_threshold) & (cfs > 1e-5))
263
+
264
+ num_samples = int(total_valid * sample_ratio) if sample_ratio < 1.0 else total_valid
265
+
266
+ if num_samples == 0:
267
+ save_ply(np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8), output_path)
268
+ return
269
+
270
+ if sample_ratio == 1.0:
271
+ with open(output_path, "wb") as f:
272
+ write_ply_header(f, num_samples)
273
+
274
+ for i in range(b):
275
+ pts = points[i].reshape(-1, 3).astype(np.float32)
276
+ cls = colors[i].reshape(-1, 3).astype(np.uint8)
277
+ cfs = confs[i].reshape(-1).astype(np.float32)
278
+
279
+ mask = (cfs >= conf_threshold) & (cfs > 1e-5)
280
+ valid_pts = pts[mask]
281
+ valid_cls = cls[mask]
282
+
283
+ for j in range(0, len(valid_pts), batch_size):
284
+ batch_pts = valid_pts[j : j + batch_size]
285
+ batch_cls = valid_cls[j : j + batch_size]
286
+ write_ply_batch(f, batch_pts, batch_cls)
287
+
288
+ else:
289
+ reservoir_pts = np.zeros((num_samples, 3), dtype=np.float32)
290
+ reservoir_clr = np.zeros((num_samples, 3), dtype=np.uint8)
291
+ count = 0
292
+
293
+ for i in range(b):
294
+ pts = points[i].reshape(-1, 3).astype(np.float32)
295
+ cls = colors[i].reshape(-1, 3).astype(np.uint8)
296
+ cfs = confs[i].reshape(-1).astype(np.float32)
297
+
298
+ mask = (cfs >= conf_threshold) & (cfs > 1e-5)
299
+ valid_pts = pts[mask]
300
+ valid_cls = cls[mask]
301
+ n_valid = len(valid_pts)
302
+
303
+ if count < num_samples:
304
+ fill_count = min(num_samples - count, n_valid)
305
+
306
+ reservoir_pts[count : count + fill_count] = valid_pts[:fill_count]
307
+ reservoir_clr[count : count + fill_count] = valid_cls[:fill_count]
308
+ count += fill_count
309
+
310
+ if fill_count < n_valid:
311
+ remaining_pts = valid_pts[fill_count:]
312
+ remaining_cls = valid_cls[fill_count:]
313
+
314
+ count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
315
+ remaining_pts, remaining_cls, count, reservoir_pts, reservoir_clr
316
+ )
317
+ else:
318
+ count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
319
+ valid_pts, valid_cls, count, reservoir_pts, reservoir_clr
320
+ )
321
+
322
+ save_ply(reservoir_pts, reservoir_clr, output_path)
323
+
324
+
325
+ """ The following function is deprecated"""
326
+
327
+ # def vectorized_reservoir_sampling(new_pts, new_cls, current_count, reservoir_pts, reservoir_clr):
328
+ # """
329
+ # - new_pts: (M, 3)
330
+ # - new_cls: (M, 3)
331
+ # - current_count
332
+ # - reservoir_pts: (K, 3)
333
+ # - reservoir_clr: (K, 3)
334
+
335
+ # """
336
+ # k = len(reservoir_pts)
337
+ # n_new = len(new_pts)
338
+
339
+ # rand_indices = np.random.randint(0, current_count + n_new, size=n_new)
340
+
341
+ # replace_mask = rand_indices < k
342
+ # replace_indices = rand_indices[replace_mask]
343
+ # replace_pts = new_pts[replace_mask]
344
+ # replace_cls = new_cls[replace_mask]
345
+
346
+ # reservoir_pts[replace_indices] = replace_pts
347
+ # reservoir_clr[replace_indices] = replace_cls
348
+
349
+ # return current_count + n_new, reservoir_pts, reservoir_clr
350
+
351
+
352
+ """
353
+ Function `vectorized_reservoir_sampling` is not mathematically accurate in sampling.
354
+ This leads to inconsistent density in the downsampled point clouds.
355
+ The `optimized_vectorized_reservoir_sampling` function has fixed this bug.
356
+
357
+ Special thanks to @Horace89 for the detailed analysis and code assistance.
358
+
359
+ See https://github.com/DengKaiCQ/VGGT-Long/issues/28 for details
360
+ """
361
+
362
+
363
+ def optimized_vectorized_reservoir_sampling(
364
+ new_points: np.ndarray,
365
+ new_colors: np.ndarray,
366
+ current_count: int,
367
+ reservoir_points: np.ndarray,
368
+ reservoir_colors: np.ndarray,
369
+ ) -> tuple[int, np.ndarray, np.ndarray]:
370
+ """
371
+ Optimized vectorized reservoir sampling with batch probability calculations.
372
+
373
+ This maintains mathematical correctness while improving performance through
374
+ vectorized operations where possible.
375
+
376
+ Args:
377
+ new_points: New point coordinates to consider, shape (M, 3)
378
+ new_colors: New point colors to consider, shape (M, 3)
379
+ current_count: Number of elements seen so far
380
+ reservoir_points: Current reservoir of sampled points, shape (K, 3)
381
+ reservoir_colors: Current reservoir of sampled colors, shape (K, 3)
382
+
383
+ Returns:
384
+ Tuple of (updated_count, updated_reservoir_points, updated_reservoir_colors)
385
+ """
386
+ random_gen = np.random
387
+
388
+ reservoir_size = len(reservoir_points)
389
+ num_new_points = len(new_points)
390
+
391
+ if num_new_points == 0:
392
+ return current_count, reservoir_points, reservoir_colors
393
+
394
+ # Calculate sequential indices for each new point
395
+ point_indices = np.arange(current_count + 1, current_count + num_new_points + 1)
396
+
397
+ # Generate random numbers for each point
398
+ random_values = random_gen.randint(0, point_indices, size=num_new_points)
399
+
400
+ # Determine which points should replace reservoir elements
401
+ replacement_mask = random_values < reservoir_size
402
+ replacement_positions = random_values[replacement_mask]
403
+
404
+ # Apply replacements
405
+ if np.any(replacement_mask):
406
+ points_to_replace = new_points[replacement_mask]
407
+ colors_to_replace = new_colors[replacement_mask]
408
+
409
+ reservoir_points[replacement_positions] = points_to_replace
410
+ reservoir_colors[replacement_positions] = colors_to_replace
411
+
412
+ return current_count + num_new_points, reservoir_points, reservoir_colors
413
+
414
+
415
+ def write_ply_header(f, num_vertices):
416
+ header = [
417
+ "ply",
418
+ "format binary_little_endian 1.0",
419
+ f"element vertex {num_vertices}",
420
+ "property float x",
421
+ "property float y",
422
+ "property float z",
423
+ "property uchar red",
424
+ "property uchar green",
425
+ "property uchar blue",
426
+ "end_header",
427
+ ]
428
+ f.write("\n".join(header).encode() + b"\n")
429
+
430
+
431
+ def write_ply_batch(f, points, colors):
432
+ structured = np.zeros(
433
+ len(points),
434
+ dtype=[
435
+ ("x", np.float32),
436
+ ("y", np.float32),
437
+ ("z", np.float32),
438
+ ("red", np.uint8),
439
+ ("green", np.uint8),
440
+ ("blue", np.uint8),
441
+ ],
442
+ )
443
+
444
+ structured["x"] = points[:, 0]
445
+ structured["y"] = points[:, 1]
446
+ structured["z"] = points[:, 2]
447
+ structured["red"] = colors[:, 0]
448
+ structured["green"] = colors[:, 1]
449
+ structured["blue"] = colors[:, 2]
450
+
451
+ f.write(structured.tobytes())
452
+
453
+
454
+ def save_ply(points, colors, filename):
455
+ with open(filename, "wb") as f:
456
+ write_ply_header(f, len(points))
457
+ write_ply_batch(f, points, colors)
458
+
459
+
460
+ def find_chunk_index(chunks, idx):
461
+ """
462
+ Find the 0-based chunk index that contains the given index idx.
463
+ chunks: List of (begin_idx, end_idx).
464
+ idx: The index to search for.
465
+ Returns the 0-based chunk index.
466
+ """
467
+ starts = [chunk[0] for chunk in chunks]
468
+ pos = bisect.bisect_right(starts, idx) - 1 # Find position of idx in starts
469
+ if pos < 0 or pos >= len(chunks):
470
+ raise ValueError(f"Index {idx} not found in any chunk")
471
+ chunk_begin, chunk_end = chunks[pos]
472
+ if idx < chunk_begin or idx > chunk_end:
473
+ raise ValueError(f"Index {idx} not found in any chunk")
474
+ return pos
475
+
476
+
477
+ def get_frame_range(chunk, idx, half_window=10):
478
+ """
479
+ Calculate the frame range centered at idx with half_window
480
+ frames on each side within chunk boundaries.
481
+ If near boundaries, take 2 * half_window frames starting from the boundary.
482
+ chunk: (begin_idx, end_idx).
483
+ idx: Center index.
484
+ half_window: Number of frames to take on each side of center index.
485
+ Returns (start, end).
486
+ """
487
+ begin, end = chunk
488
+ window_size = 2 * half_window
489
+
490
+ if idx - half_window < begin:
491
+ start = begin
492
+ end_candidate = begin + window_size
493
+ end = min(end, end_candidate)
494
+
495
+ elif idx + half_window > end:
496
+ end_candidate = end
497
+ start_candidate = end - window_size
498
+ start = max(begin, start_candidate)
499
+
500
+ else:
501
+ start = idx - half_window
502
+ end = idx + half_window
503
+ return (start, end)
504
+
505
+
506
+ def process_loop_list(chunk_index, loop_list, half_window=10):
507
+ """
508
+ Process loop_list and return chunk indices and frame ranges for each (idx1, idx2) pair.
509
+ chunk_index: List of (begin_idx, end_idx) tuples.
510
+ loop_list: List of (idx1, idx2) tuples.
511
+ half_window: Number of frames to take on each side of center index (default 10).
512
+ Returns list of (chunk_idx1, range1, chunk_idx2, range2) tuples where:
513
+ - chunk_idx1, chunk_idx2: Chunk indices (1-based).
514
+ - range1, range2: Frame range tuples (start, end).
515
+ """
516
+ results = []
517
+ for idx1, idx2 in loop_list:
518
+ try:
519
+ chunk_idx1_0based = find_chunk_index(chunk_index, idx1)
520
+ chunk1 = chunk_index[chunk_idx1_0based]
521
+ range1 = get_frame_range(chunk1, idx1, half_window)
522
+
523
+ chunk_idx2_0based = find_chunk_index(chunk_index, idx2)
524
+ chunk2 = chunk_index[chunk_idx2_0based]
525
+ range2 = get_frame_range(chunk2, idx2, half_window)
526
+
527
+ result = (chunk_idx1_0based, range1, chunk_idx2_0based, range2)
528
+ results.append(result)
529
+ except ValueError as e:
530
+ print(f"Skipping pair ({idx1}, {idx2}): {e}")
531
+ return results
532
+
533
+
534
+ def compute_sim3_ab(S_a, S_b):
535
+
536
+ s_a, R_a, T_a = S_a
537
+ s_b, R_b, T_b = S_b
538
+
539
+ s_ab = s_b / s_a
540
+ R_ab = R_b @ R_a.T
541
+ T_ab = T_b - s_ab * (R_ab @ T_a)
542
+
543
+ return (s_ab, R_ab, T_ab)
544
+
545
+
546
+ def merge_ply_files(input_dir, output_path):
547
+ """
548
+ Merge all PLY files in a directory into one file (without loading into memory)
549
+
550
+ Args:
551
+ - input_dir: Input directory containing multiple '{idx}_pcd.ply' files
552
+ - output_path: Output file path (e.g., 'combined.ply')
553
+ """
554
+
555
+ print("Merging PLY files...")
556
+
557
+ input_files = sorted(glob.glob(os.path.join(input_dir, "*_pcd.ply")))
558
+
559
+ if not input_files:
560
+ print("No PLY files found")
561
+ return
562
+
563
+ idx_file = 0
564
+ len(input_files)
565
+
566
+ total_vertices = 0
567
+ for file in input_files: # Count total vertices
568
+ with open(file, "rb") as f:
569
+ for line in f:
570
+ if line.startswith(b"element vertex"):
571
+ vertex_count = int(line.split()[-1])
572
+ total_vertices += vertex_count
573
+ elif line.startswith(b"end_header"):
574
+ break
575
+
576
+ with open(output_path, "wb") as out_f:
577
+ # Write new header
578
+ out_f.write(b"ply\n")
579
+ out_f.write(b"format binary_little_endian 1.0\n")
580
+ out_f.write(f"element vertex {total_vertices}\n".encode())
581
+ out_f.write(b"property float x\n")
582
+ out_f.write(b"property float y\n")
583
+ out_f.write(b"property float z\n")
584
+ out_f.write(b"property uchar red\n")
585
+ out_f.write(b"property uchar green\n")
586
+ out_f.write(b"property uchar blue\n")
587
+ out_f.write(b"end_header\n")
588
+
589
+ for file in input_files:
590
+ print(f"Processing {idx_file}/{len(input_files)}: {file}")
591
+ idx_file += 1
592
+ with open(file, "rb") as in_f:
593
+ # Skip the head
594
+ in_header = True
595
+ while in_header:
596
+ line = in_f.readline()
597
+ if line.startswith(b"end_header"):
598
+ in_header = False
599
+ data = in_f.read()
600
+ out_f.write(data)
601
+
602
+ print(f"Merge completed! Total points: {total_vertices}")
603
+ print(f"Output file: {output_path}")
604
+
605
+
606
+ def weighted_estimate_se3(source_points, target_points, weights):
607
+ """
608
+ source_points: (Nx3)
609
+ target_points: (Nx3)
610
+ :weights: (N,) [0,1]
611
+ """
612
+ total_weight = np.sum(weights)
613
+ if total_weight < 1e-6:
614
+ raise ValueError("Total weight too small for meaningful estimation")
615
+
616
+ normalized_weights = weights / total_weight
617
+
618
+ mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
619
+ mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
620
+
621
+ src_centered = source_points - mu_src
622
+ tgt_centered = target_points - mu_tgt
623
+
624
+ weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
625
+ weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
626
+
627
+ H = weighted_src.T @ weighted_tgt
628
+
629
+ U, _, Vt = np.linalg.svd(H)
630
+ R = Vt.T @ U.T
631
+
632
+ if np.linalg.det(R) < 0:
633
+ Vt[2, :] *= -1
634
+ R = Vt.T @ U.T
635
+
636
+ t = mu_tgt - R @ mu_src
637
+
638
+ return 1.0, R, t
639
+
640
+
641
+ def weighted_estimate_sim3(source_points, target_points, weights):
642
+ """
643
+ source_points: (Nx3)
644
+ target_points: (Nx3)
645
+ :weights: (N,) [0,1]
646
+ """
647
+ total_weight = np.sum(weights)
648
+ if total_weight < 1e-6:
649
+ raise ValueError("Total weight too small for meaningful estimation")
650
+
651
+ normalized_weights = weights / total_weight
652
+
653
+ mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
654
+ mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
655
+
656
+ src_centered = source_points - mu_src
657
+ tgt_centered = target_points - mu_tgt
658
+
659
+ scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
660
+ scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
661
+ s = scale_tgt / scale_src
662
+
663
+ weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
664
+ weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
665
+
666
+ H = weighted_src.T @ weighted_tgt
667
+
668
+ U, _, Vt = np.linalg.svd(H)
669
+ R = Vt.T @ U.T
670
+
671
+ if np.linalg.det(R) < 0:
672
+ Vt[2, :] *= -1
673
+ R = Vt.T @ U.T
674
+
675
+ t = mu_tgt - s * R @ mu_src
676
+ return s, R, t
677
+
678
+
679
+ def huber_loss(r, delta):
680
+ abs_r = np.abs(r)
681
+ return np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
682
+
683
+
684
+ def robust_weighted_estimate_sim3(
685
+ src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
686
+ ):
687
+ """
688
+ src: (Nx3)
689
+ tgt: (Nx3)
690
+ init_weights: (N,)
691
+ """
692
+ if align_method == "sim3":
693
+ s, R, t = weighted_estimate_sim3(src, tgt, init_weights)
694
+ elif align_method == "se3" or align_method == "scale+se3":
695
+ s, R, t = weighted_estimate_se3(src, tgt, init_weights)
696
+
697
+ prev_error = float("inf")
698
+
699
+ for iter in range(max_iters):
700
+
701
+ transformed = s * (src @ R.T) + t
702
+ residuals = np.linalg.norm(tgt - transformed, axis=1) # (N,)
703
+ print(f"Residuals: {np.mean(residuals)}")
704
+
705
+ abs_res = np.abs(residuals)
706
+ huber_weights = np.ones_like(residuals)
707
+ large_res_mask = abs_res > delta
708
+ huber_weights[large_res_mask] = delta / abs_res[large_res_mask]
709
+
710
+ combined_weights = init_weights * huber_weights
711
+
712
+ combined_weights /= np.sum(combined_weights) + 1e-12
713
+
714
+ if align_method == "se3":
715
+ s_new, R_new, t_new = weighted_estimate_se3(src, tgt, combined_weights)
716
+ elif align_method == "sim3" or align_method == "scale+se3":
717
+ s_new, R_new, t_new = weighted_estimate_sim3(src, tgt, combined_weights)
718
+
719
+ param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
720
+ rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
721
+ current_error = np.sum(huber_loss(residuals, delta) * init_weights)
722
+
723
+ if (param_change < tol and rot_angle < np.radians(0.1)) or (
724
+ abs(prev_error - current_error) < tol * prev_error
725
+ ):
726
+ break
727
+
728
+ s, R, t = s_new, R_new, t_new
729
+ prev_error = current_error
730
+
731
+ return s, R, t
732
+
733
+
734
+ # ===== Speed Up Begin =====
735
+
736
+
737
+ @njit(cache=True)
738
+ def _weighted_estimate_se3_numba(source_points, target_points, weights):
739
+ # Ensure float32
740
+ source_points = source_points.astype(np.float32)
741
+ target_points = target_points.astype(np.float32)
742
+ weights = weights.astype(np.float32)
743
+
744
+ total_weight = np.sum(weights)
745
+ if total_weight < 1e-6:
746
+ return (
747
+ 1.0,
748
+ np.zeros(3, dtype=np.float32),
749
+ np.zeros(3, dtype=np.float32),
750
+ np.zeros((3, 3), dtype=np.float32),
751
+ )
752
+
753
+ normalized_weights = weights / total_weight
754
+
755
+ mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
756
+ mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
757
+
758
+ src_centered = source_points - mu_src
759
+ tgt_centered = target_points - mu_tgt
760
+
761
+ weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
762
+ weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
763
+
764
+ H = weighted_src.T @ weighted_tgt
765
+
766
+ return 1.0, mu_src, mu_tgt, H
767
+
768
+
769
+ @njit(cache=True)
770
+ def _weighted_estimate_sim3_numba(source_points, target_points, weights):
771
+ # Ensure float32
772
+ source_points = source_points.astype(np.float32)
773
+ target_points = target_points.astype(np.float32)
774
+ weights = weights.astype(np.float32)
775
+
776
+ total_weight = np.sum(weights)
777
+ if total_weight < 1e-6:
778
+ return (
779
+ -1.0,
780
+ np.zeros(3, dtype=np.float32),
781
+ np.zeros(3, dtype=np.float32),
782
+ np.zeros((3, 3), dtype=np.float32),
783
+ )
784
+
785
+ normalized_weights = weights / total_weight
786
+
787
+ mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
788
+ mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
789
+
790
+ src_centered = source_points - mu_src
791
+ tgt_centered = target_points - mu_tgt
792
+
793
+ scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
794
+ scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
795
+ s = scale_tgt / scale_src
796
+
797
+ weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
798
+ weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
799
+
800
+ H = weighted_src.T @ weighted_tgt
801
+
802
+ return s, mu_src, mu_tgt, H
803
+
804
+
805
+ def weighted_estimate_sim3_numba(source_points, target_points, weights, align_method="sim3"):
806
+ if align_method == "sim3":
807
+ s, mu_src, mu_tgt, H = _weighted_estimate_sim3_numba(source_points, target_points, weights)
808
+ elif align_method == "se3" or align_method == "scale+se3":
809
+ s, mu_src, mu_tgt, H = _weighted_estimate_se3_numba(source_points, target_points, weights)
810
+
811
+ if s < 0:
812
+ raise ValueError("Total weight too small for meaningful estimation")
813
+
814
+ # Ensure float32
815
+ H = H.astype(np.float32)
816
+ U, _, Vt = np.linalg.svd(H.astype(np.float32)) # float32 SVD
817
+
818
+ R = Vt.T @ U.T
819
+ if np.linalg.det(R) < 0:
820
+ Vt[2, :] *= -1
821
+ R = Vt.T @ U.T
822
+
823
+ if align_method == "se3" or align_method == "scale+se3":
824
+ t = mu_tgt - R @ mu_src
825
+ else:
826
+ t = mu_tgt - s * R @ mu_src
827
+
828
+ return s, R, t
829
+
830
+
831
+ @njit(cache=True)
832
+ def huber_loss_numba(r, delta):
833
+ r = r.astype(np.float32)
834
+ delta = np.float32(delta)
835
+ abs_r = np.abs(r)
836
+ result = np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
837
+ return result.astype(np.float32)
838
+
839
+
840
+ @njit(cache=True)
841
+ def compute_residuals_numba(tgt, transformed):
842
+ residuals = np.empty(tgt.shape[0], dtype=np.float32)
843
+ for i in range(tgt.shape[0]):
844
+ diff = tgt[i] - transformed[i]
845
+ residuals[i] = np.sqrt(np.sum(diff**2))
846
+ return residuals
847
+
848
+
849
+ @njit(cache=True)
850
+ def compute_huber_weights_numba(residuals, delta):
851
+ weights = np.ones(residuals.shape, dtype=np.float32)
852
+ for i in range(residuals.shape[0]):
853
+ r = residuals[i]
854
+ if r > delta:
855
+ weights[i] = delta / r
856
+ return weights
857
+
858
+
859
+ @njit(cache=True)
860
+ def apply_transformation_numba(src, s, R, t):
861
+ transformed = np.empty_like(src)
862
+ for i in range(src.shape[0]):
863
+ p = src[i]
864
+ transformed[i] = s * (R @ p) + t
865
+ return transformed
866
+
867
+
868
+ def robust_weighted_estimate_sim3_numba(
869
+ src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
870
+ ):
871
+ src = src.astype(np.float32)
872
+ tgt = tgt.astype(np.float32)
873
+ init_weights = init_weights.astype(np.float32)
874
+
875
+ s, R, t = weighted_estimate_sim3_numba(src, tgt, init_weights, align_method=align_method)
876
+
877
+ prev_error = float("inf")
878
+
879
+ for iter in range(max_iters):
880
+ transformed = apply_transformation_numba(src, s, R, t)
881
+ residuals = compute_residuals_numba(tgt, transformed)
882
+
883
+ print(f"Residuals: {np.mean(residuals)}")
884
+
885
+ huber_weights = compute_huber_weights_numba(residuals, delta)
886
+ combined_weights = init_weights * huber_weights
887
+ combined_weights /= np.sum(combined_weights) + 1e-12
888
+
889
+ s_new, R_new, t_new = weighted_estimate_sim3_numba(
890
+ src, tgt, combined_weights, align_method=align_method
891
+ )
892
+
893
+ param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
894
+ rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
895
+
896
+ current_error = np.sum(huber_loss_numba(residuals, delta) * init_weights)
897
+
898
+ if (param_change < tol and rot_angle < np.radians(0.1)) or (
899
+ abs(prev_error - current_error) < tol * prev_error
900
+ ):
901
+ break
902
+
903
+ s, R, t = s_new, R_new, t_new
904
+ prev_error = current_error
905
+
906
+ return s, R, t
907
+
908
+
909
+ def warmup_numba():
910
+
911
+ print("\nWarming up Numba JIT-compiled functions...")
912
+
913
+ src = np.random.randn(50000, 3).astype(np.float32)
914
+ tgt = np.random.randn(50000, 3).astype(np.float32)
915
+ weights = np.ones(50000, dtype=np.float32)
916
+ residuals = np.abs(np.random.randn(50000).astype(np.float32))
917
+ R = np.eye(3, dtype=np.float32)
918
+ t = np.zeros(3, dtype=np.float32)
919
+ s = np.float32(1.0)
920
+ delta = np.float32(1.0)
921
+
922
+ try:
923
+ _ = _weighted_estimate_sim3_numba(src, tgt, weights)
924
+ print(" - _weighted_estimate_sim3_numba warmed up.")
925
+ except Exception as e:
926
+ print(" ! Failed to warm up _weighted_estimate_sim3_numba:", e)
927
+
928
+ try:
929
+ _ = _weighted_estimate_se3_numba(src, tgt, weights)
930
+ print(" - _weighted_estimate_se3_numba warmed up.")
931
+ except Exception as e:
932
+ print(" ! Failed to warm up _weighted_estimate_se3_numba:", e)
933
+
934
+ try:
935
+ _ = huber_loss_numba(residuals, delta)
936
+ print(" - huber_loss_numba warmed up.")
937
+ except Exception as e:
938
+ print(" ! Failed to warm up huber_loss_numba:", e)
939
+
940
+ try:
941
+ _ = compute_huber_weights_numba(residuals, delta)
942
+ print(" - compute_huber_weights_numba warmed up.")
943
+ except Exception as e:
944
+ print(" ! Failed to warm up compute_huber_weights_numba:", e)
945
+
946
+ try:
947
+ _ = compute_residuals_numba(tgt, src)
948
+ print(" - compute_residuals_numba warmed up.")
949
+ except Exception as e:
950
+ print(" ! Failed to warm up compute_residuals_numba:", e)
951
+
952
+ try:
953
+ _ = apply_transformation_numba(src, s, R, t)
954
+ print(" - apply_transformation_numba warmed up.")
955
+ except Exception as e:
956
+ print(" ! Failed to warm up apply_transformation_numba:", e)
957
+
958
+ print("Numba warm-up complete.\n")
959
+
960
+
961
+ # ===== Speed Up End =====
962
+
963
+ # ===== Scale precompute begin =====
964
+
965
+
966
+ def compute_scale_ransac(
967
+ depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, max_samples=10000
968
+ ):
969
+ """
970
+ Args:
971
+ depth1: (n1, h, w)
972
+ depth2: (n2, h, w)
973
+ conf1: (n1, h, w)
974
+ conf2: (n2, h, w)
975
+
976
+ """
977
+
978
+ depth1_flat = depth1.reshape(-1)
979
+ depth2_flat = depth2.reshape(-1)
980
+ conf1_flat = conf1.reshape(-1)
981
+ conf2_flat = conf2.reshape(-1)
982
+
983
+ conf_threshold = max(
984
+ np.median(conf1_flat) * conf_threshold_ratio,
985
+ np.median(conf2_flat) * conf_threshold_ratio,
986
+ 1e-6,
987
+ )
988
+
989
+ valid_mask = (
990
+ (conf1_flat > conf_threshold)
991
+ & (conf2_flat > conf_threshold)
992
+ & (depth1_flat > 1e-3)
993
+ & (depth2_flat > 1e-3)
994
+ & (depth1_flat < 100)
995
+ & (depth2_flat < 100)
996
+ )
997
+
998
+ if np.sum(valid_mask) < 100:
999
+ print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
1000
+ return 1.0, 0.0
1001
+
1002
+ valid_depth1 = depth1_flat[valid_mask]
1003
+ valid_depth2 = depth2_flat[valid_mask]
1004
+
1005
+ if len(valid_depth1) > max_samples:
1006
+ indices = np.random.choice(len(valid_depth1), max_samples, replace=False)
1007
+ valid_depth1 = valid_depth1[indices]
1008
+ valid_depth2 = valid_depth2[indices]
1009
+
1010
+ X = valid_depth2.reshape(-1, 1)
1011
+ y = valid_depth1
1012
+
1013
+ base_estimator = LinearRegression(fit_intercept=False)
1014
+ ransac = RANSACRegressor(
1015
+ estimator=base_estimator,
1016
+ max_trials=1000,
1017
+ min_samples=max(10, len(X) // 100),
1018
+ residual_threshold=0.1,
1019
+ random_state=42,
1020
+ )
1021
+
1022
+ ransac.fit(X, y)
1023
+ scale_factor = ransac.estimator_.coef_[0]
1024
+ inlier_mask = ransac.inlier_mask_
1025
+ inlier_ratio = np.sum(inlier_mask) / len(inlier_mask)
1026
+
1027
+ print(f"RANSAC scale: {scale_factor:.6f}, inlier ratio: {inlier_ratio:.4f}")
1028
+
1029
+ if 0.1 < scale_factor < 10.0:
1030
+ return scale_factor, inlier_ratio
1031
+ else:
1032
+ print(f"Warning: Unreasonable scale {scale_factor}, using 1.0")
1033
+ return 1.0, inlier_ratio
1034
+
1035
+
1036
+ def compute_scale_weighted(
1037
+ depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, weight_power=2.0, robust_quantile=0.9
1038
+ ):
1039
+ """
1040
+ Args:
1041
+ depth1: (n1, h, w)
1042
+ depth2: (n2, h, w)
1043
+ conf1: (n1, h, w)
1044
+ conf2: (n2, h, w)
1045
+ """
1046
+ depth1_flat = depth1.reshape(-1)
1047
+ depth2_flat = depth2.reshape(-1)
1048
+ conf1_flat = conf1.reshape(-1)
1049
+ conf2_flat = conf2.reshape(-1)
1050
+
1051
+ conf_threshold = max(
1052
+ np.median(conf1_flat) * conf_threshold_ratio,
1053
+ np.median(conf2_flat) * conf_threshold_ratio,
1054
+ 1e-6,
1055
+ )
1056
+
1057
+ valid_mask = (
1058
+ (conf1_flat > conf_threshold)
1059
+ & (conf2_flat > conf_threshold)
1060
+ & (depth1_flat > 1e-3)
1061
+ & (depth2_flat > 1e-3)
1062
+ & (depth1_flat < 100)
1063
+ & (depth2_flat < 100)
1064
+ )
1065
+
1066
+ if np.sum(valid_mask) < 100:
1067
+ print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
1068
+ return 1.0, 0.0
1069
+
1070
+ valid_depth1 = depth1_flat[valid_mask]
1071
+ valid_depth2 = depth2_flat[valid_mask]
1072
+ valid_conf1 = conf1_flat[valid_mask]
1073
+ valid_conf2 = conf2_flat[valid_mask]
1074
+
1075
+ combined_weights = (valid_conf1 * valid_conf2) ** weight_power
1076
+
1077
+ combined_weights = combined_weights / (np.sum(combined_weights) + 1e-8)
1078
+
1079
+ ratios = valid_depth1 / (valid_depth2 + 1e-8)
1080
+
1081
+ sorted_indices = np.argsort(ratios)
1082
+ sorted_ratios = ratios[sorted_indices]
1083
+ sorted_weights = combined_weights[sorted_indices]
1084
+
1085
+ cumulative_weights = np.cumsum(sorted_weights)
1086
+ median_idx = np.searchsorted(cumulative_weights, 0.5)
1087
+ scale_median = sorted_ratios[median_idx] if median_idx < len(sorted_ratios) else 1.0
1088
+
1089
+ quantile_idx = np.searchsorted(cumulative_weights, robust_quantile)
1090
+ scale_quantile = (
1091
+ sorted_ratios[quantile_idx] if quantile_idx < len(sorted_ratios) else scale_median
1092
+ )
1093
+
1094
+ weight_entropy = -np.sum(combined_weights * np.log(combined_weights + 1e-8))
1095
+ max_entropy = np.log(len(combined_weights))
1096
+ confidence_score = 1.0 - (weight_entropy / max_entropy) if max_entropy > 0 else 0.0
1097
+
1098
+ print(f"Weighted scale: {scale_quantile:.6f}, confidence: {confidence_score:.4f}")
1099
+
1100
+ if 0.1 < scale_quantile < 10.0:
1101
+ return scale_quantile, confidence_score
1102
+ else:
1103
+ print(f"Warning: Unreasonable scale {scale_quantile}, using 1.0")
1104
+ return 1.0, confidence_score
1105
+
1106
+
1107
+ def compute_chunk_scale_advanced(depth1, depth2, conf1, conf2, method="auto"):
1108
+ """
1109
+ method: 'auto', 'ransac', 'weighted'
1110
+ """
1111
+ if method == "ransac":
1112
+ scale, score = compute_scale_ransac(depth1, depth2, conf1, conf2)
1113
+ return scale, score, "ransac"
1114
+
1115
+ elif method == "weighted":
1116
+ scale, score = compute_scale_weighted(depth1, depth2, conf1, conf2)
1117
+ return scale, score, "weighted"
1118
+
1119
+ elif method == "auto":
1120
+ scale_ransac, inlier_ratio = compute_scale_ransac(depth1, depth2, conf1, conf2)
1121
+ scale_weighted, conf_score = compute_scale_weighted(depth1, depth2, conf1, conf2)
1122
+
1123
+ ransac_quality = inlier_ratio
1124
+ weighted_quality = conf_score
1125
+
1126
+ print(f"RANSAC quality: {ransac_quality:.4f}, Weighted quality: {weighted_quality:.4f}")
1127
+
1128
+ if ransac_quality > 0.7 and weighted_quality > 0.7:
1129
+ # both method are good, we take both of them by average
1130
+ final_scale = (scale_ransac + scale_weighted) / 2
1131
+ final_method = "average"
1132
+ elif ransac_quality > weighted_quality:
1133
+ final_scale = scale_ransac
1134
+ final_method = "ransac"
1135
+ else:
1136
+ final_scale = scale_weighted
1137
+ final_method = "weighted"
1138
+
1139
+ final_quality = max(ransac_quality, weighted_quality)
1140
+ return final_scale, final_quality, final_method
1141
+
1142
+
1143
+ def precompute_scale_chunks_with_depth(
1144
+ chunk1_depth, chunk1_conf, chunk2_depth, chunk2_conf, method="auto"
1145
+ ):
1146
+ """
1147
+ Args:
1148
+ chunk1_depth: (n1, h, w)
1149
+ chunk1_conf: (n1, h, w)
1150
+ chunk2_depth: (n2, h, w)
1151
+ chunk2_conf: (n2, h, w)
1152
+ method: 'auto', 'ransac', 'weighted'
1153
+ """
1154
+
1155
+ scale_factor, quality_score, method_used = compute_chunk_scale_advanced(
1156
+ chunk1_depth, chunk2_depth, chunk1_conf, chunk2_conf, method
1157
+ )
1158
+
1159
+ print(f"Final scale: {scale_factor:.6f}, quality: {quality_score:.4f}, method: {method_used}")
1160
+
1161
+ return scale_factor, quality_score, method_used
1162
+
1163
+
1164
+ # ===== Scale precompute end =====
1165
+
1166
+
1167
+ def weighted_align_point_maps(
1168
+ point_map1, conf1, point_map2, conf2, conf_threshold, config, precompute_scale=None
1169
+ ):
1170
+ """point_map2 -> point_map1"""
1171
+ b1, _, _, _ = point_map1.shape
1172
+ b2, _, _, _ = point_map2.shape
1173
+ b = min(b1, b2)
1174
+
1175
+ if precompute_scale is not None: # meaning we are using align method 'scale+se3'
1176
+ point_map2 *= precompute_scale
1177
+
1178
+ aligned_points1 = []
1179
+ aligned_points2 = []
1180
+ confidence_weights = []
1181
+
1182
+ for i in range(b):
1183
+ mask1 = conf1[i] > conf_threshold
1184
+ mask2 = conf2[i] > conf_threshold
1185
+ valid_mask = mask1 & mask2
1186
+
1187
+ idx = np.where(valid_mask)
1188
+ if len(idx[0]) == 0:
1189
+ continue
1190
+
1191
+ pts1 = point_map1[i][idx]
1192
+ pts2 = point_map2[i][idx]
1193
+
1194
+ combined_conf = np.sqrt(conf1[i][idx] * conf2[i][idx])
1195
+
1196
+ aligned_points1.append(pts1)
1197
+ aligned_points2.append(pts2)
1198
+ confidence_weights.append(combined_conf)
1199
+
1200
+ if len(aligned_points1) == 0:
1201
+ raise ValueError("No matching point pairs were found!")
1202
+
1203
+ all_pts1 = np.concatenate(aligned_points1, axis=0)
1204
+ all_pts2 = np.concatenate(aligned_points2, axis=0)
1205
+ all_weights = np.concatenate(confidence_weights, axis=0)
1206
+
1207
+ print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
1208
+
1209
+ if config["Model"]["align_lib"] == "numba":
1210
+ s, R, t = robust_weighted_estimate_sim3_numba(
1211
+ all_pts2,
1212
+ all_pts1,
1213
+ all_weights,
1214
+ delta=config["Model"]["IRLS"]["delta"],
1215
+ max_iters=config["Model"]["IRLS"]["max_iters"],
1216
+ tol=eval(config["Model"]["IRLS"]["tol"]),
1217
+ align_method=config["Model"]["align_method"],
1218
+ )
1219
+ elif config["Model"]["align_lib"] == "numpy": # numpy
1220
+ s, R, t = robust_weighted_estimate_sim3(
1221
+ all_pts2,
1222
+ all_pts1,
1223
+ all_weights,
1224
+ delta=config["Model"]["IRLS"]["delta"],
1225
+ max_iters=config["Model"]["IRLS"]["max_iters"],
1226
+ tol=eval(config["Model"]["IRLS"]["tol"]),
1227
+ align_method=config["Model"]["align_method"],
1228
+ )
1229
+ elif config["Model"]["align_lib"] == "torch": # torch
1230
+ s, R, t = robust_weighted_estimate_sim3_torch(
1231
+ all_pts2,
1232
+ all_pts1,
1233
+ all_weights,
1234
+ delta=config["Model"]["IRLS"]["delta"],
1235
+ max_iters=config["Model"]["IRLS"]["max_iters"],
1236
+ tol=eval(config["Model"]["IRLS"]["tol"]),
1237
+ align_method=config["Model"]["align_method"],
1238
+ )
1239
+ elif config["Model"]["align_lib"] == "triton": # triton
1240
+ s, R, t = robust_weighted_estimate_sim3_triton(
1241
+ all_pts2,
1242
+ all_pts1,
1243
+ all_weights,
1244
+ delta=config["Model"]["IRLS"]["delta"],
1245
+ max_iters=config["Model"]["IRLS"]["max_iters"],
1246
+ tol=eval(config["Model"]["IRLS"]["tol"]),
1247
+ align_method=config["Model"]["align_method"],
1248
+ )
1249
+ else:
1250
+ raise ValueError(f"Unknown align_lib: {config['Model']['align_lib']}")
1251
+
1252
+ if precompute_scale is not None: # meaning we are using align method 'scale+se3'
1253
+ # we need this precompute_scale for loop align
1254
+ s = precompute_scale
1255
+
1256
+ mean_error = compute_alignment_error(
1257
+ point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
1258
+ )
1259
+ print(f"Mean error: {mean_error}")
1260
+
1261
+ return s, R, t
Depth-Anything-3/da3_streaming/scripts/download_weights.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ mkdir weights
4
+ cd ./weights
5
+
6
+ # SALAD (~ 340 MiB)
7
+ echo "Downloading SALAD weights (~ 340 MiB) ..."
8
+ SALAD_URL="https://github.com/serizba/salad/releases/download/v1.0.0/dino_salad.ckpt"
9
+ curl -L "$SALAD_URL" -o "./dino_salad.ckpt"
10
+
11
+
12
+ # DA3NESTED-GIANT-LARGE-1.1
13
+ echo "Downloading DA3NESTED-GIANT-LARGE-1.1 weights and config (~ 6.76 GiB)..."
14
+ BASE_URL="https://huggingface.co/depth-anything/DA3NESTED-GIANT-LARGE-1.1/resolve/main"
15
+
16
+ # download config.json (~ 3.1 KiB)
17
+ curl -L "$BASE_URL/config.json" -o "./config.json"
18
+
19
+ # download model.safetensors (~ 6.76 GiB)
20
+ curl -L "$BASE_URL/model.safetensors" -o "./model.safetensors"
Depth-Anything-3/docs/API.md ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📚 DepthAnything3 API Documentation
2
+
3
+ ## 📑 Table of Contents
4
+
5
+ 1. [📖 Overview](#overview)
6
+ 2. [💡 Usage Examples](#usage-examples)
7
+ 3. [🔧 Core API](#core-api)
8
+ - [DepthAnything3 Class](#depthanything3-class)
9
+ - [inference() Method](#inference-method)
10
+ 4. [⚙️ Parameters](#parameters)
11
+ - [Input Parameters](#input-parameters)
12
+ - [Pose Alignment Parameters](#pose-alignment-parameters)
13
+ - [Feature Export Parameters](#feature-export-parameters)
14
+ - [Rendering Parameters](#rendering-parameters)
15
+ - [Processing Parameters](#processing-parameters)
16
+ - [Export Parameters](#export-parameters)
17
+ 5. [📤 Export Formats](#export-formats)
18
+ 6. [↩️ Return Value](#return-value)
19
+
20
+ ## 📖 Overview
21
+
22
+ This documentation provides comprehensive API reference for DepthAnything3, including usage examples, parameter specifications, export formats, and advanced features. It covers both basic pose and depth estimation workflows and advanced pose-conditioned processing with multiple export capabilities.
23
+
24
+ ## 💡 Usage Examples
25
+
26
+ Here are quick examples to get you started:
27
+
28
+ ### 🚀 Basic Depth Estimation
29
+ ```python
30
+ from depth_anything_3.api import DepthAnything3
31
+
32
+ # Initialize and run inference
33
+ model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE").to("cuda")
34
+ prediction = model.inference(["image1.jpg", "image2.jpg"])
35
+ ```
36
+
37
+ ### 📷 Pose-Conditioned Depth Estimation
38
+ ```python
39
+ import numpy as np
40
+
41
+ # With camera parameters for better consistency
42
+ prediction = model.inference(
43
+ image=["image1.jpg", "image2.jpg"],
44
+ extrinsics=extrinsics_array, # (N, 4, 4)
45
+ intrinsics=intrinsics_array # (N, 3, 3)
46
+ )
47
+ ```
48
+
49
+ ### 📤 Export Results
50
+ ```python
51
+ # Export depth data and 3D visualization
52
+ prediction = model.inference(
53
+ image=image_paths,
54
+ export_dir="./output",
55
+ export_format="mini_npz-glb"
56
+ )
57
+ ```
58
+
59
+ ### 🔍 Feature Extraction
60
+ ```python
61
+ # Export intermediate features from specific layers
62
+ prediction = model.inference(
63
+ image=image_paths,
64
+ export_dir="./output",
65
+ export_format="feat_vis",
66
+ export_feat_layers=[0, 1, 2] # Export features from layers 0, 1, 2
67
+ )
68
+ ```
69
+
70
+ ### ✨ Advanced Export with Gaussian Splatting
71
+ ```python
72
+ # Export multiple formats including Gaussian Splatting
73
+ # Note: infer_gs=True requires da3-giant or da3nested-giant-large model
74
+ model = DepthAnything3(model_name="da3-giant").to("cuda")
75
+
76
+ prediction = model.inference(
77
+ image=image_paths,
78
+ extrinsics=extrinsics_array,
79
+ intrinsics=intrinsics_array,
80
+ export_dir="./output",
81
+ export_format="npz-glb-gs_ply-gs_video",
82
+ align_to_input_ext_scale=True,
83
+ infer_gs=True, # Required for gs_ply and gs_video exports
84
+ )
85
+ ```
86
+
87
+ ### 🎨 Advanced Export with Feature Visualization
88
+ ```python
89
+ # Export with intermediate feature visualization
90
+ prediction = model.inference(
91
+ image=image_paths,
92
+ export_dir="./output",
93
+ export_format="mini_npz-glb-depth_vis-feat_vis",
94
+ export_feat_layers=[0, 5, 10, 15, 20],
95
+ feat_vis_fps=30,
96
+ )
97
+ ```
98
+
99
+ ### 📐 Using Ray-Based Pose Estimation
100
+ ```python
101
+ # Use ray-based pose estimation instead of camera decoder
102
+ prediction = model.inference(
103
+ image=image_paths,
104
+ export_dir="./output",
105
+ export_format="glb",
106
+ use_ray_pose=True, # Enable ray-based pose estimation
107
+ )
108
+ ```
109
+
110
+ ### 🎯 Reference View Selection
111
+ ```python
112
+ # For multi-view inputs, automatically select the best reference view
113
+ prediction = model.inference(
114
+ image=image_paths,
115
+ ref_view_strategy="saddle_balanced", # Default: balanced selection
116
+ )
117
+
118
+ # For video sequences, use middle frame as reference
119
+ prediction = model.inference(
120
+ image=video_frames,
121
+ ref_view_strategy="middle", # Good for temporally ordered inputs
122
+ )
123
+ ```
124
+
125
+ ## 🔧 Core API
126
+
127
+ ### 🔨 DepthAnything3 Class
128
+
129
+ The main API class that provides depth estimation capabilities with optional pose conditioning.
130
+
131
+ #### 🎯 Initialization
132
+
133
+ ```python
134
+ from depth_anything_3 import DepthAnything3
135
+
136
+ # Initialize the model with a model name
137
+ model = DepthAnything3(model_name="da3-large")
138
+ model = model.to("cuda") # Move to GPU
139
+ ```
140
+
141
+ **Parameters:**
142
+ - `model_name` (str, default: "da3-large"): The name of the model preset to use.
143
+ - **Available models:**
144
+ - 🦾 `"da3-giant"` - 1.15B params, any-view model with GS support
145
+ - ⭐ `"da3-large"` - 0.35B params, any-view model (recommended for most use cases)
146
+ - 📦 `"da3-base"` - 0.12B params, any-view model
147
+ - 🪶 `"da3-small"` - 0.08B params, any-view model
148
+ - 👁️ `"da3mono-large"` - 0.35B params, monocular depth only
149
+ - 📏 `"da3metric-large"` - 0.35B params, metric depth with sky segmentation
150
+ - 🎯 `"da3nested-giant-large"` - 1.40B params, nested model with all features
151
+
152
+ ### 🚀 inference() Method
153
+
154
+ The primary inference method that processes images and returns depth predictions.
155
+
156
+ ```python
157
+ prediction = model.inference(
158
+ image=image_list,
159
+ extrinsics=extrinsics_array, # Optional
160
+ intrinsics=intrinsics_array, # Optional
161
+ align_to_input_ext_scale=True, # Whether to align predicted poses to input scale
162
+ infer_gs=True, # Enable Gaussian branch for gs exports
163
+ use_ray_pose=False, # Use ray-based pose estimation instead of camera decoder
164
+ ref_view_strategy="saddle_balanced", # Reference view selection strategy
165
+ render_exts=render_extrinsics, # Optional renders for gs_video
166
+ render_ixts=render_intrinsics, # Optional renders for gs_video
167
+ render_hw=(height, width), # Optional renders for gs_video
168
+ process_res=504,
169
+ process_res_method="upper_bound_resize",
170
+ export_dir="output_directory", # Optional
171
+ export_format="mini_npz",
172
+ export_feat_layers=[], # List of layer indices to export features from
173
+ conf_thresh_percentile=40.0, # Confidence threshold percentile for depth map in GLB export
174
+ num_max_points=1_000_000, # Maximum number of points to export in GLB export
175
+ show_cameras=True, # Whether to show cameras in GLB export
176
+ feat_vis_fps=15, # Frames per second for feature visualization in feat_vis export
177
+ export_kwargs={} # Optional, additional arguments to export functions. export_format:key:val, see 'Parameters/Export Parameters' for details
178
+ )
179
+ ```
180
+
181
+ ## ⚙️ Parameters
182
+
183
+ ### 📸 Input Parameters
184
+
185
+ #### `image` (required)
186
+ - **Type**: `List[Union[np.ndarray, Image.Image, str]]`
187
+ - **Description**: List of input images. Can be numpy arrays, PIL Images, or file paths.
188
+ - **Example**:
189
+ ```python
190
+ # From file paths
191
+ image = ["image1.jpg", "image2.jpg", "image3.jpg"]
192
+
193
+ # From numpy arrays
194
+ image = [np.array(img1), np.array(img2)]
195
+
196
+ # From PIL Images
197
+ image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
198
+ ```
199
+
200
+ #### `extrinsics` (optional)
201
+ - **Type**: `Optional[np.ndarray]`
202
+ - **Shape**: `(N, 4, 4)` where N is the number of input images
203
+ - **Description**: Camera extrinsic matrices (world-to-camera transformation). When provided, enables pose-conditioned depth estimation mode.
204
+ - **Note**: If not provided, the model operates in standard depth estimation mode.
205
+
206
+ #### `intrinsics` (optional)
207
+ - **Type**: `Optional[np.ndarray]`
208
+ - **Shape**: `(N, 3, 3)` where N is the number of input images
209
+ - **Description**: Camera intrinsic matrices containing focal length and principal point information. When provided, enables pose-conditioned depth estimation mode.
210
+
211
+ ### 🎯 Pose Alignment Parameters
212
+
213
+ #### `align_to_input_ext_scale` (default: True)
214
+ - **Type**: `bool`
215
+ - **Description**: When True the predicted extrinsics are replaced with the input
216
+ ones and the depth maps are rescaled to match their metric scale. When False the
217
+ function returns the internally aligned poses computed via Umeyama alignment.
218
+
219
+ #### `infer_gs` (default: False)
220
+ - **Type**: `bool`
221
+ - **Description**: Enable Gaussian Splatting branch for gaussian splatting exports. Required when using `gs_ply` or `gs_video` export formats.
222
+
223
+ #### `use_ray_pose` (default: False)
224
+ - **Type**: `bool`
225
+ - **Description**: Use ray-based pose estimation instead of camera decoder for pose prediction. When True, the model uses ray prediction heads to estimate camera poses; when False, it uses the camera decoder approach.
226
+
227
+ #### `ref_view_strategy` (default: "saddle_balanced")
228
+ - **Type**: `str`
229
+ - **Description**: Strategy for selecting the reference view from multiple input views. Options: `"first"`, `"middle"`, `"saddle_balanced"`, `"saddle_sim_range"`. Only applied when number of views ≥ 3. See [detailed documentation](funcs/ref_view_strategy.md) for strategy comparisons.
230
+ - **Available strategies**:
231
+ - `"saddle_balanced"`: Selects view with balanced features across multiple metrics (recommended default)
232
+ - `"saddle_sim_range"`: Selects view with largest similarity range
233
+ - `"first"`: Always uses first view (not recommended, equivalent to no reordering for views < 3)
234
+ - `"middle"`: Uses middle view (recommended for video sequences)
235
+
236
+ ### 🔍 Feature Export Parameters
237
+
238
+ #### `export_feat_layers` (default: [])
239
+ - **Type**: `List[int]`
240
+ - **Description**: List of layer indices to export intermediate features from. Features are stored in the `aux` dictionary of the Prediction object with keys like `feat_layer_0`, `feat_layer_1`, etc.
241
+
242
+ ### 🎥 Rendering Parameters
243
+
244
+ These arguments are only used when exporting Gaussian-splatting videos (include
245
+ `"gs_video"` in `export_format`). They describe an auxiliary camera trajectory
246
+ with ``M`` views.
247
+
248
+ #### `render_exts` (optional)
249
+ - **Type**: `Optional[np.ndarray]`
250
+ - **Shape**: `(M, 4, 4)`
251
+ - **Description**: Camera extrinsics for the synthesized trajectory. If omitted,
252
+ the exporter falls back to the predicted poses.
253
+
254
+ #### `render_ixts` (optional)
255
+ - **Type**: `Optional[np.ndarray]`
256
+ - **Shape**: `(M, 3, 3)`
257
+ - **Description**: Camera intrinsics for each rendered frame. Leave `None` to
258
+ reuse the input intrinsics.
259
+
260
+ #### `render_hw` (optional)
261
+ - **Type**: `Optional[Tuple[int, int]]`
262
+ - **Description**: Explicit output resolution `(height, width)` for the rendered
263
+ frames. Defaults to the input resolution when not provided.
264
+
265
+ ### ⚡ Processing Parameters
266
+
267
+ #### `process_res` (default: 504)
268
+ - **Type**: `int`
269
+ - **Description**: Base resolution for processing. The model will resize images to this resolution for inference.
270
+
271
+ #### `process_res_method` (default: "upper_bound_resize")
272
+ - **Type**: `str`
273
+ - **Description**: Method for resizing images to the target resolution.
274
+ - **Options**:
275
+ - `"upper_bound_resize"`: Resize so that the specified dimension (504) becomes the longer side
276
+ - `"lower_bound_resize"`: Resize so that the specified dimension (504) becomes the shorter side
277
+ - **Example**:
278
+ - Input: 1200×1600 → Output: 378×504 (with `process_res=504`, `process_res_method="upper_bound_resize"`)
279
+ - Input: 504×672 → Output: 504×672 (no change needed)
280
+
281
+ ### 📦 Export Parameters
282
+
283
+ #### `export_dir` (optional)
284
+ - **Type**: `Optional[str]`
285
+ - **Description**: Directory path where exported files will be saved. If not provided, no files will be exported.
286
+
287
+ #### `export_format` (default: "mini_npz")
288
+ - **Type**: `str`
289
+ - **Description**: Format for exporting results. Supports multiple formats separated by `-`.
290
+ - **Example**: `"mini_npz-glb"` exports both mini_npz and glb formats.
291
+
292
+ #### 🌐 GLB Export Parameters
293
+
294
+ These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"glb"`.
295
+
296
+ ##### `conf_thresh_percentile` (default: 40.0)
297
+ - **Type**: `float`
298
+ - **Description**: Lower percentile for adaptive confidence threshold. Points below this confidence percentile will be filtered out from the point cloud.
299
+
300
+ ##### `num_max_points` (default: 1,000,000)
301
+ - **Type**: `int`
302
+ - **Description**: Maximum number of points in the exported point cloud. If the point cloud exceeds this limit, it will be downsampled.
303
+
304
+ ##### `show_cameras` (default: True)
305
+ - **Type**: `bool`
306
+ - **Description**: Whether to include camera wireframes in the exported GLB file for visualization.
307
+
308
+ #### 🎨 Feature Visualization Parameters
309
+
310
+ These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"feat_vis"`.
311
+
312
+ ##### `feat_vis_fps` (default: 15)
313
+ - **Type**: `int`
314
+ - **Description**: Frame rate for the output video when visualizing features across multiple images.
315
+
316
+ #### ✨🎥 3DGS and 3DGS Video Parameters
317
+
318
+ These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"gs_ply"` or `"gs_video"`.
319
+
320
+ ##### `export_kwargs` (default: `{}`)
321
+ - Type: `dict[str, dict[str, Any]]`
322
+ - Description: Per-format extra arguments passed to export functions, mainly for `"gs_ply"` and `"gs_video"`.
323
+ - Access pattern: `export_kwargs[export_format][key] = value`
324
+ - Example:
325
+ ```python
326
+ {
327
+ "gs_ply": {
328
+ "gs_views_interval": 1,
329
+ },
330
+ "gs_video": {
331
+ "trj_mode": "interpolate_smooth",
332
+ "chunk_size": 1,
333
+ "vis_depth": None,
334
+ },
335
+ }
336
+ ```
337
+
338
+ ## 📤 Export Formats
339
+
340
+ The API supports multiple export formats for different use cases:
341
+
342
+ ### 📊 `mini_npz`
343
+ - **Description**: Minimal NPZ format containing essential data
344
+ - **Contents**: `depth`, `conf`, `exts`, `ixts`
345
+ - **Use case**: Lightweight storage for depth data with camera parameters
346
+
347
+ ### 📦 `npz`
348
+ - **Description**: Full NPZ format with comprehensive data
349
+ - **Contents**: `depth`, `conf`, `exts`, `ixts`, `image`, etc.
350
+ - **Use case**: Complete data export for advanced processing
351
+
352
+ ### 🌐 `glb`
353
+ - **Description**: 3D visualization format with point cloud and camera poses
354
+ - **Contents**:
355
+ - Point cloud with colors from original images
356
+ - Camera wireframes for visualization
357
+ - Confidence-based filtering and downsampling
358
+ - **Use case**: 3D visualization, inspection, and analysis
359
+ - **Features**:
360
+ - Automatic sky depth handling
361
+ - Confidence threshold filtering
362
+ - Background filtering (black/white)
363
+ - Scene scale normalization
364
+ - **Parameters** (passed via `inference()` method directly):
365
+ - `conf_thresh_percentile` (float, default: 40.0): Lower percentile for adaptive confidence threshold. Points below this confidence percentile will be filtered out.
366
+ - `num_max_points` (int, default: 1,000,000): Maximum number of points in the exported point cloud. If exceeded, points will be downsampled.
367
+ - `show_cameras` (bool, default: True): Whether to include camera wireframes in the exported GLB file for visualization.
368
+
369
+ ### ✨ `gs_ply`
370
+ - **Description**: Gaussian Splatting point cloud format
371
+ - **Contents**: 3DGS data in PLY format. Compatible with standard 3DGS viewers such as [SuperSplat](https://superspl.at/editor) (recommended), [SPARK](https://sparkjs.dev/viewer/).
372
+ - **Use case**: Gaussian Splatting reconstruction
373
+ - **Requirements**: Must set `infer_gs=True` when calling `inference()`. Only supported by `da3-giant` and `da3nested-giant-large` models.
374
+ - **Additional configs**, provided via `export_kwargs` (see [Export Parameters](#export-parameters)):
375
+ - `gs_views_interval`: Export to 3DGS every N views, default: `1`.
376
+
377
+ ### 🎥 `gs_video`
378
+ - **Description**: Rasterized 3DGS to obtain videos
379
+ - **Contents**: A video of 3DGS-rasterized views using either provided viewpoints or a predefined camera trajectory.
380
+ - **Use case**: Video rendering for Gaussian Splatting
381
+ - **Requirements**: Must set `infer_gs=True` when calling `inference()`. Only supported by `da3-giant` and `da3nested-giant-large` models.
382
+ - **Note**: Can optionally use `render_exts`, `render_ixts`, and `render_hw` parameters in `inference()` method to specify novel viewpoints.
383
+ - **Additional configs**, provided via `export_kwargs` (see [Export Parameters](#export-parameters)):
384
+ - `extrinsics`: Optional world-to-camera poses for novel views. Falls back to the predicted poses of input views if not provided. (Alternatively, use `render_exts` parameter in `inference()`)
385
+ - `intrinsics`: Optional camera intrinsics for novel views. Falls back to the predicted intrinsics of input views if not provided. (Alternatively, use `render_ixts` parameter in `inference()`)
386
+ - `out_image_hw`: Optional output resolution `H x W`. Falls back to input resolution if not provided. (Alternatively, use `render_hw` parameter in `inference()`)
387
+ - `chunk_size`: Number of views rasterized per batch. Default: `8`.
388
+ - `trj_mode`: Predefined camera trajectory for novel-view rendering.
389
+ - `color_mode`: Same as `render_mode` in [gsplat](https://docs.gsplat.studio/main/apis/rasterization.html#gsplat.rasterization).
390
+ - `vis_depth`: How depth is combined with RGB. Default: `hcat` (horizontal concatenation).
391
+ - `enable_tqdm`: Whether to display a tqdm progress bar during rendering.
392
+ - `output_name`: File name of the rendered video.
393
+ - `video_quality`: Video quality to save. Default: `high`.
394
+ - `high`: High quality video (default)
395
+ - `medium`: Medium quality video (balance of storage space and quality)
396
+ - `low`: Low quality video (fewer storage space)
397
+
398
+ ### 🔍 `feat_vis`
399
+ - **Description**: Feature visualization format
400
+ - **Contents**: PCA-visualized intermediate features from specified layers
401
+ - **Use case**: Model interpretability and feature analysis
402
+ - **Note**: Requires `export_feat_layers` to be specified
403
+ - **Parameters** (passed via `inference()` method directly):
404
+ - `feat_vis_fps` (int, default: 15): Frame rate for the output video when visualizing features across multiple images.
405
+
406
+ ### 🎨 `depth_vis`
407
+ - **Description**: Depth visualization format
408
+ - **Contents**: Color-coded depth maps alongside original images
409
+ - **Use case**: Visual inspection of depth estimation quality
410
+
411
+ ### 🔗 Multiple Format Export
412
+ You can export multiple formats simultaneously by separating them with `-`:
413
+
414
+ ```python
415
+ # Export both mini_npz and glb formats
416
+ export_format = "mini_npz-glb"
417
+
418
+ # Export multiple formats
419
+ export_format = "npz-glb-gs_ply"
420
+ ```
421
+
422
+ ## ↩️ Return Value
423
+
424
+ The `inference()` method returns a `Prediction` object with the following attributes:
425
+
426
+ ### 📊 Core Outputs
427
+
428
+ - **depth**: `np.ndarray` - Estimated depth maps with shape `(N, H, W)` where N is the number of images, H is height, and W is width.
429
+ - **conf**: `np.ndarray` - Confidence maps with shape `(N, H, W)` indicating prediction reliability (optional, depends on model).
430
+
431
+ ### 📷 Camera Parameters
432
+
433
+ - **extrinsics**: `np.ndarray` - Camera extrinsic matrices with shape `(N, 3, 4)` representing world-to-camera transformations. Only present if camera poses were estimated or provided as input.
434
+ - **intrinsics**: `np.ndarray` - Camera intrinsic matrices with shape `(N, 3, 3)` containing focal length and principal point information. Only present if poses were estimated or provided as input.
435
+
436
+ ### 🎁 Additional Outputs
437
+
438
+ - **processed_images**: `np.ndarray` - Preprocessed input images with shape `(N, H, W, 3)` in RGB format (0-255 uint8).
439
+ - **aux**: `dict` - Auxiliary outputs including:
440
+ - `feat_layer_X`: Intermediate features from layer X (if `export_feat_layers` was specified)
441
+ - `gaussians`: 3D Gaussian Splats data (if `infer_gs=True`)
442
+
443
+ ### 💻 Usage Example
444
+
445
+ ```python
446
+ prediction = model.inference(image=["img1.jpg", "img2.jpg"])
447
+
448
+ # Access depth maps
449
+ depth_maps = prediction.depth # shape: (2, H, W)
450
+
451
+ # Access confidence
452
+ if hasattr(prediction, 'conf'):
453
+ confidence = prediction.conf
454
+
455
+ # Access camera parameters (if available)
456
+ if hasattr(prediction, 'extrinsics'):
457
+ camera_poses = prediction.extrinsics # shape: (2, 4, 4)
458
+
459
+ if hasattr(prediction, 'intrinsics'):
460
+ camera_intrinsics = prediction.intrinsics # shape: (2, 3, 3)
461
+
462
+ # Access intermediate features (if export_feat_layers was set)
463
+ if hasattr(prediction, 'aux') and 'feat_layer_0' in prediction.aux:
464
+ features = prediction.aux['feat_layer_0']
465
+ ```
Depth-Anything-3/docs/BENCHMARK.md ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📏 Visual Geometry Benchmark
2
+
3
+ This document provides comprehensive instructions for running benchmark evaluation on Depth Anything 3.
4
+
5
+ ## ✨ Highlights
6
+
7
+ - 🗂️ **Diverse and Challenging Datasets**: 5 datasets (ETH3D, 7Scenes, ScanNet++, HiRoom, DTU) covering from objects to indoor and outdoor scenes. Part of datasets are recalibrated for high accuracy (see [ScanNet++](#scannet) details). All preprocessed datasets are uploaded to [depth-anything/DA3-BENCH](https://huggingface.co/datasets/depth-anything/DA3-BENCH).
8
+ - 🔧 **Robust Evaluation Pipeline**: Standardized pipeline featuring RANSAC-based pose alignment for better coordinate system alignment, TSDF fusion for directly reflecting depth 3D consistency.
9
+ - 📊 **Standardized Metrics**: Performance measured using established metrics: AUC for pose accuracy, F1-score and Chamfer Distance for reconstruction.
10
+
11
+ ---
12
+
13
+ ## 📑 Table of Contents
14
+
15
+ - [🚀 Quick Start](#quick-start)
16
+ - [📥 Dataset Download](#dataset-download)
17
+ - [⚙️ Evaluation Pipeline](#evaluation-pipeline)
18
+ - [🔧 Configuration](#configuration)
19
+ - [📊 Metrics](#metrics)
20
+ - [🗂️ Dataset Details](#dataset-details)
21
+ - [💻 Command Reference](#command-reference)
22
+ - [🔍 Troubleshooting](#troubleshooting)
23
+
24
+ ---
25
+
26
+ ## 🚀 Quick Start
27
+
28
+ ### 1. Download Benchmark Data
29
+
30
+ > 💡 **Note:** Install HuggingFace CLI first: `pip install -U huggingface_hub[cli]`
31
+ >
32
+ > 🌐 **Mirror:** If download is slow, try: `export HF_ENDPOINT=https://hf-mirror.com`
33
+
34
+ ```bash
35
+ cd da3_release
36
+
37
+ # Create directory and download from HuggingFace
38
+ mkdir -p workspace/benchmark_dataset
39
+ hf download depth-anything/DA3-BENCH \
40
+ --local-dir workspace/benchmark_dataset \
41
+ --repo-type dataset
42
+
43
+ # Extract all datasets
44
+ cd workspace/benchmark_dataset
45
+ for f in *.zip; do unzip -q "$f"; done
46
+ ```
47
+
48
+ ### 2. Run Evaluation
49
+
50
+ ```bash
51
+ # Set model (default: depth-anything/DA3-GIANT)
52
+ MODEL=depth-anything/DA3-GIANT
53
+
54
+ # Full evaluation (all datasets, all modes)
55
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL
56
+
57
+ # View results
58
+ python -m depth_anything_3.bench.evaluator eval.print_only=true
59
+ ```
60
+
61
+ ---
62
+
63
+ ## 📥 Dataset Download
64
+
65
+ All benchmark datasets are hosted on HuggingFace: **[depth-anything/DA3-BENCH](https://huggingface.co/datasets/depth-anything/DA3-BENCH)**
66
+
67
+ | Dataset | File | Size | Description |
68
+ |---------|------|------|-------------|
69
+ | ETH3D | `eth3d.zip` | ~14.1 GB | High-resolution multi-view stereo (indoor/outdoor) |
70
+ | ScanNet++ | `scannetpp.zip` | ~10.1 GB | High-quality RGB-D indoor scenes |
71
+ | DTU-49 | `dtu.zip` | ~8.3 GB | Multi-view stereo benchmark (22 scenes × 49 views) |
72
+ | 7Scenes | `7scenes.zip` | ~3.3 GB | RGB-D indoor localization |
73
+ | DTU-64 | `dtu64.zip` | ~1.7 GB | DTU subset for pose evaluation (13 scenes × 64 views) |
74
+ | HiRoom | `hiroom.zip` | ~0.7 GB | High-resolution indoor rooms |
75
+
76
+ ### Download Options
77
+
78
+ **Option 1: Download All (Recommended)**
79
+ ```bash
80
+ hf download depth-anything/DA3-BENCH \
81
+ --local-dir workspace/benchmark_dataset \
82
+ --repo-type dataset
83
+ ```
84
+
85
+ **Option 2: Download Specific Dataset**
86
+ ```bash
87
+ # Download only HiRoom
88
+ hf download depth-anything/DA3-BENCH hiroom.zip \
89
+ --local-dir workspace/benchmark_dataset \
90
+ --repo-type dataset
91
+ ```
92
+
93
+ **Option 3: Manual Download**
94
+
95
+ Visit [https://huggingface.co/datasets/depth-anything/DA3-BENCH](https://huggingface.co/datasets/depth-anything/DA3-BENCH) and download the zip files manually.
96
+
97
+ ### Extract Datasets
98
+
99
+ ```bash
100
+ cd workspace/benchmark_dataset
101
+
102
+ # Extract all
103
+ for f in *.zip; do unzip -q "$f"; done
104
+
105
+ # Or extract specific dataset
106
+ unzip hiroom.zip
107
+ ```
108
+
109
+ ### Expected Directory Structure
110
+
111
+ After extraction, your directory should look like:
112
+ ```
113
+ workspace/benchmark_dataset/
114
+ ├── eth3d/
115
+ │ ├── courtyard/
116
+ │ ├── electro/
117
+ │ └── ...
118
+ ├── 7scenes/
119
+ │ └── 7Scenes/
120
+ │ ├── chess/
121
+ │ └── ...
122
+ ├── scannetpp/
123
+ │ ├── 09c1414f1b/
124
+ │ └── ...
125
+ ├── hiroom/
126
+ │ ├── data/
127
+ │ ├── fused_pcd/
128
+ │ └── selected_scene_list_val.txt
129
+ ├── dtu/
130
+ │ ├── Rectified/
131
+ │ ├── Cameras/
132
+ │ ├── Points/
133
+ │ ├── SampleSet/
134
+ │ └── depth_raw/
135
+ └── dtu64/
136
+ ├── Cameras/
137
+ ├── scan105/
138
+ └── ...
139
+ ```
140
+
141
+ ---
142
+
143
+ ## ⚙️ Evaluation Pipeline
144
+
145
+
146
+
147
+ ### Evaluation Modes
148
+
149
+ | Mode | Description | Metrics |
150
+ |------|-------------|---------|
151
+ | `pose` | Camera pose estimation | AUC@3°, AUC@30° |
152
+ | `recon_unposed` | 3D reconstruction with **predicted** poses | F-score, Overall |
153
+ | `recon_posed` | 3D reconstruction with **GT** poses | F-score, Overall |
154
+
155
+ ### Basic Usage
156
+
157
+ ```bash
158
+ cd da3_release
159
+ MODEL=depth-anything/DA3-GIANT
160
+
161
+ # Full evaluation (inference + evaluation + print results)
162
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL
163
+
164
+ # Skip inference, only evaluate existing predictions
165
+ python -m depth_anything_3.bench.evaluator eval.eval_only=true
166
+
167
+ # Only print saved metrics
168
+ python -m depth_anything_3.bench.evaluator eval.print_only=true
169
+ ```
170
+
171
+ ### Selective Evaluation
172
+
173
+ ```bash
174
+ # Evaluate specific datasets
175
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL eval.datasets=[hiroom]
176
+
177
+ # Evaluate specific modes
178
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL eval.modes=[pose,recon_unposed]
179
+
180
+ # Combine dataset and mode selection
181
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL \
182
+ eval.datasets=[hiroom] \
183
+ eval.modes=[pose]
184
+ ```
185
+
186
+ ### 🖥️ Multi-GPU Inference
187
+
188
+ The evaluator automatically distributes inference across available GPUs:
189
+
190
+ ```bash
191
+ # Use 4 GPUs
192
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m depth_anything_3.bench.evaluator model.path=$MODEL
193
+
194
+ # Use all available GPUs (default)
195
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL
196
+
197
+ # Single GPU
198
+ CUDA_VISIBLE_DEVICES=0 python -m depth_anything_3.bench.evaluator model.path=$MODEL
199
+ ```
200
+
201
+ ---
202
+
203
+ ## 🔧 Configuration
204
+
205
+ ### Config File
206
+
207
+ Default config: `src/depth_anything_3/bench/configs/eval_bench.yaml`
208
+
209
+ ```yaml
210
+ # Model path
211
+ model:
212
+ path: depth-anything/DA3-GIANT
213
+
214
+ # Workspace directory
215
+ workspace:
216
+ work_dir: ./workspace/evaluation
217
+
218
+ # Evaluation settings
219
+ eval:
220
+ datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu, dtu64]
221
+ modes: [pose, recon_unposed, recon_posed]
222
+ max_frames: 100 # Max frames per scene (-1 = no limit)
223
+ scenes: null # Specific scenes (null = all)
224
+
225
+ # Inference settings
226
+ inference:
227
+ num_fusion_workers: 4
228
+ debug: false
229
+ ```
230
+
231
+ ### Output Structure
232
+
233
+ ```
234
+ workspace/evaluation/
235
+ ├── model_results/ # Inference outputs
236
+ │ ├── eth3d/
237
+ │ │ └── {scene}/
238
+ │ │ ├── unposed/ # Predictions for recon_unposed
239
+ │ │ └── posed/ # Predictions for recon_posed
240
+ │ ├── 7scenes/
241
+ │ ├── scannetpp/
242
+ │ ├── hiroom/
243
+ │ ├── dtu/
244
+ │ └── dtu64/
245
+ └── metric_results/ # Evaluation metrics (JSON)
246
+ ├── eth3d_pose.json
247
+ ├── eth3d_recon_unposed.json
248
+ ├── eth3d_recon_posed.json
249
+ └── ...
250
+ ```
251
+
252
+ ---
253
+
254
+ ## 📊 Metrics
255
+
256
+ ### 🎯 Pose Estimation
257
+
258
+ | Metric | Description |
259
+ |--------|-------------|
260
+ | **Auc3** | Area Under Curve at 3° angular error threshold |
261
+ | **Auc30** | Area Under Curve at 30° angular error threshold |
262
+
263
+ ### 🏗️ 3D Reconstruction
264
+
265
+ | Metric | Description | Note |
266
+ |--------|-------------|------|
267
+ | **F-score** | Harmonic mean of Precision and Recall | Higher is better |
268
+ | **Overall** | (Accuracy + Completeness) / 2 | Lower is better (error in meters/mm) |
269
+
270
+ > **Note:** DTU reports Overall in millimeters; other datasets report in meters.
271
+
272
+ ### Expected Results for DA3-GIANT
273
+
274
+ If your setup is correct, you should get the following results when evaluating the **DA3-GIANT** model:
275
+
276
+ ```
277
+ ========================================================
278
+ 📊 SUMMARY
279
+ ========================================================
280
+
281
+ 🎯 POSE ESTIMATION
282
+ ---------------------------------------------------------------------------------------
283
+ Metric Avg HiRoom ETH3D DTU-64 7Scenes ScanNet++
284
+ ---------------------------------------------------------------------------------------
285
+ Auc3 0.6705 0.8030 0.4872 0.9408 0.2744 0.8470
286
+ Auc30 0.9436 0.9592 0.9153 0.9939 0.8668 0.9827
287
+
288
+ 🏗️ RECON_UNPOSED (Pred Pose)
289
+ ---------------------------------------------------------------------------------------
290
+ Metric Avg* HiRoom ETH3D DTU 7Scenes ScanNet++
291
+ ---------------------------------------------------------------------------------------
292
+ F-score 0.7345 0.8629 0.7876 N/A 0.5043 0.7831
293
+ Overall 0.1682 0.0457 0.4366 1.7927 0.1230 0.0676
294
+
295
+ 🏗️ RECON_POSED (GT Pose)
296
+ ---------------------------------------------------------------------------------------
297
+ Metric Avg* HiRoom ETH3D DTU 7Scenes ScanNet++
298
+ ---------------------------------------------------------------------------------------
299
+ F-score 0.7978 0.9546 0.8685 N/A 0.5635 0.8045
300
+ Overall 0.1408 0.0213 0.3679 1.7488 0.1092 0.0649
301
+
302
+ * Avg F-score / Overall = average over HiRoom, ETH3D, 7Scenes, ScanNet++ (4 datasets)
303
+ ```
304
+
305
+ ---
306
+
307
+ ## 🗂️ Dataset Details
308
+
309
+ ### ETH3D
310
+
311
+ High-resolution multi-view stereo benchmark with laser-scanned ground truth.
312
+
313
+ - **Scenes:** 11 (courtyard, electro, kicker, pipes, relief, delivery_area, facade, office, playground, relief_2, terrains)
314
+ - **Resolution:** Variable (high-res DSLR images)
315
+ - **GT:** Laser-scanned meshes + depth maps
316
+
317
+ > **⚠️ Image Filtering:** Some images with unusual camera rotations are filtered out for stable evaluation. See `ETH3D_FILTER_KEYS` in `constants.py`.
318
+
319
+ ### 7Scenes
320
+
321
+ RGB-D dataset for camera relocalization.
322
+
323
+ - **Scenes:** 7 (chess, fire, heads, office, pumpkin, redkitchen, stairs)
324
+ - **Resolution:** 640×480
325
+ - **GT:** Poses from KinectFusion, meshes from TSDF fusion
326
+
327
+ ### ScanNet++
328
+
329
+ High-quality indoor RGB-D dataset with dense annotations.
330
+
331
+ - **Scenes:** 20 validation scenes
332
+ - **Resolution:** 768×1024 (after undistortion)
333
+ - **GT:** High-quality meshes from FARO scanner
334
+
335
+ > **⚠️ Camera Pose Re-calibration:** The default ScanNet++ poses are often inaccurate due to motion blur and textureless frames from iPhone captures. We re-ran COLMAP with the following improvements:
336
+ > - **Frame filtering:** Removed blurry images during frame extraction
337
+ > - **Fisheye calibration:** Jointly calibrated fisheye camera for wider FOV and better accuracy
338
+ > - **Exhaustive matching:** Used COLMAP's exhaustive matcher and mapper for reliable poses (takes several days per scene but necessary for quality)
339
+ > - All processed scenes are available at [haotongl/scannetpp_zipnerf](https://huggingface.co/datasets/haotongl/scannetpp_zipnerf)
340
+
341
+ ### HiRoom
342
+
343
+ Indoor room scenes with high-resolution RGB-D data.
344
+
345
+ - **Scenes:** 24 validation scenes
346
+ - **GT:** Fused point clouds
347
+
348
+ ### DTU-49 (Reconstruction Only)
349
+
350
+ Multi-view stereo benchmark following MVSNet evaluation protocol.
351
+
352
+ - **Scenes:** 22 evaluation scenes
353
+ - **Views:** 49 images per scene
354
+ - **GT:** Laser-scanned point clouds with observation masks
355
+ - **Metrics:** Overall only (accuracy + completeness in mm)
356
+
357
+ ### DTU-64 (Pose Only)
358
+
359
+ DTU subset for pose estimation evaluation.
360
+
361
+ - **Scenes:** 13 scenes
362
+ - **Views:** 64 images per scene
363
+ - **Metrics:** AUC@3°, AUC@30°
364
+
365
+ > **Why two DTU settings?**
366
+ > - **DTU-64** (pose): More views = more challenging pose estimation
367
+ > - **DTU-49** (recon): Standard MVSNet protocol for fair comparison with MVS methods
368
+
369
+ ---
370
+
371
+ ## 💻 Command Reference
372
+
373
+ ```
374
+ python -m depth_anything_3.bench.evaluator [OPTIONS] [KEY=VALUE ...]
375
+
376
+ Configuration:
377
+ --config PATH Config YAML file (default: bench/configs/eval_bench.yaml)
378
+
379
+ Config Overrides (using dotlist notation):
380
+ model.path=VALUE Model path or HuggingFace ID
381
+ workspace.work_dir=VALUE Working directory for outputs
382
+ eval.datasets=[dataset1,dataset2] Datasets to evaluate (eth3d,7scenes,scannetpp,hiroom,dtu,dtu64)
383
+ eval.modes=[mode1,mode2] Evaluation modes (pose,recon_unposed,recon_posed)
384
+ eval.scenes=[scene1,scene2] Specific scenes to evaluate (null=all)
385
+ eval.max_frames=VALUE Max frames per scene (-1=no limit, default: 100)
386
+ eval.ref_view_strategy=VALUE Reference view strategy (default: first)
387
+ eval.eval_only=VALUE Only run evaluation (skip inference) (true/false)
388
+ eval.print_only=VALUE Only print saved metrics (true/false)
389
+ inference.num_fusion_workers=VALUE Number of parallel workers (default: 4)
390
+ inference.debug=VALUE Enable debug mode (true/false)
391
+
392
+ Special Flags:
393
+ --help, -h Show this help message
394
+
395
+ Multi-GPU:
396
+ Use CUDA_VISIBLE_DEVICES to specify GPUs (auto-detected and distributed)
397
+ ```
398
+
399
+ ### Examples
400
+
401
+ ```bash
402
+ MODEL=depth-anything/DA3-GIANT
403
+
404
+ # Full evaluation
405
+ python -m depth_anything_3.bench.evaluator model.path=$MODEL
406
+
407
+ # Quick test on HiRoom only
408
+ python -m depth_anything_3.bench.evaluator \
409
+ model.path=$MODEL \
410
+ eval.datasets=[hiroom] \
411
+ eval.modes=[pose]
412
+
413
+ # Pose-only evaluation (all 5 pose datasets)
414
+ python -m depth_anything_3.bench.evaluator \
415
+ model.path=$MODEL \
416
+ eval.datasets=[eth3d,7scenes,scannetpp,hiroom,dtu64] \
417
+ eval.modes=[pose]
418
+
419
+ # Recon-only evaluation (all 5 recon datasets)
420
+ python -m depth_anything_3.bench.evaluator \
421
+ model.path=$MODEL \
422
+ eval.datasets=[eth3d,7scenes,scannetpp,hiroom,dtu] \
423
+ eval.modes=[recon_unposed,recon_posed]
424
+
425
+ # Debug specific scenes
426
+ python -m depth_anything_3.bench.evaluator \
427
+ model.path=$MODEL \
428
+ eval.datasets=[eth3d] \
429
+ eval.scenes=[courtyard] \
430
+ inference.debug=true
431
+
432
+ # Re-evaluate without re-running inference
433
+ python -m depth_anything_3.bench.evaluator eval.eval_only=true
434
+
435
+ # Just view results
436
+ python -m depth_anything_3.bench.evaluator eval.print_only=true
437
+ ```
438
+
439
+ ---
440
+
441
+ ## 🔍 Troubleshooting
442
+
443
+ ### Data Path Issues
444
+
445
+ Ensure dataset paths in `src/depth_anything_3/utils/constants.py` are correct:
446
+
447
+ ```python
448
+ # Default paths (relative to project root)
449
+ ETH3D_EVAL_DATA_ROOT = "workspace/benchmark_dataset/eth3d"
450
+ SEVENSCENES_EVAL_DATA_ROOT = "workspace/benchmark_dataset/7scenes"
451
+ SCANNETPP_EVAL_DATA_ROOT = "workspace/benchmark_dataset/scannetpp"
452
+ HIROOM_EVAL_DATA_ROOT = "workspace/benchmark_dataset/hiroom/data"
453
+ DTU_EVAL_DATA_ROOT = "workspace/benchmark_dataset/dtu"
454
+ DTU64_EVAL_DATA_ROOT = "workspace/benchmark_dataset/dtu64"
455
+ ```
456
+
457
+ ---
458
+
459
+ ## 📝 Citation
460
+
461
+ If you find this benchmark useful, please cite:
462
+
463
+ ```
464
+ @article{depthanything3,
465
+ title={Depth Anything 3: Recovering the visual space from any views},
466
+ author={Haotong Lin and Sili Chen and Jun Hao Liew and Donny Y. Chen and Zhenyu Li and Guang Shi and Jiashi Feng and Bingyi Kang},
467
+ journal={arXiv preprint arXiv:2511.10647},
468
+ year={2025}
469
+ }
470
+ ```
471
+
472
+ Please also cite the original dataset papers for each benchmark you use.
473
+
474
+ ---
475
+
476
+ ## 📄 License
477
+
478
+ The benchmark datasets are provided for research purposes only. Users must follow the original licenses of each dataset:
479
+
480
+ - **ETH3D:** [https://www.eth3d.net/](https://www.eth3d.net/)
481
+ - **7Scenes:** [Microsoft Research](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/)
482
+ - **ScanNet++:** [http://www.scan-net.org/](http://www.scan-net.org/)
483
+ - **DTU:** [https://roboimagedata.compute.dtu.dk/](https://roboimagedata.compute.dtu.dk/)
484
+ - **HiRoom:** [SVLightVerse](https://jerrypiglet.github.io/SVLightVerse/)
Depth-Anything-3/docs/CLI.md ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Depth Anything 3 Command Line Interface
2
+
3
+ ## 📋 Table of Contents
4
+
5
+ - [📖 Overview](#overview)
6
+ - [⚡ Quick Start](#quick-start)
7
+ - [📚 Command Reference](#command-reference)
8
+ - [🤖 auto - Auto Mode](#auto---auto-mode)
9
+ - [🖼️ image - Single Image Processing](#image---single-image-processing)
10
+ - [🗂️ images - Image Directory Processing](#images---image-directory-processing)
11
+ - [🎬 video - Video Processing](#video---video-processing)
12
+ - [📐 colmap - COLMAP Dataset Processing](#colmap---colmap-dataset-processing)
13
+ - [🔧 backend - Backend Service](#backend---backend-service)
14
+ - [🎨 gradio - Gradio Application](#gradio---gradio-application)
15
+ - [🖼️ gallery - Gallery Server](#gallery---gallery-server)
16
+ - [⚙️ Parameter Details](#parameter-details)
17
+ - [💡 Usage Examples](#usage-examples)
18
+
19
+ ## 📖 Overview
20
+
21
+ The Depth Anything 3 CLI provides a comprehensive command-line toolkit supporting image depth estimation, video processing, COLMAP dataset handling, and web applications.
22
+
23
+ The backend service enables cache model to GPU so that we do not need to reload model for each command.
24
+
25
+ ## ⚡ Quick Start
26
+
27
+ The CLI can run fully offline or connect to the backend for cached weights and task scheduling:
28
+
29
+ ```bash
30
+ # 🔧 Start backend service (optional, keeps model resident in GPU memory)
31
+ da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE
32
+
33
+ # 🚀 Use auto mode to process input
34
+ da3 auto path/to/input --export-dir ./workspace/scene001
35
+
36
+ # ♻️ Reuse backend for next job
37
+ da3 auto path/to/video.mp4 \
38
+ --export-dir ./workspace/scene002 \
39
+ --use-backend \
40
+ --backend-url http://localhost:8008
41
+ ```
42
+
43
+ Each export directory contains `scene.glb`, `scene.jpg`, and optional extras such as `depth_vis/` or `gs_video/` depending on the requested format.
44
+
45
+ ## 📚 Command Reference
46
+
47
+ ### 🤖 auto - Auto Mode
48
+
49
+ Automatically detect input type and dispatch to the appropriate handler.
50
+
51
+ **Usage:**
52
+
53
+ ```bash
54
+ da3 auto INPUT_PATH [OPTIONS]
55
+ ```
56
+
57
+ **Input Type Detection:**
58
+ - 🖼️ Single image file (.jpg, .png, .jpeg, .webp, .bmp, .tiff, .tif)
59
+ - 📁 Image directory
60
+ - 🎬 Video file (.mp4, .avi, .mov, .mkv, .flv, .wmv, .webm, .m4v)
61
+ - 📐 COLMAP directory (containing `images/` and `sparse/` subdirectories)
62
+
63
+ **Parameters:**
64
+
65
+ | Parameter | Type | Default | Description |
66
+ |-----------|------|---------|-------------|
67
+ | `INPUT_PATH` | str | Required | Input path (image, directory, video, or COLMAP) |
68
+ | `--model-dir` | str | Default model | Model directory path |
69
+ | `--export-dir` | str | `debug` | Export directory |
70
+ | `--export-format` | str | `glb` | Export format (supports `mini_npz`, `glb`, `feat_vis`, etc., can be combined with hyphens) |
71
+ | `--device` | str | `cuda` | Device to use |
72
+ | `--use-backend` | bool | `False` | Use backend service for inference |
73
+ | `--backend-url` | str | `http://localhost:8008` | Backend service URL |
74
+ | `--process-res` | int | `504` | Processing resolution |
75
+ | `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
76
+ | `--export-feat` | str | `""` | Export features from specified layers, comma-separated (e.g., `"0,1,2"`) |
77
+ | `--auto-cleanup` | bool | `False` | Automatically clean export directory without confirmation |
78
+ | `--fps` | float | `1.0` | [Video] Frame sampling FPS |
79
+ | `--sparse-subdir` | str | `""` | [COLMAP] Sparse reconstruction subdirectory (e.g., `"0"` for `sparse/0/`) |
80
+ | `--align-to-input-ext-scale` | bool | `True` | [COLMAP] Align prediction to input extrinsics scale |
81
+ | `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
82
+ | `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy: `first`, `middle`, `saddle_balanced`, `saddle_sim_range`. See [docs](funcs/ref_view_strategy.md) |
83
+ | `--conf-thresh-percentile` | float | `40.0` | [GLB] Lower percentile for adaptive confidence threshold |
84
+ | `--num-max-points` | int | `1000000` | [GLB] Maximum number of points in the point cloud |
85
+ | `--show-cameras` | bool | `True` | [GLB] Show camera wireframes in the exported scene |
86
+ | `--feat-vis-fps` | int | `15` | [FEAT_VIS] Frame rate for output video |
87
+
88
+ **Examples:**
89
+
90
+ ```bash
91
+ # 🖼️ Auto-process an image
92
+ da3 auto path/to/image.jpg --export-dir ./output
93
+
94
+ # 🎬 Auto-process a video
95
+ da3 auto path/to/video.mp4 --fps 2.0 --export-dir ./output
96
+
97
+ # 🔧 Use backend service
98
+ da3 auto path/to/input \
99
+ --export-format mini_npz-glb \
100
+ --use-backend \
101
+ --backend-url http://localhost:8008 \
102
+ --export-dir ./output
103
+ ```
104
+
105
+ ---
106
+
107
+ ### 🖼️ image - Single Image Processing
108
+
109
+ Process a single image for camera pose and depth estimation.
110
+
111
+ **Usage:**
112
+
113
+ ```bash
114
+ da3 image IMAGE_PATH [OPTIONS]
115
+ ```
116
+
117
+ **Parameters:**
118
+
119
+ | Parameter | Type | Default | Description |
120
+ |-----------|------|---------|-------------|
121
+ | `IMAGE_PATH` | str | Required | Input image file path |
122
+ | `--model-dir` | str | Default model | Model directory path |
123
+ | `--export-dir` | str | `debug` | Export directory |
124
+ | `--export-format` | str | `glb` | Export format |
125
+ | `--device` | str | `cuda` | Device to use |
126
+ | `--use-backend` | bool | `False` | Use backend service for inference |
127
+ | `--backend-url` | str | `http://localhost:8008` | Backend service URL |
128
+ | `--process-res` | int | `504` | Processing resolution |
129
+ | `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
130
+ | `--export-feat` | str | `""` | Export feature layer indices (comma-separated) |
131
+ | `--auto-cleanup` | bool | `False` | Automatically clean export directory |
132
+ | `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
133
+ | `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
134
+ | `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
135
+ | `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
136
+ | `--show-cameras` | bool | `True` | [GLB] Show cameras |
137
+ | `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
138
+
139
+ **Examples:**
140
+
141
+ ```bash
142
+ # ✨ Basic usage
143
+ da3 image path/to/image.png --export-dir ./output
144
+
145
+ # ⚡ With backend acceleration
146
+ da3 image path/to/image.png \
147
+ --use-backend \
148
+ --backend-url http://localhost:8008 \
149
+ --export-dir ./output
150
+
151
+ # 🔍 Export feature visualization
152
+ da3 image image.jpg \
153
+ --export-format feat_vis \
154
+ --export-feat "9,19,29,39" \
155
+ --export-dir ./results
156
+ ```
157
+
158
+ ---
159
+
160
+ ### 🗂️ images - Image Directory Processing
161
+
162
+ Process a directory of images for batch depth estimation.
163
+
164
+ **Usage:**
165
+
166
+ ```bash
167
+ da3 images IMAGES_DIR [OPTIONS]
168
+ ```
169
+
170
+ **Parameters:**
171
+
172
+ | Parameter | Type | Default | Description |
173
+ |-----------|------|---------|-------------|
174
+ | `IMAGES_DIR` | str | Required | Directory path containing images |
175
+ | `--image-extensions` | str | `png,jpg,jpeg` | Image file extensions to process (comma-separated) |
176
+ | `--model-dir` | str | Default model | Model directory path |
177
+ | `--export-dir` | str | `debug` | Export directory |
178
+ | `--export-format` | str | `glb` | Export format |
179
+ | `--device` | str | `cuda` | Device to use |
180
+ | `--use-backend` | bool | `False` | Use backend service for inference |
181
+ | `--backend-url` | str | `http://localhost:8008` | Backend service URL |
182
+ | `--process-res` | int | `504` | Processing resolution |
183
+ | `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
184
+ | `--export-feat` | str | `""` | Export feature layer indices |
185
+ | `--auto-cleanup` | bool | `False` | Automatically clean export directory |
186
+ | `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
187
+ | `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
188
+ | `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
189
+ | `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
190
+ | `--show-cameras` | bool | `True` | [GLB] Show cameras |
191
+ | `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
192
+
193
+ **Examples:**
194
+
195
+ ```bash
196
+ # 📁 Process directory (defaults to png/jpg/jpeg)
197
+ da3 images ./image_folder --export-dir ./output
198
+
199
+ # 🎯 Custom extensions
200
+ da3 images ./dataset --image-extensions "png,jpg,webp" --export-dir ./output
201
+
202
+ # 🔧 Use backend service
203
+ da3 images ./dataset \
204
+ --use-backend \
205
+ --backend-url http://localhost:8008 \
206
+ --export-dir ./output
207
+ ```
208
+
209
+ ---
210
+
211
+ ### 🎬 video - Video Processing
212
+
213
+ Process video by extracting frames for depth estimation.
214
+
215
+ **Usage:**
216
+
217
+ ```bash
218
+ da3 video VIDEO_PATH [OPTIONS]
219
+ ```
220
+
221
+ **Parameters:**
222
+
223
+ | Parameter | Type | Default | Description |
224
+ |-----------|------|---------|-------------|
225
+ | `VIDEO_PATH` | str | Required | Input video file path |
226
+ | `--fps` | float | `1.0` | Frame extraction sampling FPS |
227
+ | `--model-dir` | str | Default model | Model directory path |
228
+ | `--export-dir` | str | `debug` | Export directory |
229
+ | `--export-format` | str | `glb` | Export format |
230
+ | `--device` | str | `cuda` | Device to use |
231
+ | `--use-backend` | bool | `False` | Use backend service for inference |
232
+ | `--backend-url` | str | `http://localhost:8008` | Backend service URL |
233
+ | `--process-res` | int | `504` | Processing resolution |
234
+ | `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
235
+ | `--export-feat` | str | `""` | Export feature layer indices |
236
+ | `--auto-cleanup` | bool | `False` | Automatically clean export directory |
237
+ | `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
238
+ | `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
239
+ | `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
240
+ | `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
241
+ | `--show-cameras` | bool | `True` | [GLB] Show cameras |
242
+ | `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
243
+
244
+ **Examples:**
245
+
246
+ ```bash
247
+ # ��� Basic video processing
248
+ da3 video path/to/video.mp4 --export-dir ./output
249
+
250
+ # ⚙️ Control frame sampling and resolution
251
+ da3 video path/to/video.mp4 \
252
+ --fps 2.0 \
253
+ --process-res 1024 \
254
+ --export-dir ./output
255
+
256
+ # 🔧 Use backend service
257
+ da3 video path/to/video.mp4 \
258
+ --use-backend \
259
+ --backend-url http://localhost:8008 \
260
+ --export-dir ./output
261
+ ```
262
+
263
+ ---
264
+
265
+ ### 📐 colmap - COLMAP Dataset Processing
266
+
267
+ Run pose-conditioned depth estimation on COLMAP data.
268
+
269
+ **Usage:**
270
+
271
+ ```bash
272
+ da3 colmap COLMAP_DIR [OPTIONS]
273
+ ```
274
+
275
+ **Parameters:**
276
+
277
+ | Parameter | Type | Default | Description |
278
+ |-----------|------|---------|-------------|
279
+ | `COLMAP_DIR` | str | Required | COLMAP directory containing `images/` and `sparse/` subdirectories |
280
+ | `--sparse-subdir` | str | `""` | Sparse reconstruction subdirectory (e.g., `"0"` for `sparse/0/`) |
281
+ | `--align-to-input-ext-scale` | bool | `True` | Align prediction to input extrinsics scale |
282
+ | `--model-dir` | str | Default model | Model directory path |
283
+ | `--export-dir` | str | `debug` | Export directory |
284
+ | `--export-format` | str | `glb` | Export format |
285
+ | `--device` | str | `cuda` | Device to use |
286
+ | `--use-backend` | bool | `False` | Use backend service for inference |
287
+ | `--backend-url` | str | `http://localhost:8008` | Backend service URL |
288
+ | `--process-res` | int | `504` | Processing resolution |
289
+ | `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
290
+ | `--export-feat` | str | `""` | Export feature layer indices |
291
+ | `--auto-cleanup` | bool | `False` | Automatically clean export directory |
292
+ | `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
293
+ | `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
294
+ | `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
295
+ | `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
296
+ | `--show-cameras` | bool | `True` | [GLB] Show cameras |
297
+ | `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
298
+
299
+ **Examples:**
300
+
301
+ ```bash
302
+ # 📐 Process COLMAP dataset
303
+ da3 colmap ./colmap_dataset --export-dir ./output
304
+
305
+ # 🎯 Use specific sparse subdirectory and align scale
306
+ da3 colmap ./colmap_dataset \
307
+ --sparse-subdir 0 \
308
+ --align-to-input-ext-scale \
309
+ --export-dir ./output
310
+
311
+ # 🔧 Use backend service
312
+ da3 colmap ./colmap_dataset \
313
+ --use-backend \
314
+ --backend-url http://localhost:8008 \
315
+ --export-dir ./output
316
+ ```
317
+
318
+ ---
319
+
320
+ ### 🔧 backend - Backend Service
321
+
322
+ Start model backend service with integrated gallery.
323
+
324
+ **Usage:**
325
+
326
+ ```bash
327
+ da3 backend [OPTIONS]
328
+ ```
329
+
330
+ **Parameters:**
331
+
332
+ | Parameter | Type | Default | Description |
333
+ |-----------|------|---------|-------------|
334
+ | `--model-dir` | str | Default model | Model directory path |
335
+ | `--device` | str | `cuda` | Device to use |
336
+ | `--host` | str | `127.0.0.1` | Host address to bind to |
337
+ | `--port` | int | `8008` | Port number to bind to |
338
+ | `--gallery-dir` | str | Default gallery dir | Gallery directory path (optional) |
339
+
340
+ **Features:**
341
+ - 🎯 Keeps model resident in GPU memory
342
+ - 🔌 Provides REST inference API
343
+ - 📊 Integrated dashboard and status monitoring
344
+ - 🖼️ Optional gallery browser (if `--gallery-dir` is provided)
345
+
346
+ **Available Endpoints:**
347
+ - 🏠 `/` - Home page
348
+ - 📊 `/dashboard` - Dashboard
349
+ - ✅ `/status` - API status
350
+ - 🖼️ `/gallery/` - Gallery browser (if enabled)
351
+
352
+ **Examples:**
353
+
354
+ ```bash
355
+ # 🚀 Basic backend service
356
+ da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE
357
+
358
+ # 🖼️ Backend with gallery
359
+ da3 backend \
360
+ --model-dir depth-anything/DA3NESTED-GIANT-LARGE \
361
+ --device cuda \
362
+ --host 0.0.0.0 \
363
+ --port 8008 \
364
+ --gallery-dir ./workspace
365
+
366
+ # 💻 Use CPU
367
+ da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE --device cpu
368
+ ```
369
+
370
+ ---
371
+
372
+ ### 🎨 gradio - Gradio Application
373
+
374
+ Launch Depth Anything 3 Gradio interactive web application.
375
+
376
+ **Usage:**
377
+
378
+ ```bash
379
+ da3 gradio [OPTIONS]
380
+ ```
381
+
382
+ **Parameters:**
383
+
384
+ | Parameter | Type | Default | Description |
385
+ |-----------|------|---------|-------------|
386
+ | `--model-dir` | str | Required | Model directory path |
387
+ | `--workspace-dir` | str | Required | Workspace directory path |
388
+ | `--gallery-dir` | str | Required | Gallery directory path |
389
+ | `--host` | str | `127.0.0.1` | Host address to bind to |
390
+ | `--port` | int | `7860` | Port number to bind to |
391
+ | `--share` | bool | `False` | Create a public link |
392
+ | `--debug` | bool | `False` | Enable debug mode |
393
+ | `--cache-examples` | bool | `False` | Pre-cache all example scenes at startup |
394
+ | `--cache-gs-tag` | str | `""` | Tag to match scene names for high-res+3DGS caching |
395
+
396
+ **Examples:**
397
+
398
+ ```bash
399
+ # 🎨 Basic Gradio application
400
+ da3 gradio \
401
+ --model-dir depth-anything/DA3NESTED-GIANT-LARGE \
402
+ --workspace-dir ./workspace \
403
+ --gallery-dir ./gallery
404
+
405
+ # 🌐 Enable sharing and debug
406
+ da3 gradio \
407
+ --model-dir depth-anything/DA3NESTED-GIANT-LARGE \
408
+ --workspace-dir ./workspace \
409
+ --gallery-dir ./gallery \
410
+ --share \
411
+ --debug
412
+
413
+ # ⚡ Pre-cache examples
414
+ da3 gradio \
415
+ --model-dir depth-anything/DA3NESTED-GIANT-LARGE \
416
+ --workspace-dir ./workspace \
417
+ --gallery-dir ./gallery \
418
+ --cache-examples \
419
+ --cache-gs-tag "dl3dv"
420
+ ```
421
+
422
+ ---
423
+
424
+ ### 🖼️ gallery - Gallery Server
425
+
426
+ Launch standalone Depth Anything 3 Gallery server.
427
+
428
+ **Usage:**
429
+
430
+ ```bash
431
+ da3 gallery [OPTIONS]
432
+ ```
433
+
434
+ **Parameters:**
435
+
436
+ | Parameter | Type | Default | Description |
437
+ |-----------|------|---------|-------------|
438
+ | `--gallery-dir` | str | Default gallery dir | Gallery root directory |
439
+ | `--host` | str | `127.0.0.1` | Host address to bind to |
440
+ | `--port` | int | `8007` | Port number to bind to |
441
+ | `--open-browser` | bool | `False` | Open browser after launch |
442
+
443
+ **Note:**
444
+ The gallery expects each scene folder to contain at least `scene.glb` and `scene.jpg`, with optional subfolders such as `depth_vis/` or `gs_video/`.
445
+
446
+ **Examples:**
447
+
448
+ ```bash
449
+ # 🖼️ Basic gallery server
450
+ da3 gallery --gallery-dir ./workspace
451
+
452
+ # 🌐 Custom host and port
453
+ da3 gallery \
454
+ --gallery-dir ./workspace \
455
+ --host 0.0.0.0 \
456
+ --port 8007
457
+
458
+ # 🚀 Auto-open browser
459
+ da3 gallery --gallery-dir ./workspace --open-browser
460
+ ```
461
+
462
+ ---
463
+
464
+ ## ⚙️ Parameter Details
465
+
466
+ ### 🔧 Common Parameters
467
+
468
+ - **`--export-dir`**: Output directory, defaults to `debug`
469
+ - **`--export-format`**: Export format, supports combining multiple formats with hyphens:
470
+ - 📦 `mini_npz`: Compressed NumPy format
471
+ - 🎨 `glb`: glTF binary format (3D scene)
472
+ - 🔍 `feat_vis`: Feature visualization
473
+ - Example: `mini_npz-glb` exports both formats
474
+
475
+ - **`--process-res`** / **`--process-res-method`**: Control preprocessing resolution strategy
476
+ - `process-res`: Target resolution (default 504)
477
+ - `process-res-method`: Resize method (default `upper_bound_resize`)
478
+
479
+ - **`--auto-cleanup`**: Remove existing export directory without confirmation
480
+
481
+ - **`--use-backend`** / **`--backend-url`**: Reuse running backend service
482
+ - ⚡ Reduces model loading time
483
+ - 🌐 Supports distributed processing
484
+
485
+ - **`--export-feat`**: Layer indices for exporting intermediate features (comma-separated)
486
+ - Example: `"9,19,29,39"`
487
+
488
+ ### 🎨 GLB Export Parameters
489
+
490
+ - **`--conf-thresh-percentile`**: Lower percentile for adaptive confidence threshold (default 40.0)
491
+ - Used to filter low-confidence points
492
+
493
+ - **`--num-max-points`**: Maximum number of points in point cloud (default 1,000,000)
494
+ - Controls output file size and performance
495
+
496
+ - **`--show-cameras`**: Show camera wireframes in exported scene (default True)
497
+
498
+ ### 🔍 Feature Visualization Parameters
499
+
500
+ - **`--feat-vis-fps`**: Frame rate for feature visualization output video (default 15)
501
+
502
+ ### 🎬 Video-Specific Parameters
503
+
504
+ - **`--fps`**: Video frame extraction sampling rate (default 1.0 FPS)
505
+ - Higher values extract more frames
506
+
507
+ ### 📐 COLMAP-Specific Parameters
508
+
509
+ - **`--sparse-subdir`**: Sparse reconstruction subdirectory
510
+ - Empty string uses `sparse/` directory
511
+ - `"0"` uses `sparse/0/` directory
512
+
513
+ - **`--align-to-input-ext-scale`**: Align prediction to input extrinsics scale (default True)
514
+ - Ensures depth estimation is consistent with COLMAP scale
515
+
516
+ ---
517
+
518
+ ## 💡 Usage Examples
519
+
520
+ ### 1️⃣ Basic Workflow
521
+
522
+ ```bash
523
+ # 🔧 Start backend service
524
+ da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE --host 0.0.0.0 --port 8008
525
+
526
+ # 🖼️ Process single image
527
+ da3 image image.jpg --export-dir ./output1 --use-backend
528
+
529
+ # 🎬 Process video
530
+ da3 video video.mp4 --fps 2.0 --export-dir ./output2 --use-backend
531
+
532
+ # 📐 Process COLMAP dataset
533
+ da3 colmap ./colmap_data --export-dir ./output3 --use-backend
534
+ ```
535
+
536
+ ### 2️⃣ Using Auto Mode
537
+
538
+ ```bash
539
+ # 🤖 Auto-detect and process
540
+ da3 auto ./unknown_input --export-dir ./output
541
+
542
+ # ⚡ With backend acceleration
543
+ da3 auto ./unknown_input \
544
+ --use-backend \
545
+ --backend-url http://localhost:8008 \
546
+ --export-dir ./output
547
+ ```
548
+
549
+ ### 3️⃣ Multi-Format Export
550
+
551
+ ```bash
552
+ # 📦 Export both NPZ and GLB formats
553
+ da3 auto assets/examples/SOH \
554
+ --export-format mini_npz-glb \
555
+ --export-dir ./workspace/soh
556
+
557
+ # 🔍 Export feature visualization
558
+ da3 image image.jpg \
559
+ --export-format feat_vis \
560
+ --export-feat "9,19,29,39" \
561
+ --export-dir ./results
562
+ ```
563
+
564
+ ### 4️⃣ Advanced Configuration
565
+
566
+ ```bash
567
+ # ⚙️ Custom resolution and point cloud density
568
+ da3 image image.jpg \
569
+ --process-res 1024 \
570
+ --num-max-points 2000000 \
571
+ --conf-thresh-percentile 30.0 \
572
+ --export-dir ./output
573
+
574
+ # 📐 COLMAP advanced options
575
+ da3 colmap ./colmap_data \
576
+ --sparse-subdir 0 \
577
+ --align-to-input-ext-scale \
578
+ --process-res 756 \
579
+ --export-dir ./output
580
+ ```
581
+
582
+ ### 5️⃣ Batch Processing Workflow
583
+
584
+ ```bash
585
+ # 🔧 Start backend
586
+ da3 backend \
587
+ --model-dir depth-anything/DA3NESTED-GIANT-LARGE \
588
+ --device cuda \
589
+ --host 0.0.0.0 \
590
+ --port 8008 \
591
+ --gallery-dir ./workspace
592
+
593
+ # 🔄 Batch process multiple scenes
594
+ for scene in scene1 scene2 scene3; do
595
+ da3 auto ./data/$scene \
596
+ --export-dir ./workspace/$scene \
597
+ --use-backend \
598
+ --auto-cleanup
599
+ done
600
+
601
+ # 🖼️ Launch gallery to view results
602
+ da3 gallery --gallery-dir ./workspace --open-browser
603
+ ```
604
+
605
+ ### 6️⃣ Web Applications
606
+
607
+ ```bash
608
+ # 🎨 Launch Gradio application
609
+ da3 gradio \
610
+ --model-dir depth-anything/DA3NESTED-GIANT-LARGE \
611
+ --workspace-dir workspace/gradio \
612
+ --gallery-dir ./gallery \
613
+ --host 0.0.0.0 \
614
+ --port 7860 \
615
+ --share
616
+ ```
617
+
618
+ ### 7️⃣ Transformer Feature Visualization
619
+
620
+ ```bash
621
+ # 🔍 Export Transformer features
622
+ # 📦 Combined with numerical output
623
+ da3 auto video.mp4 \
624
+ --export-format glb-feat_vis \
625
+ --export-feat "11,21,31" \
626
+ --export-dir ./debug \
627
+ --use-backend
628
+ ```
629
+
630
+ ---
631
+
632
+ ## 📝 Notes
633
+
634
+ 1. **🔧 Backend Service**: Recommended for processing multiple tasks to improve efficiency
635
+ 2. **💾 GPU Memory**: Be mindful of GPU memory usage when processing high-resolution inputs
636
+ 3. **📁 Export Directory**: Use `--auto-cleanup` to avoid manual confirmation for deletion
637
+ 4. **🔀 Format Combination**: Multiple export formats can be combined with hyphens (e.g., `mini_npz-glb-feat_vis`)
638
+ 5. **📐 COLMAP Data**: Ensure COLMAP directory structure is correct (contains `images/` and `sparse/` subdirectories)
639
+
640
+ ---
641
+
642
+ ## ❓ Getting Help
643
+
644
+ View detailed help for any command:
645
+
646
+ ```bash
647
+ # 📖 View main help
648
+ da3 --help
649
+
650
+ # 🔍 View specific command help
651
+ da3 auto --help
652
+ da3 image --help
653
+ da3 backend --help
654
+ ```
Depth-Anything-3/docs/funcs/ref_view_strategy.md ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📐 Reference View Selection Strategy
2
+
3
+ ## 📖 Overview
4
+
5
+ Reference view selection is a component in multi-view depth estimation. When processing multiple input views, the model needs to determine which view should serve as the primary reference frame for depth prediction, defining the world coordinate system.
6
+
7
+ Different reference view will leads to different reconstruction results. This is a known consideration in multi-view geometry and was analyzed in [PI3](https://arxiv.org/abs/2507.13347). The choice of reference view can affect the quality and consistency of depth predictions across the scene.
8
+
9
+
10
+ ## 🚀 Our Simple Solution: Automatic Reference View Selection
11
+
12
+ DA3 provides a simple approach to address this through **automatic reference view selection** based on **class tokens**. Instead of relying on heuristics or manual selection, the model analyzes the class token features from all input views and intelligently selects the most suitable reference frame.
13
+
14
+ ---
15
+
16
+ ## 🎨 Available Strategies
17
+
18
+ ### 1. ⚖️ `saddle_balanced` (Recommended, Default)
19
+
20
+ **Philosophy:**
21
+ Select a view that achieves balance across multiple feature metrics. This strategy looks for a "middle ground" view that is neither too similar nor too different from other views, making it a stable reference point.
22
+
23
+ **How it works:**
24
+ 1. Extracts and normalizes class tokens from all views
25
+ 2. Computes three complementary metrics for each view:
26
+ - **Similarity score**: Average cosine similarity with other views
27
+ - **Feature norm**: L2 norm of the original features
28
+ - **Feature variance**: Variance across feature dimensions
29
+ 3. Normalizes each metric to [0, 1] range
30
+ 4. Selects the view closest to 0.5 (median) across all three metrics
31
+
32
+ ### 2. 🎢 `saddle_sim_range`
33
+
34
+ **Philosophy:**
35
+ Select a view with the largest similarity range to other views. This identifies "saddle point" views that are highly similar to some views but dissimilar to others, making them information-rich anchor points.
36
+
37
+ **How it works:**
38
+ 1. Computes pairwise cosine similarity between all views
39
+ 2. For each view, calculates the range (max - min) of similarities to other views
40
+ 3. Selects the view with the maximum similarity range
41
+
42
+ ---
43
+
44
+ ### 3. 1️⃣ `first` (Not Recommended)
45
+
46
+ **Philosophy:**
47
+ Always use the first view in the input sequence as the reference.
48
+
49
+ **How it works:**
50
+ Simply returns index 0.
51
+
52
+ **When to use:**
53
+ - ⛔ **Not recommended** in general
54
+ - 🔧 Only use when you have manually pre-sorted your views and know the first view is optimal
55
+ - 🐛 Debugging or baseline comparisons
56
+
57
+ ---
58
+
59
+ ### 4. ⏸️ `middle`
60
+
61
+ **Philosophy:**
62
+ Select the view in the middle of the input sequence.
63
+
64
+ **How it works:**
65
+ Returns the view at index `S // 2` where S is the number of views.
66
+
67
+ **When to use:**
68
+ - ⏱️ **Only recommended when input images are temporally ordered**
69
+ - 🎬 Video sequences (e.g., **DA3-LONG** setting)
70
+ - 📹 Sequential captures where the middle frame likely has the most stable viewpoint
71
+
72
+ **Specific use case: DA3-LONG** 🎬
73
+ In video-based depth estimation scenarios (like DA3-LONG), where inputs are consecutive frames, `middle` is often the **optimal choice** because that it has maximum overlap with all other frames.
74
+
75
+
76
+ ## 💻 Usage
77
+
78
+ ### 🐍 Python API
79
+
80
+ ```python
81
+ from depth_anything_3 import DepthAnything3
82
+
83
+ model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE")
84
+
85
+ # Use default (saddle_balanced)
86
+ prediction = model.inference(
87
+ images,
88
+ ref_view_strategy="saddle_balanced"
89
+ )
90
+
91
+ # For video sequences, consider using middle
92
+ prediction = model.inference(
93
+ video_frames,
94
+ ref_view_strategy="middle" # Good for temporal sequences
95
+ )
96
+
97
+ # For complex scenes with wide baselines
98
+ prediction = model.inference(
99
+ images,
100
+ ref_view_strategy="saddle_sim_range"
101
+ )
102
+ ```
103
+
104
+ ### 🖥️ Command Line Interface
105
+
106
+ ```bash
107
+ # Default (saddle_balanced)
108
+ da3 auto input/ --export-dir output/
109
+
110
+ # Explicitly specify strategy
111
+ da3 auto input/ --ref-view-strategy saddle_balanced
112
+
113
+ # For video processing
114
+ da3 video input.mp4 --ref-view-strategy middle
115
+
116
+ # For wide-baseline multi-view
117
+ da3 images captures/ --ref-view-strategy saddle_sim_range
118
+ ```
119
+
120
+ ---
121
+
122
+ ### 🎯 When Selection Is Applied
123
+
124
+ Reference view selection is applied when:
125
+ - 3️⃣ Number of views S ≥ 3
126
+
127
+ ---
128
+
129
+ ## 💡 Recommendations
130
+
131
+ ### 📋 Quick Guide
132
+
133
+ | Scenario | Recommended Strategy | Rationale |
134
+ |----------|---------------------|-----------|
135
+ | **Default / Unknown** | `saddle_balanced` | Robust, balanced, works well across diverse scenarios |
136
+ | **Video frames** | `middle` | Temporal coherence, stable middle frame |
137
+ | **Wide-baseline multi-view** | `saddle_sim_range` | Maximizes information coverage |
138
+ | **Pre-sorted inputs** | `first` | Use only if you've manually optimized ordering |
139
+ | **Single image** | `first` | Automatically used (no reordering needed for S ≤ 2) |
140
+
141
+ ### ✨ Best Practices
142
+
143
+ 1. 🎯 **Start with defaults**: `saddle_balanced` works well in most cases
144
+ 2. 🎬 **Consider your input type**: Use `middle` for videos, `saddle_balanced` for photos
145
+ 3. 🔬 **Experiment if needed**: Try different strategies if results are suboptimal
146
+ 4. 📊 **Monitor performance**: Check `glb` quality and consistency across views.
147
+
148
+ ---
149
+
150
+ ## 🔧 Technical Details
151
+
152
+ ### 🎚️ Selection Threshold
153
+
154
+ The reference view selection is only triggered when:
155
+ ```python
156
+ num_views >= 3 # At least 3 views required
157
+ ```
158
+
159
+ For 1-2 views, no reordering is performed (equivalent to using `first`).
160
+
161
+ ### ⚙️ Implementation
162
+
163
+ The selection happens at layer `alt_start - 1` in the vision transformer, before the first global attention layer. This ensures the selected reference view influences the entire depth prediction pipeline.
164
+
165
+ ---
166
+
167
+ ## ❓ FAQ
168
+
169
+ **Q: 🤔 Why is this feature provided?**
170
+ A: The model can handle any view order, but this feature provides automatic optimization for reference view selection, which can help improve depth prediction quality in multi-view scenarios.
171
+
172
+ **Q: ⏱️ Does this add computational cost?**
173
+ A: The overhead is totally negligible.
174
+
175
+ **Q: 🎮 Can I manually specify which view to use as reference?**
176
+ A: Not directly through this parameter. You can pre-sort your input images to place your preferred reference view first and use `ref_view_strategy="first"`.
177
+
178
+ **Q: ⚙️ What happens if I don't specify this parameter?**
179
+ A: The default `saddle_balanced` strategy is used automatically.
180
+
181
+ **Q: 📊 Is this feature used in the DA3 paper benchmarks?**
182
+ A: No, the paper used `first` as the default strategy for all multi-view experiments. The current default has been updated to `saddle_balanced` for better robustness.
183
+
Depth-Anything-3/notebooks/da3.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Depth-Anything-3/src/depth_anything_3/api.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Depth Anything 3 API module.
16
+
17
+ This module provides the main API for Depth Anything 3, including model loading,
18
+ inference, and export capabilities. It supports both single and nested model architectures.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import time
24
+ from typing import Optional, Sequence
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ from huggingface_hub import PyTorchModelHubMixin
29
+ from PIL import Image
30
+
31
+ from depth_anything_3.cfg import create_object, load_config
32
+ from depth_anything_3.registry import MODEL_REGISTRY
33
+ from depth_anything_3.specs import Prediction
34
+ from depth_anything_3.utils.export import export
35
+ from depth_anything_3.utils.geometry import affine_inverse
36
+ from depth_anything_3.utils.io.input_processor import InputProcessor
37
+ from depth_anything_3.utils.io.output_processor import OutputProcessor
38
+ from depth_anything_3.utils.logger import logger
39
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
40
+
41
+ torch.backends.cudnn.benchmark = False
42
+ # logger.info("CUDNN Benchmark Disabled")
43
+
44
+ SAFETENSORS_NAME = "model.safetensors"
45
+ CONFIG_NAME = "config.json"
46
+
47
+
48
+ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
49
+ """
50
+ Depth Anything 3 main API class.
51
+
52
+ This class provides a high-level interface for depth estimation using Depth Anything 3.
53
+ It supports both single and nested model architectures with metric scaling capabilities.
54
+
55
+ Features:
56
+ - Hugging Face Hub integration via PyTorchModelHubMixin
57
+ - Support for multiple model presets (vitb, vitg, nested variants)
58
+ - Automatic mixed precision inference
59
+ - Export capabilities for various formats (GLB, PLY, NPZ, etc.)
60
+ - Camera pose estimation and metric depth scaling
61
+
62
+ Usage:
63
+ # Load from Hugging Face Hub
64
+ model = DepthAnything3.from_pretrained("huggingface/model-name")
65
+
66
+ # Or create with specific preset
67
+ model = DepthAnything3(preset="vitg")
68
+
69
+ # Run inference
70
+ prediction = model.inference(images, export_dir="output", export_format="glb")
71
+ """
72
+
73
+ _commit_hash: str | None = None # Set by mixin when loading from Hub
74
+
75
+ def __init__(self, model_name: str = "da3-large", **kwargs):
76
+ """
77
+ Initialize DepthAnything3 with specified preset.
78
+
79
+ Args:
80
+ model_name: The name of the model preset to use.
81
+ Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
82
+ **kwargs: Additional keyword arguments (currently unused).
83
+ """
84
+ super().__init__()
85
+ self.model_name = model_name
86
+
87
+ # Build the underlying network
88
+ self.config = load_config(MODEL_REGISTRY[self.model_name])
89
+ self.model = create_object(self.config)
90
+ self.model.eval()
91
+
92
+ # Initialize processors
93
+ self.input_processor = InputProcessor()
94
+ self.output_processor = OutputProcessor()
95
+
96
+ # Device management (set by user)
97
+ self.device = None
98
+
99
+ @torch.inference_mode()
100
+ def forward(
101
+ self,
102
+ image: torch.Tensor,
103
+ extrinsics: torch.Tensor | None = None,
104
+ intrinsics: torch.Tensor | None = None,
105
+ export_feat_layers: list[int] | None = None,
106
+ infer_gs: bool = False,
107
+ use_ray_pose: bool = False,
108
+ ref_view_strategy: str = "saddle_balanced",
109
+ ) -> dict[str, torch.Tensor]:
110
+ """
111
+ Forward pass through the model.
112
+
113
+ Args:
114
+ image: Input batch with shape ``(B, N, 3, H, W)`` on the model device.
115
+ extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``.
116
+ intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``.
117
+ export_feat_layers: Layer indices to return intermediate features for.
118
+ infer_gs: Enable Gaussian Splatting branch.
119
+ use_ray_pose: Use ray-based pose estimation instead of camera decoder.
120
+ ref_view_strategy: Strategy for selecting reference view from multiple views.
121
+
122
+ Returns:
123
+ Dictionary containing model predictions
124
+ """
125
+ # Determine optimal autocast dtype
126
+ autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
127
+ with torch.no_grad():
128
+ with torch.autocast(device_type=image.device.type, dtype=autocast_dtype):
129
+ return self.model(
130
+ image, extrinsics, intrinsics, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy
131
+ )
132
+
133
+ def inference(
134
+ self,
135
+ image: list[np.ndarray | Image.Image | str],
136
+ extrinsics: np.ndarray | None = None,
137
+ intrinsics: np.ndarray | None = None,
138
+ align_to_input_ext_scale: bool = True,
139
+ infer_gs: bool = False,
140
+ use_ray_pose: bool = False,
141
+ ref_view_strategy: str = "saddle_balanced",
142
+ render_exts: np.ndarray | None = None,
143
+ render_ixts: np.ndarray | None = None,
144
+ render_hw: tuple[int, int] | None = None,
145
+ process_res: int = 504,
146
+ process_res_method: str = "upper_bound_resize",
147
+ export_dir: str | None = None,
148
+ export_format: str = "mini_npz",
149
+ export_feat_layers: Sequence[int] | None = None,
150
+ # GLB export parameters
151
+ conf_thresh_percentile: float = 40.0,
152
+ num_max_points: int = 1_000_000,
153
+ show_cameras: bool = True,
154
+ # Feat_vis export parameters
155
+ feat_vis_fps: int = 15,
156
+ # Other export parameters, e.g., gs_ply, gs_video
157
+ export_kwargs: Optional[dict] = {},
158
+ ) -> Prediction:
159
+ """
160
+ Run inference on input images.
161
+
162
+ Args:
163
+ image: List of input images (numpy arrays, PIL Images, or file paths)
164
+ extrinsics: Camera extrinsics (N, 4, 4)
165
+ intrinsics: Camera intrinsics (N, 3, 3)
166
+ align_to_input_ext_scale: whether to align the input pose scale to the prediction
167
+ infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports)
168
+ use_ray_pose: Use ray-based pose estimation instead of camera decoder (default: False)
169
+ ref_view_strategy: Strategy for selecting reference view from multiple views.
170
+ Options: "first", "middle", "saddle_balanced", "saddle_sim_range".
171
+ Default: "saddle_balanced". For single view input (S ≤ 2), no reordering is performed.
172
+ render_exts: Optional render extrinsics for Gaussian video export
173
+ render_ixts: Optional render intrinsics for Gaussian video export
174
+ render_hw: Optional render resolution for Gaussian video export
175
+ process_res: Processing resolution
176
+ process_res_method: Resize method for processing
177
+ export_dir: Directory to export results
178
+ export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video)
179
+ export_feat_layers: Layer indices to export intermediate features from
180
+ conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501
181
+ num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000)
182
+ show_cameras: [GLB] Show camera wireframes in the exported scene (default: True)
183
+ feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15)
184
+ export_kwargs: additional arguments to export functions.
185
+
186
+ Returns:
187
+ Prediction object containing depth maps and camera parameters
188
+ """
189
+ if "gs" in export_format:
190
+ assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
191
+
192
+ if "colmap" in export_format:
193
+ assert isinstance(image[0], str), "`image` must be image paths for COLMAP export."
194
+
195
+ # Preprocess images
196
+ imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
197
+ image, extrinsics, intrinsics, process_res, process_res_method
198
+ )
199
+
200
+ # Prepare tensors for model
201
+ imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
202
+
203
+ # Normalize extrinsics
204
+ ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
205
+
206
+ # Run model forward pass
207
+ export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
208
+
209
+ raw_output = self._run_model_forward(
210
+ imgs, ex_t_norm, in_t, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy
211
+ )
212
+
213
+ # Convert raw output to prediction
214
+ prediction = self._convert_to_prediction(raw_output)
215
+
216
+ # Align prediction to extrinsincs
217
+ prediction = self._align_to_input_extrinsics_intrinsics(
218
+ extrinsics, intrinsics, prediction, align_to_input_ext_scale
219
+ )
220
+
221
+ # Add processed images for visualization
222
+ prediction = self._add_processed_images(prediction, imgs_cpu)
223
+
224
+ # Export if requested
225
+ if export_dir is not None:
226
+
227
+ if "gs" in export_format:
228
+ if infer_gs and "gs_video" not in export_format:
229
+ export_format = f"{export_format}-gs_video"
230
+ if "gs_video" in export_format:
231
+ if "gs_video" not in export_kwargs:
232
+ export_kwargs["gs_video"] = {}
233
+ export_kwargs["gs_video"].update(
234
+ {
235
+ "extrinsics": render_exts,
236
+ "intrinsics": render_ixts,
237
+ "out_image_hw": render_hw,
238
+ }
239
+ )
240
+ # Add GLB export parameters
241
+ if "glb" in export_format:
242
+ if "glb" not in export_kwargs:
243
+ export_kwargs["glb"] = {}
244
+ export_kwargs["glb"].update(
245
+ {
246
+ "conf_thresh_percentile": conf_thresh_percentile,
247
+ "num_max_points": num_max_points,
248
+ "show_cameras": show_cameras,
249
+ }
250
+ )
251
+ # Add Feat_vis export parameters
252
+ if "feat_vis" in export_format:
253
+ if "feat_vis" not in export_kwargs:
254
+ export_kwargs["feat_vis"] = {}
255
+ export_kwargs["feat_vis"].update(
256
+ {
257
+ "fps": feat_vis_fps,
258
+ }
259
+ )
260
+ # Add COLMAP export parameters
261
+ if "colmap" in export_format:
262
+ if "colmap" not in export_kwargs:
263
+ export_kwargs["colmap"] = {}
264
+ export_kwargs["colmap"].update(
265
+ {
266
+ "image_paths": image,
267
+ "conf_thresh_percentile": conf_thresh_percentile,
268
+ "process_res_method": process_res_method,
269
+ }
270
+ )
271
+ self._export_results(prediction, export_format, export_dir, **export_kwargs)
272
+
273
+ return prediction
274
+
275
+ def _preprocess_inputs(
276
+ self,
277
+ image: list[np.ndarray | Image.Image | str],
278
+ extrinsics: np.ndarray | None = None,
279
+ intrinsics: np.ndarray | None = None,
280
+ process_res: int = 504,
281
+ process_res_method: str = "upper_bound_resize",
282
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
283
+ """Preprocess input images using input processor."""
284
+ start_time = time.time()
285
+ imgs_cpu, extrinsics, intrinsics = self.input_processor(
286
+ image,
287
+ extrinsics.copy() if extrinsics is not None else None,
288
+ intrinsics.copy() if intrinsics is not None else None,
289
+ process_res,
290
+ process_res_method,
291
+ )
292
+ end_time = time.time()
293
+ logger.info(
294
+ "Processed Images Done taking",
295
+ end_time - start_time,
296
+ "seconds. Shape: ",
297
+ imgs_cpu.shape,
298
+ )
299
+ return imgs_cpu, extrinsics, intrinsics
300
+
301
+ def _prepare_model_inputs(
302
+ self,
303
+ imgs_cpu: torch.Tensor,
304
+ extrinsics: torch.Tensor | None,
305
+ intrinsics: torch.Tensor | None,
306
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
307
+ """Prepare tensors for model input."""
308
+ device = self._get_model_device()
309
+
310
+ # Move images to model device
311
+ imgs = imgs_cpu.to(device, non_blocking=True)[None].float()
312
+
313
+ # Convert camera parameters to tensors
314
+ ex_t = (
315
+ extrinsics.to(device, non_blocking=True)[None].float()
316
+ if extrinsics is not None
317
+ else None
318
+ )
319
+ in_t = (
320
+ intrinsics.to(device, non_blocking=True)[None].float()
321
+ if intrinsics is not None
322
+ else None
323
+ )
324
+
325
+ return imgs, ex_t, in_t
326
+
327
+ def _normalize_extrinsics(self, ex_t: torch.Tensor | None) -> torch.Tensor | None:
328
+ """Normalize extrinsics"""
329
+ if ex_t is None:
330
+ return None
331
+ transform = affine_inverse(ex_t[:, :1])
332
+ ex_t_norm = ex_t @ transform
333
+ c2ws = affine_inverse(ex_t_norm)
334
+ translations = c2ws[..., :3, 3]
335
+ dists = translations.norm(dim=-1)
336
+ median_dist = torch.median(dists)
337
+ median_dist = torch.clamp(median_dist, min=1e-1)
338
+ ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
339
+ return ex_t_norm
340
+
341
+ def _align_to_input_extrinsics_intrinsics(
342
+ self,
343
+ extrinsics: torch.Tensor | None,
344
+ intrinsics: torch.Tensor | None,
345
+ prediction: Prediction,
346
+ align_to_input_ext_scale: bool = True,
347
+ ransac_view_thresh: int = 10,
348
+ ) -> Prediction:
349
+ """Align depth map to input extrinsics"""
350
+ if extrinsics is None:
351
+ return prediction
352
+ prediction.intrinsics = intrinsics.numpy()
353
+ _, _, scale, aligned_extrinsics = align_poses_umeyama(
354
+ prediction.extrinsics,
355
+ extrinsics.numpy(),
356
+ ransac=len(extrinsics) >= ransac_view_thresh,
357
+ return_aligned=True,
358
+ random_state=42,
359
+ )
360
+ if align_to_input_ext_scale:
361
+ prediction.extrinsics = extrinsics[..., :3, :].numpy()
362
+ prediction.depth /= scale
363
+ else:
364
+ prediction.extrinsics = aligned_extrinsics
365
+ return prediction
366
+
367
+ def _run_model_forward(
368
+ self,
369
+ imgs: torch.Tensor,
370
+ ex_t: torch.Tensor | None,
371
+ in_t: torch.Tensor | None,
372
+ export_feat_layers: Sequence[int] | None = None,
373
+ infer_gs: bool = False,
374
+ use_ray_pose: bool = False,
375
+ ref_view_strategy: str = "saddle_balanced",
376
+ ) -> dict[str, torch.Tensor]:
377
+ """Run model forward pass."""
378
+ device = imgs.device
379
+ need_sync = device.type == "cuda"
380
+ if need_sync:
381
+ torch.cuda.synchronize(device)
382
+ start_time = time.time()
383
+ feat_layers = list(export_feat_layers) if export_feat_layers is not None else None
384
+ output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs, use_ray_pose, ref_view_strategy)
385
+ if need_sync:
386
+ torch.cuda.synchronize(device)
387
+ end_time = time.time()
388
+ logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds")
389
+ return output
390
+
391
+ def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction:
392
+ """Convert raw model output to Prediction object."""
393
+ start_time = time.time()
394
+ output = self.output_processor(raw_output)
395
+ end_time = time.time()
396
+ logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds")
397
+ return output
398
+
399
+ def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction:
400
+ """Add processed images to prediction for visualization."""
401
+ # Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize
402
+ processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3)
403
+
404
+ # Denormalize from ImageNet normalization
405
+ mean = np.array([0.485, 0.456, 0.406])
406
+ std = np.array([0.229, 0.224, 0.225])
407
+ processed_imgs = processed_imgs * std + mean
408
+ processed_imgs = np.clip(processed_imgs, 0, 1)
409
+ processed_imgs = (processed_imgs * 255).astype(np.uint8)
410
+
411
+ prediction.processed_images = processed_imgs
412
+ return prediction
413
+
414
+ def _export_results(
415
+ self, prediction: Prediction, export_format: str, export_dir: str, **kwargs
416
+ ) -> None:
417
+ """Export results to specified format and directory."""
418
+ start_time = time.time()
419
+ export(prediction, export_format, export_dir, **kwargs)
420
+ end_time = time.time()
421
+ logger.info(f"Export Results Done. Time: {end_time - start_time} seconds")
422
+
423
+ def _get_model_device(self) -> torch.device:
424
+ """
425
+ Get the device where the model is located.
426
+
427
+ Returns:
428
+ Device where the model parameters are located
429
+
430
+ Raises:
431
+ ValueError: If no tensors are found in the model
432
+ """
433
+ if self.device is not None:
434
+ return self.device
435
+
436
+ # Find device from parameters
437
+ for param in self.parameters():
438
+ self.device = param.device
439
+ return param.device
440
+
441
+ # Find device from buffers
442
+ for buffer in self.buffers():
443
+ self.device = buffer.device
444
+ return buffer.device
445
+
446
+ raise ValueError("No tensor found in model")
Depth-Anything-3/src/depth_anything_3/app/css_and_html.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+
3
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ CSS and HTML content for the Depth Anything 3 Gradio application.
19
+ This module contains all the CSS styles and HTML content blocks
20
+ used in the Gradio interface.
21
+ """
22
+
23
+ # CSS Styles for the Gradio interface
24
+ GRADIO_CSS = """
25
+ /* Add Font Awesome CDN with all styles including brands and colors */
26
+ @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css');
27
+
28
+ /* Add custom styles for colored icons */
29
+ .fa-color-blue {
30
+ color: #3b82f6;
31
+ }
32
+
33
+ .fa-color-purple {
34
+ color: #8b5cf6;
35
+ }
36
+
37
+ .fa-color-cyan {
38
+ color: #06b6d4;
39
+ }
40
+
41
+ .fa-color-green {
42
+ color: #10b981;
43
+ }
44
+
45
+ .fa-color-yellow {
46
+ color: #f59e0b;
47
+ }
48
+
49
+ .fa-color-red {
50
+ color: #ef4444;
51
+ }
52
+
53
+ .link-btn {
54
+ display: inline-flex;
55
+ align-items: center;
56
+ gap: 8px;
57
+ text-decoration: none;
58
+ padding: 12px 24px;
59
+ border-radius: 50px;
60
+ font-weight: 500;
61
+ transition: all 0.3s ease;
62
+ }
63
+
64
+ /* Dark mode tech theme */
65
+ @media (prefers-color-scheme: dark) {
66
+ html, body {
67
+ background: #1e293b;
68
+ color: #ffffff;
69
+ }
70
+
71
+ .gradio-container {
72
+ background: #1e293b;
73
+ color: #ffffff;
74
+ }
75
+
76
+ .link-btn {
77
+ background: rgba(255, 255, 255, 0.2);
78
+ color: white;
79
+ backdrop-filter: blur(10px);
80
+ border: 1px solid rgba(255, 255, 255, 0.3);
81
+ }
82
+
83
+ .link-btn:hover {
84
+ background: rgba(255, 255, 255, 0.3);
85
+ transform: translateY(-2px);
86
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
87
+ }
88
+
89
+ .tech-bg {
90
+ background: linear-gradient(135deg, #0f172a, #1e293b); /* Darker colors */
91
+ position: relative;
92
+ overflow: hidden;
93
+ }
94
+
95
+ .tech-bg::before {
96
+ content: '';
97
+ position: absolute;
98
+ top: 0;
99
+ left: 0;
100
+ right: 0;
101
+ bottom: 0;
102
+ background:
103
+ radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
104
+ radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
105
+ radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.1) 0%, transparent 50%); /* Reduced opacity */
106
+ animation: techPulse 8s ease-in-out infinite;
107
+ }
108
+
109
+ .gradio-container .panel,
110
+ .gradio-container .block,
111
+ .gradio-container .form {
112
+ background: rgba(0, 0, 0, 0.3);
113
+ border: 1px solid rgba(59, 130, 246, 0.2);
114
+ border-radius: 10px;
115
+ }
116
+
117
+ .gradio-container * {
118
+ color: #ffffff;
119
+ }
120
+
121
+ .gradio-container label {
122
+ color: #e0e0e0;
123
+ }
124
+
125
+ .gradio-container .markdown {
126
+ color: #e0e0e0;
127
+ }
128
+ }
129
+
130
+ /* Light mode tech theme */
131
+ @media (prefers-color-scheme: light) {
132
+ html, body {
133
+ background: #ffffff;
134
+ color: #1e293b;
135
+ }
136
+
137
+ .gradio-container {
138
+ background: #ffffff;
139
+ color: #1e293b;
140
+ }
141
+
142
+ .tech-bg {
143
+ background: linear-gradient(135deg, #ffffff, #f1f5f9);
144
+ position: relative;
145
+ overflow: hidden;
146
+ }
147
+
148
+ .link-btn {
149
+ background: rgba(59, 130, 246, 0.15);
150
+ color: var(--body-text-color);
151
+ border: 1px solid rgba(59, 130, 246, 0.3);
152
+ }
153
+
154
+ .link-btn:hover {
155
+ background: rgba(59, 130, 246, 0.25);
156
+ transform: translateY(-2px);
157
+ box-shadow: 0 8px 25px rgba(59, 130, 246, 0.2);
158
+ }
159
+
160
+ .tech-bg::before {
161
+ content: '';
162
+ position: absolute;
163
+ top: 0;
164
+ left: 0;
165
+ right: 0;
166
+ bottom: 0;
167
+ background:
168
+ radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.1) 0%, transparent 50%),
169
+ radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.1) 0%, transparent 50%),
170
+ radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.08) 0%, transparent 50%);
171
+ animation: techPulse 8s ease-in-out infinite;
172
+ }
173
+
174
+ .gradio-container .panel,
175
+ .gradio-container .block,
176
+ .gradio-container .form {
177
+ background: rgba(255, 255, 255, 0.8);
178
+ border: 1px solid rgba(59, 130, 246, 0.3);
179
+ border-radius: 10px;
180
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
181
+ }
182
+
183
+ .gradio-container * {
184
+ color: #1e293b;
185
+ }
186
+
187
+ .gradio-container label {
188
+ color: #334155;
189
+ }
190
+
191
+ .gradio-container .markdown {
192
+ color: #334155;
193
+ }
194
+ }
195
+
196
+
197
+
198
+
199
+ @keyframes techPulse {
200
+ 0%, 100% { opacity: 0.5; }
201
+ 50% { opacity: 0.8; }
202
+ }
203
+
204
+ /* Custom log with tech gradient */
205
+ .custom-log * {
206
+ font-style: italic;
207
+ font-size: 22px !important;
208
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
209
+ background-size: 400% 400%;
210
+ -webkit-background-clip: text;
211
+ background-clip: text;
212
+ font-weight: bold !important;
213
+ color: transparent !important;
214
+ text-align: center !important;
215
+ animation: techGradient 3s ease infinite;
216
+ }
217
+
218
+ @keyframes techGradient {
219
+ 0% { background-position: 0% 50%; }
220
+ 50% { background-position: 100% 50%; }
221
+ 100% { background-position: 0% 50%; }
222
+ }
223
+
224
+ @keyframes metricPulse {
225
+ 0%, 100% { background-position: 0% 50%; }
226
+ 50% { background-position: 100% 50%; }
227
+ }
228
+
229
+ @keyframes pointcloudPulse {
230
+ 0%, 100% { background-position: 0% 50%; }
231
+ 50% { background-position: 100% 50%; }
232
+ }
233
+
234
+ @keyframes camerasPulse {
235
+ 0%, 100% { background-position: 0% 50%; }
236
+ 50% { background-position: 100% 50%; }
237
+ }
238
+
239
+ @keyframes gaussiansPulse {
240
+ 0%, 100% { background-position: 0% 50%; }
241
+ 50% { background-position: 100% 50%; }
242
+ }
243
+
244
+ /* Special colors for key terms - Global styles */
245
+ .metric-text {
246
+ background: linear-gradient(45deg, #ff6b6b, #ff8e53, #ff6b6b);
247
+ background-size: 200% 200%;
248
+ -webkit-background-clip: text;
249
+ background-clip: text;
250
+ color: transparent !important;
251
+ animation: metricPulse 2s ease-in-out infinite;
252
+ font-weight: 700;
253
+ text-shadow: 0 0 10px rgba(255, 107, 107, 0.5);
254
+ }
255
+
256
+ .pointcloud-text {
257
+ background: linear-gradient(45deg, #4ecdc4, #44a08d, #4ecdc4);
258
+ background-size: 200% 200%;
259
+ -webkit-background-clip: text;
260
+ background-clip: text;
261
+ color: transparent !important;
262
+ animation: pointcloudPulse 2.5s ease-in-out infinite;
263
+ font-weight: 700;
264
+ text-shadow: 0 0 10px rgba(78, 205, 196, 0.5);
265
+ }
266
+
267
+ .cameras-text {
268
+ background: linear-gradient(45deg, #667eea, #764ba2, #667eea);
269
+ background-size: 200% 200%;
270
+ -webkit-background-clip: text;
271
+ background-clip: text;
272
+ color: transparent !important;
273
+ animation: camerasPulse 3s ease-in-out infinite;
274
+ font-weight: 700;
275
+ text-shadow: 0 0 10px rgba(102, 126, 234, 0.5);
276
+ }
277
+
278
+ .gaussians-text {
279
+ background: linear-gradient(45deg, #f093fb, #f5576c, #f093fb);
280
+ background-size: 200% 200%;
281
+ -webkit-background-clip: text;
282
+ background-clip: text;
283
+ color: transparent !important;
284
+ animation: gaussiansPulse 2.2s ease-in-out infinite;
285
+ font-weight: 700;
286
+ text-shadow: 0 0 10px rgba(240, 147, 251, 0.5);
287
+ }
288
+
289
+ .example-log * {
290
+ font-style: italic;
291
+ font-size: 16px !important;
292
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
293
+ -webkit-background-clip: text;
294
+ background-clip: text;
295
+ color: transparent !important;
296
+ }
297
+
298
+ #my_radio .wrap {
299
+ display: flex;
300
+ flex-wrap: nowrap;
301
+ justify-content: center;
302
+ align-items: center;
303
+ }
304
+
305
+ #my_radio .wrap label {
306
+ display: flex;
307
+ width: 50%;
308
+ justify-content: center;
309
+ align-items: center;
310
+ margin: 0;
311
+ padding: 10px 0;
312
+ box-sizing: border-box;
313
+ }
314
+
315
+ /* Align navigation buttons with dropdown bottom */
316
+ .navigation-row {
317
+ display: flex !important;
318
+ align-items: flex-end !important;
319
+ gap: 8px !important;
320
+ }
321
+
322
+ .navigation-row > div:nth-child(1),
323
+ .navigation-row > div:nth-child(3) {
324
+ align-self: flex-end !important;
325
+ }
326
+
327
+ .navigation-row > div:nth-child(2) {
328
+ flex: 1 !important;
329
+ }
330
+
331
+ /* Make thumbnails clickable with pointer cursor */
332
+ .clickable-thumbnail img {
333
+ cursor: pointer !important;
334
+ }
335
+
336
+ .clickable-thumbnail:hover img {
337
+ cursor: pointer !important;
338
+ opacity: 0.8;
339
+ transition: opacity 0.3s ease;
340
+ }
341
+
342
+ /* Make thumbnail containers narrower horizontally */
343
+ .clickable-thumbnail {
344
+ padding: 5px 2px !important;
345
+ margin: 0 2px !important;
346
+ }
347
+
348
+ .clickable-thumbnail .image-container {
349
+ margin: 0 !important;
350
+ padding: 0 !important;
351
+ }
352
+
353
+ .scene-info {
354
+ text-align: center !important;
355
+ padding: 5px 2px !important;
356
+ margin: 0 !important;
357
+ }
358
+ """
359
+
360
+
361
+ def get_header_html(logo_base64=None):
362
+ """
363
+ Generate the main header HTML with logo and title.
364
+
365
+ Args:
366
+ logo_base64 (str, optional): Base64 encoded logo image
367
+
368
+ Returns:
369
+ str: HTML string for the header
370
+ """
371
+ return """
372
+ <div class="tech-bg" style="text-align: center; margin-bottom: 5px; padding: 40px 20px; border-radius: 15px; position: relative; overflow: hidden;">
373
+ <div style="position: relative; z-index: 2;">
374
+ <h1 style="margin: 0; font-size: 3.5em; font-weight: 700;
375
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
376
+ background-size: 400% 400%;
377
+ -webkit-background-clip: text;
378
+ background-clip: text;
379
+ color: transparent;
380
+ animation: techGradient 3s ease infinite;
381
+ text-shadow: 0 0 30px rgba(59, 130, 246, 0.5);
382
+ letter-spacing: 2px;">
383
+ Depth Anything 3
384
+ </h1>
385
+ <p style="margin: 15px 0 0 0; font-size: 2.16em; font-weight: 300;" class="header-subtitle">
386
+ Recovering the Visual Space from Any Views
387
+ </p>
388
+ <div style="margin-top: 20px;">
389
+ <!-- Revert buttons to original inline styles -->
390
+ <a href="https://depth-anything-3.github.io" target="_blank" class="link-btn">
391
+ <i class="fas fa-globe" style="margin-right: 8px;"></i> Project Page
392
+ </a>
393
+ <a href="https://arxiv.org/abs/2406.09414" target="_blank" class="link-btn">
394
+ <i class="fas fa-file-pdf" style="margin-right: 8px;"></i> Paper
395
+ </a>
396
+ <a href="https://github.com/ByteDance-Seed/Depth-Anything-3" target="_blank" class="link-btn">
397
+ <i class="fab fa-github" style="margin-right: 8px;"></i> Code
398
+ </a>
399
+ </div>
400
+ </div>
401
+ </div>
402
+
403
+ <style>
404
+ /* Ensure tech-bg class is properly applied in dark mode */
405
+ @media (prefers-color-scheme: dark) {
406
+ .header-subtitle {
407
+ color: #cbd5e1;
408
+ }
409
+ /* Increase priority to ensure background color is properly applied */
410
+ .tech-bg {
411
+ background: linear-gradient(135deg, #0f172a, #1e293b) !important;
412
+ }
413
+ }
414
+
415
+ @media (prefers-color-scheme: light) {
416
+ .header-subtitle {
417
+ color: #475569;
418
+ }
419
+ /* Also add explicit background color for light mode */
420
+ .tech-bg {
421
+ background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%) !important;
422
+ }
423
+ }
424
+ </style>
425
+ """
426
+
427
+
428
+ def get_description_html():
429
+ """
430
+ Generate the main description and getting started HTML.
431
+
432
+ Returns:
433
+ str: HTML string for the description
434
+ """
435
+ return """
436
+ <div class="description-container" style="padding: 25px; border-radius: 15px; margin: 0 0 20px 0;">
437
+ <h2 class="description-title" style="margin-top: 0; font-size: 1.6em; text-align: center;">
438
+ <i class="fas fa-bullseye fa-color-red" style="margin-right: 8px;"></i> What This Demo Does
439
+ </h2>
440
+ <div class="description-content" style="padding: 20px; border-radius: 10px; margin: 15px 0; text-align: center;">
441
+ <p class="description-main" style="line-height: 1.6; margin: 0; font-size: 1.45em;">
442
+ <strong>Upload images or videos</strong> → <strong>Get <span class="metric-text">Metric</span> <span class="pointcloud-text">Point Clouds</span>, <span class="cameras-text">Cameras</span> and <span class="gaussians-text">Novel Views</span></strong> → <strong>Explore in 3D</strong>
443
+ </p>
444
+ </div>
445
+
446
+ <div style="text-align: center; margin-top: 15px;">
447
+ <p class="description-tip" style="font-style: italic; margin: 0;">
448
+ <i class="fas fa-lightbulb fa-color-yellow" style="margin-right: 8px;"></i> <strong>Tip:</strong> Landscape-oriented images or videos are preferred for best 3D recovering.
449
+ </p>
450
+ </div>
451
+ </div>
452
+
453
+ <style>
454
+ @media (prefers-color-scheme: dark) {
455
+ .description-container {
456
+ background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
457
+ border: 1px solid rgba(59, 130, 246, 0.2);
458
+ }
459
+ .description-title { color: #3b82f6; }
460
+ .description-content { background: rgba(0, 0, 0, 0.3); }
461
+ .description-main { color: #e0e0e0; }
462
+ .description-text { color: #cbd5e1; }
463
+ .description-tip { color: #cbd5e1; }
464
+ }
465
+
466
+ @media (prefers-color-scheme: light) {
467
+ .description-container {
468
+ background: linear-gradient(135deg, rgba(59, 130, 246, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%);
469
+ border: 1px solid rgba(59, 130, 246, 0.3);
470
+ }
471
+ .description-title { color: #3b82f6; }
472
+ .description-content { background: transparent; }
473
+ .description-main { color: #1e293b; }
474
+ .description-text { color: #475569; }
475
+ .description-tip { color: #475569; }
476
+ }
477
+ </style>
478
+ """
479
+
480
+
481
+ def get_acknowledgements_html():
482
+ """
483
+ Generate the acknowledgements section HTML.
484
+
485
+ Returns:
486
+ str: HTML string for the acknowledgements
487
+ """
488
+ return """
489
+ <div style="background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
490
+ padding: 25px; border-radius: 15px; margin: 20px 0; border: 1px solid rgba(59, 130, 246, 0.2);">
491
+ <h3 style="color: #3b82f6; margin-top: 0; text-align: center; font-size: 1.4em;">
492
+ <i class="fas fa-trophy fa-color-yellow" style="margin-right: 8px;"></i> Research Credits & Acknowledgments
493
+ </h3>
494
+
495
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 15px 0;">
496
+ <!-- Original Research Section (Left) -->
497
+ <div style="text-align: center;">
498
+ <h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-flask fa-color-green" style="margin-right: 8px;"></i> Original Research</h4>
499
+ <p style="color: #e0e0e0; margin: 5px 0;">
500
+ <a href="https://depth-anything-3.github.io" target="_blank"
501
+ style="color: #3b82f6; text-decoration: none; font-weight: 600;">
502
+ Depth Anything 3
503
+ </a>
504
+ </p>
505
+ </div>
506
+
507
+ <!-- Previous Versions Section (Right) -->
508
+ <div style="text-align: center;">
509
+ <h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-history fa-color-blue" style="margin-right: 8px;"></i> Previous Versions</h4>
510
+ <div style="display: flex; flex-direction: row; gap: 15px; justify-content: center; align-items: center;">
511
+ <p style="color: #e0e0e0; margin: 0;">
512
+ <a href="https://huggingface.co/spaces/LiheYoung/Depth-Anything" target="_blank"
513
+ style="color: #3b82f6; text-decoration: none; font-weight: 600;">
514
+ Depth-Anything
515
+ </a>
516
+ </p>
517
+ <span style="color: #e0e0e0;">•</span>
518
+ <p style="color: #e0e0e0; margin: 0;">
519
+ <a href="https://huggingface.co/spaces/depth-anything/Depth-Anything-V2" target="_blank"
520
+ style="color: #3b82f6; text-decoration: none; font-weight: 600;">
521
+ Depth-Anything-V2
522
+ </a>
523
+ </p>
524
+ </div>
525
+ </div>
526
+ </div>
527
+
528
+ <!-- HF Demo Adapted from - Centered at the bottom of the whole block -->
529
+ <div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid rgba(59, 130, 246, 0.3); text-align: center;">
530
+ <p style="color: #a0a0a0; font-size: 0.9em; margin: 0;">
531
+ <i class="fas fa-code-branch fa-color-gray" style="margin-right: 5px;"></i> HF demo adapted from <a href="https://huggingface.co/spaces/facebook/map-anything" target="_blank" style="color: inherit; text-decoration: none;">Map Anything</a>
532
+ </p>
533
+ </div>
534
+ </div>
535
+ """
536
+
537
+
538
+ def get_gradio_theme():
539
+ """
540
+ Get the configured Gradio theme with adaptive tech colors.
541
+
542
+ Returns:
543
+ gr.themes.Base: Configured Gradio theme
544
+ """
545
+ import gradio as gr
546
+
547
+ return gr.themes.Base(
548
+ primary_hue=gr.themes.Color(
549
+ c50="#eff6ff",
550
+ c100="#dbeafe",
551
+ c200="#bfdbfe",
552
+ c300="#93c5fd",
553
+ c400="#60a5fa",
554
+ c500="#3b82f6",
555
+ c600="#2563eb",
556
+ c700="#1d4ed8",
557
+ c800="#1e40af",
558
+ c900="#1e3a8a",
559
+ c950="#172554",
560
+ ),
561
+ secondary_hue=gr.themes.Color(
562
+ c50="#f5f3ff",
563
+ c100="#ede9fe",
564
+ c200="#ddd6fe",
565
+ c300="#c4b5fd",
566
+ c400="#a78bfa",
567
+ c500="#8b5cf6",
568
+ c600="#7c3aed",
569
+ c700="#6d28d9",
570
+ c800="#5b21b6",
571
+ c900="#4c1d95",
572
+ c950="#2e1065",
573
+ ),
574
+ neutral_hue=gr.themes.Color(
575
+ c50="#f8fafc",
576
+ c100="#f1f5f9",
577
+ c200="#e2e8f0",
578
+ c300="#cbd5e1",
579
+ c400="#94a3b8",
580
+ c500="#64748b",
581
+ c600="#475569",
582
+ c700="#334155",
583
+ c800="#1e293b",
584
+ c900="#0f172a",
585
+ c950="#020617",
586
+ ),
587
+ )
588
+
589
+
590
+ # Measure tab instructions HTML
591
+ MEASURE_INSTRUCTIONS_HTML = """
592
+ ### Click points on the image to compute distance.
593
+ > <i class="fas fa-triangle-exclamation fa-color-red" style="margin-right: 5px;"></i> Metric scale estimation is difficult on aerial/drone images.
594
+ """
Depth-Anything-3/src/depth_anything_3/app/gradio_app.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Refactored Gradio App for Depth Anything 3.
17
+
18
+ This is the main application file that orchestrates all components.
19
+ The original functionality has been split into modular components for better maintainability.
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+ from typing import Any, Dict, List
25
+ import gradio as gr
26
+
27
+ from depth_anything_3.app.css_and_html import GRADIO_CSS, get_gradio_theme
28
+ from depth_anything_3.app.modules.event_handlers import EventHandlers
29
+ from depth_anything_3.app.modules.ui_components import UIComponents
30
+
31
+ # Set environment variables
32
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
33
+
34
+
35
+ class DepthAnything3App:
36
+ """
37
+ Main application class for Depth Anything 3 Gradio app.
38
+ """
39
+
40
+ def __init__(self, model_dir: str = None, workspace_dir: str = None, gallery_dir: str = None):
41
+ """
42
+ Initialize the application.
43
+
44
+ Args:
45
+ model_dir: Path to the model directory
46
+ workspace_dir: Path to the workspace directory
47
+ gallery_dir: Path to the gallery directory
48
+ """
49
+ self.model_dir = model_dir
50
+ self.workspace_dir = workspace_dir
51
+ self.gallery_dir = gallery_dir
52
+
53
+ # Set environment variables for directories
54
+ if self.model_dir:
55
+ os.environ["DA3_MODEL_DIR"] = self.model_dir
56
+ if self.workspace_dir:
57
+ os.environ["DA3_WORKSPACE_DIR"] = self.workspace_dir
58
+ if self.gallery_dir:
59
+ os.environ["DA3_GALLERY_DIR"] = self.gallery_dir
60
+
61
+ self.event_handlers = EventHandlers()
62
+ self.ui_components = UIComponents()
63
+
64
+ def cache_examples(
65
+ self,
66
+ show_cam: bool = True,
67
+ filter_black_bg: bool = False,
68
+ filter_white_bg: bool = False,
69
+ save_percentage: float = 20.0,
70
+ num_max_points: int = 1000,
71
+ cache_gs_tag: str = "",
72
+ gs_trj_mode: str = "smooth",
73
+ gs_video_quality: str = "low",
74
+ ) -> None:
75
+ """
76
+ Pre-cache all example scenes at startup.
77
+
78
+ Args:
79
+ show_cam: Whether to show camera in visualization
80
+ filter_black_bg: Whether to filter black background
81
+ filter_white_bg: Whether to filter white background
82
+ save_percentage: Filter percentage for point cloud
83
+ num_max_points: Maximum number of points
84
+ cache_gs_tag: Tag to match scene names for high-res+3DGS caching (e.g., "dl3dv")
85
+ gs_trj_mode: Trajectory mode for 3DGS
86
+ gs_video_quality: Video quality for 3DGS
87
+ """
88
+ from depth_anything_3.app.modules.utils import get_scene_info
89
+
90
+ examples_dir = os.path.join(self.workspace_dir, "examples")
91
+ if not os.path.exists(examples_dir):
92
+ print(f"Examples directory not found: {examples_dir}")
93
+ return
94
+
95
+ scenes = get_scene_info(examples_dir)
96
+ if not scenes:
97
+ print("No example scenes found to cache.")
98
+ return
99
+
100
+ print(f"\n{'='*60}")
101
+ print(f"Caching {len(scenes)} example scenes...")
102
+ print(f"{'='*60}\n")
103
+
104
+ for i, scene in enumerate(scenes, 1):
105
+ scene_name = scene["name"]
106
+
107
+ # Check if scene name matches the gs tag for high-res+3DGS caching
108
+ use_high_res_gs = cache_gs_tag and cache_gs_tag.lower() in scene_name.lower()
109
+
110
+ if use_high_res_gs:
111
+ print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (HIGH-RES + 3DGS)")
112
+ print(f" - Number of images: {scene['num_images']}")
113
+ print(f" - Matched tag: '{cache_gs_tag}' - using high_res + 3DGS")
114
+ else:
115
+ print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (LOW-RES)")
116
+ print(f" - Number of images: {scene['num_images']}")
117
+
118
+ try:
119
+ # Load example scene
120
+ _, target_dir, _, _, _, _, _, _, _ = self.event_handlers.load_example_scene(
121
+ scene_name
122
+ )
123
+
124
+ if target_dir and target_dir != "None":
125
+ # Run reconstruction with appropriate settings
126
+ print(" - Running reconstruction...")
127
+ result = self.event_handlers.gradio_demo(
128
+ target_dir=target_dir,
129
+ show_cam=show_cam,
130
+ filter_black_bg=filter_black_bg,
131
+ filter_white_bg=filter_white_bg,
132
+ process_res_method="high_res" if use_high_res_gs else "low_res",
133
+ save_percentage=save_percentage,
134
+ num_max_points=num_max_points,
135
+ infer_gs=use_high_res_gs,
136
+ ref_view_strategy="saddle_balanced",
137
+ gs_trj_mode=gs_trj_mode,
138
+ gs_video_quality=gs_video_quality,
139
+ )
140
+
141
+ # Check if successful
142
+ if result[0] is not None: # reconstruction_output
143
+ print(f" ✓ Scene '{scene_name}' cached successfully")
144
+ else:
145
+ print(f" ✗ Scene '{scene_name}' caching failed: {result[1]}")
146
+ else:
147
+ print(f" ✗ Scene '{scene_name}' loading failed")
148
+
149
+ except Exception as e:
150
+ print(f" ✗ Error caching scene '{scene_name}': {str(e)}")
151
+
152
+ print()
153
+
154
+ print("=" * 60)
155
+ print("Example scene caching completed!")
156
+ print("=" * 60 + "\n")
157
+
158
+ def create_app(self) -> gr.Blocks:
159
+ """
160
+ Create and configure the Gradio application.
161
+
162
+ Returns:
163
+ Configured Gradio Blocks interface
164
+ """
165
+
166
+ # Initialize theme
167
+ def get_theme():
168
+ return get_gradio_theme()
169
+
170
+ with gr.Blocks(theme=get_theme(), css=GRADIO_CSS) as demo:
171
+ # State variables for the tabbed interface
172
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
173
+ processed_data_state = gr.State(value=None)
174
+ measure_points_state = gr.State(value=[])
175
+ selected_image_index_state = gr.State(value=0) # Track selected image index
176
+ # current_view_index = gr.State(value=0) # noqa: F841 Track current view index
177
+
178
+ # Header and description
179
+ self.ui_components.create_header_section()
180
+ self.ui_components.create_description_section()
181
+
182
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
183
+
184
+ # Main content area
185
+ with gr.Row():
186
+ with gr.Column(scale=2):
187
+ # Upload section
188
+ (
189
+ input_video,
190
+ s_time_interval,
191
+ input_images,
192
+ image_gallery,
193
+ ) = self.ui_components.create_upload_section()
194
+
195
+ with gr.Column(scale=4):
196
+ with gr.Column():
197
+ # gr.Markdown("**Metric 3D Reconstruction (Point Cloud and Camera Poses)**")
198
+ # Reconstruction control section (buttons) - moved below tabs
199
+
200
+ log_output = gr.Markdown(
201
+ "Please upload a video or images, then click Reconstruct.",
202
+ elem_classes=["custom-log"],
203
+ )
204
+
205
+ # Tabbed interface
206
+ with gr.Tabs():
207
+ with gr.Tab("Point Cloud & Cameras"):
208
+ reconstruction_output = (
209
+ self.ui_components.create_3d_viewer_section()
210
+ )
211
+
212
+ with gr.Tab("Metric Depth"):
213
+ (
214
+ prev_measure_btn,
215
+ measure_view_selector,
216
+ next_measure_btn,
217
+ measure_image,
218
+ measure_depth_image,
219
+ measure_text,
220
+ ) = self.ui_components.create_measure_section()
221
+
222
+ with gr.Tab("3DGS Rendered Novel Views"):
223
+ gs_video, gs_info = self.ui_components.create_nvs_video()
224
+
225
+ # Inference control section (before inference)
226
+ (process_res_method_dropdown, infer_gs, ref_view_strategy_dropdown) = (
227
+ self.ui_components.create_inference_control_section()
228
+ )
229
+
230
+ # Display control section - includes 3DGS options, buttons, and Visualization Options # noqa: E501
231
+ (
232
+ show_cam,
233
+ filter_black_bg,
234
+ filter_white_bg,
235
+ save_percentage,
236
+ num_max_points,
237
+ gs_trj_mode,
238
+ gs_video_quality,
239
+ submit_btn,
240
+ clear_btn,
241
+ ) = self.ui_components.create_display_control_section()
242
+
243
+ # bind visibility of gs_trj_mode to infer_gs
244
+ infer_gs.change(
245
+ fn=lambda checked: (
246
+ gr.update(visible=checked),
247
+ gr.update(visible=checked),
248
+ gr.update(visible=checked),
249
+ gr.update(visible=(not checked)),
250
+ ),
251
+ inputs=infer_gs,
252
+ outputs=[gs_trj_mode, gs_video_quality, gs_video, gs_info],
253
+ )
254
+
255
+ # Example scenes section
256
+ gr.Markdown("## Example Scenes")
257
+
258
+ scenes = self.ui_components.create_example_scenes_section()
259
+ scene_components = self.ui_components.create_example_scene_grid(scenes)
260
+
261
+ # Set up event handlers
262
+ self._setup_event_handlers(
263
+ demo,
264
+ is_example,
265
+ processed_data_state,
266
+ measure_points_state,
267
+ target_dir_output,
268
+ input_video,
269
+ input_images,
270
+ s_time_interval,
271
+ image_gallery,
272
+ reconstruction_output,
273
+ log_output,
274
+ show_cam,
275
+ filter_black_bg,
276
+ filter_white_bg,
277
+ process_res_method_dropdown,
278
+ save_percentage,
279
+ submit_btn,
280
+ clear_btn,
281
+ num_max_points,
282
+ infer_gs,
283
+ ref_view_strategy_dropdown,
284
+ selected_image_index_state,
285
+ measure_view_selector,
286
+ measure_image,
287
+ measure_depth_image,
288
+ measure_text,
289
+ prev_measure_btn,
290
+ next_measure_btn,
291
+ scenes,
292
+ scene_components,
293
+ gs_video,
294
+ gs_info,
295
+ gs_trj_mode,
296
+ gs_video_quality,
297
+ )
298
+
299
+ # Acknowledgements
300
+ self.ui_components.create_acknowledgements_section()
301
+
302
+ return demo
303
+
304
+ def _setup_event_handlers(
305
+ self,
306
+ demo: gr.Blocks,
307
+ is_example: gr.Textbox,
308
+ processed_data_state: gr.State,
309
+ measure_points_state: gr.State,
310
+ target_dir_output: gr.Textbox,
311
+ input_video: gr.Video,
312
+ input_images: gr.File,
313
+ s_time_interval: gr.Slider,
314
+ image_gallery: gr.Gallery,
315
+ reconstruction_output: gr.Model3D,
316
+ log_output: gr.Markdown,
317
+ show_cam: gr.Checkbox,
318
+ filter_black_bg: gr.Checkbox,
319
+ filter_white_bg: gr.Checkbox,
320
+ process_res_method_dropdown: gr.Dropdown,
321
+ save_percentage: gr.Slider,
322
+ submit_btn: gr.Button,
323
+ clear_btn: gr.ClearButton,
324
+ num_max_points: gr.Slider,
325
+ infer_gs: gr.Checkbox,
326
+ ref_view_strategy_dropdown: gr.Dropdown,
327
+ selected_image_index_state: gr.State,
328
+ measure_view_selector: gr.Dropdown,
329
+ measure_image: gr.Image,
330
+ measure_depth_image: gr.Image,
331
+ measure_text: gr.Markdown,
332
+ prev_measure_btn: gr.Button,
333
+ next_measure_btn: gr.Button,
334
+ scenes: List[Dict[str, Any]],
335
+ scene_components: List[gr.Image],
336
+ gs_video: gr.Video,
337
+ gs_info: gr.Markdown,
338
+ gs_trj_mode: gr.Dropdown,
339
+ gs_video_quality: gr.Dropdown,
340
+ ) -> None:
341
+ """
342
+ Set up all event handlers for the application.
343
+
344
+ Args:
345
+ demo: Gradio Blocks interface
346
+ All other arguments: Gradio components to connect
347
+ """
348
+ # Configure clear button
349
+ clear_btn.add(
350
+ [
351
+ input_video,
352
+ input_images,
353
+ reconstruction_output,
354
+ log_output,
355
+ target_dir_output,
356
+ image_gallery,
357
+ gs_video,
358
+ ]
359
+ )
360
+
361
+ # Main reconstruction button
362
+ submit_btn.click(
363
+ fn=self.event_handlers.clear_fields, inputs=[], outputs=[reconstruction_output]
364
+ ).then(fn=self.event_handlers.update_log, inputs=[], outputs=[log_output]).then(
365
+ fn=self.event_handlers.gradio_demo,
366
+ inputs=[
367
+ target_dir_output,
368
+ show_cam,
369
+ filter_black_bg,
370
+ filter_white_bg,
371
+ process_res_method_dropdown,
372
+ save_percentage,
373
+ # pass num_max_points
374
+ num_max_points,
375
+ infer_gs,
376
+ ref_view_strategy_dropdown,
377
+ gs_trj_mode,
378
+ gs_video_quality,
379
+ ],
380
+ outputs=[
381
+ reconstruction_output,
382
+ log_output,
383
+ processed_data_state,
384
+ measure_image,
385
+ measure_depth_image,
386
+ measure_text,
387
+ measure_view_selector,
388
+ gs_video,
389
+ gs_video, # gs_video visibility
390
+ gs_info, # gs_info visibility
391
+ ],
392
+ ).then(
393
+ fn=lambda: "False",
394
+ inputs=[],
395
+ outputs=[is_example], # set is_example to "False"
396
+ )
397
+
398
+ # Real-time visualization updates
399
+ self._setup_visualization_handlers(
400
+ show_cam,
401
+ filter_black_bg,
402
+ filter_white_bg,
403
+ process_res_method_dropdown,
404
+ target_dir_output,
405
+ is_example,
406
+ reconstruction_output,
407
+ log_output,
408
+ )
409
+
410
+ # File upload handlers
411
+ input_video.change(
412
+ fn=self.event_handlers.handle_uploads,
413
+ inputs=[input_video, input_images, s_time_interval],
414
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
415
+ )
416
+ input_images.change(
417
+ fn=self.event_handlers.handle_uploads,
418
+ inputs=[input_video, input_images, s_time_interval],
419
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
420
+ )
421
+
422
+ # Navigation handlers
423
+ self._setup_navigation_handlers(
424
+ prev_measure_btn,
425
+ next_measure_btn,
426
+ measure_view_selector,
427
+ measure_image,
428
+ measure_depth_image,
429
+ measure_points_state,
430
+ processed_data_state,
431
+ )
432
+
433
+ # Measurement handler
434
+ measure_image.select(
435
+ fn=self.event_handlers.measure,
436
+ inputs=[processed_data_state, measure_points_state, measure_view_selector],
437
+ outputs=[measure_image, measure_depth_image, measure_points_state, measure_text],
438
+ )
439
+
440
+ # Example scene handlers
441
+ self._setup_example_scene_handlers(
442
+ scenes,
443
+ scene_components,
444
+ reconstruction_output,
445
+ target_dir_output,
446
+ image_gallery,
447
+ log_output,
448
+ is_example,
449
+ processed_data_state,
450
+ measure_view_selector,
451
+ measure_image,
452
+ measure_depth_image,
453
+ gs_video,
454
+ gs_info,
455
+ )
456
+
457
+ def _setup_visualization_handlers(
458
+ self,
459
+ show_cam: gr.Checkbox,
460
+ filter_black_bg: gr.Checkbox,
461
+ filter_white_bg: gr.Checkbox,
462
+ process_res_method_dropdown: gr.Dropdown,
463
+ target_dir_output: gr.Textbox,
464
+ is_example: gr.Textbox,
465
+ reconstruction_output: gr.Model3D,
466
+ log_output: gr.Markdown,
467
+ ) -> None:
468
+ """Set up visualization update handlers."""
469
+ # Common inputs for visualization updates
470
+ viz_inputs = [
471
+ target_dir_output,
472
+ show_cam,
473
+ is_example,
474
+ filter_black_bg,
475
+ filter_white_bg,
476
+ process_res_method_dropdown,
477
+ ]
478
+
479
+ # Set up change handlers for all visualization controls
480
+ for component in [show_cam, filter_black_bg, filter_white_bg]:
481
+ component.change(
482
+ fn=self.event_handlers.update_visualization,
483
+ inputs=viz_inputs,
484
+ outputs=[reconstruction_output, log_output],
485
+ )
486
+
487
+ def _setup_navigation_handlers(
488
+ self,
489
+ prev_measure_btn: gr.Button,
490
+ next_measure_btn: gr.Button,
491
+ measure_view_selector: gr.Dropdown,
492
+ measure_image: gr.Image,
493
+ measure_depth_image: gr.Image,
494
+ measure_points_state: gr.State,
495
+ processed_data_state: gr.State,
496
+ ) -> None:
497
+ """Set up navigation handlers for measure tab."""
498
+ # Measure tab navigation
499
+ prev_measure_btn.click(
500
+ fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
501
+ processed_data, current_selector, -1
502
+ ),
503
+ inputs=[processed_data_state, measure_view_selector],
504
+ outputs=[
505
+ measure_view_selector,
506
+ measure_image,
507
+ measure_depth_image,
508
+ measure_points_state,
509
+ ],
510
+ )
511
+
512
+ next_measure_btn.click(
513
+ fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
514
+ processed_data, current_selector, 1
515
+ ),
516
+ inputs=[processed_data_state, measure_view_selector],
517
+ outputs=[
518
+ measure_view_selector,
519
+ measure_image,
520
+ measure_depth_image,
521
+ measure_points_state,
522
+ ],
523
+ )
524
+
525
+ measure_view_selector.change(
526
+ fn=lambda processed_data, selector_value: (
527
+ self.event_handlers.update_measure_view(
528
+ processed_data, int(selector_value.split()[1]) - 1
529
+ )
530
+ if selector_value
531
+ else (None, None, [])
532
+ ),
533
+ inputs=[processed_data_state, measure_view_selector],
534
+ outputs=[measure_image, measure_depth_image, measure_points_state],
535
+ )
536
+
537
+ def _setup_example_scene_handlers(
538
+ self,
539
+ scenes: List[Dict[str, Any]],
540
+ scene_components: List[gr.Image],
541
+ reconstruction_output: gr.Model3D,
542
+ target_dir_output: gr.Textbox,
543
+ image_gallery: gr.Gallery,
544
+ log_output: gr.Markdown,
545
+ is_example: gr.Textbox,
546
+ processed_data_state: gr.State,
547
+ measure_view_selector: gr.Dropdown,
548
+ measure_image: gr.Image,
549
+ measure_depth_image: gr.Image,
550
+ gs_video: gr.Video,
551
+ gs_info: gr.Markdown,
552
+ ) -> None:
553
+ """Set up example scene handlers."""
554
+
555
+ def load_and_update_measure(name):
556
+ result = self.event_handlers.load_example_scene(name)
557
+ # result = (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
558
+
559
+ # Update measure view if processed_data is available
560
+ measure_img = None
561
+ measure_depth = None
562
+ if result[4] is not None: # processed_data exists
563
+ measure_img, measure_depth, _ = (
564
+ self.event_handlers.visualization_handler.update_measure_view(result[4], 0)
565
+ )
566
+
567
+ return result + ("True", measure_img, measure_depth)
568
+
569
+ for i, scene in enumerate(scenes):
570
+ if i < len(scene_components):
571
+ scene_components[i].select(
572
+ fn=lambda name=scene["name"]: load_and_update_measure(name),
573
+ outputs=[
574
+ reconstruction_output,
575
+ target_dir_output,
576
+ image_gallery,
577
+ log_output,
578
+ processed_data_state,
579
+ measure_view_selector,
580
+ gs_video,
581
+ gs_video, # gs_video_visibility
582
+ gs_info, # gs_info_visibility
583
+ is_example,
584
+ measure_image,
585
+ measure_depth_image,
586
+ ],
587
+ )
588
+
589
+ def launch(self, host: str = "127.0.0.1", port: int = 7860, **kwargs) -> None:
590
+ """
591
+ Launch the application.
592
+
593
+ Args:
594
+ host: Host address to bind to
595
+ port: Port number to bind to
596
+ **kwargs: Additional arguments for demo.launch()
597
+ """
598
+ demo = self.create_app()
599
+ demo.queue(max_size=20).launch(
600
+ show_error=True, ssr_mode=False, server_name=host, server_port=port, **kwargs
601
+ )
602
+
603
+
604
+ def main():
605
+ """Main function to run the application."""
606
+ parser = argparse.ArgumentParser(
607
+ description="Depth Anything 3 Gradio Application",
608
+ formatter_class=argparse.RawDescriptionHelpFormatter,
609
+ epilog="""
610
+ Examples:
611
+ # Basic usage
612
+ python gradio_app.py --help
613
+ python gradio_app.py --host 0.0.0.0 --port 8080
614
+ python gradio_app.py --model-dir /path/to/model --workspace-dir /path/to/workspace
615
+
616
+ # Cache examples at startup (all low-res)
617
+ python gradio_app.py --cache-examples
618
+
619
+ # Cache with selective high-res+3DGS for scenes matching tag
620
+ python gradio_app.py --cache-examples --cache-gs-tag dl3dv
621
+ # This will use high-res + 3DGS for scenes containing "dl3dv" in their name,
622
+ # and low-res only for other scenes
623
+ """,
624
+ )
625
+
626
+ # Server configuration
627
+ parser.add_argument(
628
+ "--host", default="127.0.0.1", help="Host address to bind to (default: 127.0.0.1)"
629
+ )
630
+ parser.add_argument(
631
+ "--port", type=int, default=7860, help="Port number to bind to (default: 7860)"
632
+ )
633
+
634
+ # Directory configuration
635
+ parser.add_argument(
636
+ "--model-dir",
637
+ default="depth-anything/DA3NESTED-GIANT-LARGE",
638
+ help="Path to the model directory (default: depth-anything/DA3NESTED-GIANT-LARGE)",
639
+ )
640
+ parser.add_argument(
641
+ "--workspace-dir",
642
+ default="workspace/gradio", # noqa: E501
643
+ help="Path to the workspace directory (default: workspace/gradio)", # noqa: E501
644
+ )
645
+ parser.add_argument(
646
+ "--gallery-dir",
647
+ default="workspace/gallery",
648
+ help="Path to the gallery directory (default: workspace/gallery)", # noqa: E501
649
+ )
650
+
651
+ # Additional Gradio options
652
+ parser.add_argument("--share", action="store_true", help="Create a public link for the app")
653
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
654
+
655
+ # Example caching options
656
+ parser.add_argument(
657
+ "--cache-examples",
658
+ action="store_true",
659
+ help="Pre-cache all example scenes at startup for faster loading",
660
+ )
661
+ parser.add_argument(
662
+ "--cache-gs-tag",
663
+ type=str,
664
+ default="",
665
+ help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", # noqa: E501
666
+ )
667
+
668
+ args = parser.parse_args()
669
+
670
+ # Create directories if they don't exist
671
+ os.makedirs(args.workspace_dir, exist_ok=True)
672
+ os.makedirs(args.gallery_dir, exist_ok=True)
673
+
674
+ # Initialize and launch the application
675
+ app = DepthAnything3App(
676
+ model_dir=args.model_dir, workspace_dir=args.workspace_dir, gallery_dir=args.gallery_dir
677
+ )
678
+
679
+ # Prepare launch arguments
680
+ launch_kwargs = {"share": args.share, "debug": args.debug}
681
+
682
+ print("Starting Depth Anything 3 Gradio App...")
683
+ print(f"Host: {args.host}")
684
+ print(f"Port: {args.port}")
685
+ print(f"Model Directory: {args.model_dir}")
686
+ print(f"Workspace Directory: {args.workspace_dir}")
687
+ print(f"Gallery Directory: {args.gallery_dir}")
688
+ print(f"Share: {args.share}")
689
+ print(f"Debug: {args.debug}")
690
+ print(f"Cache Examples: {args.cache_examples}")
691
+ if args.cache_examples:
692
+ if args.cache_gs_tag:
693
+ print(
694
+ f"Cache GS Tag: '{args.cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" # noqa: E501
695
+ ) # noqa: E501
696
+ else:
697
+ print("Cache GS Tag: None (all scenes will use low-res only)")
698
+
699
+ # Pre-cache examples if requested
700
+ if args.cache_examples:
701
+ print("\n" + "=" * 60)
702
+ print("Pre-caching mode enabled")
703
+ if args.cache_gs_tag:
704
+ print(f"Scenes containing '{args.cache_gs_tag}' will use HIGH-RES + 3DGS")
705
+ print("Other scenes will use LOW-RES only")
706
+ else:
707
+ print("All scenes will use LOW-RES only")
708
+ print("=" * 60)
709
+ app.cache_examples(
710
+ show_cam=True,
711
+ filter_black_bg=False,
712
+ filter_white_bg=False,
713
+ save_percentage=5.0,
714
+ num_max_points=1000,
715
+ cache_gs_tag=args.cache_gs_tag,
716
+ gs_trj_mode="smooth",
717
+ gs_video_quality="low",
718
+ )
719
+
720
+ app.launch(host=args.host, port=args.port, **launch_kwargs)
721
+
722
+
723
+ if __name__ == "__main__":
724
+ main()
Depth-Anything-3/src/depth_anything_3/app/modules/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Modules package for Depth Anything 3 Gradio app.
17
+
18
+ This package contains all the modular components for the Gradio application.
19
+ """
20
+
21
+ from depth_anything_3.app.modules.event_handlers import EventHandlers
22
+ from depth_anything_3.app.modules.file_handlers import FileHandler
23
+ from depth_anything_3.app.modules.model_inference import ModelInference
24
+ from depth_anything_3.app.modules.ui_components import UIComponents
25
+ from depth_anything_3.app.modules.utils import (
26
+ create_depth_visualization,
27
+ get_logo_base64,
28
+ get_scene_info,
29
+ save_to_gallery_func,
30
+ )
31
+ from depth_anything_3.app.modules.visualization import VisualizationHandler
32
+
33
+ __all__ = [
34
+ "ModelInference",
35
+ "FileHandler",
36
+ "VisualizationHandler",
37
+ "EventHandlers",
38
+ "UIComponents",
39
+ "create_depth_visualization",
40
+ "save_to_gallery_func",
41
+ "get_scene_info",
42
+ "get_logo_base64",
43
+ ]
Depth-Anything-3/src/depth_anything_3/app/modules/event_handlers.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Event handling module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles all event callbacks and user interactions.
19
+ """
20
+
21
+ import os
22
+ import time
23
+ from glob import glob
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+ import gradio as gr
26
+ import numpy as np
27
+ import torch
28
+
29
+ from depth_anything_3.app.modules.file_handlers import FileHandler
30
+ from depth_anything_3.app.modules.model_inference import ModelInference
31
+ from depth_anything_3.utils.memory import cleanup_cuda_memory
32
+ from depth_anything_3.app.modules.visualization import VisualizationHandler
33
+
34
+
35
+ class EventHandlers:
36
+ """
37
+ Handles all event callbacks and user interactions for the Gradio app.
38
+ """
39
+
40
+ def __init__(self):
41
+ """Initialize the event handlers."""
42
+ self.model_inference = ModelInference()
43
+ self.file_handler = FileHandler()
44
+ self.visualization_handler = VisualizationHandler()
45
+
46
+ def clear_fields(self) -> None:
47
+ """
48
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
49
+ """
50
+ return None
51
+
52
+ def update_log(self) -> str:
53
+ """
54
+ Display a quick log message while waiting.
55
+ """
56
+ return "Loading and Reconstructing..."
57
+
58
+ def save_current_visualization(
59
+ self,
60
+ target_dir: str,
61
+ save_percentage: float,
62
+ show_cam: bool,
63
+ filter_black_bg: bool,
64
+ filter_white_bg: bool,
65
+ processed_data: Optional[Dict],
66
+ scene_name: str = "",
67
+ ) -> str:
68
+ """
69
+ Save current visualization results to gallery with specified save percentage.
70
+
71
+ Args:
72
+ target_dir: Directory containing results
73
+ save_percentage: Percentage of points to save (0-100)
74
+ show_cam: Whether to show cameras
75
+ filter_black_bg: Whether to filter black background
76
+ filter_white_bg: Whether to filter white background
77
+ processed_data: Processed data from reconstruction
78
+
79
+ Returns:
80
+ Status message
81
+ """
82
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
83
+ return "No reconstruction available. Please run 'Reconstruct' first."
84
+
85
+ if processed_data is None:
86
+ return "No processed data available. Please run 'Reconstruct' first."
87
+
88
+ try:
89
+ # Add debug information
90
+ print("[DEBUG] save_current_visualization called with:")
91
+ print(f" target_dir: {target_dir}")
92
+ print(f" save_percentage: {save_percentage}")
93
+ print(f" show_cam: {show_cam}")
94
+ print(f" filter_black_bg: {filter_black_bg}")
95
+ print(f" filter_white_bg: {filter_white_bg}")
96
+ print(f" processed_data: {processed_data is not None}")
97
+
98
+ # Import the gallery save function
99
+ # Create gallery name with user input or auto-generated
100
+ import datetime
101
+
102
+ from .utils import save_to_gallery_func
103
+
104
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
105
+ if scene_name and scene_name.strip():
106
+ gallery_name = f"{scene_name.strip()}_{timestamp}_pct{save_percentage:.0f}"
107
+ else:
108
+ gallery_name = f"save_{timestamp}_pct{save_percentage:.0f}"
109
+
110
+ print(f"[DEBUG] Saving to gallery with name: {gallery_name}")
111
+
112
+ # Save entire process folder to gallery
113
+ success, message = save_to_gallery_func(
114
+ target_dir=target_dir, processed_data=processed_data, gallery_name=gallery_name
115
+ )
116
+
117
+ if success:
118
+ print(f"[DEBUG] Gallery save completed successfully: {message}")
119
+ return (
120
+ "Successfully saved to gallery!\n"
121
+ f"Gallery name: {gallery_name}\n"
122
+ f"Save percentage: {save_percentage}%\n"
123
+ f"Show cameras: {show_cam}\n"
124
+ f"Filter black bg: {filter_black_bg}\n"
125
+ f"Filter white bg: {filter_white_bg}\n\n"
126
+ f"{message}"
127
+ )
128
+ else:
129
+ print(f"[DEBUG] Gallery save failed: {message}")
130
+ return f"Failed to save to gallery: {message}"
131
+
132
+ except Exception as e:
133
+ return f"Error saving visualization: {str(e)}"
134
+
135
+ def gradio_demo(
136
+ self,
137
+ target_dir: str,
138
+ show_cam: bool = True,
139
+ filter_black_bg: bool = False,
140
+ filter_white_bg: bool = False,
141
+ process_res_method: str = "upper_bound_resize",
142
+ save_percentage: float = 30.0,
143
+ num_max_points: int = 1_000_000,
144
+ infer_gs: bool = False,
145
+ ref_view_strategy: str = "saddle_balanced",
146
+ gs_trj_mode: str = "extend",
147
+ gs_video_quality: str = "high",
148
+ ) -> Tuple[
149
+ Optional[str],
150
+ str,
151
+ Optional[Dict],
152
+ Optional[np.ndarray],
153
+ Optional[np.ndarray],
154
+ str,
155
+ gr.Dropdown,
156
+ Optional[str], # gs video path
157
+ gr.update, # gs video visibility update
158
+ gr.update, # gs info visibility update
159
+ ]:
160
+ """
161
+ Perform reconstruction using the already-created target_dir/images.
162
+
163
+ Args:
164
+ target_dir: Directory containing images
165
+ show_cam: Whether to show camera
166
+ filter_black_bg: Whether to filter black background
167
+ filter_white_bg: Whether to filter white background
168
+ process_res_method: Method for resizing input images
169
+ save_percentage: Filter percentage for point cloud
170
+ num_max_points: Maximum number of points
171
+ infer_gs: Whether to infer 3D Gaussian Splatting
172
+ ref_view_strategy: Reference view selection strategy
173
+
174
+ Returns:
175
+ Tuple of reconstruction results
176
+ """
177
+ if not os.path.isdir(target_dir) or target_dir == "None":
178
+ return (
179
+ None,
180
+ "No valid target directory found. Please upload first.",
181
+ None,
182
+ None,
183
+ None,
184
+ "",
185
+ None,
186
+ None,
187
+ gr.update(visible=False), # gs_video
188
+ gr.update(visible=True), # gs_info
189
+ )
190
+
191
+ start_time = time.time()
192
+ cleanup_cuda_memory()
193
+
194
+ # Get image files for logging
195
+ target_dir_images = os.path.join(target_dir, "images")
196
+ all_files = (
197
+ sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
198
+ )
199
+
200
+ print("Running DepthAnything3 model...")
201
+ print(f"Reference view strategy: {ref_view_strategy}")
202
+
203
+ with torch.no_grad():
204
+ prediction, processed_data = self.model_inference.run_inference(
205
+ target_dir,
206
+ process_res_method=process_res_method,
207
+ show_camera=show_cam,
208
+ save_percentage=save_percentage,
209
+ num_max_points=int(num_max_points * 1000), # Convert K to actual count
210
+ infer_gs=infer_gs,
211
+ ref_view_strategy=ref_view_strategy,
212
+ gs_trj_mode=gs_trj_mode,
213
+ gs_video_quality=gs_video_quality,
214
+ )
215
+
216
+ # The GLB file is already generated by the API
217
+ glbfile = os.path.join(target_dir, "scene.glb")
218
+
219
+ # Handle 3DGS video based on infer_gs flag
220
+ gsvideo_path = None
221
+ gs_video_visible = False
222
+ gs_info_visible = True
223
+
224
+ if infer_gs:
225
+ try:
226
+ gsvideo_path = sorted(glob(os.path.join(target_dir, "gs_video", "*.mp4")))[-1]
227
+ gs_video_visible = True
228
+ gs_info_visible = False
229
+ except IndexError:
230
+ gsvideo_path = None
231
+ print("3DGS video not found, but infer_gs was enabled")
232
+
233
+ # Cleanup
234
+ cleanup_cuda_memory()
235
+
236
+ end_time = time.time()
237
+ print(f"Total time: {end_time - start_time:.2f} seconds")
238
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
239
+
240
+ # Populate visualization tabs with processed data
241
+ depth_vis, measure_img, measure_depth_vis, measure_pts = (
242
+ self.visualization_handler.populate_visualization_tabs(processed_data)
243
+ )
244
+
245
+ # Update view selectors based on available views
246
+ depth_selector, measure_selector = self.visualization_handler.update_view_selectors(
247
+ processed_data
248
+ )
249
+
250
+ return (
251
+ glbfile,
252
+ log_msg,
253
+ processed_data,
254
+ measure_img, # measure_image
255
+ measure_depth_vis, # measure_depth_image
256
+ "", # measure_text (empty initially)
257
+ measure_selector, # measure_view_selector
258
+ gsvideo_path,
259
+ gr.update(visible=gs_video_visible), # gs_video visibility
260
+ gr.update(visible=gs_info_visible), # gs_info visibility
261
+ )
262
+
263
+ def update_visualization(
264
+ self,
265
+ target_dir: str,
266
+ show_cam: bool,
267
+ is_example: str,
268
+ filter_black_bg: bool = False,
269
+ filter_white_bg: bool = False,
270
+ process_res_method: str = "upper_bound_resize",
271
+ ) -> Tuple[gr.update, str]:
272
+ """
273
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
274
+ and return it for the 3D viewer.
275
+
276
+ Args:
277
+ target_dir: Directory containing results
278
+ show_cam: Whether to show camera
279
+ is_example: Whether this is an example scene
280
+ filter_black_bg: Whether to filter black background
281
+ filter_white_bg: Whether to filter white background
282
+ process_res_method: Method for resizing input images
283
+
284
+ Returns:
285
+ Tuple of (glb_file, log_message)
286
+ """
287
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
288
+ return (
289
+ gr.update(),
290
+ "No reconstruction available. Please click the Reconstruct button first.",
291
+ )
292
+
293
+ # Check if GLB exists (could be cached example or reconstructed scene)
294
+ glbfile = os.path.join(target_dir, "scene.glb")
295
+ if os.path.exists(glbfile):
296
+ return (
297
+ glbfile,
298
+ (
299
+ "Visualization loaded from cache."
300
+ if is_example == "True"
301
+ else "Visualization updated."
302
+ ),
303
+ )
304
+
305
+ # If no GLB but it's an example that hasn't been reconstructed yet
306
+ if is_example == "True":
307
+ return (
308
+ gr.update(),
309
+ "No reconstruction available. Please click the Reconstruct button first.",
310
+ )
311
+
312
+ # For non-examples, check predictions.npz
313
+ predictions_path = os.path.join(target_dir, "predictions.npz")
314
+ if not os.path.exists(predictions_path):
315
+ error_message = (
316
+ f"No reconstruction available at {predictions_path}. "
317
+ "Please run 'Reconstruct' first."
318
+ )
319
+ return gr.update(), error_message
320
+
321
+ loaded = np.load(predictions_path, allow_pickle=True)
322
+ predictions = {key: loaded[key] for key in loaded.keys()} # noqa: F841
323
+
324
+ return (
325
+ glbfile,
326
+ "Visualization updated.",
327
+ )
328
+
329
+ def handle_uploads(
330
+ self,
331
+ input_video: Optional[str],
332
+ input_images: Optional[List],
333
+ s_time_interval: float = 10.0,
334
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
335
+ """
336
+ Handle file uploads and update gallery.
337
+
338
+ Args:
339
+ input_video: Path to input video file
340
+ input_images: List of input image files
341
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
342
+
343
+ Returns:
344
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message)
345
+ """
346
+ return self.file_handler.update_gallery_on_upload(
347
+ input_video, input_images, s_time_interval
348
+ )
349
+
350
+ def load_example_scene(self, scene_name: str, examples_dir: str = None) -> Tuple[
351
+ Optional[str],
352
+ Optional[str],
353
+ Optional[List],
354
+ str,
355
+ Optional[Dict],
356
+ gr.Dropdown,
357
+ Optional[str],
358
+ gr.update,
359
+ gr.update,
360
+ ]:
361
+ """
362
+ Load a scene from examples directory.
363
+
364
+ Args:
365
+ scene_name: Name of the scene to load
366
+ examples_dir: Path to examples directory (if None, uses workspace_dir/examples)
367
+
368
+ Returns:
369
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
370
+ """
371
+ if examples_dir is None:
372
+ # Get workspace directory from environment variable
373
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
374
+ examples_dir = os.path.join(workspace_dir, "examples")
375
+
376
+ reconstruction_output, target_dir, image_paths, log_message = (
377
+ self.file_handler.load_example_scene(scene_name, examples_dir)
378
+ )
379
+
380
+ # Try to load cached processed data if available
381
+ processed_data = None
382
+ measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1")
383
+ gs_video_path = None
384
+ gs_video_visible = False
385
+ gs_info_visible = True
386
+
387
+ if target_dir and target_dir != "None":
388
+ predictions_path = os.path.join(target_dir, "predictions.npz")
389
+ if os.path.exists(predictions_path):
390
+ try:
391
+ # Load predictions from cache
392
+ loaded = np.load(predictions_path, allow_pickle=True)
393
+ predictions = {key: loaded[key] for key in loaded.keys()}
394
+
395
+ # Reconstruct processed_data structure
396
+ num_images = len(predictions.get("images", []))
397
+ processed_data = {}
398
+
399
+ for i in range(num_images):
400
+ processed_data[i] = {
401
+ "image": predictions["images"][i] if "images" in predictions else None,
402
+ "depth": predictions["depths"][i] if "depths" in predictions else None,
403
+ "depth_image": os.path.join(
404
+ target_dir, "depth_vis", f"{i:04d}.jpg" # Fixed: use .jpg not .png
405
+ ),
406
+ "intrinsics": (
407
+ predictions["intrinsics"][i]
408
+ if "intrinsics" in predictions
409
+ and i < len(predictions["intrinsics"])
410
+ else None
411
+ ),
412
+ "mask": None,
413
+ }
414
+
415
+ # Update measure view selector
416
+ choices = [f"View {i + 1}" for i in range(num_images)]
417
+ measure_view_selector = gr.Dropdown(choices=choices, value=choices[0])
418
+
419
+ except Exception as e:
420
+ print(f"Error loading cached data: {e}")
421
+
422
+ # Check for cached 3DGS video
423
+ gs_video_dir = os.path.join(target_dir, "gs_video")
424
+ if os.path.exists(gs_video_dir):
425
+ try:
426
+ from glob import glob
427
+
428
+ gs_videos = sorted(glob(os.path.join(gs_video_dir, "*.mp4")))
429
+ if gs_videos:
430
+ gs_video_path = gs_videos[-1]
431
+ gs_video_visible = True
432
+ gs_info_visible = False
433
+ print(f"Loaded cached 3DGS video: {gs_video_path}")
434
+ except Exception as e:
435
+ print(f"Error loading cached 3DGS video: {e}")
436
+
437
+ return (
438
+ reconstruction_output,
439
+ target_dir,
440
+ image_paths,
441
+ log_message,
442
+ processed_data,
443
+ measure_view_selector,
444
+ gs_video_path,
445
+ gr.update(visible=gs_video_visible),
446
+ gr.update(visible=gs_info_visible),
447
+ )
448
+
449
+ def navigate_depth_view(
450
+ self,
451
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
452
+ current_selector: str,
453
+ direction: int,
454
+ ) -> Tuple[str, Optional[str]]:
455
+ """
456
+ Navigate depth view.
457
+
458
+ Args:
459
+ processed_data: Processed data dictionary
460
+ current_selector: Current selector value
461
+ direction: Direction to navigate
462
+
463
+ Returns:
464
+ Tuple of (new_selector_value, depth_vis)
465
+ """
466
+ return self.visualization_handler.navigate_depth_view(
467
+ processed_data, current_selector, direction
468
+ )
469
+
470
+ def update_depth_view(
471
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
472
+ ) -> Optional[str]:
473
+ """
474
+ Update depth view for a specific view index.
475
+
476
+ Args:
477
+ processed_data: Processed data dictionary
478
+ view_index: Index of the view to update
479
+
480
+ Returns:
481
+ Path to depth visualization image or None
482
+ """
483
+ return self.visualization_handler.update_depth_view(processed_data, view_index)
484
+
485
+ def navigate_measure_view(
486
+ self,
487
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
488
+ current_selector: str,
489
+ direction: int,
490
+ ) -> Tuple[str, Optional[np.ndarray], Optional[np.ndarray], List]:
491
+ """
492
+ Navigate measure view.
493
+
494
+ Args:
495
+ processed_data: Processed data dictionary
496
+ current_selector: Current selector value
497
+ direction: Direction to navigate
498
+
499
+ Returns:
500
+ Tuple of (new_selector_value, measure_image, depth_right_half, measure_points)
501
+ """
502
+ return self.visualization_handler.navigate_measure_view(
503
+ processed_data, current_selector, direction
504
+ )
505
+
506
+ def update_measure_view(
507
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
508
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
509
+ """
510
+ Update measure view for a specific view index.
511
+
512
+ Args:
513
+ processed_data: Processed data dictionary
514
+ view_index: Index of the view to update
515
+
516
+ Returns:
517
+ Tuple of (measure_image, depth_right_half, measure_points)
518
+ """
519
+ return self.visualization_handler.update_measure_view(processed_data, view_index)
520
+
521
+ def measure(
522
+ self,
523
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
524
+ measure_points: List,
525
+ current_view_selector: str,
526
+ event: gr.SelectData,
527
+ ) -> List:
528
+ """
529
+ Handle measurement on images.
530
+
531
+ Args:
532
+ processed_data: Processed data dictionary
533
+ measure_points: List of current measure points
534
+ current_view_selector: Current view selector value
535
+ event: Gradio select event
536
+
537
+ Returns:
538
+ List of [image, depth_right_half, measure_points, text]
539
+ """
540
+ return self.visualization_handler.measure(
541
+ processed_data, measure_points, current_view_selector, event
542
+ )
543
+
544
+ def select_first_frame(
545
+ self, image_gallery: List, selected_index: int = 0
546
+ ) -> Tuple[List, str, str]:
547
+ """
548
+ Select the first frame from the image gallery.
549
+
550
+ Args:
551
+ image_gallery: List of images in the gallery
552
+ selected_index: Index of the selected image (default: 0)
553
+
554
+ Returns:
555
+ Tuple of (updated_image_gallery, log_message, selected_frame_path)
556
+ """
557
+ try:
558
+ if not image_gallery or len(image_gallery) == 0:
559
+ return image_gallery, "No images available to select as first frame.", ""
560
+
561
+ # Handle None or invalid selected_index
562
+ if (
563
+ selected_index is None
564
+ or selected_index < 0
565
+ or selected_index >= len(image_gallery)
566
+ ):
567
+ selected_index = 0
568
+ print(f"Invalid selected_index: {selected_index}, using default: 0")
569
+
570
+ # Get the selected image based on index
571
+ selected_image = image_gallery[selected_index]
572
+ print(f"Selected image index: {selected_index}")
573
+ print(f"Total images: {len(image_gallery)}")
574
+
575
+ # Extract the file path from the selected image
576
+ selected_frame_path = ""
577
+ print(f"Selected image type: {type(selected_image)}")
578
+ print(f"Selected image: {selected_image}")
579
+
580
+ if isinstance(selected_image, tuple):
581
+ # Gradio Gallery returns tuple (path, None)
582
+ selected_frame_path = selected_image[0]
583
+ elif isinstance(selected_image, str):
584
+ selected_frame_path = selected_image
585
+ elif hasattr(selected_image, "name"):
586
+ selected_frame_path = selected_image.name
587
+ elif isinstance(selected_image, dict):
588
+ if "name" in selected_image:
589
+ selected_frame_path = selected_image["name"]
590
+ elif "path" in selected_image:
591
+ selected_frame_path = selected_image["path"]
592
+ elif "src" in selected_image:
593
+ selected_frame_path = selected_image["src"]
594
+ else:
595
+ # Try to convert to string
596
+ selected_frame_path = str(selected_image)
597
+
598
+ print(f"Extracted path: {selected_frame_path}")
599
+
600
+ # Extract filename from the path for matching
601
+ import os
602
+
603
+ selected_filename = os.path.basename(selected_frame_path)
604
+ print(f"Selected filename: {selected_filename}")
605
+
606
+ # Move the selected image to the front
607
+ updated_gallery = [selected_image] + [
608
+ img for img in image_gallery if img != selected_image
609
+ ]
610
+
611
+ log_message = (
612
+ f"Selected frame: {selected_filename}. "
613
+ f"Moved to first position. Total frames: {len(updated_gallery)}"
614
+ )
615
+ return updated_gallery, log_message, selected_filename
616
+
617
+ except Exception as e:
618
+ print(f"Error selecting first frame: {e}")
619
+ return image_gallery, f"Error selecting first frame: {e}", ""
Depth-Anything-3/src/depth_anything_3/app/modules/file_handlers.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ File handling module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles file uploads, video processing, and file operations.
19
+ """
20
+
21
+ import os
22
+ import shutil
23
+ import time
24
+ from datetime import datetime
25
+ from typing import List, Optional, Tuple
26
+ import cv2
27
+ from PIL import Image
28
+ from pillow_heif import register_heif_opener
29
+
30
+ register_heif_opener()
31
+
32
+
33
+ class FileHandler:
34
+ """
35
+ Handles file uploads and processing for the Gradio app.
36
+ """
37
+
38
+ def __init__(self):
39
+ """Initialize the file handler."""
40
+
41
+ def handle_uploads(
42
+ self,
43
+ input_video: Optional[str],
44
+ input_images: Optional[List],
45
+ s_time_interval: float = 10.0,
46
+ ) -> Tuple[str, List[str]]:
47
+ """
48
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
49
+ images or extracted frames from video into it.
50
+
51
+ Args:
52
+ input_video: Path to input video file
53
+ input_images: List of input image files
54
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
55
+
56
+ Returns:
57
+ Tuple of (target_dir, image_paths)
58
+ """
59
+ start_time = time.time()
60
+
61
+ # Get workspace directory from environment variable or use default
62
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
63
+ if not os.path.exists(workspace_dir):
64
+ os.makedirs(workspace_dir)
65
+
66
+ # Create input_images subdirectory
67
+ input_images_dir = os.path.join(workspace_dir, "input_images")
68
+ if not os.path.exists(input_images_dir):
69
+ os.makedirs(input_images_dir)
70
+
71
+ # Create a unique folder name within input_images
72
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
73
+ target_dir = os.path.join(input_images_dir, f"session_{timestamp}")
74
+ target_dir_images = os.path.join(target_dir, "images")
75
+
76
+ # Clean up if somehow that folder already exists
77
+ if os.path.exists(target_dir):
78
+ shutil.rmtree(target_dir)
79
+ os.makedirs(target_dir)
80
+ os.makedirs(target_dir_images)
81
+
82
+ image_paths = []
83
+
84
+ # Handle images
85
+ if input_images is not None:
86
+ image_paths.extend(self._process_images(input_images, target_dir_images))
87
+
88
+ # Handle video
89
+ if input_video is not None:
90
+ image_paths.extend(
91
+ self._process_video(input_video, target_dir_images, s_time_interval)
92
+ )
93
+
94
+ # Sort final images for gallery
95
+ image_paths = sorted(image_paths)
96
+
97
+ end_time = time.time()
98
+ print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
99
+ return target_dir, image_paths
100
+
101
+ def _process_images(self, input_images: List, target_dir_images: str) -> List[str]:
102
+ """
103
+ Process uploaded images.
104
+
105
+ Args:
106
+ input_images: List of input image files
107
+ target_dir_images: Target directory for images
108
+
109
+ Returns:
110
+ List of processed image paths
111
+ """
112
+ image_paths = []
113
+
114
+ for file_data in input_images:
115
+ if isinstance(file_data, dict) and "name" in file_data:
116
+ file_path = file_data["name"]
117
+ else:
118
+ file_path = file_data
119
+
120
+ # Check if the file is a HEIC image
121
+ file_ext = os.path.splitext(file_path)[1].lower()
122
+ if file_ext in [".heic", ".heif"]:
123
+ # Convert HEIC to JPEG for better gallery compatibility
124
+ try:
125
+ with Image.open(file_path) as img:
126
+ # Convert to RGB if necessary (HEIC can have different color modes)
127
+ if img.mode not in ("RGB", "L"):
128
+ img = img.convert("RGB")
129
+
130
+ # Create JPEG filename
131
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
132
+ dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
133
+
134
+ # Save as JPEG with high quality
135
+ img.save(dst_path, "JPEG", quality=95)
136
+ image_paths.append(dst_path)
137
+ print(
138
+ f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> "
139
+ f"{os.path.basename(dst_path)}"
140
+ )
141
+ except Exception as e:
142
+ print(f"Error converting HEIC file {file_path}: {e}")
143
+ # Fall back to copying as is
144
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
145
+ shutil.copy(file_path, dst_path)
146
+ image_paths.append(dst_path)
147
+ else:
148
+ # Regular image files - copy as is
149
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
150
+ shutil.copy(file_path, dst_path)
151
+ image_paths.append(dst_path)
152
+
153
+ return image_paths
154
+
155
+ def _process_video(
156
+ self, input_video: str, target_dir_images: str, s_time_interval: float
157
+ ) -> List[str]:
158
+ """
159
+ Process video file and extract frames.
160
+
161
+ Args:
162
+ input_video: Path to input video file
163
+ target_dir_images: Target directory for extracted frames
164
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
165
+
166
+ Returns:
167
+ List of extracted frame paths
168
+ """
169
+ image_paths = []
170
+
171
+ if isinstance(input_video, dict) and "name" in input_video:
172
+ video_path = input_video["name"]
173
+ else:
174
+ video_path = input_video
175
+
176
+ vs = cv2.VideoCapture(video_path)
177
+ fps = vs.get(cv2.CAP_PROP_FPS)
178
+ frame_interval = max(1, int(fps / s_time_interval)) # Convert FPS to frame interval
179
+
180
+ count = 0
181
+ video_frame_num = 0
182
+ while True:
183
+ gotit, frame = vs.read()
184
+ if not gotit:
185
+ break
186
+ count += 1
187
+ if count % frame_interval == 0:
188
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
189
+ cv2.imwrite(image_path, frame)
190
+ image_paths.append(image_path)
191
+ video_frame_num += 1
192
+
193
+ return image_paths
194
+
195
+ def update_gallery_on_upload(
196
+ self,
197
+ input_video: Optional[str],
198
+ input_images: Optional[List],
199
+ s_time_interval: float = 10.0,
200
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
201
+ """
202
+ Handle file uploads and update gallery.
203
+
204
+ Args:
205
+ input_video: Path to input video file
206
+ input_images: List of input image files
207
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
208
+
209
+ Returns:
210
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message)
211
+ """
212
+ if not input_video and not input_images:
213
+ return None, None, None, None
214
+
215
+ target_dir, image_paths = self.handle_uploads(input_video, input_images, s_time_interval)
216
+ return (
217
+ None,
218
+ target_dir,
219
+ image_paths,
220
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
221
+ )
222
+
223
+ def load_example_scene(
224
+ self, scene_name: str, examples_dir: str = "examples"
225
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], str]:
226
+ """
227
+ Load a scene from examples directory.
228
+
229
+ Args:
230
+ scene_name: Name of the scene to load
231
+ examples_dir: Path to examples directory
232
+
233
+ Returns:
234
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message)
235
+ """
236
+ from depth_anything_3.app.modules.utils import get_scene_info
237
+
238
+ scenes = get_scene_info(examples_dir)
239
+
240
+ # Find the selected scene
241
+ selected_scene = None
242
+ for scene in scenes:
243
+ if scene["name"] == scene_name:
244
+ selected_scene = scene
245
+ break
246
+
247
+ if selected_scene is None:
248
+ return None, None, None, "Scene not found"
249
+
250
+ # Use fixed directory name for examples (not timestamp-based)
251
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
252
+ input_images_dir = os.path.join(workspace_dir, "input_images")
253
+ if not os.path.exists(input_images_dir):
254
+ os.makedirs(input_images_dir)
255
+
256
+ # Create a fixed folder name based on scene name
257
+ target_dir = os.path.join(input_images_dir, f"example_{scene_name}")
258
+ target_dir_images = os.path.join(target_dir, "images")
259
+
260
+ # Check if already cached (GLB file exists)
261
+ glb_path = os.path.join(target_dir, "scene.glb")
262
+ is_cached = os.path.exists(glb_path)
263
+
264
+ # Create directory if it doesn't exist
265
+ if not os.path.exists(target_dir):
266
+ os.makedirs(target_dir)
267
+ os.makedirs(target_dir_images)
268
+
269
+ # Copy images if directory is new or empty
270
+ if not os.path.exists(target_dir_images) or len(os.listdir(target_dir_images)) == 0:
271
+ os.makedirs(target_dir_images, exist_ok=True)
272
+ image_paths = []
273
+ for file_path in selected_scene["image_files"]:
274
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
275
+ shutil.copy(file_path, dst_path)
276
+ image_paths.append(dst_path)
277
+ else:
278
+ # Use existing images
279
+ image_paths = sorted(
280
+ [
281
+ os.path.join(target_dir_images, f)
282
+ for f in os.listdir(target_dir_images)
283
+ if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"))
284
+ ]
285
+ )
286
+
287
+ # Return cached GLB if available
288
+ if is_cached:
289
+ return (
290
+ glb_path, # Return cached reconstruction
291
+ target_dir, # Set target directory
292
+ image_paths, # Set gallery
293
+ f"Loaded cached scene '{scene_name}' with {selected_scene['num_images']} images.",
294
+ )
295
+ else:
296
+ return (
297
+ None, # No cached reconstruction
298
+ target_dir, # Set target directory
299
+ image_paths, # Set gallery
300
+ (
301
+ f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. "
302
+ "Click 'Reconstruct' to begin 3D processing."
303
+ ),
304
+ )
Depth-Anything-3/src/depth_anything_3/app/modules/model_inference.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Model inference module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles all model-related operations including inference,
19
+ data processing, and result preparation.
20
+ """
21
+
22
+ import glob
23
+ import os
24
+ from typing import Any, Dict, Optional, Tuple
25
+ import numpy as np
26
+ import torch
27
+
28
+ from depth_anything_3.api import DepthAnything3
29
+ from depth_anything_3.utils.memory import cleanup_cuda_memory
30
+ from depth_anything_3.utils.export.glb import export_to_glb
31
+ from depth_anything_3.utils.export.gs import export_to_gs_video
32
+
33
+
34
+ class ModelInference:
35
+ """
36
+ Handles model inference and data processing for Depth Anything 3.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Initialize the model inference handler."""
41
+ self.model = None
42
+
43
+ def initialize_model(self, device: str = "cuda") -> None:
44
+ """
45
+ Initialize the DepthAnything3 model.
46
+
47
+ Args:
48
+ device: Device to load the model on
49
+ """
50
+ if self.model is None:
51
+ # Get model directory from environment variable or use default
52
+ model_dir = os.environ.get(
53
+ "DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL"
54
+ )
55
+ self.model = DepthAnything3.from_pretrained(model_dir)
56
+ self.model = self.model.to(device)
57
+ else:
58
+ self.model = self.model.to(device)
59
+
60
+ self.model.eval()
61
+
62
+ def run_inference(
63
+ self,
64
+ target_dir: str,
65
+ filter_black_bg: bool = False,
66
+ filter_white_bg: bool = False,
67
+ process_res_method: str = "upper_bound_resize",
68
+ show_camera: bool = True,
69
+ save_percentage: float = 30.0,
70
+ num_max_points: int = 1_000_000,
71
+ infer_gs: bool = False,
72
+ ref_view_strategy: str = "saddle_balanced",
73
+ gs_trj_mode: str = "extend",
74
+ gs_video_quality: str = "high",
75
+ ) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
76
+ """
77
+ Run DepthAnything3 model inference on images.
78
+
79
+ Args:
80
+ target_dir: Directory containing images
81
+ filter_black_bg: Whether to filter black background
82
+ filter_white_bg: Whether to filter white background
83
+ process_res_method: Method for resizing input images
84
+ show_camera: Whether to show camera in 3D view
85
+ save_percentage: Percentage of points to save (0-100)
86
+ num_max_points: Maximum number of points in point cloud
87
+ infer_gs: Whether to infer 3D Gaussian Splatting
88
+ ref_view_strategy: Reference view selection strategy
89
+ gs_trj_mode: Trajectory mode for 3DGS
90
+ gs_video_quality: Video quality for 3DGS
91
+
92
+ Returns:
93
+ Tuple of (prediction, processed_data)
94
+ """
95
+ print(f"Processing images from {target_dir}")
96
+
97
+ # Device check
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
+ device = torch.device(device)
100
+
101
+ # Initialize model if needed
102
+ self.initialize_model(device)
103
+
104
+ # Get image paths
105
+ print("Loading images...")
106
+ image_folder_path = os.path.join(target_dir, "images")
107
+ all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))
108
+
109
+ # Filter for image files
110
+ image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
111
+ all_image_paths = [
112
+ path
113
+ for path in all_image_paths
114
+ if any(path.lower().endswith(ext) for ext in image_extensions)
115
+ ]
116
+
117
+ print(f"Found {len(all_image_paths)} images")
118
+ print(f"All image paths: {all_image_paths}")
119
+
120
+ # Use sorted image order (reference view will be selected automatically)
121
+ image_paths = all_image_paths
122
+ print(f"Reference view selection strategy: {ref_view_strategy}")
123
+
124
+ if len(image_paths) == 0:
125
+ raise ValueError("No images found. Check your upload.")
126
+
127
+ # Map UI options to actual method names
128
+ method_mapping = {"high_res": "lower_bound_resize", "low_res": "upper_bound_resize"}
129
+ actual_method = method_mapping.get(process_res_method, "upper_bound_crop")
130
+
131
+ # Run model inference
132
+ print(f"Running inference with method: {actual_method}")
133
+ with torch.no_grad():
134
+ prediction = self.model.inference(
135
+ image_paths,
136
+ export_dir=None,
137
+ process_res_method=actual_method,
138
+ infer_gs=infer_gs,
139
+ ref_view_strategy=ref_view_strategy,
140
+ )
141
+ # num_max_points: int = 1_000_000,
142
+ export_to_glb(
143
+ prediction,
144
+ filter_black_bg=filter_black_bg,
145
+ filter_white_bg=filter_white_bg,
146
+ export_dir=target_dir,
147
+ show_cameras=show_camera,
148
+ conf_thresh_percentile=save_percentage,
149
+ num_max_points=int(num_max_points),
150
+ )
151
+
152
+ # export to gs video if needed
153
+ if infer_gs:
154
+ mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
155
+ print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}")
156
+ export_to_gs_video(
157
+ prediction,
158
+ export_dir=target_dir,
159
+ chunk_size=4,
160
+ trj_mode=mode_mapping.get(gs_trj_mode, "extend"),
161
+ enable_tqdm=True,
162
+ vis_depth="hcat",
163
+ video_quality=gs_video_quality,
164
+ )
165
+
166
+ # Save predictions.npz for caching metric depth data
167
+ self._save_predictions_cache(target_dir, prediction)
168
+
169
+ # Process results
170
+ processed_data = self._process_results(target_dir, prediction, image_paths)
171
+
172
+ # Clean up using centralized memory utilities for consistency with backend
173
+ cleanup_cuda_memory()
174
+
175
+ return prediction, processed_data
176
+
177
+ def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None:
178
+ """
179
+ Save predictions data to predictions.npz for caching.
180
+
181
+ Args:
182
+ target_dir: Directory to save the cache
183
+ prediction: Model prediction object
184
+ """
185
+ try:
186
+ output_file = os.path.join(target_dir, "predictions.npz")
187
+
188
+ # Build save dict with prediction data
189
+ save_dict = {}
190
+
191
+ # Save processed images if available
192
+ if prediction.processed_images is not None:
193
+ save_dict["images"] = prediction.processed_images
194
+
195
+ # Save depth data
196
+ if prediction.depth is not None:
197
+ save_dict["depths"] = np.round(prediction.depth, 6)
198
+
199
+ # Save confidence if available
200
+ if prediction.conf is not None:
201
+ save_dict["conf"] = np.round(prediction.conf, 2)
202
+
203
+ # Save camera parameters
204
+ if prediction.extrinsics is not None:
205
+ save_dict["extrinsics"] = prediction.extrinsics
206
+ if prediction.intrinsics is not None:
207
+ save_dict["intrinsics"] = prediction.intrinsics
208
+
209
+ # Save to file
210
+ np.savez_compressed(output_file, **save_dict)
211
+ print(f"Saved predictions cache to: {output_file}")
212
+
213
+ except Exception as e:
214
+ print(f"Warning: Failed to save predictions cache: {e}")
215
+
216
+ def _process_results(
217
+ self, target_dir: str, prediction: Any, image_paths: list
218
+ ) -> Dict[int, Dict[str, Any]]:
219
+ """
220
+ Process model results into structured data.
221
+
222
+ Args:
223
+ target_dir: Directory containing results
224
+ prediction: Model prediction object
225
+ image_paths: List of input image paths
226
+
227
+ Returns:
228
+ Dictionary containing processed data for each view
229
+ """
230
+ processed_data = {}
231
+
232
+ # Read generated depth visualization files
233
+ depth_vis_dir = os.path.join(target_dir, "depth_vis")
234
+
235
+ if os.path.exists(depth_vis_dir):
236
+ depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg")))
237
+ for i, depth_file in enumerate(depth_files):
238
+ # Use processed images directly from API
239
+ processed_image = None
240
+ if prediction.processed_images is not None and i < len(
241
+ prediction.processed_images
242
+ ):
243
+ processed_image = prediction.processed_images[i]
244
+
245
+ processed_data[i] = {
246
+ "depth_image": depth_file,
247
+ "image": processed_image,
248
+ "original_image_path": image_paths[i] if i < len(image_paths) else None,
249
+ "depth": prediction.depth[i] if i < len(prediction.depth) else None,
250
+ "intrinsics": (
251
+ prediction.intrinsics[i]
252
+ if prediction.intrinsics is not None and i < len(prediction.intrinsics)
253
+ else None
254
+ ),
255
+ "mask": None, # No mask information available
256
+ }
257
+
258
+ return processed_data
259
+
260
+ # cleanup() removed: call cleanup_cuda_memory() directly where needed.
Depth-Anything-3/src/depth_anything_3/app/modules/ui_components.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ UI components module for Depth Anything 3 Gradio app.
17
+
18
+ This module contains UI component definitions and layout functions.
19
+ """
20
+
21
+ import os
22
+ from typing import Any, Dict, List, Tuple
23
+ import gradio as gr
24
+
25
+ from depth_anything_3.app.modules.utils import get_logo_base64, get_scene_info
26
+
27
+
28
+ class UIComponents:
29
+ """
30
+ Handles UI component creation and layout for the Gradio app.
31
+ """
32
+
33
+ def __init__(self):
34
+ """Initialize the UI components handler."""
35
+
36
+ def create_upload_section(self) -> Tuple[gr.Video, gr.Slider, gr.File, gr.Gallery]:
37
+ """
38
+ Create the upload section with video, images, and gallery components.
39
+
40
+ Returns:
41
+ A tuple of Gradio components: (input_video, s_time_interval, input_images, image_gallery).
42
+ """
43
+ input_video = gr.Video(label="Upload Video", interactive=True)
44
+ s_time_interval = gr.Slider(
45
+ minimum=0.1,
46
+ maximum=60,
47
+ value=10,
48
+ step=0.1,
49
+ label="Sampling FPS (Frames Per Second)",
50
+ interactive=True,
51
+ visible=True,
52
+ )
53
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
54
+ image_gallery = gr.Gallery(
55
+ label="Preview",
56
+ columns=4,
57
+ height="300px",
58
+ show_download_button=True,
59
+ object_fit="contain",
60
+ preview=True,
61
+ interactive=False,
62
+ )
63
+
64
+ return input_video, s_time_interval, input_images, image_gallery
65
+
66
+ def create_3d_viewer_section(self) -> gr.Model3D:
67
+ """
68
+ Create the 3D viewer component.
69
+
70
+ Returns:
71
+ 3D model viewer component
72
+ """
73
+ return gr.Model3D(
74
+ height=520,
75
+ zoom_speed=0.5,
76
+ pan_speed=0.5,
77
+ clear_color=[0.0, 0.0, 0.0, 0.0],
78
+ key="persistent_3d_viewer",
79
+ elem_id="reconstruction_3d_viewer",
80
+ )
81
+
82
+ def create_nvs_video(self) -> Tuple[gr.Video, gr.Markdown]:
83
+ """
84
+ Create the 3DGS rendered video display component and info message.
85
+
86
+ Returns:
87
+ Tuple of (video component, info message component)
88
+ """
89
+ with gr.Column():
90
+ gs_info = gr.Markdown(
91
+ (
92
+ "‼️ **3D Gaussian Splatting rendering is currently DISABLED.** <br><br><br>"
93
+ "To render novel views from 3DGS, "
94
+ "enable **Infer 3D Gaussian Splatting** below. <br>"
95
+ "Next, in **Visualization Options**, "
96
+ "*optionally* configure the **rendering trajectory** (default: smooth) "
97
+ "and **video quality** (default: low), "
98
+ "then click **Reconstruct**."
99
+ ),
100
+ visible=True,
101
+ height=520,
102
+ )
103
+ gs_video = gr.Video(
104
+ height=520,
105
+ label="3DGS Rendered NVS Video (depth shown for reference only)",
106
+ interactive=False,
107
+ visible=False,
108
+ )
109
+ return gs_video, gs_info
110
+
111
+ def create_depth_section(self) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image]:
112
+ """
113
+ Create the depth visualization section.
114
+
115
+ Returns:
116
+ A tuple of (prev_depth_btn, depth_view_selector, next_depth_btn, depth_map)
117
+ """
118
+ with gr.Row(elem_classes=["navigation-row"]):
119
+ prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
120
+ depth_view_selector = gr.Dropdown(
121
+ choices=["View 1"],
122
+ value="View 1",
123
+ label="Select View",
124
+ scale=2,
125
+ interactive=True,
126
+ allow_custom_value=True,
127
+ )
128
+ next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
129
+ depth_map = gr.Image(
130
+ type="numpy",
131
+ label="Colorized Depth Map",
132
+ format="png",
133
+ interactive=False,
134
+ )
135
+
136
+ return prev_depth_btn, depth_view_selector, next_depth_btn, depth_map
137
+
138
+ def create_measure_section(
139
+ self,
140
+ ) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image, gr.Image, gr.Markdown]:
141
+ """
142
+ Create the measurement section.
143
+
144
+ Returns:
145
+ A tuple of (prev_measure_btn, measure_view_selector, next_measure_btn, measure_image,
146
+ measure_depth_image, measure_text)
147
+ """
148
+ from depth_anything_3.app.css_and_html import MEASURE_INSTRUCTIONS_HTML
149
+
150
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
151
+ with gr.Row(elem_classes=["navigation-row"]):
152
+ prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1)
153
+ measure_view_selector = gr.Dropdown(
154
+ choices=["View 1"],
155
+ value="View 1",
156
+ label="Select View",
157
+ scale=2,
158
+ interactive=True,
159
+ allow_custom_value=True,
160
+ )
161
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
162
+ with gr.Row():
163
+ measure_image = gr.Image(
164
+ type="numpy",
165
+ show_label=False,
166
+ format="webp",
167
+ interactive=False,
168
+ sources=[],
169
+ label="RGB Image",
170
+ scale=1,
171
+ height=275,
172
+ )
173
+ measure_depth_image = gr.Image(
174
+ type="numpy",
175
+ show_label=False,
176
+ format="webp",
177
+ interactive=False,
178
+ sources=[],
179
+ label="Depth Visualization (Right Half)",
180
+ scale=1,
181
+ height=275,
182
+ )
183
+ gr.Markdown(
184
+ "**Note:** Images have been adjusted to model processing size. "
185
+ "Click two points on the RGB image to measure distance."
186
+ )
187
+ measure_text = gr.Markdown("")
188
+
189
+ return (
190
+ prev_measure_btn,
191
+ measure_view_selector,
192
+ next_measure_btn,
193
+ measure_image,
194
+ measure_depth_image,
195
+ measure_text,
196
+ )
197
+
198
+ def create_inference_control_section(self) -> Tuple[gr.Dropdown, gr.Checkbox, gr.Dropdown]:
199
+ """
200
+ Create the inference control section (before inference).
201
+
202
+ Returns:
203
+ Tuple of (process_res_method_dropdown, infer_gs, ref_view_strategy)
204
+ """
205
+ with gr.Row():
206
+ process_res_method_dropdown = gr.Dropdown(
207
+ choices=["high_res", "low_res"],
208
+ value="low_res",
209
+ label="Image Processing Method",
210
+ info="low_res for much more images",
211
+ scale=1,
212
+ )
213
+ # Modify line 220, add color class
214
+ infer_gs = gr.Checkbox(
215
+ label="Infer 3D Gaussian Splatting",
216
+ value=False,
217
+ info=(
218
+ 'Enable novel view rendering from 3DGS (<i class="fas fa-triangle-exclamation '
219
+ 'fa-color-red"></i> requires extra processing time)'
220
+ ),
221
+ scale=1,
222
+ )
223
+ ref_view_strategy = gr.Dropdown(
224
+ choices=["saddle_balanced", "saddle_sim_range", "first", "middle"],
225
+ value="saddle_balanced",
226
+ label="Reference View Strategy",
227
+ info="Strategy for selecting reference view from multiple inputs",
228
+ scale=1,
229
+ )
230
+
231
+ return (process_res_method_dropdown, infer_gs, ref_view_strategy)
232
+
233
+ def create_display_control_section(
234
+ self,
235
+ ) -> Tuple[
236
+ gr.Checkbox,
237
+ gr.Checkbox,
238
+ gr.Checkbox,
239
+ gr.Slider,
240
+ gr.Slider,
241
+ gr.Dropdown,
242
+ gr.Dropdown,
243
+ gr.Button,
244
+ gr.ClearButton,
245
+ ]:
246
+ """
247
+ Create the display control section (options for visualization).
248
+
249
+ Returns:
250
+ Tuple of display control components including buttons
251
+ """
252
+ with gr.Column():
253
+ # 3DGS options at the top
254
+ with gr.Row():
255
+ gs_trj_mode = gr.Dropdown(
256
+ choices=["smooth", "extend"],
257
+ value="smooth",
258
+ label=("Rendering trajectory for 3DGS viewpoints (requires n_views ≥ 2)"),
259
+ info=("'smooth' for view interpolation; 'extend' for longer trajectory"),
260
+ visible=False, # initially hidden
261
+ )
262
+ gs_video_quality = gr.Dropdown(
263
+ choices=["low", "medium", "high"],
264
+ value="low",
265
+ label=("Video quality for 3DGS rendered outputs"),
266
+ info=("'low' for faster loading speed; 'high' for better visual quality"),
267
+ visible=False, # initially hidden
268
+ )
269
+
270
+ # Reconstruct and Clear buttons (before Visualization Options)
271
+ with gr.Row():
272
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
273
+ clear_btn = gr.ClearButton(scale=1)
274
+
275
+ gr.Markdown("### Visualization Options: (Click Reconstruct to update)")
276
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
277
+ filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
278
+ filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
279
+ save_percentage = gr.Slider(
280
+ minimum=0,
281
+ maximum=100,
282
+ value=10,
283
+ step=1,
284
+ label="Filter Percentage",
285
+ info="Confidence Threshold (%): Higher values filter more points.",
286
+ )
287
+ num_max_points = gr.Slider(
288
+ minimum=1000,
289
+ maximum=100000,
290
+ value=1000,
291
+ step=1000,
292
+ label="Max Points (K points)",
293
+ info="Maximum number of points to export to GLB (in thousands)",
294
+ )
295
+
296
+ return (
297
+ show_cam,
298
+ filter_black_bg,
299
+ filter_white_bg,
300
+ save_percentage,
301
+ num_max_points,
302
+ gs_trj_mode,
303
+ gs_video_quality,
304
+ submit_btn,
305
+ clear_btn,
306
+ )
307
+
308
+ def create_control_section(
309
+ self,
310
+ ) -> Tuple[
311
+ gr.Button,
312
+ gr.ClearButton,
313
+ gr.Dropdown,
314
+ gr.Checkbox,
315
+ gr.Checkbox,
316
+ gr.Checkbox,
317
+ gr.Checkbox,
318
+ gr.Checkbox,
319
+ gr.Dropdown,
320
+ gr.Checkbox,
321
+ gr.Textbox,
322
+ ]:
323
+ """
324
+ Create the control section with buttons and options.
325
+
326
+ Returns:
327
+ Tuple of control components
328
+ """
329
+ with gr.Row():
330
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
331
+ clear_btn = gr.ClearButton(
332
+ scale=1,
333
+ )
334
+
335
+ with gr.Row():
336
+ frame_filter = gr.Dropdown(
337
+ choices=["All"], value="All", label="Show Points from Frame"
338
+ )
339
+ with gr.Column():
340
+ gr.Markdown("### Visualization Option: (Click Reconstruct to update)")
341
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
342
+ show_mesh = gr.Checkbox(label="Show Mesh", value=True)
343
+ filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
344
+ filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
345
+ gr.Markdown("### Reconstruction Options: (updated on next run)")
346
+ apply_mask_checkbox = gr.Checkbox(
347
+ label="Apply mask for predicted ambiguous depth classes & edges",
348
+ value=True,
349
+ )
350
+ process_res_method_dropdown = gr.Dropdown(
351
+ choices=[
352
+ "upper_bound_resize",
353
+ "upper_bound_crop",
354
+ "lower_bound_resize",
355
+ "lower_bound_crop",
356
+ ],
357
+ value="upper_bound_resize",
358
+ label="Image Processing Method",
359
+ info="Method for resizing input images",
360
+ )
361
+ save_to_gallery_checkbox = gr.Checkbox(
362
+ label="Save to Gallery",
363
+ value=False,
364
+ info="Save current reconstruction results to gallery directory",
365
+ )
366
+ gallery_name_input = gr.Textbox(
367
+ label="Gallery Name",
368
+ placeholder="Enter a name for the gallery folder",
369
+ value="",
370
+ info="Leave empty for auto-generated name with timestamp",
371
+ )
372
+
373
+ return (
374
+ submit_btn,
375
+ clear_btn,
376
+ frame_filter,
377
+ show_cam,
378
+ show_mesh,
379
+ filter_black_bg,
380
+ filter_white_bg,
381
+ apply_mask_checkbox,
382
+ process_res_method_dropdown,
383
+ save_to_gallery_checkbox,
384
+ gallery_name_input,
385
+ )
386
+
387
+ def create_example_scenes_section(self) -> List[Dict[str, Any]]:
388
+ """
389
+ Create the example scenes section.
390
+
391
+ Returns:
392
+ List of scene information dictionaries
393
+ """
394
+ # Get workspace directory from environment variable
395
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
396
+ examples_dir = os.path.join(workspace_dir, "examples")
397
+
398
+ # Get scene information
399
+ scenes = get_scene_info(examples_dir)
400
+
401
+ return scenes
402
+
403
+ def create_example_scene_grid(self, scenes: List[Dict[str, Any]]) -> List[gr.Image]:
404
+ """
405
+ Create the example scene grid.
406
+
407
+ Args:
408
+ scenes: List of scene information dictionaries
409
+
410
+ Returns:
411
+ List of scene image components
412
+ """
413
+ scene_components = []
414
+
415
+ if scenes:
416
+ for i in range(0, len(scenes), 4): # Process 4 scenes per row
417
+ with gr.Row():
418
+ for j in range(4):
419
+ scene_idx = i + j
420
+ if scene_idx < len(scenes):
421
+ scene = scenes[scene_idx]
422
+ with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
423
+ # Clickable thumbnail
424
+ scene_img = gr.Image(
425
+ value=scene["thumbnail"],
426
+ height=150,
427
+ interactive=False,
428
+ show_label=False,
429
+ elem_id=f"scene_thumb_{scene['name']}",
430
+ sources=[],
431
+ )
432
+ scene_components.append(scene_img)
433
+
434
+ # Scene name and image count as text below thumbnail
435
+ gr.Markdown(
436
+ f"**{scene['name']}** \n {scene['num_images']} images",
437
+ elem_classes=["scene-info"],
438
+ )
439
+ else:
440
+ # Empty column to maintain grid structure
441
+ with gr.Column(scale=1):
442
+ pass
443
+
444
+ return scene_components
445
+
446
+ def create_header_section(self) -> gr.HTML:
447
+ """
448
+ Create the header section with logo and title.
449
+
450
+ Returns:
451
+ Header HTML component
452
+ """
453
+ from depth_anything_3.app.css_and_html import get_header_html
454
+
455
+ return gr.HTML(get_header_html(get_logo_base64()))
456
+
457
+ def create_description_section(self) -> gr.HTML:
458
+ """
459
+ Create the description section.
460
+
461
+ Returns:
462
+ Description HTML component
463
+ """
464
+ from depth_anything_3.app.css_and_html import get_description_html
465
+
466
+ return gr.HTML(get_description_html())
467
+
468
+ def create_acknowledgements_section(self) -> gr.HTML:
469
+ """
470
+ Create the acknowledgements section.
471
+
472
+ Returns:
473
+ Acknowledgements HTML component
474
+ """
475
+ from depth_anything_3.app.css_and_html import get_acknowledgements_html
476
+
477
+ return gr.HTML(get_acknowledgements_html())
Depth-Anything-3/src/depth_anything_3/app/modules/utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Utility functions for Depth Anything 3 Gradio app.
17
+
18
+ This module contains helper functions for data processing, visualization,
19
+ and file operations.
20
+ """
21
+
22
+
23
+ import json
24
+ import os
25
+ import shutil
26
+ from datetime import datetime
27
+ from typing import Any, Dict, List, Optional, Tuple
28
+ import numpy as np
29
+
30
+ def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]:
31
+ """
32
+ Create a colored depth visualization.
33
+
34
+ Args:
35
+ depth: Depth array
36
+
37
+ Returns:
38
+ Colored depth visualization or None
39
+ """
40
+ if depth is None:
41
+ return None
42
+
43
+ # Normalize depth to 0-1 range
44
+ depth_min = depth[depth > 0].min() if (depth > 0).any() else 0
45
+ depth_max = depth.max()
46
+
47
+ if depth_max <= depth_min:
48
+ return None
49
+
50
+ # Normalize depth
51
+ depth_norm = (depth - depth_min) / (depth_max - depth_min)
52
+ depth_norm = np.clip(depth_norm, 0, 1)
53
+
54
+ # Apply colormap (using matplotlib's viridis colormap)
55
+ import matplotlib.cm as cm
56
+
57
+ # Convert to colored image
58
+ depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel
59
+ depth_colored = (depth_colored * 255).astype(np.uint8)
60
+
61
+ return depth_colored
62
+
63
+
64
+ def save_to_gallery_func(
65
+ target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None
66
+ ) -> Tuple[bool, str]:
67
+ """
68
+ Save the current reconstruction results to the gallery directory.
69
+
70
+ Args:
71
+ target_dir: Source directory containing reconstruction results
72
+ processed_data: Processed data dictionary
73
+ gallery_name: Name for the gallery folder
74
+
75
+ Returns:
76
+ Tuple of (success, message)
77
+ """
78
+ try:
79
+ # Get gallery directory from environment variable or use default
80
+ gallery_dir = os.environ.get(
81
+ "DA3_GALLERY_DIR",
82
+ "workspace/gallery",
83
+ )
84
+ if not os.path.exists(gallery_dir):
85
+ os.makedirs(gallery_dir)
86
+
87
+ # Use provided name or create a unique name
88
+ if gallery_name is None or gallery_name.strip() == "":
89
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
90
+ gallery_name = f"reconstruction_{timestamp}"
91
+
92
+ gallery_path = os.path.join(gallery_dir, gallery_name)
93
+
94
+ # Check if directory already exists
95
+ if os.path.exists(gallery_path):
96
+ return False, f"Save failed: folder '{gallery_name}' already exists"
97
+
98
+ # Create the gallery directory
99
+ os.makedirs(gallery_path, exist_ok=True)
100
+
101
+ # Copy GLB file
102
+ glb_source = os.path.join(target_dir, "scene.glb")
103
+ glb_dest = os.path.join(gallery_path, "scene.glb")
104
+ if os.path.exists(glb_source):
105
+ shutil.copy2(glb_source, glb_dest)
106
+
107
+ # Copy depth visualization images
108
+ depth_vis_dir = os.path.join(target_dir, "depth_vis")
109
+ if os.path.exists(depth_vis_dir):
110
+ gallery_depth_vis = os.path.join(gallery_path, "depth_vis")
111
+ shutil.copytree(depth_vis_dir, gallery_depth_vis)
112
+
113
+ # Copy original images
114
+ images_source = os.path.join(target_dir, "images")
115
+ if os.path.exists(images_source):
116
+ gallery_images = os.path.join(gallery_path, "images")
117
+ shutil.copytree(images_source, gallery_images)
118
+
119
+ scene_preview_source = os.path.join(target_dir, "scene.jpg")
120
+ scene_preview_dest = os.path.join(gallery_path, "scene.jpg")
121
+ shutil.copy2(scene_preview_source, scene_preview_dest)
122
+
123
+ # Save metadata
124
+ metadata = {
125
+ "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
126
+ "num_images": len(processed_data) if processed_data else 0,
127
+ "gallery_name": gallery_name,
128
+ }
129
+
130
+ with open(os.path.join(gallery_path, "metadata.json"), "w") as f:
131
+ json.dump(metadata, f, indent=2)
132
+
133
+ print(f"Saved reconstruction to gallery: {gallery_path}")
134
+ return True, f"Save successful: saved to {gallery_path}"
135
+
136
+ except Exception as e:
137
+ print(f"Error saving to gallery: {e}")
138
+ return False, f"Save failed: {str(e)}"
139
+
140
+
141
+ def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]:
142
+ """
143
+ Get information about scenes in the examples directory.
144
+
145
+ Args:
146
+ examples_dir: Path to examples directory
147
+
148
+ Returns:
149
+ List of scene information dictionaries
150
+ """
151
+ import glob
152
+
153
+ scenes = []
154
+ if not os.path.exists(examples_dir):
155
+ return scenes
156
+
157
+ for scene_folder in sorted(os.listdir(examples_dir)):
158
+ scene_path = os.path.join(examples_dir, scene_folder)
159
+ if os.path.isdir(scene_path):
160
+ # Find all image files in the scene folder
161
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
162
+ image_files = []
163
+ for ext in image_extensions:
164
+ image_files.extend(glob.glob(os.path.join(scene_path, ext)))
165
+ image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
166
+
167
+ if image_files:
168
+ # Sort images and get the first one for thumbnail
169
+ image_files = sorted(image_files)
170
+ first_image = image_files[0]
171
+ num_images = len(image_files)
172
+
173
+ scenes.append(
174
+ {
175
+ "name": scene_folder,
176
+ "path": scene_path,
177
+ "thumbnail": first_image,
178
+ "num_images": num_images,
179
+ "image_files": image_files,
180
+ }
181
+ )
182
+
183
+ return scenes
184
+
185
+
186
+ # NOTE: cleanup was moved to a single canonical helper in
187
+ # `depth_anything_3.utils.memory.cleanup_cuda_memory`.
188
+ # Callers should import and call that directly instead of using this module.
189
+
190
+
191
+ def get_logo_base64() -> Optional[str]:
192
+ """
193
+ Convert WAI logo to base64 for embedding in HTML.
194
+
195
+ Returns:
196
+ Base64 encoded logo string or None
197
+ """
198
+ import base64
199
+
200
+ logo_path = "examples/WAI-Logo/wai_logo.png"
201
+ try:
202
+ with open(logo_path, "rb") as img_file:
203
+ img_data = img_file.read()
204
+ base64_str = base64.b64encode(img_data).decode()
205
+ return f"data:image/png;base64,{base64_str}"
206
+ except FileNotFoundError:
207
+ return None
Depth-Anything-3/src/depth_anything_3/app/modules/visualization.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Visualization module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles visualization updates, navigation, and measurement functionality.
19
+ """
20
+
21
+ import os
22
+ from typing import Any, Dict, List, Optional, Tuple
23
+ import cv2
24
+ import gradio as gr
25
+ import numpy as np
26
+
27
+
28
+ class VisualizationHandler:
29
+ """
30
+ Handles visualization updates and navigation for the Gradio app.
31
+ """
32
+
33
+ def __init__(self):
34
+ """Initialize the visualization handler."""
35
+
36
+ def update_view_selectors(
37
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]]
38
+ ) -> Tuple[gr.Dropdown, gr.Dropdown]:
39
+ """
40
+ Update view selector dropdowns based on available views.
41
+
42
+ Args:
43
+ processed_data: Processed data dictionary
44
+
45
+ Returns:
46
+ Tuple of (depth_view_selector, measure_view_selector)
47
+ """
48
+ if processed_data is None or len(processed_data) == 0:
49
+ choices = ["View 1"]
50
+ else:
51
+ num_views = len(processed_data)
52
+ choices = [f"View {i + 1}" for i in range(num_views)]
53
+
54
+ return (
55
+ gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
56
+ gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
57
+ )
58
+
59
+ def get_view_data_by_index(
60
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
61
+ ) -> Optional[Dict[str, Any]]:
62
+ """
63
+ Get view data by index, handling bounds.
64
+
65
+ Args:
66
+ processed_data: Processed data dictionary
67
+ view_index: Index of the view to get
68
+
69
+ Returns:
70
+ View data dictionary or None
71
+ """
72
+ if processed_data is None or len(processed_data) == 0:
73
+ return None
74
+
75
+ view_keys = list(processed_data.keys())
76
+ if view_index < 0 or view_index >= len(view_keys):
77
+ view_index = 0
78
+
79
+ return processed_data[view_keys[view_index]]
80
+
81
+ def update_depth_view(
82
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
83
+ ) -> Optional[str]:
84
+ """
85
+ Update depth view for a specific view index.
86
+
87
+ Args:
88
+ processed_data: Processed data dictionary
89
+ view_index: Index of the view to update
90
+
91
+ Returns:
92
+ Path to depth visualization image or None
93
+ """
94
+ view_data = self.get_view_data_by_index(processed_data, view_index)
95
+ if view_data is None or view_data.get("depth_image") is None:
96
+ return None
97
+
98
+ # Return the depth visualization image directly
99
+ return view_data["depth_image"]
100
+
101
+ def navigate_depth_view(
102
+ self,
103
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
104
+ current_selector_value: str,
105
+ direction: int,
106
+ ) -> Tuple[str, Optional[str]]:
107
+ """
108
+ Navigate depth view (direction: -1 for previous, +1 for next).
109
+
110
+ Args:
111
+ processed_data: Processed data dictionary
112
+ current_selector_value: Current selector value
113
+ direction: Direction to navigate (-1 for previous, +1 for next)
114
+
115
+ Returns:
116
+ Tuple of (new_selector_value, depth_vis)
117
+ """
118
+ if processed_data is None or len(processed_data) == 0:
119
+ return "View 1", None
120
+
121
+ # Parse current view number
122
+ try:
123
+ current_view = int(current_selector_value.split()[1]) - 1
124
+ except: # noqa
125
+ current_view = 0
126
+
127
+ num_views = len(processed_data)
128
+ new_view = (current_view + direction) % num_views
129
+
130
+ new_selector_value = f"View {new_view + 1}"
131
+ depth_vis = self.update_depth_view(processed_data, new_view)
132
+
133
+ return new_selector_value, depth_vis
134
+
135
+ def update_measure_view(
136
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
137
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
138
+ """
139
+ Update measure view for a specific view index.
140
+
141
+ Args:
142
+ processed_data: Processed data dictionary
143
+ view_index: Index of the view to update
144
+
145
+ Returns:
146
+ Tuple of (measure_image, depth_right_half, measure_points)
147
+ """
148
+ view_data = self.get_view_data_by_index(processed_data, view_index)
149
+ if view_data is None:
150
+ return None, None, [] # image, depth_right_half, measure_points
151
+
152
+ # Get the processed (resized) image
153
+ if "image" in view_data and view_data["image"] is not None:
154
+ image = view_data["image"].copy()
155
+ else:
156
+ return None, None, []
157
+
158
+ # Ensure image is in uint8 format
159
+ if image.dtype != np.uint8:
160
+ if image.max() <= 1.0:
161
+ image = (image * 255).astype(np.uint8)
162
+ else:
163
+ image = image.astype(np.uint8)
164
+
165
+ # Extract right half of the depth visualization (pure depth part)
166
+ depth_image_path = view_data.get("depth_image", None)
167
+ depth_right_half = None
168
+
169
+ if depth_image_path and os.path.exists(depth_image_path):
170
+ try:
171
+ # Load the combined depth visualization image
172
+ depth_combined = cv2.imread(depth_image_path)
173
+ depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB)
174
+ if depth_combined is not None:
175
+ height, width = depth_combined.shape[:2]
176
+ # Extract right half (depth visualization part)
177
+ depth_right_half = depth_combined[:, width // 2 :]
178
+ except Exception as e:
179
+ print(f"Error extracting depth right half: {e}")
180
+
181
+ return image, depth_right_half, []
182
+
183
+ def navigate_measure_view(
184
+ self,
185
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
186
+ current_selector_value: str,
187
+ direction: int,
188
+ ) -> Tuple[str, Optional[np.ndarray], Optional[str], List]:
189
+ """
190
+ Navigate measure view (direction: -1 for previous, +1 for next).
191
+
192
+ Args:
193
+ processed_data: Processed data dictionary
194
+ current_selector_value: Current selector value
195
+ direction: Direction to navigate (-1 for previous, +1 for next)
196
+
197
+ Returns:
198
+ Tuple of (new_selector_value, measure_image, depth_image_path, measure_points)
199
+ """
200
+ if processed_data is None or len(processed_data) == 0:
201
+ return "View 1", None, None, []
202
+
203
+ # Parse current view number
204
+ try:
205
+ current_view = int(current_selector_value.split()[1]) - 1
206
+ except: # noqa
207
+ current_view = 0
208
+
209
+ num_views = len(processed_data)
210
+ new_view = (current_view + direction) % num_views
211
+
212
+ new_selector_value = f"View {new_view + 1}"
213
+ measure_image, depth_right_half, measure_points = self.update_measure_view(
214
+ processed_data, new_view
215
+ )
216
+
217
+ return new_selector_value, measure_image, depth_right_half, measure_points
218
+
219
+ def populate_visualization_tabs(
220
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]]
221
+ ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]:
222
+ """
223
+ Populate the depth and measure tabs with processed data.
224
+
225
+ Args:
226
+ processed_data: Processed data dictionary
227
+
228
+ Returns:
229
+ Tuple of (depth_vis, measure_img, depth_image_path, measure_points)
230
+ """
231
+ if processed_data is None or len(processed_data) == 0:
232
+ return None, None, None, []
233
+
234
+ # Use update function to get depth visualization
235
+ depth_vis = self.update_depth_view(processed_data, 0)
236
+ measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0)
237
+
238
+ return depth_vis, measure_img, depth_right_half, []
239
+
240
+ def reset_measure(
241
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]]
242
+ ) -> Tuple[Optional[np.ndarray], List, str]:
243
+ """
244
+ Reset measure points.
245
+
246
+ Args:
247
+ processed_data: Processed data dictionary
248
+
249
+ Returns:
250
+ Tuple of (image, measure_points, text)
251
+ """
252
+ if processed_data is None or len(processed_data) == 0:
253
+ return None, [], ""
254
+
255
+ # Return the first view image
256
+ first_view = list(processed_data.values())[0]
257
+ return first_view["image"], [], ""
258
+
259
+ def measure(
260
+ self,
261
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
262
+ measure_points: List,
263
+ current_view_selector: str,
264
+ event: gr.SelectData,
265
+ ) -> List:
266
+ """
267
+ Handle measurement on images.
268
+
269
+ Args:
270
+ processed_data: Processed data dictionary
271
+ measure_points: List of current measure points
272
+ current_view_selector: Current view selector value
273
+ event: Gradio select event
274
+
275
+ Returns:
276
+ List of [image, depth_right_half, measure_points, text]
277
+ """
278
+ try:
279
+ print(f"Measure function called with selector: {current_view_selector}")
280
+
281
+ if processed_data is None or len(processed_data) == 0:
282
+ return [None, [], "No data available"]
283
+
284
+ # Use the currently selected view instead of always using the first view
285
+ try:
286
+ current_view_index = int(current_view_selector.split()[1]) - 1
287
+ except: # noqa
288
+ current_view_index = 0
289
+
290
+ print(f"Using view index: {current_view_index}")
291
+
292
+ # Get view data safely
293
+ if current_view_index < 0 or current_view_index >= len(processed_data):
294
+ current_view_index = 0
295
+
296
+ view_keys = list(processed_data.keys())
297
+ current_view = processed_data[view_keys[current_view_index]]
298
+
299
+ if current_view is None:
300
+ return [None, [], "No view data available"]
301
+
302
+ point2d = event.index[0], event.index[1]
303
+ print(f"Clicked point: {point2d}")
304
+
305
+ measure_points.append(point2d)
306
+
307
+ # Get image and depth visualization
308
+ image, depth_right_half, _ = self.update_measure_view(
309
+ processed_data, current_view_index
310
+ )
311
+ if image is None:
312
+ return [None, [], "No image available"]
313
+
314
+ image = image.copy()
315
+
316
+ # Ensure image is in uint8 format for proper cv2 operations
317
+ try:
318
+ if image.dtype != np.uint8:
319
+ if image.max() <= 1.0:
320
+ # Image is in [0, 1] range, convert to [0, 255]
321
+ image = (image * 255).astype(np.uint8)
322
+ else:
323
+ # Image is already in [0, 255] range
324
+ image = image.astype(np.uint8)
325
+ except Exception as e:
326
+ print(f"Image conversion error: {e}")
327
+ return [None, [], f"Image conversion error: {e}"]
328
+
329
+ # Draw circles for points
330
+ try:
331
+ for p in measure_points:
332
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
333
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
334
+ except Exception as e:
335
+ print(f"Drawing error: {e}")
336
+ return [None, [], f"Drawing error: {e}"]
337
+
338
+ # Get depth information from processed_data
339
+ depth_text = ""
340
+ try:
341
+ for i, p in enumerate(measure_points):
342
+ if (
343
+ current_view["depth"] is not None
344
+ and 0 <= p[1] < current_view["depth"].shape[0]
345
+ and 0 <= p[0] < current_view["depth"].shape[1]
346
+ ):
347
+ d = current_view["depth"][p[1], p[0]]
348
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n"
349
+ else:
350
+ depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501
351
+ except Exception as e:
352
+ print(f"Depth text error: {e}")
353
+ depth_text = f"Error computing depth: {e}\n"
354
+
355
+ if len(measure_points) == 2:
356
+ try:
357
+ point1, point2 = measure_points
358
+ # Draw line
359
+ if (
360
+ 0 <= point1[0] < image.shape[1]
361
+ and 0 <= point1[1] < image.shape[0]
362
+ and 0 <= point2[0] < image.shape[1]
363
+ and 0 <= point2[1] < image.shape[0]
364
+ ):
365
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
366
+
367
+ # Compute 3D distance using depth information and camera intrinsics
368
+ distance_text = "- **Distance: Unable to calculate 3D distance**"
369
+ if (
370
+ current_view["depth"] is not None
371
+ and 0 <= point1[1] < current_view["depth"].shape[0]
372
+ and 0 <= point1[0] < current_view["depth"].shape[1]
373
+ and 0 <= point2[1] < current_view["depth"].shape[0]
374
+ and 0 <= point2[0] < current_view["depth"].shape[1]
375
+ ):
376
+ try:
377
+ # Get depth values at the two points
378
+ d1 = current_view["depth"][point1[1], point1[0]]
379
+ d2 = current_view["depth"][point2[1], point2[0]]
380
+
381
+ # Convert 2D pixel coordinates to 3D world coordinates
382
+ if current_view["intrinsics"] is not None:
383
+ # Get camera intrinsics
384
+ K = current_view["intrinsics"] # 3x3 intrinsic matrix
385
+ fx, fy = K[0, 0], K[1, 1] # focal lengths
386
+ cx, cy = K[0, 2], K[1, 2] # principal point
387
+
388
+ # Convert pixel coordinates to normalized camera coordinates
389
+ # Point 1: (u1, v1) -> (x1, y1, z1)
390
+ u1, v1 = point1[0], point1[1]
391
+ x1 = (u1 - cx) * d1 / fx
392
+ y1 = (v1 - cy) * d1 / fy
393
+ z1 = d1
394
+
395
+ # Point 2: (u2, v2) -> (x2, y2, z2)
396
+ u2, v2 = point2[0], point2[1]
397
+ x2 = (u2 - cx) * d2 / fx
398
+ y2 = (v2 - cy) * d2 / fy
399
+ z2 = d2
400
+
401
+ # Calculate 3D Euclidean distance
402
+ p1_3d = np.array([x1, y1, z1])
403
+ p2_3d = np.array([x2, y2, z2])
404
+ distance_3d = np.linalg.norm(p1_3d - p2_3d)
405
+
406
+ distance_text = f"- **Distance: {distance_3d:.2f}m**"
407
+ else:
408
+ # Fallback to simplified calculation if no intrinsics
409
+ pixel_distance = np.sqrt(
410
+ (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
411
+ )
412
+ avg_depth = (d1 + d2) / 2
413
+ scale_factor = avg_depth / 1000 # Rough scaling factor
414
+ estimated_3d_distance = pixel_distance * scale_factor
415
+ distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501
416
+
417
+ except Exception as e:
418
+ print(f"Distance computation error: {e}")
419
+ distance_text = f"- **Distance computation error: {e}**"
420
+
421
+ measure_points = []
422
+ text = depth_text + distance_text
423
+ print(f"Measurement complete: {text}")
424
+ return [image, depth_right_half, measure_points, text]
425
+ except Exception as e:
426
+ print(f"Final measurement error: {e}")
427
+ return [None, [], f"Measurement error: {e}"]
428
+ else:
429
+ print(f"Single point measurement: {depth_text}")
430
+ return [image, depth_right_half, measure_points, depth_text]
431
+
432
+ except Exception as e:
433
+ print(f"Overall measure function error: {e}")
434
+ return [None, [], f"Measure function error: {e}"]
Depth-Anything-3/src/depth_anything_3/bench/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Depth Anything 3 Benchmark Evaluation Module.
17
+
18
+ This module provides tools for evaluating DepthAnything3 model on various benchmark datasets.
19
+ Currently supported datasets:
20
+ - DTU (3D Reconstruction)
21
+ - DTU-64 (Pose Evaluation Only)
22
+ - ETH3D (3D Reconstruction)
23
+ - 7Scenes (3D Reconstruction)
24
+ - ScanNet++ (3D Reconstruction)
25
+ - HiRoom (3D Reconstruction)
26
+
27
+ Supported evaluation modes:
28
+ - pose: Camera pose estimation evaluation
29
+ - recon_unposed: 3D reconstruction with predicted poses
30
+ - recon_posed: 3D reconstruction with ground truth poses
31
+ """
32
+
33
+ from depth_anything_3.bench.registries import MV_REGISTRY, MONO_REGISTRY
34
+
35
+
36
+ def __getattr__(name):
37
+ """Lazy import to avoid circular import when running as __main__."""
38
+ if name == "Evaluator":
39
+ from depth_anything_3.bench.evaluator import Evaluator
40
+ return Evaluator
41
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
42
+
43
+
44
+ __all__ = ["Evaluator", "MV_REGISTRY", "MONO_REGISTRY"]
45
+
Depth-Anything-3/src/depth_anything_3/bench/configs/eval_bench.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DepthAnything3 Benchmark Evaluation Configuration
2
+ #
3
+ # This config can be loaded and overridden via command line.
4
+ # Example: python -m depth_anything_3.bench.evaluator --model /path/to/model --work_dir /path/to/workspace
5
+ #
6
+ # See depth_anything_3.cfg for config utility functions.
7
+
8
+ # ==============================================================================
9
+ # Model Configuration
10
+ # ==============================================================================
11
+ model:
12
+ # Path to model checkpoint or HuggingFace model ID
13
+ path: depth-anything/DA3-GIANT
14
+
15
+ # ==============================================================================
16
+ # Workspace Configuration
17
+ # ==============================================================================
18
+ workspace:
19
+ # Working directory for outputs (model results, metrics, etc.)
20
+ work_dir: ./workspace/evaluation
21
+
22
+ # ==============================================================================
23
+ # Evaluation Configuration
24
+ # ==============================================================================
25
+ eval:
26
+ # Datasets to evaluate
27
+ # Options: dtu, dtu64, eth3d, 7scenes (sevenscenes), scannetpp, hiroom
28
+ datasets:
29
+ - eth3d
30
+ - 7scenes
31
+ - scannetpp
32
+ - hiroom
33
+ - dtu
34
+ - dtu64
35
+
36
+ # Evaluation modes
37
+ # Options: pose, recon_unposed, recon_posed, view_syn
38
+ modes:
39
+ - pose
40
+ - recon_unposed
41
+ - recon_posed
42
+
43
+ # Reference view selection strategy for inference
44
+ # Options: first, saddle_balanced, auto, mid
45
+ ref_view_strategy: "first"
46
+
47
+ # Specific scenes to evaluate (null = all scenes)
48
+ # Example: [courtyard, relief] for eth3d
49
+ scenes: null
50
+
51
+ # Maximum number of frames per scene (for sampling)
52
+ # If a scene has more frames, randomly sample to this limit.
53
+ # Set to -1 to disable sampling.
54
+ max_frames: 100
55
+
56
+ # Only run evaluation (skip inference)
57
+ eval_only: false
58
+
59
+ # Only print saved metrics (skip inference and evaluation)
60
+ print_only: false
61
+
62
+ # ==============================================================================
63
+ # Inference Configuration
64
+ # ==============================================================================
65
+ inference:
66
+ # Number of parallel workers for TSDF fusion
67
+ num_fusion_workers: 4
68
+
69
+ # Enable debug mode with verbose output
70
+ debug: false
71
+
72
+ # ==============================================================================
73
+ # Preset Configurations
74
+ # ==============================================================================
75
+ # These can be activated via command line: --preset full_eval
76
+
77
+ presets:
78
+ # Full evaluation on all 6 datasets
79
+ full_eval:
80
+ datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu, dtu64]
81
+ modes: [pose, recon_unposed, recon_posed]
82
+
83
+ # Pose-only evaluation
84
+ pose_only:
85
+ datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu64]
86
+ modes: [pose]
87
+
88
+ # Reconstruction-only evaluation (5 datasets, excluding dtu64)
89
+ recon_only:
90
+ datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu]
91
+ modes: [recon_unposed, recon_posed]
92
+
93
+ # Quick test (single scene per dataset)
94
+ quick_test:
95
+ datasets: [eth3d]
96
+ modes: [pose, recon_unposed]
97
+ scenes: [courtyard]
98
+
Depth-Anything-3/src/depth_anything_3/bench/dataset.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Base dataset class for benchmark evaluation.
17
+
18
+ All dataset implementations should inherit from this class and implement
19
+ the required abstract methods.
20
+ """
21
+
22
+ import os
23
+ import time
24
+ from abc import abstractmethod
25
+ from typing import Dict as TDict
26
+
27
+ import numpy as np
28
+ import torch
29
+ from addict import Dict
30
+
31
+ from depth_anything_3.bench.utils import compute_pose
32
+ from depth_anything_3.utils.geometry import as_homogeneous
33
+
34
+
35
+ def _wait_for_file_ready(path: str, timeout: float = 3.0, interval: float = 0.2) -> None:
36
+ """Wait until file size stabilizes for 2 consecutive checks."""
37
+ last_size = -1
38
+ stable_count = 0
39
+ start = time.time()
40
+ while time.time() - start < timeout:
41
+ time.sleep(interval)
42
+ size = os.path.getsize(path)
43
+ if size == last_size and size > 0:
44
+ stable_count += 1
45
+ if stable_count >= 2: # Need 2 consecutive stable checks
46
+ return
47
+ else:
48
+ stable_count = 0
49
+ last_size = size
50
+
51
+
52
+ class Dataset:
53
+ """
54
+ Base class for all benchmark datasets.
55
+
56
+ Subclasses must implement:
57
+ - SCENES: List of scene identifiers
58
+ - data_root: Path to dataset root
59
+ - get_data(scene): Return scene data (images, intrinsics, extrinsics, etc.)
60
+ - eval3d(scene, fuse_path): Evaluate 3D reconstruction
61
+ - fuse3d(scene, result_path, fuse_path, mode): Fuse depth maps into point cloud
62
+
63
+ Optional overrides:
64
+ - eval_pose(scene, result_path): Evaluate pose estimation (default provided)
65
+ """
66
+
67
+ # Subclasses should define these
68
+ SCENES: list = []
69
+ data_root: str = ""
70
+
71
+ def __init__(self):
72
+ pass
73
+
74
+ def eval_pose(self, scene: str, result_path: str) -> TDict[str, float]:
75
+ """
76
+ Evaluate camera pose estimation accuracy.
77
+
78
+ Args:
79
+ scene: Scene identifier
80
+ result_path: Path to .npz file containing predicted extrinsics
81
+
82
+ Returns:
83
+ Dict with pose metrics (auc30, auc15, auc05, auc03)
84
+ """
85
+ _wait_for_file_ready(result_path)
86
+ pred = np.load(result_path)
87
+ gt = self.get_data(scene)
88
+ return compute_pose(
89
+ torch.from_numpy(as_homogeneous(pred["extrinsics"])),
90
+ torch.from_numpy(as_homogeneous(gt["extrinsics"])),
91
+ )
92
+
93
+ @abstractmethod
94
+ def get_data(self, scene: str) -> Dict:
95
+ """
96
+ Get scene data including images, camera parameters, and auxiliary info.
97
+
98
+ Args:
99
+ scene: Scene identifier
100
+
101
+ Returns:
102
+ Dict with:
103
+ - image_files: List[str] - paths to images
104
+ - extrinsics: np.ndarray [N, 4, 4] - camera extrinsics (world-to-camera)
105
+ - intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
106
+ - aux: Dict - auxiliary data (masks, GT paths, etc.)
107
+ """
108
+ raise NotImplementedError
109
+
110
+ @abstractmethod
111
+ def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
112
+ """
113
+ Evaluate 3D reconstruction quality against ground truth.
114
+
115
+ Args:
116
+ scene: Scene identifier
117
+ fuse_path: Path to fused point cloud (.ply)
118
+
119
+ Returns:
120
+ Dict with reconstruction metrics (e.g., acc, comp, overall)
121
+ """
122
+ raise NotImplementedError
123
+
124
+ @abstractmethod
125
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
126
+ """
127
+ Fuse per-view depth maps into a single point cloud.
128
+
129
+ Args:
130
+ scene: Scene identifier
131
+ result_path: Path to .npz file with predicted depths and poses
132
+ fuse_path: Output path for fused point cloud (.ply)
133
+ mode: Fusion mode ("recon_unposed" or "recon_posed")
134
+ """
135
+ raise NotImplementedError
136
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Benchmark dataset implementations.
17
+
18
+ Datasets are auto-registered via decorators when imported.
19
+ Add new dataset files here and they will be automatically discovered.
20
+ """
21
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ DTU Benchmark dataset implementation.
17
+
18
+ DTU is a multi-view stereo benchmark for 3D reconstruction evaluation.
19
+ Reference: https://roboimagedata.compute.dtu.dk/
20
+
21
+ Note: DepthAnything3 was never trained on any images from DTU.
22
+ """
23
+
24
+ import glob
25
+ import os
26
+ from typing import Dict as TDict, List
27
+
28
+ import numpy as np
29
+ import open3d as o3d
30
+ import torch
31
+ import torch.nn.functional as F
32
+ from addict import Dict
33
+ from PIL import Image
34
+ from plyfile import PlyData
35
+ from scipy.io import loadmat
36
+ from sklearn import neighbors as skln
37
+ from tqdm import tqdm
38
+
39
+ from depth_anything_3.bench.dataset import Dataset
40
+ from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
41
+ from depth_anything_3.utils.constants import (
42
+ DTU_DIST_THRESH,
43
+ DTU_EVAL_DATA_ROOT,
44
+ DTU_MAX_POINTS,
45
+ DTU_NUM_CONSIST,
46
+ DTU_SCENES,
47
+ )
48
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
49
+
50
+
51
+ @MV_REGISTRY.register(name="dtu")
52
+ @MONO_REGISTRY.register(name="dtu")
53
+ class DTU(Dataset):
54
+ """
55
+ DTU Benchmark dataset wrapper for DepthAnything3 evaluation.
56
+
57
+ Supports:
58
+ - Camera pose estimation evaluation (AUC metrics)
59
+ - 3D reconstruction evaluation (accuracy, completeness, overall)
60
+ - Point cloud fusion from depth maps
61
+
62
+ The dataset uses MVSNet evaluation protocol:
63
+ https://drive.google.com/file/d/1rX0EXlUL4prRxrRu2DgLJv2j7-tpUD4D/view
64
+ """
65
+
66
+ data_root = DTU_EVAL_DATA_ROOT
67
+ SCENES = DTU_SCENES
68
+
69
+ # Evaluation/triangulation hyperparameters from constants
70
+ dist_thresh = DTU_DIST_THRESH
71
+ num_consist = DTU_NUM_CONSIST
72
+
73
+ # ------------------------------
74
+ # Public API
75
+ # ------------------------------
76
+
77
+ def read_cam_file(self, filename: str) -> tuple:
78
+ """
79
+ Read DTU camera file containing extrinsics and intrinsics.
80
+
81
+ Args:
82
+ filename: Path to camera text file
83
+
84
+ Returns:
85
+ Tuple of (intrinsics [3,3], extrinsics [4,4])
86
+ """
87
+ with open(filename) as f:
88
+ lines = [line.rstrip() for line in f.readlines()]
89
+ extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4))
90
+ intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3))
91
+ return intrinsics, extrinsics
92
+
93
+ def get_data(self, scene: str) -> Dict:
94
+ """
95
+ Collect per-view image paths, intrinsics/extrinsics, and GT masks.
96
+
97
+ Args:
98
+ scene: Scene identifier (e.g., "scan1")
99
+
100
+ Returns:
101
+ Dict with:
102
+ - image_files: List[str] - paths to images
103
+ - extrinsics: np.ndarray [N, 4, 4]
104
+ - intrinsics: np.ndarray [N, 3, 3]
105
+ - aux.mask_files: List[str] - paths to depth masks
106
+ """
107
+ rgb_folder = os.path.join(self.data_root, "Rectified", scene)
108
+ camera_folder = os.path.join(self.data_root, "Cameras")
109
+
110
+ files = sorted(glob.glob(os.path.join(rgb_folder, "*.png")))
111
+ # Reorder: place index 33 first (reference view convention)
112
+ files = [files[33]] + files[:33] + files[34:]
113
+
114
+ out = Dict(
115
+ {
116
+ "image_files": files,
117
+ "extrinsics": [],
118
+ "intrinsics": [],
119
+ "aux": Dict({"mask_files": []}),
120
+ }
121
+ )
122
+
123
+ for rgb_file in files:
124
+ basename = os.path.basename(rgb_file)
125
+ file_idx = basename.split("_")[1]
126
+ cam_idx = depth_idx = int(file_idx) - 1
127
+
128
+ mask_file = self._depth_mask_path(scene, depth_idx)
129
+ proj_mat_filename = os.path.join(camera_folder, f"{cam_idx:0>8}_cam.txt")
130
+
131
+ ixt, ext = self.read_cam_file(proj_mat_filename)
132
+ out.extrinsics.append(ext)
133
+ out.intrinsics.append(ixt)
134
+ out.aux.mask_files.append(mask_file)
135
+
136
+ out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
137
+ out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
138
+ return out
139
+
140
+ def get_3dgtpath(self, scene: str) -> str:
141
+ """Get path to ground truth point cloud for a scene."""
142
+ scene_id = int(scene[4:])
143
+ return os.path.join(self.data_root, f"Points/stl/stl{scene_id:03}_total.ply")
144
+
145
+ def eval3d(self, scene: str, fuse_path: str, use_gpu: bool = False) -> TDict[str, float]:
146
+ """
147
+ Evaluate fused point cloud against DTU GT with ObsMask/Plane.
148
+
149
+ Args:
150
+ scene: Scene identifier
151
+ fuse_path: Path to fused point cloud
152
+ use_gpu: If True, use GPU-accelerated distance computation (faster but may have minor numerical differences)
153
+
154
+ Returns:
155
+ Dict with metrics: {"comp": float, "acc": float, "overall": float}
156
+ """
157
+ scene_id = int(scene[4:])
158
+ gt_ply = os.path.join(self.data_root, f"Points/stl/stl{scene_id:03}_total.ply")
159
+ mask_file = os.path.join(
160
+ self.data_root, f"SampleSet/mvs_data/ObsMask/ObsMask{scene_id}_10.mat"
161
+ )
162
+ plane_file = os.path.join(
163
+ self.data_root, f"SampleSet/mvs_data/ObsMask/Plane{scene_id}.mat"
164
+ )
165
+ result = self._evaluate_reconstruction(
166
+ scene, fuse_path, gt_ply, mask_file, plane_file, use_gpu=use_gpu
167
+ )
168
+ return {"comp": result[0], "acc": result[1], "overall": result[2]}
169
+
170
+ def load_masks(self, mask_files: List[str]) -> np.ndarray:
171
+ """
172
+ Load DTU depth validity masks.
173
+
174
+ Args:
175
+ mask_files: List of paths to mask images
176
+
177
+ Returns:
178
+ Boolean array [N, H, W] indicating valid depth regions
179
+ """
180
+ masks = []
181
+ for mask_file in mask_files:
182
+ mask = Image.open(mask_file)
183
+ mask = np.array(mask, dtype=np.float32)
184
+ masks.append(mask > 10)
185
+ return np.asarray(masks)
186
+
187
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
188
+ """
189
+ Fuse per-view depths into a point cloud and save to PLY.
190
+
191
+ Args:
192
+ scene: Scene identifier (e.g., "scan114")
193
+ result_path: Path to npz file containing predicted depths/poses
194
+ fuse_path: Output path for fused point cloud (.ply)
195
+ mode: "recon_unposed" or "recon_posed"
196
+ """
197
+ gt_data = self.get_data(scene)
198
+ pred_data = Dict({k: v for k, v in np.load(result_path).items()})
199
+ masks = self.load_masks(gt_data.aux.mask_files)
200
+
201
+ if mode == "recon_unposed":
202
+ depths, intrinsics, extrinsics = self._prep_unposed(pred_data, gt_data, masks)
203
+ elif mode == "recon_posed":
204
+ depths, intrinsics, extrinsics = self._prep_posed(pred_data, gt_data, masks)
205
+ else:
206
+ raise ValueError(f"Invalid mode: {mode}")
207
+
208
+ proj_mat = self._build_proj_mats(intrinsics, extrinsics)
209
+
210
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
+ dtype = torch.float32
212
+ depths_t = torch.from_numpy(depths).to(device=device, dtype=dtype).unsqueeze(1)
213
+ proj_t = torch.from_numpy(proj_mat).to(device=device, dtype=dtype)
214
+ height, width = depths_t.shape[-2:]
215
+
216
+ points: List[np.ndarray] = []
217
+ for idx in range(len(gt_data.image_files)):
218
+ if mode == "recon_unposed":
219
+ # Simple unfiltered back-projection per frame
220
+ cur_p_pcd = self._generate_points_from_depth(
221
+ depths_t[idx : idx + 1], proj_t[idx : idx + 1]
222
+ )
223
+ mask = (depths_t[idx : idx + 1] > 0.001).squeeze()
224
+ cur_p_pcd = cur_p_pcd[:, :, mask]
225
+ no_filter_pc = cur_p_pcd.squeeze(0).permute(1, 0).cpu().numpy()
226
+ points.append(no_filter_pc)
227
+ else: # recon_posed
228
+ final_pc = self._fuse_consistent_points(depths_t, proj_t, idx, height, width)
229
+ points.append(final_pc)
230
+
231
+ # Concatenate and optionally downsample to hard cap
232
+ points_np = np.concatenate(points, axis=0)
233
+ points_np = self._cap_points(points_np, max_points=DTU_MAX_POINTS)
234
+
235
+ os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
236
+ pcd = o3d.geometry.PointCloud()
237
+ pcd.points = o3d.utility.Vector3dVector(points_np)
238
+ o3d.io.write_point_cloud(fuse_path, pcd)
239
+
240
+ # ------------------------------
241
+ # Geometry helpers
242
+ # ------------------------------
243
+
244
+ def _generate_points_from_depth(
245
+ self, depth: torch.Tensor, proj: torch.Tensor
246
+ ) -> torch.Tensor:
247
+ """
248
+ Back-project depth map into 3D world coordinates.
249
+
250
+ Args:
251
+ depth: Depth tensor [B, 1, H, W]
252
+ proj: Projection matrix [B, 4, 4] = [[K@R, K@t], [0,0,0,1]]
253
+
254
+ Returns:
255
+ Point cloud tensor [B, 3, H, W]
256
+ """
257
+ batch, height, width = depth.shape[0], depth.shape[2], depth.shape[3]
258
+ inv_proj = torch.inverse(proj)
259
+ rot = inv_proj[:, :3, :3]
260
+ trans = inv_proj[:, :3, 3:4]
261
+
262
+ y, x = torch.meshgrid(
263
+ [
264
+ torch.arange(0, height, dtype=torch.float32, device=depth.device),
265
+ torch.arange(0, width, dtype=torch.float32, device=depth.device),
266
+ ],
267
+ indexing="ij",
268
+ )
269
+ y, x = y.contiguous(), x.contiguous()
270
+ y, x = y.view(height * width), x.view(height * width)
271
+ xyz = torch.stack((x, y, torch.ones_like(x)))
272
+ xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)
273
+ rot_xyz = torch.matmul(rot, xyz)
274
+ rot_depth_xyz = rot_xyz * depth.view(batch, 1, -1)
275
+ proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1)
276
+ return proj_xyz.view(batch, 3, height, width)
277
+
278
+ def _homo_warping(
279
+ self,
280
+ src_fea: torch.Tensor,
281
+ src_proj: torch.Tensor,
282
+ ref_proj: torch.Tensor,
283
+ depth_values: torch.Tensor,
284
+ ) -> torch.Tensor:
285
+ """
286
+ Homography warping for multi-view consistency checking.
287
+
288
+ Args:
289
+ src_fea: Source features [B, C, H, W]
290
+ src_proj: Source projection [B, 4, 4]
291
+ ref_proj: Reference projection [B, 4, 4]
292
+ depth_values: Depth values [B, Ndepth] or [B, Ndepth, H, W]
293
+
294
+ Returns:
295
+ Warped features [B, C, H, W]
296
+ """
297
+ batch, channels = src_fea.shape[0], src_fea.shape[1]
298
+ height, width = src_fea.shape[2], src_fea.shape[3]
299
+
300
+ with torch.no_grad():
301
+ proj = torch.matmul(src_proj, torch.inverse(ref_proj))
302
+ rot = proj[:, :3, :3]
303
+ trans = proj[:, :3, 3:4]
304
+
305
+ y, x = torch.meshgrid(
306
+ [
307
+ torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
308
+ torch.arange(0, width, dtype=torch.float32, device=src_fea.device),
309
+ ],
310
+ indexing="ij",
311
+ )
312
+ y, x = y.contiguous(), x.contiguous()
313
+ y, x = y.view(height * width), x.view(height * width)
314
+ xyz = torch.stack((x, y, torch.ones_like(x)))
315
+ xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)
316
+ rot_xyz = torch.matmul(rot, xyz)
317
+
318
+ rot_depth_xyz = rot_xyz.unsqueeze(2) * depth_values.view(-1, 1, 1, height * width)
319
+ proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)
320
+ proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]
321
+ proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
322
+ proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
323
+ grid = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)
324
+
325
+ warped_src_fea = F.grid_sample(
326
+ src_fea,
327
+ grid.view(batch, height, width, 2),
328
+ mode="bilinear",
329
+ padding_mode="zeros",
330
+ align_corners=True,
331
+ )
332
+ return warped_src_fea.view(batch, channels, height, width)
333
+
334
+ def _filter_depth(
335
+ self,
336
+ ref_depth: torch.Tensor,
337
+ src_depths: torch.Tensor,
338
+ ref_proj: torch.Tensor,
339
+ src_projs: torch.Tensor,
340
+ ) -> tuple:
341
+ """
342
+ Compute geometric consistency between reference and source depths.
343
+
344
+ Args:
345
+ ref_depth: Reference depth [1, 1, H, W]
346
+ src_depths: Source depths [B, 1, H, W]
347
+ ref_proj: Reference projection [1, 4, 4]
348
+ src_projs: Source projections [B, 4, 4]
349
+
350
+ Returns:
351
+ Tuple of (ref_pc, aligned_pcs, dist)
352
+ """
353
+ ref_pc = self._generate_points_from_depth(ref_depth, ref_proj)
354
+ src_pcs = self._generate_points_from_depth(src_depths, src_projs)
355
+ aligned_pcs = self._homo_warping(src_pcs, src_projs, ref_proj, ref_depth)
356
+ x_2 = (ref_pc[:, 0] - aligned_pcs[:, 0]) ** 2
357
+ y_2 = (ref_pc[:, 1] - aligned_pcs[:, 1]) ** 2
358
+ z_2 = (ref_pc[:, 2] - aligned_pcs[:, 2]) ** 2
359
+ dist = torch.sqrt(x_2 + y_2 + z_2).unsqueeze(1)
360
+ return ref_pc, aligned_pcs, dist
361
+
362
+ def _extract_points(
363
+ self, pc: torch.Tensor, mask: torch.Tensor, rgb: np.ndarray = None
364
+ ) -> np.ndarray:
365
+ """Extract masked points from a dense grid."""
366
+ pc = pc.cpu().numpy()
367
+ mask = mask.cpu().numpy().reshape(-1)
368
+ pc = pc.reshape(-1, 3)
369
+ points = pc[np.where(mask)]
370
+ if rgb is not None:
371
+ rgb = rgb.reshape(-1, 3)
372
+ colors = rgb[np.where(mask)]
373
+ return np.concatenate([points, colors], axis=1)
374
+ return points
375
+
376
+ # ------------------------------
377
+ # 3D Reconstruction Evaluation
378
+ # ------------------------------
379
+
380
+ def _evaluate_reconstruction(
381
+ self,
382
+ scanid: str,
383
+ pred_ply: str,
384
+ gt_ply: str,
385
+ mask_file: str,
386
+ plane_file: str,
387
+ down_dense: float = 0.2,
388
+ patch: int = 60,
389
+ max_dist: int = 20,
390
+ use_gpu: bool = False,
391
+ ) -> tuple:
392
+ """
393
+ Compute accuracy, completeness, and overall metrics for one scan.
394
+
395
+ Args:
396
+ scanid: Scan identifier
397
+ pred_ply: Predicted point cloud path or array
398
+ gt_ply: Ground truth point cloud path or array
399
+ mask_file: ObsMask file path
400
+ plane_file: Plane file path
401
+ down_dense: Downsample density (min distance between points)
402
+ patch: Patch size for boundary
403
+ max_dist: Outlier threshold in mm
404
+ use_gpu: If True, use GPU-accelerated distance computation
405
+
406
+ Returns:
407
+ Tuple of (mean_d2s, mean_s2d, overall)
408
+ """
409
+ thresh = down_dense
410
+
411
+ # Load and downsample predicted point cloud
412
+ data_pcd = self._read_ply(pred_ply) if isinstance(pred_ply, str) else pred_ply
413
+ # Use fixed seed for reproducibility
414
+ shuffle_rng = np.random.default_rng(seed=42)
415
+ shuffle_rng.shuffle(data_pcd, axis=0)
416
+
417
+ # Downsample point cloud
418
+ nn_engine = skln.NearestNeighbors(
419
+ n_neighbors=1, radius=thresh, algorithm="kd_tree", n_jobs=-1
420
+ )
421
+ nn_engine.fit(data_pcd)
422
+ rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False)
423
+ mask = np.ones(data_pcd.shape[0], dtype=np.bool_)
424
+ for curr, idxs in enumerate(rnn_idxs):
425
+ if mask[curr]:
426
+ mask[idxs] = 0
427
+ mask[curr] = 1
428
+ data_down = data_pcd[mask]
429
+
430
+ # Restrict to observed volume (ObsMask)
431
+ obs_mask_file = loadmat(mask_file)
432
+ ObsMask, BB, Res = (obs_mask_file[attr] for attr in ["ObsMask", "BB", "Res"])
433
+ BB = BB.astype(np.float32)
434
+
435
+ inbound = ((data_down >= BB[:1] - patch) & (data_down < BB[1:] + patch * 2)).sum(
436
+ axis=-1
437
+ ) == 3
438
+ data_in = data_down[inbound]
439
+
440
+ data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32)
441
+ grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(
442
+ axis=-1
443
+ ) == 3
444
+ data_grid_in = data_grid[grid_inbound]
445
+ in_obs = ObsMask[data_grid_in[:, 0], data_grid_in[:, 1], data_grid_in[:, 2]].astype(
446
+ np.bool_
447
+ )
448
+ data_in_obs = data_in[grid_inbound][in_obs]
449
+
450
+ # Compute accuracy (pred -> GT) and completeness (GT -> pred)
451
+ stl = self._read_ply(gt_ply) if isinstance(gt_ply, str) else gt_ply
452
+
453
+ if use_gpu and torch.cuda.is_available():
454
+ # GPU-accelerated distance computation
455
+ mean_d2s = self._knn_dist_gpu(data_in_obs, stl, max_dist)
456
+ else:
457
+ # CPU version (original, for exact reproduction)
458
+ nn_engine.fit(stl)
459
+ dist_d2s, _ = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True)
460
+ mean_d2s = dist_d2s[dist_d2s < max_dist].mean()
461
+
462
+ ground_plane = loadmat(plane_file)["P"]
463
+ stl_hom = np.concatenate([stl, np.ones_like(stl[:, :1])], -1)
464
+ above = (ground_plane.reshape((1, 4)) * stl_hom).sum(-1) > 0
465
+ stl_above = stl[above]
466
+
467
+ if use_gpu and torch.cuda.is_available():
468
+ # GPU-accelerated distance computation
469
+ mean_s2d = self._knn_dist_gpu(stl_above, data_in, max_dist)
470
+ else:
471
+ # CPU version (original, for exact reproduction)
472
+ nn_engine.fit(data_in)
473
+ dist_s2d, _ = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True)
474
+ mean_s2d = dist_s2d[dist_s2d < max_dist].mean()
475
+
476
+ overall = (mean_d2s + mean_s2d) / 2
477
+ return mean_d2s, mean_s2d, overall
478
+
479
+ def _knn_dist_gpu(
480
+ self,
481
+ query: np.ndarray,
482
+ target: np.ndarray,
483
+ max_dist: float,
484
+ batch_size: int = 8192,
485
+ target_batch_size: int = 50000,
486
+ ) -> float:
487
+ """
488
+ GPU-accelerated nearest neighbor distance computation.
489
+
490
+ Args:
491
+ query: Query points [N, 3]
492
+ target: Target points [M, 3]
493
+ max_dist: Outlier threshold
494
+ batch_size: Batch size for query to avoid OOM (tuned for 16GB GPU)
495
+ target_batch_size: Batch size for target to avoid OOM
496
+
497
+ Returns:
498
+ Mean distance (excluding outliers)
499
+ """
500
+ device = torch.device("cuda")
501
+
502
+ all_min_dists = []
503
+ n_query_batches = (len(query) + batch_size - 1) // batch_size
504
+ n_target_batches = (len(target) + target_batch_size - 1) // target_batch_size
505
+
506
+ # Pre-load target batches to GPU to avoid repeated transfers
507
+ # Memory: ~50000 pts * 3 coords * 4 bytes * n_batches
508
+ target_batches = []
509
+ for j in range(0, len(target), target_batch_size):
510
+ target_batch = target[j : j + target_batch_size]
511
+ target_t = torch.from_numpy(target_batch).float().to(device)
512
+ target_batches.append(target_t)
513
+
514
+ with tqdm(total=n_query_batches, desc=" GPU KNN", leave=False, ncols=100) as pbar:
515
+ for i in range(0, len(query), batch_size):
516
+ batch = query[i : i + batch_size]
517
+ query_t = torch.from_numpy(batch).float().to(device)
518
+
519
+ # Compute distances to all target batches
520
+ # Memory peak: query_batch × target_batch_size × 4 bytes
521
+ # = 8192 × 50000 × 4 = ~1.6 GB per cdist call
522
+ batch_min_dists = []
523
+ for target_t in target_batches:
524
+ dists = torch.cdist(query_t, target_t)
525
+ batch_min_dists.append(dists.min(dim=1).values)
526
+ del dists # Free immediately
527
+
528
+ # Get minimum distance across all target batches
529
+ min_dists = torch.stack(batch_min_dists, dim=1).min(dim=1).values
530
+ all_min_dists.append(min_dists.cpu().numpy())
531
+
532
+ del query_t, min_dists, batch_min_dists
533
+ pbar.update(1)
534
+
535
+ # Clean up target batches
536
+ for target_t in target_batches:
537
+ del target_t
538
+ torch.cuda.empty_cache()
539
+
540
+ all_min_dists = np.concatenate(all_min_dists)
541
+ return all_min_dists[all_min_dists < max_dist].mean()
542
+
543
+ def _read_ply(self, file: str) -> np.ndarray:
544
+ """Read point cloud from PLY file."""
545
+ data = PlyData.read(file)
546
+ vertex = data["vertex"]
547
+ return np.stack([vertex["x"], vertex["y"], vertex["z"]], axis=-1)
548
+
549
+ # ------------------------------
550
+ # Private helpers
551
+ # ------------------------------
552
+
553
+ def _depth_mask_path(self, scene: str, depth_idx: int) -> str:
554
+ """Get path to depth mask for a scene and frame."""
555
+ return os.path.join(
556
+ self.data_root, "depth_raw", "Depths", scene, f"depth_visual_{depth_idx:04d}.png"
557
+ )
558
+
559
+ def _prep_unposed(
560
+ self, pred_data: Dict, gt_data: Dict, masks: np.ndarray
561
+ ) -> tuple:
562
+ """
563
+ Prepare depths/intrinsics/extrinsics for recon_unposed mode.
564
+
565
+ Applies Umeyama scale, rescales intrinsics if depth resolution differs,
566
+ and zeroes invalid-mask depths (nearest interpolation as in paper).
567
+ """
568
+ _, _, scale, extrinsics = align_poses_umeyama(
569
+ gt_data.extrinsics.copy(),
570
+ pred_data.extrinsics.copy(),
571
+ ransac=True,
572
+ return_aligned=True,
573
+ random_state=42,
574
+ )
575
+ depths = pred_data.depth * scale
576
+ intrinsics = pred_data.intrinsics.copy()
577
+
578
+ if depths.shape[-2:] != masks.shape[-2:]:
579
+ # When resizing depths to mask size, adjust intrinsics accordingly
580
+ sx = masks.shape[-1] / depths.shape[-1]
581
+ sy = masks.shape[-2] / depths.shape[-2]
582
+ intrinsics[:, 0:1] *= sx
583
+ intrinsics[:, 1:2] *= sy
584
+ depths = F.interpolate(
585
+ torch.from_numpy(depths)[None].float(),
586
+ size=(masks.shape[-2], masks.shape[-1]),
587
+ mode="nearest",
588
+ )[0].numpy()
589
+ depths[masks == False] = 0.0 # noqa: E712
590
+
591
+ return depths, intrinsics, extrinsics
592
+
593
+ def _prep_posed(
594
+ self, pred_data: Dict, gt_data: Dict, masks: np.ndarray
595
+ ) -> tuple:
596
+ """
597
+ Prepare depths/intrinsics/extrinsics for recon_posed mode.
598
+
599
+ Uses GT intrinsics/extrinsics but aligns scale via Umeyama.
600
+ Same mask order as other datasets: mask BEFORE scale.
601
+ """
602
+ _, _, scale, _ = align_poses_umeyama(
603
+ gt_data.extrinsics.copy(),
604
+ pred_data.extrinsics.copy(),
605
+ ransac=True,
606
+ return_aligned=True,
607
+ random_state=42,
608
+ )
609
+ depths = pred_data.depth.copy()
610
+ intrinsics = gt_data.intrinsics.copy()
611
+ extrinsics = gt_data.extrinsics.copy()
612
+
613
+ if depths.shape[-2:] != masks.shape[-2:]:
614
+ depths = F.interpolate(
615
+ torch.from_numpy(depths)[None].float(),
616
+ size=(masks.shape[-2], masks.shape[-1]),
617
+ mode="nearest",
618
+ )[0].numpy()
619
+
620
+ # Mask BEFORE scale (same as other datasets)
621
+ depths[masks == False] = 0.0 # noqa: E712
622
+ depths = depths * scale
623
+
624
+ return depths, intrinsics, extrinsics
625
+
626
+ def _build_proj_mats(
627
+ self, intrinsics: np.ndarray, extrinsics: np.ndarray
628
+ ) -> np.ndarray:
629
+ """Compute per-view 4x4 projection matrices from K and [R|t]."""
630
+ proj_mat_list = []
631
+ for i in range(len(intrinsics)):
632
+ proj_mat = np.eye(4, dtype=np.float32)
633
+ proj_mat[:3, :4] = np.dot(intrinsics[i], extrinsics[i][:3])
634
+ proj_mat_list.append(proj_mat)
635
+ return np.stack(proj_mat_list, axis=0)
636
+
637
+ def _fuse_consistent_points(
638
+ self,
639
+ depths_t: torch.Tensor,
640
+ proj_t: torch.Tensor,
641
+ idx: int,
642
+ H: int,
643
+ W: int,
644
+ ) -> np.ndarray:
645
+ """Fuse points consistent across multiple source views for a reference index."""
646
+ device, dtype = depths_t.device, depths_t.dtype
647
+ pc_buff = torch.zeros((3, H, W), device=device, dtype=dtype)
648
+ val_cnt = torch.zeros((1, H, W), device=device, dtype=dtype)
649
+
650
+ j = 0
651
+ batch_size = 20
652
+ tot_frame = depths_t.shape[0]
653
+ while True:
654
+ ref_pc, pcs, dist = self._filter_depth(
655
+ ref_depth=depths_t[idx : idx + 1],
656
+ src_depths=depths_t[j : min(j + batch_size, tot_frame)],
657
+ ref_proj=proj_t[idx : idx + 1],
658
+ src_projs=proj_t[j : min(j + batch_size, tot_frame)],
659
+ )
660
+ masks = (dist < self.dist_thresh).float()
661
+ masked_pc = pcs * masks
662
+ pc_buff += masked_pc.sum(dim=0, keepdim=False)
663
+ val_cnt += masks.sum(dim=0, keepdim=False)
664
+ j += batch_size
665
+ if j >= tot_frame:
666
+ break
667
+
668
+ final_mask = (val_cnt >= self.num_consist).squeeze(0)
669
+ avg_points = torch.div(pc_buff, val_cnt).permute(1, 2, 0)
670
+ final_pc = self._extract_points(avg_points, final_mask)
671
+ return final_pc
672
+
673
+ def _cap_points(self, points: np.ndarray, max_points: int) -> np.ndarray:
674
+ """Downsample points if exceeding max count."""
675
+ if len(points) <= max_points:
676
+ return points
677
+ # Use fixed seed for reproducibility
678
+ rng = np.random.default_rng(seed=42)
679
+ random_idx = rng.choice(len(points), max_points, replace=False)
680
+ return points[random_idx]
681
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu64.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ DTU-64 Dataset implementation for POSE EVALUATION ONLY.
17
+
18
+ This is a subset of DTU with 64 images per scene, specifically designed for
19
+ camera pose estimation evaluation. It does NOT support 3D reconstruction.
20
+
21
+ Note: GT depth loading is not implemented as it's not needed for pose evaluation.
22
+ """
23
+
24
+ import glob
25
+ import os
26
+ from typing import Dict as TDict
27
+
28
+ import numpy as np
29
+ from addict import Dict
30
+
31
+ from depth_anything_3.bench.dataset import Dataset
32
+ from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
33
+ from depth_anything_3.utils.constants import (
34
+ DTU64_CAMERA_ROOT,
35
+ DTU64_EVAL_DATA_ROOT,
36
+ DTU64_SCENES,
37
+ )
38
+
39
+
40
+ @MV_REGISTRY.register(name="dtu64")
41
+ @MONO_REGISTRY.register(name="dtu64")
42
+ class DTU64(Dataset):
43
+ """
44
+ DTU-64 Dataset wrapper for DepthAnything3 POSE EVALUATION ONLY.
45
+
46
+ This dataset is a subset of DTU with 64 images per scene.
47
+ It is specifically designed for camera pose estimation evaluation
48
+ and does NOT support 3D reconstruction evaluation.
49
+
50
+ Dataset structure:
51
+ DTU/scans/
52
+ ├── {scene}/
53
+ │ └── image/ # RGB images (64 per scene)
54
+ └── Cameras/
55
+ └── {idx}_cam.txt # Camera parameters
56
+
57
+ Supported modes:
58
+ - pose: Camera pose estimation evaluation
59
+
60
+ NOT supported:
61
+ - recon_unposed: 3D reconstruction (no GT depth available)
62
+ - recon_posed: 3D reconstruction (no GT depth available)
63
+ """
64
+
65
+ data_root = DTU64_EVAL_DATA_ROOT
66
+ camera_root = DTU64_CAMERA_ROOT
67
+ SCENES = DTU64_SCENES
68
+
69
+ def __init__(self):
70
+ super().__init__()
71
+ self._scene_cache = {}
72
+
73
+ # ------------------------------
74
+ # Camera file parsing
75
+ # ------------------------------
76
+
77
+ def read_cam_file(self, filename: str) -> tuple:
78
+ """
79
+ Read DTU camera file containing extrinsics and intrinsics.
80
+
81
+ Args:
82
+ filename: Path to camera text file
83
+
84
+ Returns:
85
+ Tuple of (intrinsics [3,3], extrinsics [4,4])
86
+ """
87
+ with open(filename) as f:
88
+ lines = [line.rstrip() for line in f.readlines()]
89
+ # extrinsics: line [1,5), 4x4 matrix
90
+ extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4))
91
+ # intrinsics: line [7-10), 3x3 matrix
92
+ intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3))
93
+ return intrinsics, extrinsics
94
+
95
+ # ------------------------------
96
+ # Public API
97
+ # ------------------------------
98
+
99
+ def get_data(self, scene: str) -> Dict:
100
+ """
101
+ Collect per-view image paths, intrinsics/extrinsics for a scene.
102
+
103
+ Args:
104
+ scene: Scene identifier (e.g., "scan105")
105
+
106
+ Returns:
107
+ Dict with:
108
+ - image_files: List[str] - paths to images (64 per scene)
109
+ - extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
110
+ - intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
111
+ - aux: Dict (empty for this dataset)
112
+ """
113
+ if scene in self._scene_cache:
114
+ return self._scene_cache[scene]
115
+
116
+ rgb_folder = os.path.join(self.data_root, scene, "image")
117
+
118
+ # Get all PNG files sorted
119
+ files = sorted(glob.glob(os.path.join(rgb_folder, "*.png")))
120
+
121
+ # Reorder: place index 33 first (reference view convention)
122
+ if len(files) > 33:
123
+ files = [files[33]] + files[:33] + files[34:]
124
+
125
+ out = Dict({
126
+ "image_files": [],
127
+ "extrinsics": [],
128
+ "intrinsics": [],
129
+ "aux": Dict({}),
130
+ })
131
+
132
+ for rgb_file in files:
133
+ basename = os.path.basename(rgb_file)
134
+ # File naming: "00000033.png" -> cam_idx = 33
135
+ file_idx = basename.split(".")[0]
136
+ cam_idx = int(file_idx)
137
+
138
+ # Camera file path
139
+ cam_file = os.path.join(self.camera_root, f"{cam_idx:0>8}_cam.txt")
140
+
141
+ if not os.path.exists(cam_file):
142
+ print(f"[DTU-64] Warning: Camera file not found: {cam_file}")
143
+ continue
144
+
145
+ intrinsics, extrinsics = self.read_cam_file(cam_file)
146
+
147
+ out.image_files.append(rgb_file)
148
+ out.extrinsics.append(extrinsics)
149
+ out.intrinsics.append(intrinsics)
150
+
151
+ out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
152
+ out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
153
+
154
+ print(f"[DTU-64] {scene}: {len(out.image_files)} images (pose evaluation only)")
155
+
156
+ self._scene_cache[scene] = out
157
+ return out
158
+
159
+ def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
160
+ """
161
+ NOT SUPPORTED for DTU-64.
162
+
163
+ DTU-64 is only for pose evaluation, not 3D reconstruction.
164
+ """
165
+ raise NotImplementedError(
166
+ "DTU-64 dataset is for POSE EVALUATION ONLY. "
167
+ "3D reconstruction evaluation is not supported. "
168
+ "Use the standard 'dtu' dataset for 3D reconstruction evaluation."
169
+ )
170
+
171
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
172
+ """
173
+ NOT SUPPORTED for DTU-64.
174
+
175
+ DTU-64 is only for pose evaluation, not 3D reconstruction.
176
+ """
177
+ raise NotImplementedError(
178
+ "DTU-64 dataset is for POSE EVALUATION ONLY. "
179
+ "3D reconstruction (fuse3d) is not supported. "
180
+ "Use the standard 'dtu' dataset for 3D reconstruction."
181
+ )
182
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/eth3d.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ ETH3D Benchmark dataset implementation.
17
+
18
+ ETH3D is a multi-view stereo benchmark with high-resolution images and
19
+ accurate ground truth geometry from laser scanning.
20
+ Reference: https://www.eth3d.net/
21
+
22
+ Evaluation metrics:
23
+ - 3D reconstruction: Accuracy, Completeness, F-score
24
+ - Camera pose estimation: AUC metrics
25
+ """
26
+
27
+ import glob
28
+ import os
29
+ from typing import Dict as TDict, List, Optional
30
+
31
+ import cv2
32
+ import numpy as np
33
+ import open3d as o3d
34
+ import torch
35
+ import torch.nn.functional as F
36
+ from addict import Dict
37
+ from PIL import Image
38
+
39
+ from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
40
+ from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
41
+ from depth_anything_3.bench.utils import (
42
+ create_tsdf_volume,
43
+ evaluate_3d_reconstruction,
44
+ fuse_depth_to_tsdf,
45
+ quat2rotmat,
46
+ sample_points_from_mesh,
47
+ )
48
+ from depth_anything_3.utils.constants import (
49
+ ETH3D_DOWN_SAMPLE,
50
+ ETH3D_EVAL_DATA_ROOT,
51
+ ETH3D_EVAL_THRESHOLD,
52
+ ETH3D_FILTER_KEYS,
53
+ ETH3D_MAX_DEPTH,
54
+ ETH3D_SAMPLING_NUMBER,
55
+ ETH3D_SCENES,
56
+ ETH3D_SDF_TRUNC,
57
+ ETH3D_VOXEL_LENGTH,
58
+ )
59
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
60
+
61
+
62
+ @MV_REGISTRY.register(name="eth3d")
63
+ @MONO_REGISTRY.register(name="eth3d")
64
+ class ETH3D(Dataset):
65
+ """
66
+ ETH3D Benchmark dataset wrapper for DepthAnything3 evaluation.
67
+
68
+ Supports:
69
+ - Camera pose estimation evaluation (AUC metrics)
70
+ - 3D reconstruction evaluation (Accuracy, Completeness, F-score)
71
+ - TSDF-based point cloud fusion
72
+
73
+ Dataset structure:
74
+ eth3d/multiview/
75
+ ├── scene_name/
76
+ │ ├── images/ # RGB images
77
+ │ ├── dslr_calibration_jpg/
78
+ │ │ ├── cameras.txt # Camera intrinsics
79
+ │ │ └── images.txt # Camera poses
80
+ │ ├── combined_mesh.ply # Ground truth mesh
81
+ │ └── ground_truth_depth/ # GT depth maps (optional)
82
+ """
83
+
84
+ data_root = ETH3D_EVAL_DATA_ROOT
85
+ SCENES = ETH3D_SCENES
86
+
87
+ # Evaluation hyperparameters from constants
88
+ max_depth = ETH3D_MAX_DEPTH
89
+ sampling_number = ETH3D_SAMPLING_NUMBER
90
+ voxel_length = ETH3D_VOXEL_LENGTH
91
+ sdf_trunc = ETH3D_SDF_TRUNC
92
+ eval_threshold = ETH3D_EVAL_THRESHOLD
93
+ down_sample = ETH3D_DOWN_SAMPLE
94
+
95
+ def __init__(self):
96
+ super().__init__()
97
+ # Pre-load scene data for efficiency
98
+ self._scene_cache = {}
99
+
100
+ # ------------------------------
101
+ # Camera file parsing
102
+ # ------------------------------
103
+
104
+ def _parse_cameras_txt(self, filepath: str) -> dict:
105
+ """
106
+ Parse COLMAP-style cameras.txt file.
107
+
108
+ Returns:
109
+ Dict mapping camera_id to intrinsic parameters
110
+ """
111
+ camera_dict = {}
112
+ with open(filepath) as f:
113
+ lines = f.readlines()
114
+ for line in lines[3:]: # Skip header
115
+ line = line.strip()
116
+ if not line or line.startswith("#"):
117
+ continue
118
+ parts = line.split()
119
+ if len(parts) < 8:
120
+ continue
121
+ cam_id = parts[0]
122
+ # Format: ID, MODEL, WIDTH, HEIGHT, fx, fy, cx, cy, [distortion params...]
123
+ camera_dict[cam_id] = {
124
+ "width": float(parts[2]),
125
+ "height": float(parts[3]),
126
+ "fx": float(parts[4]),
127
+ "fy": float(parts[5]),
128
+ "cx": float(parts[6]),
129
+ "cy": float(parts[7]),
130
+ }
131
+ return camera_dict
132
+
133
+ def _parse_images_txt(self, filepath: str) -> dict:
134
+ """
135
+ Parse COLMAP-style images.txt file.
136
+
137
+ Returns:
138
+ Dict mapping image path to pose parameters
139
+ """
140
+ pose_dict = {}
141
+ with open(filepath) as f:
142
+ lines = f.readlines()
143
+ for idx, line in enumerate(lines[4:]): # Skip header
144
+ line = line.strip()
145
+ if not line or line.startswith("#"):
146
+ continue
147
+ # Every other line contains pose info
148
+ if idx % 2 == 0:
149
+ parts = line.split()
150
+ if len(parts) < 10:
151
+ continue
152
+ # Format: IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME
153
+ image_id = parts[0]
154
+ qw, qx, qy, qz = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
155
+ tx, ty, tz = float(parts[5]), float(parts[6]), float(parts[7])
156
+ camera_id = parts[8]
157
+ name = parts[9]
158
+ pose_dict[name] = {
159
+ "image_id": image_id,
160
+ "quat": [qw, qx, qy, qz],
161
+ "trans": [tx, ty, tz],
162
+ "camera_id": camera_id,
163
+ }
164
+ return pose_dict
165
+
166
+ def _should_filter_image(self, scene: str, image_name: str) -> bool:
167
+ """Check if image should be filtered out based on known problematic views."""
168
+ filter_keys = ETH3D_FILTER_KEYS.get(scene, [])
169
+ for key in filter_keys:
170
+ if image_name.endswith(key):
171
+ return True
172
+ return False
173
+
174
+ # ------------------------------
175
+ # Public API
176
+ # ------------------------------
177
+
178
+ def get_data(self, scene: str) -> Dict:
179
+ """
180
+ Collect per-view image paths, intrinsics/extrinsics for a scene.
181
+
182
+ Args:
183
+ scene: Scene identifier (e.g., "courtyard")
184
+
185
+ Returns:
186
+ Dict with:
187
+ - image_files: List[str] - paths to images
188
+ - extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
189
+ - intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
190
+ - aux: Dict with gt_mesh_path
191
+ """
192
+ # Check cache
193
+ if scene in self._scene_cache:
194
+ return self._scene_cache[scene]
195
+
196
+ scene_dir = os.path.join(self.data_root, scene)
197
+
198
+ # Parse camera files
199
+ cameras_file = os.path.join(scene_dir, "dslr_calibration_jpg", "cameras.txt")
200
+ images_file = os.path.join(scene_dir, "dslr_calibration_jpg", "images.txt")
201
+ camera_dict = self._parse_cameras_txt(cameras_file)
202
+ pose_dict = self._parse_images_txt(images_file)
203
+
204
+ # Ground truth mesh path
205
+ gt_mesh_path = os.path.join(scene_dir, "combined_mesh.ply")
206
+
207
+ out = Dict({
208
+ "image_files": [],
209
+ "extrinsics": [],
210
+ "intrinsics": [],
211
+ "aux": Dict({
212
+ "gt_mesh_path": gt_mesh_path,
213
+ "heights": [],
214
+ "widths": [],
215
+ }),
216
+ })
217
+
218
+ # Process each image (preserve original order from images.txt)
219
+ filtered_count = 0
220
+ for image_name, pose_info in pose_dict.items():
221
+ # Filter problematic views
222
+ if self._should_filter_image(scene, image_name):
223
+ filtered_count += 1
224
+ continue
225
+
226
+ image_path = os.path.join(scene_dir, "images", image_name)
227
+ if not os.path.exists(image_path):
228
+ continue
229
+
230
+ cam_info = camera_dict.get(pose_info["camera_id"])
231
+ if cam_info is None:
232
+ continue
233
+
234
+ # Build intrinsics matrix
235
+ ixt = np.array([
236
+ [cam_info["fx"], 0, cam_info["cx"]],
237
+ [0, cam_info["fy"], cam_info["cy"]],
238
+ [0, 0, 1],
239
+ ], dtype=np.float32)
240
+
241
+ # Build extrinsics matrix (world-to-camera)
242
+ # COLMAP format: world point -> camera point
243
+ rot = quat2rotmat(pose_info["quat"])
244
+ ext = np.eye(4, dtype=np.float32)
245
+ ext[:3, :3] = rot
246
+ ext[:3, 3] = pose_info["trans"]
247
+
248
+ out.image_files.append(image_path)
249
+ out.extrinsics.append(ext)
250
+ out.intrinsics.append(ixt)
251
+ out.aux.heights.append(cam_info["height"])
252
+ out.aux.widths.append(cam_info["width"])
253
+
254
+ out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
255
+ out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
256
+
257
+ # Print scene info
258
+ total_images = len(pose_dict)
259
+ used_images = len(out.image_files)
260
+ print(f"[ETH3D] {scene}: {used_images}/{total_images} images "
261
+ f"(filtered {filtered_count}, missing {total_images - used_images - filtered_count})")
262
+
263
+ if used_images < 3:
264
+ print(f"[ETH3D] ⚠️ WARNING: {scene} has only {used_images} images - evaluation may fail!")
265
+
266
+ # Cache result
267
+ self._scene_cache[scene] = out
268
+ return out
269
+
270
+ def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
271
+ """
272
+ Evaluate fused point cloud against ETH3D ground truth mesh.
273
+
274
+ Args:
275
+ scene: Scene identifier
276
+ fuse_path: Path to fused point cloud (.ply)
277
+
278
+ Returns:
279
+ Dict with metrics: acc, comp, overall, precision, recall, fscore
280
+ """
281
+ gt_data = self.get_data(scene)
282
+ gt_mesh_path = gt_data.aux.gt_mesh_path
283
+
284
+ # Load and sample ground truth mesh
285
+ gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
286
+ gt_pcd = sample_points_from_mesh(gt_mesh, self.sampling_number)
287
+
288
+ # Load predicted point cloud
289
+ pred_pcd = o3d.io.read_point_cloud(fuse_path)
290
+
291
+ # Evaluate using shared utility function
292
+ metrics = evaluate_3d_reconstruction(
293
+ pred_pcd,
294
+ gt_pcd,
295
+ threshold=self.eval_threshold,
296
+ down_sample=self.down_sample,
297
+ )
298
+
299
+ return metrics
300
+
301
+ def _load_gt_meta(self, result_path: str) -> Dict:
302
+ """
303
+ Load saved GT meta (extrinsics, intrinsics, image_files) for fusion.
304
+
305
+ This is needed when frames are sampled, so fuse3d uses the correct
306
+ (sampled) GT instead of full dataset GT.
307
+
308
+ Args:
309
+ result_path: Path to npz file (used to derive gt_meta.npz path)
310
+
311
+ Returns:
312
+ Dict with GT data, or None if gt_meta.npz doesn't exist
313
+ """
314
+ # gt_meta.npz is in the same exports/ directory as results.npz
315
+ export_dir = os.path.dirname(result_path) # exports/mini_npz/
316
+ gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
317
+
318
+ if os.path.exists(gt_meta_path):
319
+ data = np.load(gt_meta_path, allow_pickle=True)
320
+ return Dict({
321
+ "extrinsics": data["extrinsics"],
322
+ "intrinsics": data["intrinsics"],
323
+ "image_files": data["image_files"] if "image_files" in data else None,
324
+ })
325
+ return None
326
+
327
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
328
+ """
329
+ Fuse per-view depths into a point cloud using TSDF fusion.
330
+
331
+ Pipeline:
332
+ 1. Load original images (keep original size)
333
+ 2. Resize depth to original image size (nearest interpolation)
334
+ 3. Adjust intrinsics to original image size
335
+ 4. Apply scale alignment and mask invalid depths
336
+ 5. TSDF fusion
337
+
338
+ Args:
339
+ scene: Scene identifier
340
+ result_path: Path to npz file with predicted depths/poses
341
+ fuse_path: Output path for fused point cloud (.ply)
342
+ mode: "recon_unposed" or "recon_posed"
343
+ """
344
+ # Try to load saved GT meta (handles frame sampling)
345
+ gt_meta = self._load_gt_meta(result_path)
346
+ if gt_meta is not None:
347
+ gt_data = gt_meta
348
+ else:
349
+ gt_data = self.get_data(scene)
350
+ _wait_for_file_ready(result_path)
351
+ pred_data = Dict({k: v for k, v in np.load(result_path).items()})
352
+
353
+ # Load original images (keep original size)
354
+ images = []
355
+ orig_sizes = [] # (H, W) for each image
356
+ for img_path in gt_data.image_files:
357
+ img = cv2.imread(img_path)
358
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
359
+ images.append(img)
360
+ orig_sizes.append((img.shape[0], img.shape[1]))
361
+
362
+ # Prepare depths, intrinsics, extrinsics with resize to original size
363
+ if mode == "recon_unposed":
364
+ depths, intrinsics, extrinsics = self._prep_unposed(
365
+ pred_data, gt_data, orig_sizes, scene=scene
366
+ )
367
+ elif mode == "recon_posed":
368
+ depths, intrinsics, extrinsics = self._prep_posed(
369
+ pred_data, gt_data, orig_sizes, scene=scene
370
+ )
371
+ else:
372
+ raise ValueError(f"Invalid mode: {mode}")
373
+
374
+ images = np.stack(images, axis=0)
375
+
376
+ # Create TSDF volume and fuse
377
+ volume = create_tsdf_volume(
378
+ voxel_length=self.voxel_length,
379
+ sdf_trunc=self.sdf_trunc,
380
+ )
381
+ mesh = fuse_depth_to_tsdf(
382
+ volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
383
+ )
384
+
385
+ # Sample points from mesh
386
+ pcd = sample_points_from_mesh(mesh, self.sampling_number)
387
+
388
+ # Save point cloud
389
+ os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
390
+ o3d.io.write_point_cloud(fuse_path, pcd)
391
+
392
+ # ------------------------------
393
+ # Private helpers
394
+ # ------------------------------
395
+
396
+ def _prep_unposed(
397
+ self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str = None
398
+ ) -> tuple:
399
+ """
400
+ Prepare depths/intrinsics/extrinsics for recon_unposed mode.
401
+
402
+ Pipeline:
403
+ 1. Umeyama scale alignment
404
+ 2. Load GT mask for each frame
405
+ 3. Resize depth to original image size (nearest)
406
+ 4. Apply GT mask BEFORE scale
407
+ 5. Apply scale
408
+ 6. Adjust intrinsics to original image size
409
+ """
410
+ # Scale alignment with fixed random_state for reproducibility
411
+ _, _, scale, extrinsics = align_poses_umeyama(
412
+ gt_data.extrinsics.copy(),
413
+ pred_data.extrinsics.copy(),
414
+ return_aligned=True,
415
+ ransac=True,
416
+ random_state=42,
417
+ )
418
+
419
+ # Get model output size
420
+ model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
421
+
422
+ # Process each frame
423
+ depths_out = []
424
+ intrinsics_out = []
425
+ for i in range(len(pred_data.depth)):
426
+ orig_h, orig_w = orig_sizes[i]
427
+ image_name = os.path.basename(gt_data.image_files[i])
428
+
429
+ # Resize depth to original image size (nearest interpolation)
430
+ depth = cv2.resize(
431
+ pred_data.depth[i],
432
+ (orig_w, orig_h),
433
+ interpolation=cv2.INTER_NEAREST,
434
+ )
435
+
436
+ # Load GT mask (apply BEFORE scale)
437
+ gt_zero_mask = None
438
+ if scene is not None:
439
+ gt_zero_mask = self._load_gt_mask(scene, image_name, (orig_h, orig_w))
440
+
441
+ # Mask invalid depths BEFORE scale
442
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
443
+
444
+ # Apply scale AFTER mask
445
+ depth = depth * scale
446
+
447
+ # Adjust intrinsics to original image size
448
+ h_ratio = orig_h / model_h
449
+ w_ratio = orig_w / model_w
450
+ ixt = pred_data.intrinsics[i].copy()
451
+ ixt[0, :] *= w_ratio # fx, 0, cx
452
+ ixt[1, :] *= h_ratio # 0, fy, cy
453
+
454
+ depths_out.append(depth)
455
+ intrinsics_out.append(ixt)
456
+
457
+ return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
458
+
459
+ def _prep_posed(
460
+ self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str = None
461
+ ) -> tuple:
462
+ """
463
+ Prepare depths/intrinsics/extrinsics for recon_posed mode.
464
+
465
+ Uses GT intrinsics/extrinsics but aligns depth scale via Umeyama.
466
+ Depth is resized to original image size.
467
+ """
468
+ # Scale alignment with fixed random_state for reproducibility
469
+ _, _, scale, _ = align_poses_umeyama(
470
+ gt_data.extrinsics.copy(),
471
+ pred_data.extrinsics.copy(),
472
+ return_aligned=True,
473
+ ransac=True,
474
+ random_state=42,
475
+ )
476
+
477
+ # Process each frame
478
+ depths_out = []
479
+ for i in range(len(pred_data.depth)):
480
+ orig_h, orig_w = orig_sizes[i]
481
+ image_name = os.path.basename(gt_data.image_files[i])
482
+
483
+ # Resize depth to original image size (nearest interpolation)
484
+ depth = cv2.resize(
485
+ pred_data.depth[i],
486
+ (orig_w, orig_h),
487
+ interpolation=cv2.INTER_NEAREST,
488
+ )
489
+
490
+ # Load GT mask (apply BEFORE scale)
491
+ gt_zero_mask = None
492
+ if scene is not None:
493
+ gt_zero_mask = self._load_gt_mask(scene, image_name, (orig_h, orig_w))
494
+
495
+ # Mask invalid depths BEFORE scale
496
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
497
+
498
+ # Apply scale AFTER mask
499
+ depth = depth * scale
500
+
501
+ depths_out.append(depth)
502
+
503
+ # Use GT intrinsics and extrinsics (already at original image size)
504
+ return np.stack(depths_out), gt_data.intrinsics.copy(), gt_data.extrinsics.copy()
505
+
506
+ def _load_gt_mask(self, scene: str, image_name: str, shape: tuple) -> np.ndarray:
507
+ """
508
+ Load GT mask for masking invalid regions.
509
+
510
+ GT mask marks occluded or invalid regions that should be excluded
511
+ from depth fusion and evaluation.
512
+
513
+ Args:
514
+ scene: Scene identifier
515
+ image_name: Image filename (e.g., "DSC_0307.JPG")
516
+ shape: (height, width) of the image
517
+
518
+ Returns:
519
+ Boolean mask where True = valid region to keep
520
+ """
521
+ h, w = shape
522
+
523
+ # GT mask file path
524
+ gt_mask_path = os.path.join(
525
+ self.data_root, scene, "masks_for_images", "dslr_images",
526
+ image_name.replace(".JPG", ".png")
527
+ )
528
+
529
+ # GT depth file path (used to determine valid depth regions)
530
+ gt_depth_path = os.path.join(
531
+ self.data_root, scene, "ground_truth_depth", "dslr_images", image_name
532
+ )
533
+
534
+ # Load GT depth
535
+ if os.path.exists(gt_depth_path):
536
+ gt_depth = np.fromfile(gt_depth_path, dtype=np.float32).reshape(h, w)
537
+ else:
538
+ gt_depth = np.ones((h, w), dtype=np.float32)
539
+
540
+ # Load GT mask
541
+ if os.path.exists(gt_mask_path):
542
+ gt_mask = cv2.imread(gt_mask_path, cv2.IMREAD_GRAYSCALE)
543
+ gt_mask = np.asarray(gt_mask)
544
+ else:
545
+ gt_mask = np.zeros((h, w), dtype=np.uint8)
546
+
547
+ # Compute zero_mask
548
+ # gt_mask == 1 means occluded/invalid region
549
+ invalid_mask_from_gt = gt_mask == 1
550
+ gt_depth_copy = gt_depth.copy()
551
+ gt_depth_copy[gt_mask == 1] = 0
552
+
553
+ invalid_mask_from_gt_depth = np.logical_or(gt_depth_copy == 0, gt_depth_copy == np.inf)
554
+
555
+ # zero_mask: valid region that should be kept
556
+ zero_mask = np.logical_and(
557
+ np.logical_not(invalid_mask_from_gt),
558
+ np.logical_not(invalid_mask_from_gt_depth)
559
+ )
560
+
561
+ return zero_mask
562
+
563
+ def _mask_invalid_depth(
564
+ self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
565
+ ) -> np.ndarray:
566
+ """
567
+ Mask invalid depth values by setting them to 0.
568
+
569
+ Logic:
570
+ 1. Apply GT mask (if provided) - marks occluded/invalid regions
571
+ 2. Mask pred invalid values (nan, inf)
572
+
573
+ Args:
574
+ depth: Depth map to mask
575
+ gt_zero_mask: Optional GT mask (True = valid region)
576
+
577
+ Returns:
578
+ Masked depth map with invalid regions set to 0
579
+ """
580
+ depth = depth.copy()
581
+
582
+ # Apply GT mask first (before scale)
583
+ if gt_zero_mask is not None:
584
+ # Also mask out invalid pred depth
585
+ pred_invalid = np.isnan(depth) | np.isinf(depth)
586
+ combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
587
+ depth = depth * combined_mask.astype(np.float32)
588
+ else:
589
+ # Fallback: only mask pred invalid values
590
+ invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
591
+ depth[invalid_mask] = 0.0
592
+
593
+ return depth
594
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/hiroom.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ HiRoom Dataset implementation.
17
+
18
+ HiRoom is an indoor RGB-D dataset containing ground truth camera poses,
19
+ depth maps, and fused point clouds.
20
+
21
+ Evaluation metrics:
22
+ - 3D reconstruction: Accuracy, Completeness, F-score
23
+ - Camera pose estimation: AUC metrics
24
+ """
25
+
26
+ import os
27
+ from typing import Dict as TDict, List
28
+
29
+ import cv2
30
+ import numpy as np
31
+ import open3d as o3d
32
+ from addict import Dict
33
+
34
+ from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
35
+ from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
36
+ from depth_anything_3.bench.utils import (
37
+ create_tsdf_volume,
38
+ evaluate_3d_reconstruction,
39
+ fuse_depth_to_tsdf,
40
+ sample_points_from_mesh,
41
+ )
42
+ from depth_anything_3.utils.constants import (
43
+ HIROOM_DOWN_SAMPLE,
44
+ HIROOM_EVAL_DATA_ROOT,
45
+ HIROOM_EVAL_THRESHOLD,
46
+ HIROOM_GT_ROOT_PATH,
47
+ HIROOM_MAX_DEPTH,
48
+ HIROOM_SAMPLING_NUMBER,
49
+ HIROOM_SCENE_LIST_PATH,
50
+ HIROOM_SDF_TRUNC,
51
+ HIROOM_VOXEL_LENGTH,
52
+ )
53
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
54
+
55
+
56
+ def _load_scene_list() -> List[str]:
57
+ """Load scene list from file."""
58
+ if os.path.exists(HIROOM_SCENE_LIST_PATH):
59
+ with open(HIROOM_SCENE_LIST_PATH, "r") as f:
60
+ return f.read().splitlines()
61
+ return []
62
+
63
+
64
+ @MV_REGISTRY.register(name="hiroom")
65
+ @MONO_REGISTRY.register(name="hiroom")
66
+ class HiRoomDataset(Dataset):
67
+ """
68
+ HiRoom Dataset wrapper for DepthAnything3 evaluation.
69
+
70
+ Supports:
71
+ - Camera pose estimation evaluation (AUC metrics)
72
+ - 3D reconstruction evaluation (Accuracy, Completeness, F-score)
73
+ - TSDF-based point cloud fusion
74
+
75
+ Dataset structure:
76
+ HiRoom/
77
+ ├── {scene_path}/
78
+ │ ├── image/ # RGB images
79
+ │ ├── depth/ # GT depth maps
80
+ │ ├── pose/ # Camera poses (.npy)
81
+ │ ├── cam_K.npy # Camera intrinsics
82
+ │ └── aliasing_mask/ # Aliasing masks
83
+
84
+ fused_pcd/
85
+ └── {scene_name}.ply # Ground truth fused point cloud
86
+ """
87
+
88
+ data_root = HIROOM_EVAL_DATA_ROOT
89
+ gt_root_path = HIROOM_GT_ROOT_PATH
90
+ SCENES = _load_scene_list()
91
+
92
+ # Evaluation hyperparameters from constants
93
+ max_depth = HIROOM_MAX_DEPTH
94
+ sampling_number = HIROOM_SAMPLING_NUMBER
95
+ voxel_length = HIROOM_VOXEL_LENGTH
96
+ sdf_trunc = HIROOM_SDF_TRUNC
97
+ eval_threshold = HIROOM_EVAL_THRESHOLD
98
+ down_sample = HIROOM_DOWN_SAMPLE
99
+
100
+ def __init__(self):
101
+ super().__init__()
102
+ self._scene_cache = {}
103
+
104
+ # ------------------------------
105
+ # Public API
106
+ # ------------------------------
107
+
108
+ def get_data(self, scene: str) -> Dict:
109
+ """
110
+ Collect per-view image paths, intrinsics/extrinsics for a scene.
111
+
112
+ Args:
113
+ scene: Scene path (e.g., "xxx/yyy/zzz")
114
+
115
+ Returns:
116
+ Dict with:
117
+ - image_files: List[str] - paths to images
118
+ - extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
119
+ - intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
120
+ - aux: Dict with gt_pcd_path, gt_depth_files, aliasing_mask_files
121
+ """
122
+ if scene in self._scene_cache:
123
+ return self._scene_cache[scene]
124
+
125
+ scene_dir = os.path.join(self.data_root, scene)
126
+ image_dir = os.path.join(scene_dir, "image")
127
+
128
+ # Get scene name for GT point cloud
129
+ scene_name = "-".join(scene.split("/")[-3:])
130
+ gt_pcd_path = os.path.join(self.gt_root_path, f"{scene_name}.ply")
131
+
132
+ # Load shared camera intrinsics
133
+ intrin_path = os.path.join(scene_dir, "cam_K.npy")
134
+ ixt_shared = np.load(intrin_path).astype(np.float32)
135
+
136
+ # Get all image names sorted
137
+ image_names = sorted(os.listdir(image_dir))
138
+
139
+ out = Dict({
140
+ "image_files": [],
141
+ "extrinsics": [],
142
+ "intrinsics": [],
143
+ "aux": Dict({
144
+ "gt_pcd_path": gt_pcd_path,
145
+ "gt_depth_files": [],
146
+ "aliasing_mask_files": [],
147
+ }),
148
+ })
149
+
150
+ for img_name in image_names:
151
+ img_path = os.path.join(image_dir, img_name)
152
+ frame_name = img_name.split(".")[0]
153
+
154
+ # Depth and pose paths
155
+ depth_path = os.path.join(scene_dir, "depth", f"{frame_name}.png")
156
+ pose_path = os.path.join(scene_dir, "pose", f"{frame_name}.npy")
157
+ aliasing_mask_path = os.path.join(scene_dir, "aliasing_mask", f"{frame_name}.png")
158
+
159
+ if not os.path.exists(pose_path):
160
+ continue
161
+
162
+ # Load extrinsics (world-to-camera)
163
+ ext = np.load(pose_path).astype(np.float32)
164
+
165
+ out.image_files.append(img_path)
166
+ out.extrinsics.append(ext)
167
+ out.intrinsics.append(ixt_shared.copy())
168
+ out.aux.gt_depth_files.append(depth_path)
169
+ out.aux.aliasing_mask_files.append(aliasing_mask_path)
170
+
171
+ out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
172
+ out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
173
+
174
+ print(f"[HiRoom] {scene}: {len(out.image_files)} images")
175
+
176
+ self._scene_cache[scene] = out
177
+ return out
178
+
179
+ def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
180
+ """
181
+ Evaluate fused point cloud against HiRoom ground truth point cloud.
182
+
183
+ Args:
184
+ scene: Scene identifier
185
+ fuse_path: Path to fused point cloud (.ply)
186
+
187
+ Returns:
188
+ Dict with metrics: acc, comp, overall, precision, recall, fscore
189
+ """
190
+ gt_data = self.get_data(scene)
191
+ gt_pcd_path = gt_data.aux.gt_pcd_path
192
+
193
+ # Load ground truth point cloud
194
+ gt_pcd = o3d.io.read_point_cloud(gt_pcd_path)
195
+
196
+ # Load predicted point cloud
197
+ pred_pcd = o3d.io.read_point_cloud(fuse_path)
198
+
199
+ # Evaluate using shared utility function
200
+ metrics = evaluate_3d_reconstruction(
201
+ pred_pcd,
202
+ gt_pcd,
203
+ threshold=self.eval_threshold,
204
+ down_sample=self.down_sample,
205
+ )
206
+
207
+ return metrics
208
+
209
+ def _load_gt_meta(self, result_path: str) -> Dict:
210
+ """Load saved GT meta for fusion."""
211
+ export_dir = os.path.dirname(result_path)
212
+ gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
213
+
214
+ if os.path.exists(gt_meta_path):
215
+ data = np.load(gt_meta_path, allow_pickle=True)
216
+ image_files = list(data["image_files"])
217
+ return Dict({
218
+ "extrinsics": data["extrinsics"],
219
+ "intrinsics": data["intrinsics"],
220
+ "image_files": image_files,
221
+ })
222
+ return None
223
+
224
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
225
+ """
226
+ Fuse per-view depths into a point cloud using TSDF fusion.
227
+
228
+ Args:
229
+ scene: Scene identifier
230
+ result_path: Path to npz file with predicted depths/poses
231
+ fuse_path: Output path for fused point cloud (.ply)
232
+ mode: "recon_unposed" or "recon_posed"
233
+ """
234
+ # Get full GT data
235
+ full_gt_data = self.get_data(scene)
236
+
237
+ # Try to load saved GT meta (handles frame sampling)
238
+ gt_meta = self._load_gt_meta(result_path)
239
+ if gt_meta is not None:
240
+ gt_data = gt_meta
241
+ image_indices = [
242
+ full_gt_data.image_files.index(f)
243
+ for f in gt_data.image_files
244
+ if f in full_gt_data.image_files
245
+ ]
246
+ else:
247
+ gt_data = full_gt_data
248
+ image_indices = list(range(len(full_gt_data.image_files)))
249
+
250
+ _wait_for_file_ready(result_path)
251
+ pred_data = Dict({k: v for k, v in np.load(result_path).items()})
252
+
253
+ # Load images
254
+ images = []
255
+ orig_sizes = []
256
+ for img_idx in image_indices:
257
+ img_path = full_gt_data.image_files[img_idx]
258
+ img = cv2.imread(img_path)
259
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
260
+ images.append(img)
261
+ orig_sizes.append((img.shape[0], img.shape[1]))
262
+
263
+ images = np.stack(images, axis=0)
264
+
265
+ # Prepare depths, intrinsics, extrinsics
266
+ if mode == "recon_unposed":
267
+ depths, intrinsics, extrinsics = self._prep_unposed(
268
+ pred_data, gt_data, full_gt_data, image_indices, orig_sizes, scene=scene
269
+ )
270
+ elif mode == "recon_posed":
271
+ depths, intrinsics, extrinsics = self._prep_posed(
272
+ pred_data, gt_data, full_gt_data, image_indices, orig_sizes, scene=scene
273
+ )
274
+ else:
275
+ raise ValueError(f"Invalid mode: {mode}")
276
+
277
+ # Create TSDF volume and fuse
278
+ volume = create_tsdf_volume(
279
+ voxel_length=self.voxel_length,
280
+ sdf_trunc=self.sdf_trunc,
281
+ )
282
+ mesh = fuse_depth_to_tsdf(
283
+ volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
284
+ )
285
+
286
+ # Sample points from mesh
287
+ pcd = sample_points_from_mesh(mesh, self.sampling_number)
288
+
289
+ # Save point cloud
290
+ os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
291
+ o3d.io.write_point_cloud(fuse_path, pcd)
292
+
293
+ # ------------------------------
294
+ # Private helpers
295
+ # ------------------------------
296
+
297
+ def _prep_unposed(
298
+ self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
299
+ image_indices: list, orig_sizes: list, scene: str = None
300
+ ) -> tuple:
301
+ """Prepare depths/intrinsics/extrinsics for recon_unposed mode."""
302
+ # Scale alignment with fixed random_state for reproducibility
303
+ _, _, scale, extrinsics = align_poses_umeyama(
304
+ gt_data.extrinsics.copy(),
305
+ pred_data.extrinsics.copy(),
306
+ return_aligned=True,
307
+ ransac=True,
308
+ random_state=42,
309
+ )
310
+
311
+ model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
312
+
313
+ depths_out = []
314
+ intrinsics_out = []
315
+ for i in range(len(pred_data.depth)):
316
+ orig_h, orig_w = orig_sizes[i]
317
+ img_idx = image_indices[i]
318
+
319
+ # Resize depth to original image size
320
+ depth = cv2.resize(
321
+ pred_data.depth[i],
322
+ (orig_w, orig_h),
323
+ interpolation=cv2.INTER_NEAREST,
324
+ )
325
+
326
+ # Load GT mask
327
+ gt_zero_mask = self._load_gt_mask(
328
+ full_gt_data.aux.gt_depth_files[img_idx],
329
+ full_gt_data.aux.aliasing_mask_files[img_idx],
330
+ )
331
+
332
+ # Mask invalid depths BEFORE scale
333
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
334
+
335
+ # Apply scale AFTER mask
336
+ depth = depth * scale
337
+
338
+ # Adjust intrinsics to original image size
339
+ h_ratio = orig_h / model_h
340
+ w_ratio = orig_w / model_w
341
+ ixt = pred_data.intrinsics[i].copy()
342
+ ixt[0, :] *= w_ratio
343
+ ixt[1, :] *= h_ratio
344
+
345
+ depths_out.append(depth)
346
+ intrinsics_out.append(ixt)
347
+
348
+ return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
349
+
350
+ def _prep_posed(
351
+ self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
352
+ image_indices: list, orig_sizes: list, scene: str = None
353
+ ) -> tuple:
354
+ """Prepare depths/intrinsics/extrinsics for recon_posed mode."""
355
+ # Scale alignment
356
+ _, _, scale, _ = align_poses_umeyama(
357
+ gt_data.extrinsics.copy(),
358
+ pred_data.extrinsics.copy(),
359
+ return_aligned=True,
360
+ ransac=True,
361
+ random_state=42,
362
+ )
363
+
364
+ depths_out = []
365
+ for i in range(len(pred_data.depth)):
366
+ orig_h, orig_w = orig_sizes[i]
367
+ img_idx = image_indices[i]
368
+
369
+ # Resize depth to original image size
370
+ depth = cv2.resize(
371
+ pred_data.depth[i],
372
+ (orig_w, orig_h),
373
+ interpolation=cv2.INTER_NEAREST,
374
+ )
375
+
376
+ # Load GT mask
377
+ gt_zero_mask = self._load_gt_mask(
378
+ full_gt_data.aux.gt_depth_files[img_idx],
379
+ full_gt_data.aux.aliasing_mask_files[img_idx],
380
+ )
381
+
382
+ # Mask invalid depths BEFORE scale
383
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
384
+
385
+ # Apply scale AFTER mask
386
+ depth = depth * scale
387
+
388
+ depths_out.append(depth)
389
+
390
+ # Use GT intrinsics and extrinsics
391
+ gt_intrinsics = np.stack([full_gt_data.intrinsics[idx] for idx in image_indices])
392
+ gt_extrinsics = np.stack([full_gt_data.extrinsics[idx] for idx in image_indices])
393
+
394
+ return np.stack(depths_out), gt_intrinsics, gt_extrinsics
395
+
396
+ def _load_gt_mask(self, gt_depth_path: str, aliasing_mask_path: str) -> np.ndarray:
397
+ """
398
+ Load GT depth and aliasing mask to create valid mask.
399
+
400
+ For HiRoom:
401
+ - GT depth is stored as 16-bit PNG, scaled to 100m range
402
+ - Aliasing mask marks regions to exclude
403
+
404
+ Returns:
405
+ Boolean mask where True = valid region to keep
406
+ """
407
+ # Load GT depth
408
+ if os.path.exists(gt_depth_path):
409
+ gt_depth = cv2.imread(gt_depth_path, -1) / 65535.0 * 100.0
410
+ else:
411
+ return None
412
+
413
+ # Load aliasing mask
414
+ aliasing_mask = None
415
+ if os.path.exists(aliasing_mask_path):
416
+ aliasing_mask = cv2.imread(aliasing_mask_path, -1) > 0
417
+
418
+ # Valid mask: depth > 0 and not in aliasing region
419
+ valid_mask = gt_depth > 0
420
+ if aliasing_mask is not None:
421
+ valid_mask = np.logical_and(valid_mask, np.logical_not(aliasing_mask))
422
+
423
+ return valid_mask
424
+
425
+ def _mask_invalid_depth(
426
+ self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
427
+ ) -> np.ndarray:
428
+ """Mask invalid depth values by setting them to 0."""
429
+ depth = depth.copy()
430
+
431
+ if gt_zero_mask is not None:
432
+ pred_invalid = np.isnan(depth) | np.isinf(depth)
433
+ combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
434
+ depth = depth * combined_mask.astype(np.float32)
435
+ else:
436
+ invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
437
+ depth[invalid_mask] = 0.0
438
+
439
+ return depth
440
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/scannetpp.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ ScanNet++ Benchmark dataset implementation.
17
+
18
+ ScanNet++ is a high-quality indoor RGB-D dataset with iPhone and DSLR images,
19
+ ground truth camera poses from COLMAP, and high-resolution 3D meshes.
20
+ Reference: https://kaldir.vc.in.tum.de/scannetpp/
21
+
22
+ Evaluation metrics:
23
+ - 3D reconstruction: Accuracy, Completeness, F-score
24
+ - Camera pose estimation: AUC metrics
25
+ """
26
+
27
+ import os
28
+ from typing import Dict as TDict
29
+
30
+ import cv2
31
+ import imageio
32
+ import numpy as np
33
+ import open3d as o3d
34
+ from addict import Dict
35
+
36
+ from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
37
+ from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
38
+ from depth_anything_3.bench.utils import (
39
+ create_tsdf_volume,
40
+ fuse_depth_to_tsdf,
41
+ nn_correspondance,
42
+ sample_points_from_mesh,
43
+ )
44
+ from depth_anything_3.utils.constants import (
45
+ SCANNETPP_DOWN_SAMPLE,
46
+ SCANNETPP_EVAL_DATA_ROOT,
47
+ SCANNETPP_EVAL_THRESHOLD,
48
+ SCANNETPP_INPUT_H,
49
+ SCANNETPP_INPUT_W,
50
+ SCANNETPP_MAX_DEPTH,
51
+ SCANNETPP_SAMPLING_NUMBER,
52
+ SCANNETPP_SCENES,
53
+ SCANNETPP_SDF_TRUNC,
54
+ SCANNETPP_VOXEL_LENGTH,
55
+ )
56
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
57
+ from depth_anything_3.utils.read_write_model import read_model
58
+
59
+
60
+ @MV_REGISTRY.register(name="scannetpp")
61
+ @MONO_REGISTRY.register(name="scannetpp")
62
+ class ScanNetPP(Dataset):
63
+ """
64
+ ScanNet++ Benchmark dataset wrapper for DepthAnything3 evaluation.
65
+
66
+ Supports:
67
+ - Camera pose estimation evaluation (AUC metrics)
68
+ - 3D reconstruction evaluation (Accuracy, Completeness, F-score)
69
+ - TSDF-based point cloud fusion
70
+
71
+ Dataset structure:
72
+ scannetpp/data/
73
+ ├── {scene_id}/
74
+ │ ├── merge_dslr_iphone/
75
+ │ │ ├── colmap/sparse_render_rgb/ # COLMAP reconstruction
76
+ │ │ ├── images/ # RGB images
77
+ │ │ └── render_depth/ # GT depth maps
78
+ │ └── scans/
79
+ │ └── mesh_aligned_0.05.ply # Ground truth mesh
80
+ """
81
+
82
+ data_root = SCANNETPP_EVAL_DATA_ROOT
83
+ SCENES = SCANNETPP_SCENES
84
+
85
+ # Input resolution after undistortion and resize
86
+ input_h = SCANNETPP_INPUT_H
87
+ input_w = SCANNETPP_INPUT_W
88
+
89
+ # Evaluation hyperparameters from constants
90
+ max_depth = SCANNETPP_MAX_DEPTH
91
+ sampling_number = SCANNETPP_SAMPLING_NUMBER
92
+ voxel_length = SCANNETPP_VOXEL_LENGTH
93
+ sdf_trunc = SCANNETPP_SDF_TRUNC
94
+ eval_threshold = SCANNETPP_EVAL_THRESHOLD
95
+ down_sample = SCANNETPP_DOWN_SAMPLE
96
+
97
+ def __init__(self):
98
+ super().__init__()
99
+ self._scene_cache = {}
100
+
101
+ # ------------------------------
102
+ # Public API
103
+ # ------------------------------
104
+
105
+ def get_data(self, scene: str) -> Dict:
106
+ """
107
+ Collect per-view image paths, intrinsics/extrinsics for a scene.
108
+
109
+ Only uses iPhone images (not DSLR).
110
+
111
+ Args:
112
+ scene: Scene identifier (e.g., "09c1414f1b")
113
+
114
+ Returns:
115
+ Dict with:
116
+ - image_files: List[str] - paths to images
117
+ - extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
118
+ - intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
119
+ - aux: Dict with gt_mesh_path, dist, roi, cam_hw, etc.
120
+ """
121
+ if scene in self._scene_cache:
122
+ return self._scene_cache[scene]
123
+
124
+ input_path = os.path.join(self.data_root, scene, "merge_dslr_iphone")
125
+ colmap_path = os.path.join(input_path, "colmap/sparse_render_rgb")
126
+ image_path = os.path.join(input_path, "images")
127
+ depth_path_dir = os.path.join(input_path, "render_depth")
128
+
129
+ # Read COLMAP model
130
+ cams, images, points3d = read_model(colmap_path)
131
+
132
+ # Map image names to IDs
133
+ name2id = {image.name: k for k, image in images.items()}
134
+ names = sorted([image.name for k, image in images.items()])
135
+ # Only use iPhone images
136
+ names = [name for name in names if "iphone" in name]
137
+
138
+ gt_mesh_path = os.path.join(
139
+ input_path.replace("merge_dslr_iphone", "scans"), "mesh_aligned_0.05.ply"
140
+ )
141
+
142
+ out = Dict({
143
+ "image_files": [],
144
+ "extrinsics": [],
145
+ "intrinsics": [],
146
+ "aux": Dict({
147
+ "gt_mesh_path": gt_mesh_path,
148
+ "dist_list": [],
149
+ "roi_list": [],
150
+ "cam_hw_list": [],
151
+ "ixt_raw_list": [],
152
+ "gt_depth_files": [],
153
+ }),
154
+ })
155
+
156
+ for name in names:
157
+ image = images[name2id[name]]
158
+ img_path = os.path.join(image_path, name)
159
+
160
+ if not os.path.exists(img_path):
161
+ continue
162
+
163
+ # Build extrinsics (world-to-camera)
164
+ ext = np.eye(4, dtype=np.float32)
165
+ ext[:3, :3] = image.qvec2rotmat()
166
+ ext[:3, 3] = image.tvec
167
+
168
+ # Get camera parameters
169
+ cam_id = image.camera_id
170
+ camera = cams[cam_id]
171
+ cam_height, cam_width = camera.height, camera.width
172
+
173
+ # Build intrinsics
174
+ ixt = np.eye(3, dtype=np.float32)
175
+ ixt[0, 0], ixt[1, 1], ixt[0, 2], ixt[1, 2] = camera.params[:4]
176
+ ixt[:2, 2] -= 0.5 # COLMAP convention adjustment
177
+ ixt_raw = ixt.copy()
178
+
179
+ # Handle distortion (OPENCV model)
180
+ dist = np.zeros(5, dtype=np.float32)
181
+ roi = (0, 0, cam_width, cam_height)
182
+ if camera.model == "OPENCV":
183
+ dist[:4] = camera.params[4:]
184
+ ixt, roi = cv2.getOptimalNewCameraMatrix(
185
+ ixt, dist, (cam_width, cam_height), 1, (cam_width, cam_height)
186
+ )
187
+
188
+ # Depth file path
189
+ frame_name = os.path.basename(name)[:-4] # Remove .jpg
190
+ depth_file = os.path.join(depth_path_dir, f"{frame_name}.png")
191
+
192
+ out.image_files.append(img_path)
193
+ out.extrinsics.append(ext)
194
+ out.intrinsics.append(ixt)
195
+ out.aux.dist_list.append(dist)
196
+ out.aux.roi_list.append(roi)
197
+ out.aux.cam_hw_list.append((cam_height, cam_width))
198
+ out.aux.ixt_raw_list.append(ixt_raw)
199
+ out.aux.gt_depth_files.append(depth_file)
200
+
201
+ out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
202
+ out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
203
+
204
+ print(f"[ScanNet++] {scene}: {len(out.image_files)} images")
205
+
206
+ self._scene_cache[scene] = out
207
+ return out
208
+
209
+ def load_image(self, img_path: str, idx: int, aux: Dict) -> np.ndarray:
210
+ """
211
+ Load and preprocess image with undistortion and cropping.
212
+
213
+ Args:
214
+ img_path: Path to image file
215
+ idx: Index of the image in the dataset
216
+ aux: Auxiliary data from get_data
217
+
218
+ Returns:
219
+ Preprocessed RGB image
220
+ """
221
+ image = imageio.imread(img_path).astype(np.uint8)
222
+ ixt_raw = aux.ixt_raw_list[idx]
223
+ ixt = aux.intrinsics[idx] if hasattr(aux, 'intrinsics') else None
224
+ dist = aux.dist_list[idx]
225
+ roi = aux.roi_list[idx]
226
+
227
+ # Undistort using raw intrinsics
228
+ # Use the stored intrinsics from get_data for newCameraMatrix
229
+ stored_ixt = self._scene_cache.get(aux.scene, {}).get('intrinsics', [None])[idx] if hasattr(aux, 'scene') else None
230
+ if stored_ixt is None:
231
+ # Recompute optimal camera matrix for undistortion
232
+ cam_h, cam_w = aux.cam_hw_list[idx]
233
+ ixt_for_undistort = ixt_raw.copy()
234
+ ixt_for_undistort, _ = cv2.getOptimalNewCameraMatrix(
235
+ ixt_raw, dist, (cam_w, cam_h), 1, (cam_w, cam_h)
236
+ )
237
+ else:
238
+ ixt_for_undistort = stored_ixt
239
+
240
+ image = cv2.undistort(image, ixt_raw, dist, newCameraMatrix=ixt_for_undistort)
241
+
242
+ # Crop to ROI
243
+ x, y, w, h = roi
244
+ image = image[y:y+h, x:x+w]
245
+
246
+ # Resize to target resolution
247
+ image = cv2.resize(image, (self.input_w, self.input_h), interpolation=cv2.INTER_AREA)
248
+
249
+ return image
250
+
251
+ def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
252
+ """
253
+ Evaluate fused point cloud against ScanNet++ ground truth mesh.
254
+
255
+ Uses AABB cropping to only evaluate points within GT bounding box.
256
+
257
+ Args:
258
+ scene: Scene identifier
259
+ fuse_path: Path to fused point cloud (.ply)
260
+
261
+ Returns:
262
+ Dict with metrics: acc, comp, overall, precision, recall, fscore
263
+ """
264
+ gt_data = self.get_data(scene)
265
+ gt_mesh_path = gt_data.aux.gt_mesh_path
266
+
267
+ # Load ground truth mesh and sample points
268
+ gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
269
+ gt_pcd = sample_points_from_mesh(gt_mesh, self.sampling_number)
270
+
271
+ # Load predicted point cloud
272
+ pred_pcd = o3d.io.read_point_cloud(fuse_path)
273
+
274
+ # Crop prediction to GT bounding box (with 0.1m margin)
275
+ aabb = gt_pcd.get_axis_aligned_bounding_box()
276
+ points = np.asarray(pred_pcd.points)
277
+ inside_mask = (
278
+ (points[:, 0] >= aabb.min_bound[0] - 0.1) &
279
+ (points[:, 0] <= aabb.max_bound[0] + 0.1) &
280
+ (points[:, 1] >= aabb.min_bound[1] - 0.1) &
281
+ (points[:, 1] <= aabb.max_bound[1] + 0.1) &
282
+ (points[:, 2] >= aabb.min_bound[2] - 0.1) &
283
+ (points[:, 2] <= aabb.max_bound[2] + 0.1)
284
+ )
285
+ pred_pcd = pred_pcd.select_by_index(inside_mask.nonzero()[0])
286
+
287
+ # Downsample
288
+ if self.down_sample > 0:
289
+ pred_pcd = pred_pcd.voxel_down_sample(self.down_sample)
290
+ gt_pcd = gt_pcd.voxel_down_sample(self.down_sample)
291
+
292
+ verts_pred = np.asarray(pred_pcd.points)
293
+ verts_gt = np.asarray(gt_pcd.points)
294
+
295
+ if len(verts_pred) == 0 or len(verts_gt) == 0:
296
+ return {
297
+ "acc": float("inf"),
298
+ "comp": float("inf"),
299
+ "overall": float("inf"),
300
+ "precision": 0.0,
301
+ "recall": 0.0,
302
+ "fscore": 0.0,
303
+ }
304
+
305
+ # Compute distances
306
+ dist_pred_to_gt = nn_correspondance(verts_gt, verts_pred)
307
+ dist_gt_to_pred = nn_correspondance(verts_pred, verts_gt)
308
+
309
+ # Compute metrics
310
+ accuracy = float(np.mean(dist_pred_to_gt))
311
+ completeness = float(np.mean(dist_gt_to_pred))
312
+ overall = (accuracy + completeness) / 2
313
+
314
+ precision = float(np.mean((dist_pred_to_gt < self.eval_threshold).astype(float)))
315
+ recall = float(np.mean((dist_gt_to_pred < self.eval_threshold).astype(float)))
316
+
317
+ if precision + recall > 0:
318
+ fscore = 2 * precision * recall / (precision + recall)
319
+ else:
320
+ fscore = 0.0
321
+
322
+ return {
323
+ "acc": accuracy,
324
+ "comp": completeness,
325
+ "overall": overall,
326
+ "precision": precision,
327
+ "recall": recall,
328
+ "fscore": fscore,
329
+ }
330
+
331
+ def _load_gt_meta(self, result_path: str) -> Dict:
332
+ """Load saved GT meta for fusion."""
333
+ export_dir = os.path.dirname(result_path)
334
+ gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
335
+
336
+ if os.path.exists(gt_meta_path):
337
+ data = np.load(gt_meta_path, allow_pickle=True)
338
+ image_files = list(data["image_files"])
339
+
340
+ # Reconstruct aux data from image files
341
+ return Dict({
342
+ "extrinsics": data["extrinsics"],
343
+ "intrinsics": data["intrinsics"],
344
+ "image_files": image_files,
345
+ })
346
+ return None
347
+
348
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
349
+ """
350
+ Fuse per-view depths into a point cloud using TSDF fusion.
351
+
352
+ Args:
353
+ scene: Scene identifier
354
+ result_path: Path to npz file with predicted depths/poses
355
+ fuse_path: Output path for fused point cloud (.ply)
356
+ mode: "recon_unposed" or "recon_posed"
357
+ """
358
+ # Get GT data
359
+ full_gt_data = self.get_data(scene)
360
+
361
+ # Try to load saved GT meta (handles frame sampling)
362
+ gt_meta = self._load_gt_meta(result_path)
363
+ if gt_meta is not None:
364
+ gt_data = gt_meta
365
+ # Need to rebuild aux from full GT data based on image indices
366
+ image_indices = [
367
+ full_gt_data.image_files.index(f)
368
+ for f in gt_data.image_files
369
+ if f in full_gt_data.image_files
370
+ ]
371
+ else:
372
+ gt_data = full_gt_data
373
+ image_indices = list(range(len(full_gt_data.image_files)))
374
+
375
+ _wait_for_file_ready(result_path)
376
+ pred_data = Dict({k: v for k, v in np.load(result_path).items()})
377
+
378
+ # Load and preprocess images
379
+ images = []
380
+ for idx, img_idx in enumerate(image_indices):
381
+ img_path = full_gt_data.image_files[img_idx]
382
+ image = imageio.imread(img_path).astype(np.uint8)
383
+
384
+ # Undistort and crop
385
+ ixt_raw = full_gt_data.aux.ixt_raw_list[img_idx]
386
+ ixt = full_gt_data.intrinsics[img_idx]
387
+ dist = full_gt_data.aux.dist_list[img_idx]
388
+ roi = full_gt_data.aux.roi_list[img_idx]
389
+
390
+ image = cv2.undistort(image, ixt_raw, dist, newCameraMatrix=ixt)
391
+ x, y, w, h = roi
392
+ image = image[y:y+h, x:x+w]
393
+ image = cv2.resize(image, (self.input_w, self.input_h), interpolation=cv2.INTER_AREA)
394
+
395
+ images.append(image)
396
+
397
+ images = np.stack(images, axis=0)
398
+
399
+ # Prepare depths, intrinsics, extrinsics
400
+ if mode == "recon_unposed":
401
+ depths, intrinsics, extrinsics = self._prep_unposed(
402
+ pred_data, gt_data, full_gt_data, image_indices, scene=scene
403
+ )
404
+ elif mode == "recon_posed":
405
+ depths, intrinsics, extrinsics = self._prep_posed(
406
+ pred_data, gt_data, full_gt_data, image_indices, scene=scene
407
+ )
408
+ else:
409
+ raise ValueError(f"Invalid mode: {mode}")
410
+
411
+ # Create TSDF volume and fuse
412
+ volume = create_tsdf_volume(
413
+ voxel_length=self.voxel_length,
414
+ sdf_trunc=self.sdf_trunc,
415
+ )
416
+ mesh = fuse_depth_to_tsdf(
417
+ volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
418
+ )
419
+
420
+ # Sample points from mesh
421
+ pcd = sample_points_from_mesh(mesh, self.sampling_number)
422
+
423
+ # Save point cloud
424
+ os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
425
+ o3d.io.write_point_cloud(fuse_path, pcd)
426
+
427
+ # ------------------------------
428
+ # Private helpers
429
+ # ------------------------------
430
+
431
+ def _prep_unposed(
432
+ self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
433
+ image_indices: list, scene: str = None
434
+ ) -> tuple:
435
+ """Prepare depths/intrinsics/extrinsics for recon_unposed mode."""
436
+ # Scale alignment with fixed random_state for reproducibility
437
+ _, _, scale, extrinsics = align_poses_umeyama(
438
+ gt_data.extrinsics.copy(),
439
+ pred_data.extrinsics.copy(),
440
+ return_aligned=True,
441
+ ransac=True,
442
+ random_state=42,
443
+ )
444
+
445
+ model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
446
+
447
+ depths_out = []
448
+ intrinsics_out = []
449
+ for i in range(len(pred_data.depth)):
450
+ img_idx = image_indices[i]
451
+
452
+ # Get original image size (after undistort+crop, before resize to input_h/w)
453
+ orig_h, orig_w = full_gt_data.aux.cam_hw_list[img_idx]
454
+
455
+ # Step 1: nearest resize to original image size
456
+ depth = cv2.resize(
457
+ pred_data.depth[i],
458
+ (orig_w, orig_h),
459
+ interpolation=cv2.INTER_NEAREST,
460
+ )
461
+
462
+ # Step 2: linear resize to target resolution
463
+ depth = cv2.resize(
464
+ depth,
465
+ (self.input_w, self.input_h),
466
+ interpolation=cv2.INTER_LINEAR,
467
+ ).astype(np.float32)
468
+
469
+ # Load GT depth for masking
470
+ gt_zero_mask = self._load_gt_mask(full_gt_data.aux.gt_depth_files[img_idx])
471
+
472
+ # Mask invalid depths BEFORE scale
473
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
474
+
475
+ # Apply scale AFTER mask
476
+ depth = depth * scale
477
+
478
+ # Adjust intrinsics to target resolution
479
+ h_ratio = self.input_h / model_h
480
+ w_ratio = self.input_w / model_w
481
+ ixt = pred_data.intrinsics[i].copy()
482
+ ixt[0, :] *= w_ratio
483
+ ixt[1, :] *= h_ratio
484
+
485
+ depths_out.append(depth)
486
+ intrinsics_out.append(ixt)
487
+
488
+ return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
489
+
490
+ def _prep_posed(
491
+ self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
492
+ image_indices: list, scene: str = None
493
+ ) -> tuple:
494
+ """Prepare depths/intrinsics/extrinsics for recon_posed mode."""
495
+ # Scale alignment
496
+ _, _, scale, _ = align_poses_umeyama(
497
+ gt_data.extrinsics.copy(),
498
+ pred_data.extrinsics.copy(),
499
+ return_aligned=True,
500
+ ransac=True,
501
+ random_state=42,
502
+ )
503
+
504
+ depths_out = []
505
+ intrinsics_out = []
506
+ extrinsics_out = []
507
+
508
+ for i in range(len(pred_data.depth)):
509
+ img_idx = image_indices[i]
510
+
511
+ # Get original image size (after undistort+crop, before resize to input_h/w)
512
+ orig_h, orig_w = full_gt_data.aux.cam_hw_list[img_idx]
513
+
514
+ # Step 1: nearest resize to original image size
515
+ depth = cv2.resize(
516
+ pred_data.depth[i],
517
+ (orig_w, orig_h),
518
+ interpolation=cv2.INTER_NEAREST,
519
+ )
520
+
521
+ # Step 2: linear resize to target resolution
522
+ depth = cv2.resize(
523
+ depth,
524
+ (self.input_w, self.input_h),
525
+ interpolation=cv2.INTER_LINEAR,
526
+ ).astype(np.float32)
527
+
528
+ # Load GT depth for masking
529
+ gt_zero_mask = self._load_gt_mask(full_gt_data.aux.gt_depth_files[img_idx])
530
+
531
+ # Mask invalid depths BEFORE scale
532
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
533
+
534
+ # Apply scale AFTER mask
535
+ depth = depth * scale
536
+
537
+ depths_out.append(depth)
538
+
539
+ # Get GT intrinsics and scale to target resolution
540
+ ixt = full_gt_data.intrinsics[img_idx].copy()
541
+ cam_h, cam_w = full_gt_data.aux.cam_hw_list[img_idx]
542
+ ixt[:2, 2] += 0.5 # Undo COLMAP convention
543
+ ixt[0, :] *= self.input_w / cam_w
544
+ ixt[1, :] *= self.input_h / cam_h
545
+ intrinsics_out.append(ixt)
546
+
547
+ extrinsics_out.append(full_gt_data.extrinsics[img_idx])
548
+
549
+ return np.stack(depths_out), np.stack(intrinsics_out), np.stack(extrinsics_out)
550
+
551
+ def _load_gt_mask(self, gt_depth_path: str) -> np.ndarray:
552
+ """
553
+ Load GT depth and create valid mask.
554
+
555
+ For ScanNet++, GT depth is stored as 16-bit PNG in millimeters.
556
+
557
+ Returns:
558
+ Boolean mask where True = valid region to keep
559
+ """
560
+ if not os.path.exists(gt_depth_path):
561
+ return None
562
+
563
+ gt_depth = imageio.imread(gt_depth_path) / 1000.0 # mm to meters
564
+
565
+ # Resize to target resolution
566
+ gt_depth = cv2.resize(
567
+ gt_depth,
568
+ (self.input_w, self.input_h),
569
+ interpolation=cv2.INTER_LINEAR,
570
+ ).astype(np.float32)
571
+
572
+ # Valid mask: depth > 0 and not inf
573
+ valid_mask = np.logical_and(gt_depth > 0, gt_depth != np.inf)
574
+ return valid_mask
575
+
576
+ def _mask_invalid_depth(
577
+ self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
578
+ ) -> np.ndarray:
579
+ """Mask invalid depth values by setting them to 0."""
580
+ depth = depth.copy()
581
+
582
+ if gt_zero_mask is not None:
583
+ pred_invalid = np.isnan(depth) | np.isinf(depth)
584
+ combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
585
+ depth = depth * combined_mask.astype(np.float32)
586
+ else:
587
+ invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
588
+ depth[invalid_mask] = 0.0
589
+
590
+ return depth
591
+
Depth-Anything-3/src/depth_anything_3/bench/datasets/sevenscenes.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ 7Scenes Benchmark dataset implementation.
17
+
18
+ 7Scenes is an indoor RGB-D dataset with ground truth camera poses and 3D meshes.
19
+ Reference: https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/
20
+
21
+ Evaluation metrics:
22
+ - 3D reconstruction: Accuracy, Completeness, F-score
23
+ - Camera pose estimation: AUC metrics
24
+ """
25
+
26
+ import os
27
+ from typing import Dict as TDict
28
+
29
+ import cv2
30
+ import numpy as np
31
+ import open3d as o3d
32
+ from addict import Dict
33
+
34
+ from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
35
+ from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
36
+ from depth_anything_3.bench.utils import (
37
+ create_tsdf_volume,
38
+ evaluate_3d_reconstruction,
39
+ fuse_depth_to_tsdf,
40
+ sample_points_from_mesh,
41
+ )
42
+ from depth_anything_3.utils.constants import (
43
+ SEVENSCENES_CX,
44
+ SEVENSCENES_CY,
45
+ SEVENSCENES_DOWN_SAMPLE,
46
+ SEVENSCENES_EVAL_DATA_ROOT,
47
+ SEVENSCENES_EVAL_THRESHOLD,
48
+ SEVENSCENES_FX,
49
+ SEVENSCENES_FY,
50
+ SEVENSCENES_MAX_DEPTH,
51
+ SEVENSCENES_SAMPLING_NUMBER,
52
+ SEVENSCENES_SCENES,
53
+ SEVENSCENES_SDF_TRUNC,
54
+ SEVENSCENES_VOXEL_LENGTH,
55
+ )
56
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
57
+
58
+
59
+ @MV_REGISTRY.register(name="7scenes")
60
+ @MONO_REGISTRY.register(name="7scenes")
61
+ class SevenScenes(Dataset):
62
+ """
63
+ 7Scenes Benchmark dataset wrapper for DepthAnything3 evaluation.
64
+
65
+ Supports:
66
+ - Camera pose estimation evaluation (AUC metrics)
67
+ - 3D reconstruction evaluation (Accuracy, Completeness, F-score)
68
+ - TSDF-based point cloud fusion
69
+
70
+ Dataset structure:
71
+ 7scenes/
72
+ ├── 7Scenes/
73
+ │ ├── {scene}/
74
+ │ │ └── seq-01/ (or seq-02 for stairs)
75
+ │ │ ├── frame-XXXXXX.color.png
76
+ │ │ ├── frame-XXXXXX.depth.png
77
+ │ │ └── frame-XXXXXX.pose.txt
78
+ │ └── meshes/
79
+ │ └── {scene}.ply # Ground truth mesh
80
+ """
81
+
82
+ data_root = SEVENSCENES_EVAL_DATA_ROOT
83
+ SCENES = SEVENSCENES_SCENES
84
+
85
+ # Evaluation hyperparameters from constants
86
+ max_depth = SEVENSCENES_MAX_DEPTH
87
+ sampling_number = SEVENSCENES_SAMPLING_NUMBER
88
+ voxel_length = SEVENSCENES_VOXEL_LENGTH
89
+ sdf_trunc = SEVENSCENES_SDF_TRUNC
90
+ eval_threshold = SEVENSCENES_EVAL_THRESHOLD
91
+ down_sample = SEVENSCENES_DOWN_SAMPLE
92
+
93
+ # Fixed camera intrinsics for all 7Scenes images
94
+ fx = SEVENSCENES_FX
95
+ fy = SEVENSCENES_FY
96
+ cx = SEVENSCENES_CX
97
+ cy = SEVENSCENES_CY
98
+
99
+ def __init__(self):
100
+ super().__init__()
101
+ self._scene_cache = {}
102
+
103
+ # ------------------------------
104
+ # Public API
105
+ # ------------------------------
106
+
107
+ def get_data(self, scene: str) -> Dict:
108
+ """
109
+ Collect per-view image paths, intrinsics/extrinsics for a scene.
110
+
111
+ Args:
112
+ scene: Scene identifier (e.g., "chess")
113
+
114
+ Returns:
115
+ Dict with:
116
+ - image_files: List[str] - paths to images
117
+ - extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
118
+ - intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
119
+ - aux: Dict with gt_mesh_path, gt_depth_files
120
+ """
121
+ if scene in self._scene_cache:
122
+ return self._scene_cache[scene]
123
+
124
+ # Different sequence for stairs scene
125
+ if scene == "stairs":
126
+ data_folder = os.path.join(self.data_root, "7Scenes", scene, "seq-02")
127
+ n_imgs = 500
128
+ else:
129
+ data_folder = os.path.join(self.data_root, "7Scenes", scene, "seq-01")
130
+ n_imgs = 1000
131
+
132
+ gt_mesh_path = os.path.join(self.data_root, "7Scenes", "meshes", f"{scene}.ply")
133
+
134
+ # Fixed intrinsics for all images
135
+ ixt = np.array([
136
+ [self.fx, 0, self.cx],
137
+ [0, self.fy, self.cy],
138
+ [0, 0, 1],
139
+ ], dtype=np.float32)
140
+
141
+ out = Dict({
142
+ "image_files": [],
143
+ "extrinsics": [],
144
+ "intrinsics": [],
145
+ "aux": Dict({
146
+ "gt_mesh_path": gt_mesh_path,
147
+ "gt_depth_files": [],
148
+ }),
149
+ })
150
+
151
+ for i in range(0, n_imgs, 1):
152
+ img_path = os.path.join(data_folder, f"frame-{i:06d}.color.png")
153
+ pose_path = os.path.join(data_folder, f"frame-{i:06d}.pose.txt")
154
+ depth_path = os.path.join(data_folder, f"frame-{i:06d}.depth.png")
155
+
156
+ if not os.path.exists(img_path) or not os.path.exists(pose_path):
157
+ continue
158
+
159
+ # Load camera-to-world pose and convert to world-to-camera (extrinsic)
160
+ c2w = np.loadtxt(pose_path)
161
+ ext = np.linalg.inv(c2w).astype(np.float32)
162
+
163
+ out.image_files.append(img_path)
164
+ out.extrinsics.append(ext)
165
+ out.intrinsics.append(ixt.copy())
166
+ out.aux.gt_depth_files.append(depth_path)
167
+
168
+ out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
169
+ out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
170
+
171
+ print(f"[7Scenes] {scene}: {len(out.image_files)} images")
172
+
173
+ self._scene_cache[scene] = out
174
+ return out
175
+
176
+ def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
177
+ """
178
+ Evaluate fused point cloud against 7Scenes ground truth mesh.
179
+
180
+ Args:
181
+ scene: Scene identifier
182
+ fuse_path: Path to fused point cloud (.ply)
183
+
184
+ Returns:
185
+ Dict with metrics: acc, comp, overall, precision, recall, fscore
186
+ """
187
+ gt_data = self.get_data(scene)
188
+ gt_mesh_path = gt_data.aux.gt_mesh_path
189
+
190
+ # Load and sample ground truth mesh
191
+ gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
192
+ gt_pcd = sample_points_from_mesh(gt_mesh, self.sampling_number)
193
+
194
+ # Load predicted point cloud
195
+ pred_pcd = o3d.io.read_point_cloud(fuse_path)
196
+
197
+ # Evaluate using shared utility function
198
+ metrics = evaluate_3d_reconstruction(
199
+ pred_pcd,
200
+ gt_pcd,
201
+ threshold=self.eval_threshold,
202
+ down_sample=self.down_sample,
203
+ )
204
+
205
+ return metrics
206
+
207
+ def _load_gt_meta(self, result_path: str) -> Dict:
208
+ """
209
+ Load saved GT meta (extrinsics, intrinsics, image_files) for fusion.
210
+
211
+ This is needed when frames are sampled, so fuse3d uses the correct
212
+ (sampled) GT instead of full dataset GT.
213
+
214
+ Args:
215
+ result_path: Path to npz file (used to derive gt_meta.npz path)
216
+
217
+ Returns:
218
+ Dict with GT data, or None if gt_meta.npz doesn't exist
219
+ """
220
+ export_dir = os.path.dirname(result_path) # exports/mini_npz/
221
+ gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
222
+
223
+ if os.path.exists(gt_meta_path):
224
+ data = np.load(gt_meta_path, allow_pickle=True)
225
+ # Build aux with gt_depth_files derived from image_files
226
+ image_files = list(data["image_files"])
227
+ gt_depth_files = [
228
+ img_path.replace("color", "depth").replace(".color.", ".depth.")
229
+ for img_path in image_files
230
+ ]
231
+ return Dict({
232
+ "extrinsics": data["extrinsics"],
233
+ "intrinsics": data["intrinsics"],
234
+ "image_files": image_files,
235
+ "aux": Dict({"gt_depth_files": gt_depth_files}),
236
+ })
237
+ return None
238
+
239
+ def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
240
+ """
241
+ Fuse per-view depths into a point cloud using TSDF fusion.
242
+
243
+ Args:
244
+ scene: Scene identifier
245
+ result_path: Path to npz file with predicted depths/poses
246
+ fuse_path: Output path for fused point cloud (.ply)
247
+ mode: "recon_unposed" or "recon_posed"
248
+ """
249
+ # Try to load saved GT meta (handles frame sampling)
250
+ gt_meta = self._load_gt_meta(result_path)
251
+ if gt_meta is not None:
252
+ gt_data = gt_meta
253
+ else:
254
+ gt_data = self.get_data(scene)
255
+ _wait_for_file_ready(result_path)
256
+ pred_data = Dict({k: v for k, v in np.load(result_path).items()})
257
+
258
+ # Load original images (keep original size)
259
+ images = []
260
+ orig_sizes = []
261
+ for img_path in gt_data.image_files:
262
+ img = cv2.imread(img_path)
263
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
264
+ images.append(img)
265
+ orig_sizes.append((img.shape[0], img.shape[1]))
266
+
267
+ # Prepare depths, intrinsics, extrinsics
268
+ if mode == "recon_unposed":
269
+ depths, intrinsics, extrinsics = self._prep_unposed(
270
+ pred_data, gt_data, orig_sizes, scene=scene
271
+ )
272
+ elif mode == "recon_posed":
273
+ depths, intrinsics, extrinsics = self._prep_posed(
274
+ pred_data, gt_data, orig_sizes, scene=scene
275
+ )
276
+ else:
277
+ raise ValueError(f"Invalid mode: {mode}")
278
+
279
+ images = np.stack(images, axis=0)
280
+
281
+ # Create TSDF volume and fuse
282
+ volume = create_tsdf_volume(
283
+ voxel_length=self.voxel_length,
284
+ sdf_trunc=self.sdf_trunc,
285
+ )
286
+ mesh = fuse_depth_to_tsdf(
287
+ volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
288
+ )
289
+
290
+ # Sample points from mesh
291
+ pcd = sample_points_from_mesh(mesh, self.sampling_number)
292
+
293
+ # Save point cloud
294
+ os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
295
+ o3d.io.write_point_cloud(fuse_path, pcd)
296
+
297
+ # ------------------------------
298
+ # Private helpers
299
+ # ------------------------------
300
+
301
+ def _prep_unposed(
302
+ self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str
303
+ ) -> tuple:
304
+ """
305
+ Prepare depths/intrinsics/extrinsics for recon_unposed mode.
306
+
307
+ Similar to ETH3D but uses GT depth for masking instead of separate mask files.
308
+ """
309
+ # Scale alignment with fixed random_state for reproducibility
310
+ _, _, scale, extrinsics = align_poses_umeyama(
311
+ gt_data.extrinsics.copy(),
312
+ pred_data.extrinsics.copy(),
313
+ return_aligned=True,
314
+ ransac=True,
315
+ random_state=42,
316
+ )
317
+
318
+ model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
319
+
320
+ depths_out = []
321
+ intrinsics_out = []
322
+ for i in range(len(pred_data.depth)):
323
+ orig_h, orig_w = orig_sizes[i]
324
+
325
+ # Resize depth to original image size (nearest interpolation)
326
+ depth = cv2.resize(
327
+ pred_data.depth[i],
328
+ (orig_w, orig_h),
329
+ interpolation=cv2.INTER_NEAREST,
330
+ )
331
+
332
+ # Load GT depth for masking
333
+ gt_zero_mask = self._load_gt_mask(gt_data.aux.gt_depth_files[i])
334
+
335
+ # Mask invalid depths BEFORE scale
336
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
337
+
338
+ # Apply scale AFTER mask
339
+ depth = depth * scale
340
+
341
+ # Adjust intrinsics to original image size
342
+ h_ratio = orig_h / model_h
343
+ w_ratio = orig_w / model_w
344
+ ixt = pred_data.intrinsics[i].copy()
345
+ ixt[0, :] *= w_ratio
346
+ ixt[1, :] *= h_ratio
347
+
348
+ depths_out.append(depth)
349
+ intrinsics_out.append(ixt)
350
+
351
+ return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
352
+
353
+ def _prep_posed(
354
+ self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str
355
+ ) -> tuple:
356
+ """
357
+ Prepare depths/intrinsics/extrinsics for recon_posed mode.
358
+ Uses GT intrinsics/extrinsics but aligns depth scale via Umeyama.
359
+ """
360
+ # Scale alignment with fixed random_state
361
+ _, _, scale, _ = align_poses_umeyama(
362
+ gt_data.extrinsics.copy(),
363
+ pred_data.extrinsics.copy(),
364
+ return_aligned=True,
365
+ ransac=True,
366
+ random_state=42,
367
+ )
368
+
369
+ model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
370
+
371
+ depths_out = []
372
+ for i in range(len(pred_data.depth)):
373
+ orig_h, orig_w = orig_sizes[i]
374
+
375
+ # Resize depth to original image size
376
+ depth = cv2.resize(
377
+ pred_data.depth[i],
378
+ (orig_w, orig_h),
379
+ interpolation=cv2.INTER_NEAREST,
380
+ )
381
+
382
+ # Load GT depth for masking
383
+ gt_zero_mask = self._load_gt_mask(gt_data.aux.gt_depth_files[i])
384
+
385
+ # Mask invalid depths BEFORE scale
386
+ depth = self._mask_invalid_depth(depth, gt_zero_mask)
387
+
388
+ # Apply scale AFTER mask
389
+ depth = depth * scale
390
+
391
+ depths_out.append(depth)
392
+
393
+ # Use GT intrinsics and extrinsics
394
+ return np.stack(depths_out), gt_data.intrinsics.copy(), gt_data.extrinsics.copy()
395
+
396
+ def _load_gt_mask(self, gt_depth_path: str) -> np.ndarray:
397
+ """
398
+ Load GT depth and create valid mask.
399
+
400
+ For 7Scenes, GT depth is stored as 16-bit PNG in millimeters.
401
+ Value 65535 indicates invalid depth.
402
+
403
+ Returns:
404
+ Boolean mask where True = valid region to keep
405
+ """
406
+ if not os.path.exists(gt_depth_path):
407
+ return None
408
+
409
+ gt_depth = cv2.imread(gt_depth_path, -1)
410
+ if gt_depth is None:
411
+ return None
412
+
413
+ # 65535 is invalid depth marker in 7Scenes
414
+ gt_depth[gt_depth == 65535] = 0
415
+ # Convert to meters
416
+ gt_depth = gt_depth / 1000.0
417
+
418
+ # Valid mask: depth > 0
419
+ valid_mask = gt_depth > 0
420
+ return valid_mask
421
+
422
+ def _mask_invalid_depth(
423
+ self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
424
+ ) -> np.ndarray:
425
+ """
426
+ Mask invalid depth values by setting them to 0.
427
+
428
+ Args:
429
+ depth: Depth map to mask
430
+ gt_zero_mask: Optional GT mask (True = valid region)
431
+
432
+ Returns:
433
+ Masked depth map with invalid regions set to 0
434
+ """
435
+ depth = depth.copy()
436
+
437
+ if gt_zero_mask is not None:
438
+ # Also mask out invalid pred depth
439
+ pred_invalid = np.isnan(depth) | np.isinf(depth)
440
+ combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
441
+ depth = depth * combined_mask.astype(np.float32)
442
+ else:
443
+ # Fallback: only mask pred invalid values
444
+ invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
445
+ depth[invalid_mask] = 0.0
446
+
447
+ return depth
448
+
449
+
Depth-Anything-3/src/depth_anything_3/bench/evaluator.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Main Evaluator class for DepthAnything3 benchmark evaluation.
17
+
18
+ Supports multiple datasets and evaluation modes:
19
+ - pose: Camera pose estimation (AUC metrics)
20
+ - recon_unposed: 3D reconstruction with predicted poses
21
+ - recon_posed: 3D reconstruction with GT poses
22
+ - view_syn: Novel view synthesis (TODO)
23
+ """
24
+
25
+ import json
26
+ import os
27
+ import random
28
+ from typing import Dict as TDict, Iterable, List
29
+
30
+ import numpy as np
31
+ import torch
32
+ from addict import Dict
33
+ from tqdm import tqdm
34
+
35
+ from depth_anything_3.bench.print_metrics import MetricsPrinter
36
+ from depth_anything_3.utils.parallel_utils import parallel_execution
37
+ from depth_anything_3.bench.registries import MV_REGISTRY
38
+ from depth_anything_3.utils.constants import EVAL_REF_VIEW_STRATEGY
39
+
40
+
41
+ class Evaluator:
42
+ """
43
+ Main evaluation orchestrator for DepthAnything3 benchmarks.
44
+
45
+ Usage:
46
+ evaluator = Evaluator(
47
+ work_dir="./eval_workspace",
48
+ datas=["dtu"],
49
+ modes=["pose", "recon_unposed", "recon_posed"],
50
+ )
51
+ api = DepthAnything3.from_pretrained("...")
52
+ evaluator.infer(api)
53
+ metrics = evaluator.eval()
54
+ evaluator.print_metrics()
55
+ """
56
+
57
+ VALID_MODES = {"pose", "recon_unposed", "recon_posed", "view_syn"}
58
+
59
+ def __init__(
60
+ self,
61
+ work_dir: str = "./eval_workspace",
62
+ datas: List[str] = ("dtu",),
63
+ modes: List[str] = ("recon_unposed",),
64
+ ref_view_strategy: str = EVAL_REF_VIEW_STRATEGY,
65
+ scenes: List[str] = None,
66
+ debug: bool = False,
67
+ num_fusion_workers: int = 4,
68
+ max_frames: int = 100,
69
+ gpu_id: int = 0,
70
+ total_gpus: int = 1,
71
+ ):
72
+ """
73
+ Initialize the evaluator.
74
+
75
+ Args:
76
+ work_dir: Base directory for model outputs and metric files
77
+ datas: List of dataset names (must be registered in MV_REGISTRY)
78
+ modes: List of evaluation modes to run
79
+ ref_view_strategy: Reference view selection strategy for inference
80
+ ("first", "saddle_balanced", etc.)
81
+ scenes: Specific scenes to evaluate (None = all scenes)
82
+ debug: Enable verbose debug output
83
+ num_fusion_workers: Number of parallel workers for TSDF fusion (default: 4)
84
+ max_frames: Maximum number of frames per scene (default: 100).
85
+ If a scene has more frames, randomly sample to this limit.
86
+ Set to -1 to disable sampling.
87
+ gpu_id: GPU index for multi-GPU (0-indexed)
88
+ total_gpus: Total number of GPUs for task distribution
89
+ """
90
+ self.work_dir = work_dir
91
+ self.datas = list(datas)
92
+ self.modes = set(modes)
93
+ self.ref_view_strategy = ref_view_strategy
94
+ self.scenes_filter = scenes
95
+ self.debug = debug
96
+ self.num_fusion_workers = num_fusion_workers
97
+ self.max_frames = max_frames
98
+ self.gpu_id = gpu_id
99
+ self.total_gpus = total_gpus
100
+
101
+ # Validate modes
102
+ unknown = self.modes - self.VALID_MODES
103
+ if unknown:
104
+ raise ValueError(f"Unknown modes: {unknown}. Valid: {sorted(self.VALID_MODES)}")
105
+
106
+ os.makedirs(self.work_dir, exist_ok=True)
107
+
108
+ # Initialize datasets
109
+ self.datasets = Dict()
110
+ for data in self.datas:
111
+ if not MV_REGISTRY.has(data):
112
+ available = list(MV_REGISTRY.all().keys())
113
+ raise ValueError(f"Dataset '{data}' not found. Available: {available}")
114
+ self.datasets[data] = MV_REGISTRY.get(data)()
115
+
116
+ # Initialize metrics printer
117
+ self._printer = MetricsPrinter()
118
+
119
+ # -------------------- Public APIs -------------------- #
120
+
121
+ def all(self, api) -> TDict[str, dict]:
122
+ """
123
+ Run complete evaluation pipeline: inference + evaluation.
124
+
125
+ Args:
126
+ api: DepthAnything3 API instance
127
+
128
+ Returns:
129
+ Combined metrics dictionary
130
+ """
131
+ self.infer(api)
132
+ return self.eval()
133
+
134
+ def _get_scenes(self, dataset) -> List[str]:
135
+ """Get list of scenes to evaluate, optionally filtered."""
136
+ all_scenes = dataset.SCENES
137
+ if self.scenes_filter:
138
+ scenes = [s for s in all_scenes if s in self.scenes_filter]
139
+ if self.debug:
140
+ print(f"[DEBUG] Filtered scenes: {scenes} (from {len(all_scenes)} total)")
141
+ return scenes
142
+ return all_scenes
143
+
144
+ def infer(self, api, model_path: str = None) -> None:
145
+ """
146
+ Run inference according to requested modes.
147
+
148
+ - Unposed export if 'pose' or 'recon_unposed' is in modes
149
+ - Posed export if 'recon_posed' or 'view_syn' is in modes
150
+
151
+ Multi-GPU: Use --gpu_id and --total_gpus to distribute tasks.
152
+ Example: Launch 4 processes with gpu_id=0,1,2,3 and total_gpus=4
153
+
154
+ Args:
155
+ api: DepthAnything3 API instance
156
+ model_path: Model path (unused, kept for API compatibility)
157
+ """
158
+ need_unposed = {"pose", "recon_unposed"} & self.modes
159
+ need_posed = {"recon_posed", "view_syn"} & self.modes
160
+ export_format = "mini_npz-glb" if self.debug else "mini_npz"
161
+
162
+ # Collect all tasks
163
+ all_tasks = []
164
+ for data in self.datas:
165
+ dataset = self.datasets[data]
166
+ for scene in self._get_scenes(dataset):
167
+ all_tasks.append((data, scene))
168
+
169
+ # Distribute tasks across GPUs
170
+ if self.total_gpus > 1:
171
+ tasks = [t for i, t in enumerate(all_tasks) if i % self.total_gpus == self.gpu_id]
172
+ print(f"[INFO] GPU {self.gpu_id}/{self.total_gpus}: {len(tasks)}/{len(all_tasks)} tasks")
173
+ else:
174
+ tasks = all_tasks
175
+ print(f"[INFO] Total inference tasks: {len(tasks)}")
176
+
177
+ for data, scene in tqdm(tasks, desc=f"Inference (GPU {self.gpu_id})"):
178
+ dataset = self.datasets[data]
179
+ scene_data = dataset.get_data(scene)
180
+ scene_data = self._sample_frames(scene_data, scene)
181
+
182
+ if need_unposed:
183
+ export_dir = self._export_dir(data, scene, posed=False)
184
+ api.inference(
185
+ scene_data.image_files,
186
+ export_dir=export_dir,
187
+ export_format=export_format,
188
+ ref_view_strategy=self.ref_view_strategy,
189
+ )
190
+ self._save_gt_meta(export_dir, scene_data)
191
+
192
+ if need_posed:
193
+ export_dir = self._export_dir(data, scene, posed=True)
194
+ api.inference(
195
+ scene_data.image_files,
196
+ scene_data.extrinsics,
197
+ scene_data.intrinsics,
198
+ export_dir=export_dir,
199
+ export_format=export_format,
200
+ ref_view_strategy=self.ref_view_strategy,
201
+ )
202
+ self._save_gt_meta(export_dir, scene_data)
203
+
204
+ def eval(self) -> TDict[str, dict]:
205
+ """
206
+ Evaluate for all configured modes and write JSON files.
207
+
208
+ Evaluation order by mode (all datasets per mode):
209
+ 1. pose - all datasets
210
+ 2. recon_unposed - all datasets
211
+ 3. recon_posed - all datasets
212
+
213
+ Returns:
214
+ Summary mapping: {"<data>_<mode>": metrics_dict}
215
+ """
216
+ summary: TDict[str, dict] = {}
217
+
218
+ # Evaluate by mode (all datasets per mode)
219
+ if "pose" in self.modes:
220
+ print(f"\n{'='*60}")
221
+ print(f"📊 Evaluating POSE for all datasets...")
222
+ print(f"{'='*60}")
223
+ for data, result in self._eval_pose():
224
+ summary[f"{data}_pose"] = result
225
+
226
+ if "recon_unposed" in self.modes:
227
+ print(f"\n{'='*60}")
228
+ print(f"📊 Evaluating RECON_UNPOSED for all datasets...")
229
+ print(f"{'='*60}")
230
+ for data, result in self._eval_reconstruction("recon_unposed"):
231
+ summary[f"{data}_recon_unposed"] = result
232
+
233
+ if "recon_posed" in self.modes:
234
+ print(f"\n{'='*60}")
235
+ print(f"📊 Evaluating RECON_POSED for all datasets...")
236
+ print(f"{'='*60}")
237
+ for data, result in self._eval_reconstruction("recon_posed"):
238
+ summary[f"{data}_recon_posed"] = result
239
+
240
+ if "view_syn" in self.modes:
241
+ # TODO: Add view synthesis metrics here when available
242
+ pass
243
+
244
+ return summary
245
+
246
+ def print_metrics(self, metrics: TDict[str, dict] = None) -> None:
247
+ """
248
+ Print evaluation metrics in a beautiful tabular format.
249
+
250
+ Args:
251
+ metrics: Metrics dictionary. If None, loads from saved JSON files.
252
+ """
253
+ if metrics is None:
254
+ metrics = self._load_metrics()
255
+
256
+ self._printer.print_results(metrics)
257
+
258
+ # -------------------- Evaluation Methods -------------------- #
259
+
260
+ def _eval_pose(self) -> Iterable[tuple]:
261
+ """Compute pose-estimation metrics for each dataset and scene."""
262
+ os.makedirs(self._metric_dir, exist_ok=True)
263
+
264
+ for data in tqdm(self.datas, desc="Datasets (pose eval)"):
265
+ dataset = self.datasets[data]
266
+ dataset_results = Dict()
267
+ scenes = self._get_scenes(dataset)
268
+
269
+ for scene in tqdm(scenes, desc=f"{data} scenes", leave=False):
270
+ export_dir = self._export_dir(data, scene, posed=False)
271
+ result_path = os.path.join(export_dir, "exports", "mini_npz", "results.npz")
272
+
273
+ # Check if result file exists and is valid
274
+ if not os.path.exists(result_path):
275
+ print(f"\n[ERROR] Result file not found: {result_path}")
276
+ print(f"[ERROR] CWD: {os.getcwd()}")
277
+ print(f"[ERROR] Please run inference first (remove --eval_only)")
278
+ continue
279
+
280
+ try:
281
+ # Use saved GT meta (handles frame sampling correctly)
282
+ gt_meta = self._load_gt_meta(export_dir)
283
+ if gt_meta is not None:
284
+ result = self._compute_pose_with_gt(result_path, gt_meta)
285
+ else:
286
+ # Fallback to dataset GT (no sampling was done)
287
+ result = dataset.eval_pose(scene, result_path)
288
+ dataset_results[scene] = self._to_float_dict(result)
289
+ except Exception as e:
290
+ print(f"\n[ERROR] Failed to evaluate pose for {data}/{scene}: {e}")
291
+ print(f"[ERROR] File path: {os.path.abspath(result_path)}")
292
+ if self.debug:
293
+ import traceback
294
+ traceback.print_exc()
295
+ continue
296
+
297
+ if not dataset_results:
298
+ print(f"[WARNING] No valid results for {data}")
299
+ continue
300
+
301
+ dataset_results["mean"] = self._mean_of_dicts(dataset_results.values())
302
+ out_path = os.path.join(self._metric_dir, f"{data}_pose.json")
303
+ self._dump_json(out_path, dataset_results)
304
+ yield data, dataset_results
305
+
306
+ def _eval_reconstruction(self, mode: str) -> Iterable[tuple]:
307
+ """
308
+ Compute reconstruction metrics for each dataset and scene.
309
+
310
+ Args:
311
+ mode: "recon_unposed" or "recon_posed"
312
+ """
313
+ assert mode in {"recon_unposed", "recon_posed"}
314
+ os.makedirs(self._metric_dir, exist_ok=True)
315
+
316
+ posed_flag = mode == "recon_posed"
317
+
318
+ # Filter out datasets that don't support reconstruction (e.g., dtu64)
319
+ recon_datas = [d for d in self.datas if d != "dtu64"]
320
+
321
+ for data in tqdm(recon_datas, desc=f"Datasets ({mode} eval)"):
322
+ dataset = self.datasets[data]
323
+ dataset_results = Dict()
324
+ scenes = self._get_scenes(dataset)
325
+
326
+ # Prepare paths for all scenes
327
+ scene_list = []
328
+ result_paths = []
329
+ fuse_paths = []
330
+ for scene in scenes:
331
+ export_dir = self._export_dir(data, scene, posed=posed_flag)
332
+ result_path = os.path.join(export_dir, "exports", "mini_npz", "results.npz")
333
+ fuse_path = os.path.join(export_dir, "exports", "fuse", "pcd.ply")
334
+ scene_list.append(scene)
335
+ result_paths.append(result_path)
336
+ fuse_paths.append(fuse_path)
337
+
338
+ # Parallel fusion (default 4 workers)
339
+ # DTU uses CUDA operations in fusion, which doesn't work well with ThreadPool
340
+ use_sequential = (data == "dtu")
341
+ parallel_execution(
342
+ scene_list,
343
+ result_paths,
344
+ fuse_paths,
345
+ action=lambda s, rp, fp: dataset.fuse3d(s, rp, fp, mode),
346
+ num_processes=self.num_fusion_workers,
347
+ print_progress=True,
348
+ desc=f"{data} fusion",
349
+ sequential=use_sequential,
350
+ )
351
+
352
+ # Sequential evaluation (fast, no need to parallelize)
353
+ for scene, fuse_path in zip(scene_list, fuse_paths):
354
+ # DTU supports CPU-based evaluation
355
+ if data == "dtu" and hasattr(dataset, "eval3d"):
356
+ result = dataset.eval3d(scene, fuse_path)
357
+ else:
358
+ result = dataset.eval3d(scene, fuse_path)
359
+ dataset_results[scene] = self._to_float_dict(result)
360
+ print(f" {mode} | {data} | {scene}: {result}")
361
+
362
+ dataset_results["mean"] = self._mean_of_dicts(dataset_results.values())
363
+ out_path = os.path.join(self._metric_dir, f"{data}_{mode}.json")
364
+ self._dump_json(out_path, dataset_results)
365
+ yield data, dataset_results
366
+
367
+ # -------------------- Helpers -------------------- #
368
+
369
+ def _save_gt_meta(self, export_dir: str, scene_data: Dict) -> None:
370
+ """
371
+ Save GT extrinsics/intrinsics/image_files for evaluation.
372
+
373
+ This is needed when frames are sampled, so eval_pose and fuse3d can use
374
+ the correct (sampled) GT instead of full dataset GT.
375
+
376
+ Args:
377
+ export_dir: Export directory for the scene
378
+ scene_data: Sampled scene data
379
+ """
380
+ meta_path = os.path.join(export_dir, "exports", "gt_meta.npz")
381
+ os.makedirs(os.path.dirname(meta_path), exist_ok=True)
382
+ np.savez_compressed(
383
+ meta_path,
384
+ extrinsics=scene_data.extrinsics,
385
+ intrinsics=scene_data.intrinsics,
386
+ image_files=np.array(scene_data.image_files, dtype=object),
387
+ )
388
+
389
+ def _load_gt_meta(self, export_dir: str) -> Dict:
390
+ """
391
+ Load saved GT extrinsics/intrinsics for evaluation.
392
+
393
+ Returns:
394
+ Dict with extrinsics and intrinsics, or None if not found
395
+ """
396
+ meta_path = os.path.join(export_dir, "exports", "gt_meta.npz")
397
+ if os.path.exists(meta_path):
398
+ data = np.load(meta_path)
399
+ return Dict({
400
+ "extrinsics": data["extrinsics"],
401
+ "intrinsics": data["intrinsics"],
402
+ })
403
+ return None
404
+
405
+ def _compute_pose_with_gt(self, result_path: str, gt_meta: Dict) -> TDict[str, float]:
406
+ """
407
+ Compute pose metrics using saved GT meta (handles frame sampling).
408
+
409
+ Args:
410
+ result_path: Path to npz with predicted extrinsics
411
+ gt_meta: Dict with GT extrinsics from saved meta
412
+
413
+ Returns:
414
+ Dict with pose metrics
415
+ """
416
+ from depth_anything_3.bench.dataset import _wait_for_file_ready
417
+ from depth_anything_3.bench.utils import compute_pose
418
+ from depth_anything_3.utils.geometry import as_homogeneous
419
+
420
+ _wait_for_file_ready(result_path)
421
+ pred = np.load(result_path)
422
+ return compute_pose(
423
+ torch.from_numpy(as_homogeneous(pred["extrinsics"])),
424
+ torch.from_numpy(as_homogeneous(gt_meta["extrinsics"])),
425
+ )
426
+
427
+ def _sample_frames(self, scene_data: Dict, scene: str) -> Dict:
428
+ """
429
+ Sample frames if scene has more than max_frames.
430
+
431
+ Uses fixed random seed (42) for reproducibility.
432
+
433
+ Args:
434
+ scene_data: Scene data dict with image_files, extrinsics, intrinsics, aux
435
+ scene: Scene name (for logging)
436
+
437
+ Returns:
438
+ Sampled scene_data if num_frames > max_frames, otherwise original
439
+ """
440
+ if self.max_frames <= 0:
441
+ return scene_data
442
+
443
+ num_frames = len(scene_data.image_files)
444
+ if num_frames <= self.max_frames:
445
+ return scene_data
446
+
447
+ # Sample with fixed seed for reproducibility
448
+ random.seed(42)
449
+ indices = list(range(num_frames))
450
+ random.shuffle(indices)
451
+ sampled_indices = sorted(indices[:self.max_frames])
452
+
453
+ print(f" [Sampling] {scene}: {num_frames} -> {self.max_frames} frames")
454
+
455
+ # Create new scene_data with sampled frames
456
+ sampled = Dict()
457
+ sampled.image_files = [scene_data.image_files[i] for i in sampled_indices]
458
+ sampled.extrinsics = scene_data.extrinsics[sampled_indices]
459
+ sampled.intrinsics = scene_data.intrinsics[sampled_indices]
460
+
461
+ # Copy aux data, sampling lists if needed
462
+ sampled.aux = Dict()
463
+ for key, val in scene_data.aux.items():
464
+ if isinstance(val, list) and len(val) == num_frames:
465
+ sampled.aux[key] = [val[i] for i in sampled_indices]
466
+ elif isinstance(val, np.ndarray) and len(val) == num_frames:
467
+ sampled.aux[key] = val[sampled_indices]
468
+ else:
469
+ sampled.aux[key] = val
470
+
471
+ return sampled
472
+
473
+ @property
474
+ def _metric_dir(self) -> str:
475
+ """Directory for storing metric JSON files."""
476
+ return os.path.join(self.work_dir, "metric_results")
477
+
478
+ def _export_dir(self, data: str, scene: str, posed: bool) -> str:
479
+ """
480
+ Get export directory path.
481
+
482
+ Structure: .../model_results/{data}/{scene}/{posed|unposed}
483
+ """
484
+ suffix = "posed" if posed else "unposed"
485
+ export_dir = os.path.join(self.work_dir, "model_results", data, scene, suffix)
486
+ os.makedirs(export_dir, exist_ok=True)
487
+ return export_dir
488
+
489
+ @staticmethod
490
+ def _to_float_dict(d: TDict[str, float]) -> dict:
491
+ """Convert numpy scalars to plain Python floats for JSON safety."""
492
+ return {k: float(v) for k, v in d.items()}
493
+
494
+ @staticmethod
495
+ def _mean_of_dicts(dicts: Iterable[dict]) -> dict:
496
+ """Compute elementwise mean across a list of homogeneous metric dicts."""
497
+ dicts = list(dicts)
498
+ if not dicts:
499
+ return {}
500
+ keys = dicts[0].keys()
501
+ return {k: float(np.mean([d[k] for d in dicts]).item()) for k in keys}
502
+
503
+ @staticmethod
504
+ def _dump_json(path: str, obj: dict, indent: int = 4) -> None:
505
+ """Write JSON with UTF-8 and pretty indentation."""
506
+ os.makedirs(os.path.dirname(path), exist_ok=True)
507
+ with open(path, "w", encoding="utf-8") as f:
508
+ json.dump(obj, f, indent=indent, ensure_ascii=False)
509
+
510
+ def _load_metrics(self) -> TDict[str, dict]:
511
+ """Load evaluation metrics from JSON files."""
512
+ metrics = {}
513
+ metric_dir = self._metric_dir
514
+
515
+ if not os.path.exists(metric_dir):
516
+ return metrics
517
+
518
+ for filename in os.listdir(metric_dir):
519
+ if filename.endswith(".json"):
520
+ filepath = os.path.join(metric_dir, filename)
521
+ try:
522
+ with open(filepath, encoding="utf-8") as f:
523
+ data = json.load(f)
524
+ key = filename[:-5] # Remove .json extension
525
+ metrics[key] = data
526
+ except Exception as e:
527
+ print(f"Warning: Failed to read metrics file: {filename} - {e}")
528
+
529
+ return metrics
530
+
531
+
532
+ # -------------------- CLI Entry Point -------------------- #
533
+
534
+
535
+ if __name__ == "__main__":
536
+ import sys
537
+ from omegaconf import OmegaConf
538
+ from depth_anything_3.cfg import load_config
539
+
540
+ # Get default config path (relative to this file)
541
+ _default_config = os.path.join(
542
+ os.path.dirname(__file__), "configs", "eval_bench.yaml"
543
+ )
544
+
545
+ # Check for help flag first (we need to handle this before OmegaConf)
546
+ if "--help" in sys.argv or "-h" in sys.argv:
547
+ pass # Will handle after config loading
548
+
549
+ # Set up argv for OmegaConf processing
550
+ argv = sys.argv[1:]
551
+
552
+ # Check if user provides custom config
553
+ config_path = _default_config
554
+ if "--config" in argv:
555
+ config_idx = argv.index("--config")
556
+ if config_idx + 1 < len(argv):
557
+ config_path = argv[config_idx + 1]
558
+ # Remove --config and its value
559
+ argv = argv[:config_idx] + argv[config_idx + 2:]
560
+
561
+ # Print help if requested
562
+ if "--help" in sys.argv or "-h" in sys.argv:
563
+ print("""
564
+ DepthAnything3 Benchmark Evaluation
565
+
566
+ Usage:
567
+ python -m depth_anything_3.bench.evaluator [OPTIONS] [KEY=VALUE ...]
568
+
569
+ Configuration:
570
+ --config PATH Config YAML file (default: bench/configs/eval_bench.yaml)
571
+
572
+ Config Overrides (using dotlist notation):
573
+ model.path=VALUE Model path or HuggingFace ID
574
+ workspace.work_dir=VALUE Working directory for outputs
575
+ eval.datasets=[dataset1,dataset2] Datasets to evaluate (eth3d,7scenes,scannetpp,hiroom,dtu,dtu64)
576
+ eval.modes=[mode1,mode2] Evaluation modes (pose,recon_unposed,recon_posed)
577
+ eval.scenes=[scene1,scene2] Specific scenes to evaluate (null=all)
578
+ eval.max_frames=VALUE Max frames per scene (-1=no limit, default: 100)
579
+ eval.ref_view_strategy=VALUE Reference view strategy (default: first)
580
+ eval.eval_only=VALUE Only run evaluation (skip inference) (true/false)
581
+ eval.print_only=VALUE Only print saved metrics (true/false)
582
+ inference.num_fusion_workers=VALUE Number of parallel workers (default: 4)
583
+ inference.debug=VALUE Enable debug mode (true/false)
584
+
585
+ Special Flags:
586
+ --help, -h Show this help message
587
+
588
+ Multi-GPU:
589
+ Use CUDA_VISIBLE_DEVICES to specify GPUs (auto-detected and distributed)
590
+
591
+ Examples:
592
+ # Use default config
593
+ python -m depth_anything_3.bench.evaluator
594
+
595
+ # Override model path
596
+ python -m depth_anything_3.bench.evaluator model.path=depth-anything/DA3-LARGE
597
+
598
+ # Evaluate specific datasets and modes
599
+ python -m depth_anything_3.bench.evaluator \\
600
+ eval.datasets=[eth3d,hiroom] \\
601
+ eval.modes=[pose]
602
+
603
+ # Use custom config with overrides
604
+ python -m depth_anything_3.bench.evaluator \\
605
+ --config my_config.yaml \\
606
+ model.path=/path/to/model \\
607
+ eval.max_frames=50
608
+
609
+ # Multi-GPU inference (auto-distributed)
610
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m depth_anything_3.bench.evaluator
611
+
612
+ # Debug specific scenes
613
+ python -m depth_anything_3.bench.evaluator \\
614
+ eval.datasets=[eth3d] \\
615
+ eval.scenes=[courtyard] \\
616
+ inference.debug=true
617
+
618
+ # Only evaluate (skip inference)
619
+ python -m depth_anything_3.bench.evaluator eval.eval_only=true
620
+
621
+ # Only print saved metrics
622
+ python -m depth_anything_3.bench.evaluator eval.print_only=true
623
+
624
+ """)
625
+ sys.exit(0)
626
+
627
+ # Load config with CLI overrides using OmegaConf dotlist
628
+ # Example: python evaluator.py model.path=/path/to/model eval.datasets=[eth3d,dtu]
629
+ config = load_config(config_path, argv=argv)
630
+
631
+ # Extract config values
632
+ work_dir = config.workspace.work_dir
633
+ model_path = config.model.path
634
+ datasets = config.eval.datasets
635
+ modes = config.eval.modes
636
+ ref_view_strategy = config.eval.ref_view_strategy
637
+ scenes = config.eval.scenes
638
+ max_frames = config.eval.max_frames
639
+ eval_only = config.eval.eval_only
640
+ print_only = config.eval.print_only
641
+ debug = config.inference.debug
642
+ num_fusion_workers = config.inference.num_fusion_workers
643
+
644
+ # GPU settings: parse from CLI dotlist args (gpu_id=X total_gpus=Y)
645
+ # These are passed by the main process when spawning workers
646
+ gpu_id = 0
647
+ total_gpus = 1
648
+ for arg in argv:
649
+ if arg.startswith("gpu_id="):
650
+ gpu_id = int(arg.split("=")[1])
651
+ elif arg.startswith("total_gpus="):
652
+ total_gpus = int(arg.split("=")[1])
653
+
654
+ # Override dataset scenes if specified
655
+ if scenes:
656
+ print(f"[INFO] Running on specific scenes: {scenes}")
657
+
658
+ evaluator = Evaluator(
659
+ work_dir=work_dir,
660
+ datas=datasets,
661
+ modes=modes,
662
+ ref_view_strategy=ref_view_strategy,
663
+ scenes=scenes,
664
+ debug=debug,
665
+ num_fusion_workers=num_fusion_workers,
666
+ max_frames=max_frames,
667
+ gpu_id=gpu_id,
668
+ total_gpus=total_gpus,
669
+ )
670
+
671
+ if print_only:
672
+ evaluator.print_metrics()
673
+ elif eval_only:
674
+ metrics = evaluator.eval()
675
+ evaluator.print_metrics(metrics)
676
+ else:
677
+ # Parse CUDA_VISIBLE_DEVICES to get GPU list
678
+ # If not set, use all available GPUs
679
+ cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
680
+ if cuda_devices is not None and cuda_devices.strip():
681
+ gpu_list = [g.strip() for g in cuda_devices.split(",") if g.strip()]
682
+ else:
683
+ # CUDA_VISIBLE_DEVICES not set, use all available GPUs
684
+ num_available = torch.cuda.device_count()
685
+ gpu_list = [str(i) for i in range(num_available)] if num_available > 0 else ["0"]
686
+
687
+ # Auto multi-GPU: if multiple GPUs and not a worker process
688
+ is_worker = os.environ.get("_DA3_WORKER") == "1"
689
+
690
+ if len(gpu_list) > 1 and not is_worker:
691
+ # Launch worker processes
692
+ import subprocess
693
+
694
+ num_gpus = len(gpu_list)
695
+ print(f"[INFO] Detected {num_gpus} GPUs: {gpu_list}")
696
+ print(f"[INFO] Launching {num_gpus} workers...")
697
+
698
+ # Build base command
699
+ base_cmd = [sys.executable, "-m", "depth_anything_3.bench.evaluator"]
700
+ # Pass config via dotlist instead of CLI args
701
+ if config_path != _default_config:
702
+ base_cmd += ["--config", config_path]
703
+ base_cmd += [f"model.path={model_path}"]
704
+ base_cmd += [f"workspace.work_dir={work_dir}"]
705
+ base_cmd += [f"eval.datasets=[{','.join(datasets)}]"]
706
+ base_cmd += [f"eval.modes=[{','.join(modes)}]"]
707
+ if scenes:
708
+ base_cmd += [f"eval.scenes=[{','.join(scenes)}]"]
709
+ base_cmd += [f"eval.max_frames={max_frames}"]
710
+ base_cmd += [f"eval.ref_view_strategy={ref_view_strategy}"]
711
+ base_cmd += [f"inference.debug={str(debug).lower()}"]
712
+ base_cmd += [f"inference.num_fusion_workers={num_fusion_workers}"]
713
+
714
+ # Launch workers
715
+ processes = []
716
+ for idx, gpu_id in enumerate(gpu_list):
717
+ env = os.environ.copy()
718
+ env["CUDA_VISIBLE_DEVICES"] = gpu_id
719
+ env["_DA3_WORKER"] = "1" # Mark as worker process
720
+
721
+ cmd = base_cmd.copy()
722
+ # GPU-specific worker config
723
+ cmd += [f"gpu_id={idx}", f"total_gpus={num_gpus}"]
724
+
725
+ print(f"[INFO] Starting worker {idx} on GPU {gpu_id}")
726
+ p = subprocess.Popen(cmd, env=env)
727
+ processes.append(p)
728
+
729
+ # Wait for all workers
730
+ for p in processes:
731
+ p.wait()
732
+
733
+ print(f"[INFO] All {num_gpus} workers completed")
734
+
735
+ # Run evaluation after all inference is done
736
+ metrics = evaluator.eval()
737
+ evaluator.print_metrics(metrics)
738
+ else:
739
+ # Single GPU or worker process
740
+ from depth_anything_3.api import DepthAnything3
741
+
742
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
743
+ api = DepthAnything3.from_pretrained(model_path)
744
+ api = api.to(device)
745
+
746
+ evaluator.infer(api, model_path=model_path)
747
+
748
+ # Only run eval if single GPU mode (workers don't eval)
749
+ if not is_worker:
750
+ metrics = evaluator.eval()
751
+ evaluator.print_metrics(metrics)
752
+
Depth-Anything-3/src/depth_anything_3/bench/print_metrics.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Beautiful metrics printing utilities for benchmark evaluation.
17
+
18
+ Provides colorized, well-formatted tabular output for evaluation results.
19
+ Supports highlighting best/worst values and grouping by dataset/mode.
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import os
25
+ import re
26
+ from typing import Dict as TDict, List, Optional
27
+
28
+
29
+ # ANSI color codes for terminal output
30
+ class Colors:
31
+ """ANSI escape codes for terminal colors."""
32
+
33
+ RESET = "\033[0m"
34
+ BOLD = "\033[1m"
35
+ RED = "\033[31m"
36
+ GREEN = "\033[32m"
37
+ YELLOW = "\033[33m"
38
+ BLUE = "\033[34m"
39
+ MAGENTA = "\033[35m"
40
+ CYAN = "\033[36m"
41
+ WHITE = "\033[37m"
42
+
43
+ # Bold variants
44
+ BOLD_RED = "\033[1;31m"
45
+ BOLD_GREEN = "\033[1;32m"
46
+ BOLD_YELLOW = "\033[1;33m"
47
+ BOLD_BLUE = "\033[1;34m"
48
+ BOLD_MAGENTA = "\033[1;35m"
49
+ BOLD_CYAN = "\033[1;36m"
50
+
51
+ # Background
52
+ BG_DARK = "\033[48;5;236m"
53
+
54
+
55
+ def strip_ansi(text: str) -> str:
56
+ """Remove ANSI escape sequences from string for length calculation."""
57
+ ansi_escape = re.compile(r"\x1b\[[0-9;]*m")
58
+ return ansi_escape.sub("", text)
59
+
60
+
61
+ def colorize_value(
62
+ value: str,
63
+ is_best: bool = False,
64
+ is_worst: bool = False,
65
+ lower_is_better: bool = False,
66
+ ) -> str:
67
+ """
68
+ Apply color to a metric value based on whether it's best/worst.
69
+
70
+ Args:
71
+ value: String representation of the value
72
+ is_best: Whether this is the best value in its column
73
+ is_worst: Whether this is the worst value in its column
74
+ lower_is_better: If True, lower values are better (e.g., error metrics)
75
+
76
+ Returns:
77
+ Colorized string
78
+ """
79
+ if lower_is_better:
80
+ # For metrics like error/distance, lower is better
81
+ if is_best:
82
+ return f"{Colors.BOLD_GREEN}{value}{Colors.RESET}"
83
+ elif is_worst:
84
+ return f"{Colors.BOLD_RED}{value}{Colors.RESET}"
85
+ else:
86
+ # For metrics like accuracy/AUC, higher is better
87
+ if is_best:
88
+ return f"{Colors.BOLD_GREEN}{value}{Colors.RESET}"
89
+ elif is_worst:
90
+ return f"{Colors.BOLD_RED}{value}{Colors.RESET}"
91
+ return value
92
+
93
+
94
+ class MetricsPrinter:
95
+ """
96
+ Beautiful tabular metrics printer with color support.
97
+
98
+ Features:
99
+ - Colorized best/worst values
100
+ - Grouped by dataset and evaluation mode
101
+ - Automatic column width calculation
102
+ - Support for multiple input directories comparison
103
+ """
104
+
105
+ # Metrics where lower values are better
106
+ LOWER_IS_BETTER = {"comp", "acc", "overall", "error", "loss", "rmse", "mae"}
107
+
108
+ def __init__(self, use_color: bool = True):
109
+ """
110
+ Initialize the printer.
111
+
112
+ Args:
113
+ use_color: Whether to use ANSI colors in output
114
+ """
115
+ self.use_color = use_color
116
+
117
+ def print_results(self, metrics: TDict[str, dict], summary_only: bool = True) -> None:
118
+ """
119
+ Print evaluation metrics in a beautiful tabular format.
120
+
121
+ Args:
122
+ metrics: Dictionary mapping "dataset_mode" to metric results
123
+ summary_only: If True, only print summary table. If False, print per-dataset details too.
124
+ """
125
+ if not metrics:
126
+ print(f"\n{Colors.BOLD_RED}❌ No evaluation metrics found{Colors.RESET}")
127
+ return
128
+
129
+ if not summary_only:
130
+ self._print_header()
131
+ grouped = self._group_by_dataset(metrics)
132
+
133
+ for dataset, modes_data in grouped.items():
134
+ self._print_dataset_section(dataset, modes_data)
135
+
136
+ # Print summary table with average metrics across datasets
137
+ self._print_summary(metrics)
138
+
139
+ self._print_footer()
140
+
141
+ def print_comparison(
142
+ self,
143
+ metrics_list: List[TDict[str, dict]],
144
+ labels: List[str],
145
+ ) -> None:
146
+ """
147
+ Print comparison table for multiple evaluation runs.
148
+
149
+ Args:
150
+ metrics_list: List of metrics dictionaries
151
+ labels: Labels for each metrics dictionary
152
+ """
153
+ if not metrics_list or not all(metrics_list):
154
+ print(f"\n{Colors.BOLD_RED}❌ No metrics to compare{Colors.RESET}")
155
+ return
156
+
157
+ # Collect all datasets and modes
158
+ all_keys = set()
159
+ for metrics in metrics_list:
160
+ all_keys.update(metrics.keys())
161
+
162
+ self._print_header("COMPARISON")
163
+
164
+ for key in sorted(all_keys):
165
+ parts = key.rsplit("_", 1)
166
+ if len(parts) == 2:
167
+ dataset, mode = parts[0], parts[1]
168
+ else:
169
+ dataset, mode = key, "unknown"
170
+
171
+ print(f"\n{Colors.BOLD_CYAN}📊 {dataset.upper()} - {mode.upper()}{Colors.RESET}")
172
+ print("-" * 100)
173
+
174
+ # Collect metrics from all runs
175
+ all_metric_names = set()
176
+ for metrics in metrics_list:
177
+ if key in metrics and "mean" in metrics[key]:
178
+ all_metric_names.update(metrics[key]["mean"].keys())
179
+
180
+ if not all_metric_names:
181
+ continue
182
+
183
+ # Build comparison table
184
+ metric_width = max(15, max(len(m) for m in all_metric_names) + 2)
185
+ label_width = max(15, max(len(l) for l in labels) + 2)
186
+
187
+ # Header
188
+ header = f"{'Metric':<{metric_width}}"
189
+ for label in labels:
190
+ header += f"{label:<{label_width}}"
191
+ print(header)
192
+ print("-" * len(strip_ansi(header)))
193
+
194
+ # Collect values for highlighting
195
+ for metric_name in sorted(all_metric_names):
196
+ values = []
197
+ for metrics in metrics_list:
198
+ if key in metrics and "mean" in metrics[key]:
199
+ val = metrics[key]["mean"].get(metric_name)
200
+ values.append(val if val is not None else float("nan"))
201
+ else:
202
+ values.append(float("nan"))
203
+
204
+ # Find best/worst
205
+ valid_values = [v for v in values if not (v != v)] # Filter NaN
206
+ if valid_values:
207
+ lower_better = any(
208
+ lb in metric_name.lower() for lb in self.LOWER_IS_BETTER
209
+ )
210
+ best_val = min(valid_values) if lower_better else max(valid_values)
211
+ worst_val = max(valid_values) if lower_better else min(valid_values)
212
+ else:
213
+ best_val = worst_val = None
214
+
215
+ # Print row
216
+ row = f"{metric_name:<{metric_width}}"
217
+ for val in values:
218
+ if val != val: # NaN check
219
+ val_str = "N/A"
220
+ else:
221
+ val_str = f"{val:.4f}"
222
+ if self.use_color and len(valid_values) > 1:
223
+ lower_better = any(
224
+ lb in metric_name.lower() for lb in self.LOWER_IS_BETTER
225
+ )
226
+ is_best = abs(val - best_val) < 1e-8 if best_val else False
227
+ is_worst = abs(val - worst_val) < 1e-8 if worst_val else False
228
+ val_str_padded = f"{val_str:<{label_width}}"
229
+ val_str = colorize_value(
230
+ val_str_padded, is_best, is_worst, lower_better
231
+ )
232
+ row += val_str
233
+ continue
234
+ row += f"{val_str:<{label_width}}"
235
+ print(row)
236
+
237
+ self._print_footer()
238
+
239
+ def _print_header(self, title: str = "EVALUATION RESULTS") -> None:
240
+ """Print report header."""
241
+ width = 100
242
+ print()
243
+ print("=" * width)
244
+ print(f"{Colors.BOLD_CYAN}📊 DEPTH ANYTHING 3 {title}{Colors.RESET}")
245
+ print("=" * width)
246
+
247
+ def _print_footer(self) -> None:
248
+ """Print report footer."""
249
+ width = 100
250
+ print()
251
+ print("=" * width)
252
+ print(f"{Colors.BOLD_GREEN}✅ Evaluation Complete{Colors.RESET}")
253
+ print("=" * width)
254
+ print()
255
+
256
+ def _group_by_dataset(self, metrics: TDict[str, dict]) -> TDict[str, dict]:
257
+ """Group metrics by dataset."""
258
+ grouped = {}
259
+ for key, data in metrics.items():
260
+ if not isinstance(data, dict) or "mean" not in data:
261
+ continue
262
+ # Parse key format: "dataset_mode" (e.g., "dtu_recon_unposed")
263
+ parts = key.split("_", 1)
264
+ if len(parts) == 2:
265
+ dataset, mode = parts
266
+ if dataset not in grouped:
267
+ grouped[dataset] = {}
268
+ grouped[dataset][mode] = data
269
+ return grouped
270
+
271
+ def _print_dataset_section(self, dataset: str, modes_data: TDict[str, dict]) -> None:
272
+ """Print metrics section for a single dataset."""
273
+ print(f"\n{Colors.BOLD_MAGENTA}🔍 {dataset.upper()}{Colors.RESET}")
274
+ print("-" * 100)
275
+
276
+ # Collect all unique metrics across all modes
277
+ all_metrics = set()
278
+ for mode_data in modes_data.values():
279
+ all_metrics.update(mode_data["mean"].keys())
280
+ all_metrics = sorted(list(all_metrics))
281
+
282
+ if not all_metrics:
283
+ print(" No metrics available")
284
+ return
285
+
286
+ # Calculate column widths
287
+ metric_width = max(18, max(len(m) for m in all_metrics) + 2)
288
+ mode_width = 18
289
+ modes = list(modes_data.keys())
290
+
291
+ # Print header
292
+ header = f"{'Metric':<{metric_width}}"
293
+ for mode in modes:
294
+ header += f"{mode.upper():<{mode_width}}"
295
+ print(f"{Colors.BOLD}{header}{Colors.RESET}")
296
+ print("-" * len(header))
297
+
298
+ # Print each metric row
299
+ for metric in all_metrics:
300
+ row = f"{metric:<{metric_width}}"
301
+
302
+ # Collect values for this metric across modes
303
+ values = []
304
+ for mode in modes:
305
+ if metric in modes_data[mode]["mean"]:
306
+ values.append(modes_data[mode]["mean"][metric])
307
+ else:
308
+ values.append(None)
309
+
310
+ # Find best/worst values
311
+ valid_values = [v for v in values if v is not None]
312
+ if valid_values:
313
+ lower_better = any(lb in metric.lower() for lb in self.LOWER_IS_BETTER)
314
+ best_val = min(valid_values) if lower_better else max(valid_values)
315
+ worst_val = max(valid_values) if lower_better else min(valid_values)
316
+ else:
317
+ best_val = worst_val = None
318
+
319
+ # Format each value
320
+ for val in values:
321
+ if val is None:
322
+ row += f"{'N/A':<{mode_width}}"
323
+ else:
324
+ val_str = f"{val:.4f}"
325
+ if self.use_color and len(valid_values) > 1:
326
+ is_best = abs(val - best_val) < 1e-8 if best_val else False
327
+ is_worst = abs(val - worst_val) < 1e-8 if worst_val else False
328
+ lower_better = any(
329
+ lb in metric.lower() for lb in self.LOWER_IS_BETTER
330
+ )
331
+ # Pad before colorizing to maintain alignment
332
+ val_str_padded = f"{val_str:<{mode_width}}"
333
+ row += colorize_value(
334
+ val_str_padded, is_best, is_worst, lower_better
335
+ )
336
+ else:
337
+ row += f"{val_str:<{mode_width}}"
338
+ print(row)
339
+
340
+ # Show scene counts
341
+ scene_info = []
342
+ for mode, mode_data in modes_data.items():
343
+ scene_count = len([k for k in mode_data.keys() if k != "mean"])
344
+ scene_info.append(f"{mode}: {scene_count} scenes")
345
+ print(f"\n{Colors.CYAN}📈 {' | '.join(scene_info)}{Colors.RESET}")
346
+
347
+ def _print_summary(self, metrics: TDict[str, dict]) -> None:
348
+ """
349
+ Print summary table with key metrics across all datasets.
350
+
351
+ Format: One row per metric, datasets as columns.
352
+ Order: HiRoom, ETH3D, DTU, 7Scenes, ScanNet++, (DTU-64 for pose only)
353
+ """
354
+ print(f"\n{Colors.BOLD_CYAN}{'=' * 120}{Colors.RESET}")
355
+ print(f"{Colors.BOLD_CYAN}📊 SUMMARY{Colors.RESET}")
356
+ print(f"{Colors.BOLD_CYAN}{'=' * 120}{Colors.RESET}")
357
+
358
+ # Dataset display order and names
359
+ DATASET_ORDER = ["hiroom", "eth3d", "dtu", "7scenes", "scannetpp", "dtu64"]
360
+ DATASET_DISPLAY = {
361
+ "hiroom": "HiRoom",
362
+ "eth3d": "ETH3D",
363
+ "dtu": "DTU",
364
+ "7scenes": "7Scenes",
365
+ "scannetpp": "ScanNet++",
366
+ "dtu64": "DTU-64",
367
+ }
368
+
369
+ # Collect all metrics into a structured dict
370
+ # metric_data[dataset][mode] = {"Auc_3": x, "Auc_30": x, "fscore": x, "overall": x}
371
+ metric_data = {}
372
+ for key, data in metrics.items():
373
+ if not isinstance(data, dict) or "mean" not in data:
374
+ continue
375
+ parts = key.split("_", 1)
376
+ if len(parts) != 2:
377
+ continue
378
+ dataset, mode = parts
379
+ dataset_lower = dataset.lower()
380
+ if dataset_lower not in metric_data:
381
+ metric_data[dataset_lower] = {}
382
+ metric_data[dataset_lower][mode] = data["mean"]
383
+
384
+ col_width = 12
385
+
386
+ def fmt_val(val):
387
+ """Format value or return N/A."""
388
+ if val is None:
389
+ return "N/A"
390
+ return f"{val:.4f}"
391
+
392
+ def get_metric(dataset, mode, metric_name):
393
+ """Get metric value or None."""
394
+ if dataset not in metric_data:
395
+ return None
396
+ if mode not in metric_data[dataset]:
397
+ return None
398
+ return metric_data[dataset][mode].get(metric_name)
399
+
400
+ # ============ POSE METRICS ============
401
+ print(f"\n{Colors.BOLD_MAGENTA}🎯 POSE ESTIMATION{Colors.RESET}")
402
+
403
+ # Pose: show all datasets except DTU (keep DTU-64 only)
404
+ # Order: HiRoom, ETH3D, DTU-64, 7Scenes, ScanNet++
405
+ pose_datasets = ["hiroom", "eth3d", "dtu64", "7scenes", "scannetpp"]
406
+
407
+ # Header: Avg first, then datasets
408
+ header = f"{'Metric':<15}{'Avg':<{col_width}}"
409
+ for ds in pose_datasets:
410
+ header += f"{DATASET_DISPLAY[ds]:<{col_width}}"
411
+ print("-" * len(strip_ansi(header)))
412
+ print(f"{Colors.BOLD}{header}{Colors.RESET}")
413
+ print("-" * len(strip_ansi(header)))
414
+
415
+ # Helper to get metric with fallback names
416
+ def get_pose_metric(dataset, metric_name):
417
+ """Get pose metric with fallback for different naming conventions."""
418
+ # Try different naming conventions
419
+ names = {
420
+ "Auc3": ["Auc_3", "auc03", "auc_3", "AUC_3", "Auc3", "auc3"],
421
+ "Auc30": ["Auc_30", "auc30", "auc_30", "AUC_30", "Auc30"],
422
+ }
423
+ for name in names.get(metric_name, [metric_name]):
424
+ val = get_metric(dataset, "pose", name)
425
+ if val is not None:
426
+ return val
427
+ return None
428
+
429
+ # Auc3 row
430
+ values = []
431
+ for ds in pose_datasets:
432
+ val = get_pose_metric(ds, "Auc3")
433
+ if val is not None:
434
+ values.append(val)
435
+ avg = sum(values) / len(values) if values else None
436
+ row = f"{'Auc3':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
437
+ for ds in pose_datasets:
438
+ val = get_pose_metric(ds, "Auc3")
439
+ row += f"{fmt_val(val):<{col_width}}"
440
+ print(row)
441
+
442
+ # Auc30 row
443
+ values = []
444
+ for ds in pose_datasets:
445
+ val = get_pose_metric(ds, "Auc30")
446
+ if val is not None:
447
+ values.append(val)
448
+ avg = sum(values) / len(values) if values else None
449
+ row = f"{'Auc30':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
450
+ for ds in pose_datasets:
451
+ val = get_pose_metric(ds, "Auc30")
452
+ row += f"{fmt_val(val):<{col_width}}"
453
+ print(row)
454
+
455
+ # ============ RECON_UNPOSED METRICS ============
456
+ print(f"\n{Colors.BOLD_MAGENTA}🏗️ RECON_UNPOSED (Pred Pose){Colors.RESET}")
457
+
458
+ # For recon, exclude dtu64 from columns
459
+ recon_datasets = ["hiroom", "eth3d", "dtu", "7scenes", "scannetpp"]
460
+ avg_datasets = ["hiroom", "eth3d", "7scenes", "scannetpp"] # Exclude DTU from avg
461
+
462
+ # Header: Avg first, then datasets
463
+ header = f"{'Metric':<15}{'Avg*':<{col_width}}"
464
+ for ds in recon_datasets:
465
+ header += f"{DATASET_DISPLAY[ds]:<{col_width}}"
466
+ print("-" * len(strip_ansi(header)))
467
+ print(f"{Colors.BOLD}{header}{Colors.RESET}")
468
+ print("-" * len(strip_ansi(header)))
469
+
470
+ # F-score row (only metric for avg)
471
+ values = []
472
+ for ds in recon_datasets:
473
+ val = get_metric(ds, "recon_unposed", "fscore")
474
+ if val is not None and ds in avg_datasets:
475
+ values.append(val)
476
+ avg = sum(values) / len(values) if values else None
477
+ row = f"{'F-score':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
478
+ for ds in recon_datasets:
479
+ val = get_metric(ds, "recon_unposed", "fscore")
480
+ row += f"{fmt_val(val):<{col_width}}"
481
+ print(row)
482
+
483
+ # Overall row (avg over 4 datasets excluding DTU)
484
+ values = []
485
+ for ds in recon_datasets:
486
+ val = get_metric(ds, "recon_unposed", "overall")
487
+ if val is not None and ds in avg_datasets:
488
+ values.append(val)
489
+ avg = sum(values) / len(values) if values else None
490
+ row = f"{'Overall':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
491
+ for ds in recon_datasets:
492
+ val = get_metric(ds, "recon_unposed", "overall")
493
+ row += f"{fmt_val(val):<{col_width}}"
494
+ print(row)
495
+
496
+ # ============ RECON_POSED METRICS ============
497
+ print(f"\n{Colors.BOLD_MAGENTA}🏗️ RECON_POSED (GT Pose){Colors.RESET}")
498
+
499
+ # Header: Avg first, then datasets
500
+ header = f"{'Metric':<15}{'Avg*':<{col_width}}"
501
+ for ds in recon_datasets:
502
+ header += f"{DATASET_DISPLAY[ds]:<{col_width}}"
503
+ print("-" * len(strip_ansi(header)))
504
+ print(f"{Colors.BOLD}{header}{Colors.RESET}")
505
+ print("-" * len(strip_ansi(header)))
506
+
507
+ # F-score row (only metric for avg)
508
+ values = []
509
+ for ds in recon_datasets:
510
+ val = get_metric(ds, "recon_posed", "fscore")
511
+ if val is not None and ds in avg_datasets:
512
+ values.append(val)
513
+ avg = sum(values) / len(values) if values else None
514
+ row = f"{'F-score':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
515
+ for ds in recon_datasets:
516
+ val = get_metric(ds, "recon_posed", "fscore")
517
+ row += f"{fmt_val(val):<{col_width}}"
518
+ print(row)
519
+
520
+ # Overall row (avg over 4 datasets excluding DTU)
521
+ values = []
522
+ for ds in recon_datasets:
523
+ val = get_metric(ds, "recon_posed", "overall")
524
+ if val is not None and ds in avg_datasets:
525
+ values.append(val)
526
+ avg = sum(values) / len(values) if values else None
527
+ row = f"{'Overall':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
528
+ for ds in recon_datasets:
529
+ val = get_metric(ds, "recon_posed", "overall")
530
+ row += f"{fmt_val(val):<{col_width}}"
531
+ print(row)
532
+
533
+ print(f"\n{Colors.CYAN}* Avg F-score / Overall = average over HiRoom, ETH3D, 7Scenes, ScanNet++ (4 datasets){Colors.RESET}")
534
+
535
+
536
+ def load_metrics_from_dir(metric_dir: str) -> TDict[str, dict]:
537
+ """
538
+ Load all metrics JSON files from a directory.
539
+
540
+ Args:
541
+ metric_dir: Path to directory containing metric JSON files
542
+
543
+ Returns:
544
+ Dictionary mapping filename (without .json) to metric data
545
+ """
546
+ metrics = {}
547
+ if not os.path.exists(metric_dir):
548
+ return metrics
549
+
550
+ for filename in os.listdir(metric_dir):
551
+ if filename.endswith(".json"):
552
+ filepath = os.path.join(metric_dir, filename)
553
+ try:
554
+ with open(filepath, encoding="utf-8") as f:
555
+ content = f.read()
556
+ # Handle trailing commas in JSON
557
+ content = re.sub(r",\s*([\]\}])", r"\1", content)
558
+ data = json.loads(content)
559
+ key = filename[:-5]
560
+ metrics[key] = data
561
+ except Exception as e:
562
+ print(f"Warning: Failed to load {filename}: {e}")
563
+
564
+ return metrics
565
+
566
+
567
+ def main():
568
+ """Command-line interface for metrics printing."""
569
+ parser = argparse.ArgumentParser(
570
+ description="Print DepthAnything3 benchmark evaluation metrics."
571
+ )
572
+ parser.add_argument(
573
+ "--input_dir",
574
+ type=str,
575
+ default="./eval_workspace/metric_results",
576
+ help="Directory containing metric JSON files (comma-separated for comparison)",
577
+ )
578
+ parser.add_argument(
579
+ "--no_color",
580
+ action="store_true",
581
+ help="Disable colored output",
582
+ )
583
+ parser.add_argument(
584
+ "--key",
585
+ type=str,
586
+ default=None,
587
+ help="Specific metric key to highlight",
588
+ )
589
+ args = parser.parse_args()
590
+
591
+ # Support multiple directories for comparison
592
+ input_dirs = [d.strip() for d in args.input_dir.split(",") if d.strip()]
593
+
594
+ printer = MetricsPrinter(use_color=not args.no_color)
595
+
596
+ if len(input_dirs) == 1:
597
+ # Single directory - simple print
598
+ metrics = load_metrics_from_dir(input_dirs[0])
599
+ printer.print_results(metrics)
600
+ else:
601
+ # Multiple directories - comparison mode
602
+ metrics_list = []
603
+ labels = []
604
+ for d in input_dirs:
605
+ metrics = load_metrics_from_dir(d)
606
+ if metrics:
607
+ metrics_list.append(metrics)
608
+ labels.append(os.path.basename(d.rstrip("/")))
609
+
610
+ if metrics_list:
611
+ printer.print_comparison(metrics_list, labels)
612
+ else:
613
+ print("No metrics found in specified directories")
614
+
615
+
616
+ if __name__ == "__main__":
617
+ main()
618
+
Depth-Anything-3/src/depth_anything_3/bench/registries.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Auto-loading registry system for benchmark datasets.
17
+
18
+ This module provides registry classes that automatically discover and import
19
+ dataset implementations from the datasets subpackage on first access.
20
+ """
21
+
22
+ import importlib
23
+ import pkgutil
24
+ import threading
25
+
26
+ from depth_anything_3.utils.registry import Registry
27
+
28
+ __all__ = ["METRIC_REGISTRY", "MONO_REGISTRY", "MV_REGISTRY", "NVS_REGISTRY"]
29
+
30
+ # ---- Lazy import: Only scan and import all datasets submodules on first registry access ----
31
+ _loaded = False
32
+ _lock = threading.Lock()
33
+
34
+
35
+ def _import_all_datasets_once():
36
+ """
37
+ Scan and import all .py submodules under depth_anything_3.bench.datasets
38
+ (skip files/packages starting with underscore), to trigger @REGISTRY.register(...) in each module.
39
+ """
40
+ global _loaded
41
+ if _loaded:
42
+ return
43
+
44
+ with _lock:
45
+ if _loaded:
46
+ return
47
+
48
+ pkg_name = "depth_anything_3.bench.datasets"
49
+ pkg = importlib.import_module(pkg_name)
50
+ pkg_paths = list(getattr(pkg, "__path__", []))
51
+
52
+ for finder, name, ispkg in pkgutil.walk_packages(pkg_paths, prefix=pkg_name + "."):
53
+ base = name.rsplit(".", 1)[-1]
54
+ if base.startswith("_"):
55
+ continue
56
+ try:
57
+ importlib.import_module(name)
58
+ except Exception as e:
59
+ print(f"[datasets auto-import] Failed to import {name}: {e}")
60
+
61
+ _loaded = True
62
+
63
+
64
+ class AutoRegistry(Registry):
65
+ """Registry that ensures all datasets are auto-discovered and imported on first use."""
66
+
67
+ def get(self, name):
68
+ _import_all_datasets_once()
69
+ return super().get(name)
70
+
71
+ def all(self):
72
+ _import_all_datasets_once()
73
+ return super().all()
74
+
75
+ def has(self, name):
76
+ _import_all_datasets_once()
77
+ return name in self._map
78
+
79
+
80
+ # Four auto-lazy registry instances for different evaluation types
81
+ METRIC_REGISTRY = AutoRegistry() # For metric depth evaluation
82
+ MONO_REGISTRY = AutoRegistry() # For monocular depth evaluation
83
+ MV_REGISTRY = AutoRegistry() # For multi-view evaluation
84
+ NVS_REGISTRY = AutoRegistry() # For novel view synthesis evaluation
85
+
Depth-Anything-3/src/depth_anything_3/bench/utils.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Utility functions for benchmark evaluation.
17
+
18
+ Contains:
19
+ - Pose evaluation metrics (AUC) and helper functions
20
+ - 3D reconstruction evaluation metrics (Acc/Comp/F-score)
21
+ - Geometry utilities (quaternion conversion, etc.)
22
+ """
23
+
24
+ from typing import Dict as TDict, Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+ import open3d as o3d
28
+ import torch
29
+ from addict import Dict
30
+ from scipy.spatial import KDTree
31
+
32
+ from depth_anything_3.utils.geometry import mat_to_quat
33
+
34
+
35
+ # =============================================================================
36
+ # Geometry Utilities
37
+ # =============================================================================
38
+
39
+
40
+ def quat2rotmat(qvec: list) -> np.ndarray:
41
+ """
42
+ Convert quaternion (WXYZ order) to rotation matrix.
43
+
44
+ Args:
45
+ qvec: Quaternion as [w, x, y, z]
46
+
47
+ Returns:
48
+ 3x3 rotation matrix
49
+ """
50
+ rotmat = np.array(
51
+ [
52
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
53
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
54
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
55
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
56
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
57
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
58
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
59
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
60
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
61
+ ]
62
+ )
63
+ rotmat = rotmat.reshape(3, 3)
64
+ return rotmat
65
+
66
+
67
+ # =============================================================================
68
+ # 3D Reconstruction Evaluation
69
+ # =============================================================================
70
+
71
+
72
+ def nn_correspondance(verts1: np.ndarray, verts2: np.ndarray) -> np.ndarray:
73
+ """
74
+ Compute nearest neighbor distances from verts2 to verts1 using KDTree.
75
+
76
+ Args:
77
+ verts1: Reference point cloud [N, 3]
78
+ verts2: Query point cloud [M, 3]
79
+
80
+ Returns:
81
+ Distance array [M,] - distance from each point in verts2 to nearest in verts1
82
+ """
83
+ if len(verts1) == 0 or len(verts2) == 0:
84
+ return np.array([])
85
+
86
+ kdtree = KDTree(verts1)
87
+ distances, _ = kdtree.query(verts2)
88
+ return distances.reshape(-1)
89
+
90
+
91
+ def evaluate_3d_reconstruction(
92
+ pcd_pred: Union[o3d.geometry.PointCloud, np.ndarray],
93
+ pcd_trgt: Union[o3d.geometry.PointCloud, np.ndarray],
94
+ threshold: float = 0.05,
95
+ down_sample: Optional[float] = None,
96
+ ) -> TDict[str, float]:
97
+ """
98
+ Evaluate 3D reconstruction quality using standard metrics.
99
+
100
+ This function computes:
101
+ - Accuracy: Mean distance from predicted points to GT surface
102
+ - Completeness: Mean distance from GT points to predicted surface
103
+ - Overall: Average of accuracy and completeness
104
+ - Precision: Fraction of predicted points within threshold of GT
105
+ - Recall: Fraction of GT points within threshold of prediction
106
+ - F-score: Harmonic mean of precision and recall
107
+
108
+ Args:
109
+ pcd_pred: Predicted point cloud (Open3D or numpy array)
110
+ pcd_trgt: Ground truth point cloud (Open3D or numpy array)
111
+ threshold: Distance threshold for precision/recall (meters)
112
+ down_sample: Voxel size for downsampling (None to skip)
113
+
114
+ Returns:
115
+ Dict with metrics: acc, comp, overall, precision, recall, fscore
116
+ """
117
+ # Convert to Open3D if needed
118
+ if isinstance(pcd_pred, np.ndarray):
119
+ pcd_pred_o3d = o3d.geometry.PointCloud()
120
+ pcd_pred_o3d.points = o3d.utility.Vector3dVector(pcd_pred)
121
+ pcd_pred = pcd_pred_o3d
122
+ if isinstance(pcd_trgt, np.ndarray):
123
+ pcd_trgt_o3d = o3d.geometry.PointCloud()
124
+ pcd_trgt_o3d.points = o3d.utility.Vector3dVector(pcd_trgt)
125
+ pcd_trgt = pcd_trgt_o3d
126
+
127
+ # Downsample if requested
128
+ if down_sample is not None and down_sample > 0:
129
+ pcd_pred = pcd_pred.voxel_down_sample(down_sample)
130
+ pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)
131
+
132
+ verts_pred = np.asarray(pcd_pred.points)
133
+ verts_trgt = np.asarray(pcd_trgt.points)
134
+
135
+ # Handle empty point clouds
136
+ if len(verts_pred) == 0 or len(verts_trgt) == 0:
137
+ return {
138
+ "acc": float("inf"),
139
+ "comp": float("inf"),
140
+ "overall": float("inf"),
141
+ "precision": 0.0,
142
+ "recall": 0.0,
143
+ "fscore": 0.0,
144
+ }
145
+
146
+ # Compute distances
147
+ dist_pred_to_gt = nn_correspondance(verts_trgt, verts_pred) # Accuracy
148
+ dist_gt_to_pred = nn_correspondance(verts_pred, verts_trgt) # Completeness
149
+
150
+ # Compute metrics
151
+ accuracy = float(np.mean(dist_pred_to_gt))
152
+ completeness = float(np.mean(dist_gt_to_pred))
153
+ overall = (accuracy + completeness) / 2
154
+
155
+ precision = float(np.mean((dist_pred_to_gt < threshold).astype(float)))
156
+ recall = float(np.mean((dist_gt_to_pred < threshold).astype(float)))
157
+
158
+ if precision + recall > 0:
159
+ fscore = 2 * precision * recall / (precision + recall)
160
+ else:
161
+ fscore = 0.0
162
+
163
+ return {
164
+ "acc": accuracy,
165
+ "comp": completeness,
166
+ "overall": overall,
167
+ "precision": precision,
168
+ "recall": recall,
169
+ "fscore": fscore,
170
+ }
171
+
172
+
173
+ def create_tsdf_volume(
174
+ voxel_length: float = 4.0 / 512.0,
175
+ sdf_trunc: float = 0.04,
176
+ color_type: str = "RGB8",
177
+ ) -> o3d.pipelines.integration.ScalableTSDFVolume:
178
+ """
179
+ Create a scalable TSDF volume for depth fusion.
180
+
181
+ Args:
182
+ voxel_length: Size of each voxel
183
+ sdf_trunc: Truncation distance for SDF
184
+ color_type: Color integration type ("RGB8" or "Gray32")
185
+
186
+ Returns:
187
+ Initialized ScalableTSDFVolume
188
+ """
189
+ if color_type == "RGB8":
190
+ color_enum = o3d.pipelines.integration.TSDFVolumeColorType.RGB8
191
+ else:
192
+ color_enum = o3d.pipelines.integration.TSDFVolumeColorType.Gray32
193
+
194
+ volume = o3d.pipelines.integration.ScalableTSDFVolume(
195
+ voxel_length=voxel_length,
196
+ sdf_trunc=sdf_trunc,
197
+ color_type=color_enum,
198
+ )
199
+ return volume
200
+
201
+
202
+ def fuse_depth_to_tsdf(
203
+ volume: o3d.pipelines.integration.ScalableTSDFVolume,
204
+ depths: np.ndarray,
205
+ images: np.ndarray,
206
+ intrinsics: np.ndarray,
207
+ extrinsics: np.ndarray,
208
+ max_depth: float = 10.0,
209
+ ) -> o3d.geometry.TriangleMesh:
210
+ """
211
+ Fuse multiple depth maps into TSDF volume and extract mesh.
212
+
213
+ Args:
214
+ volume: TSDF volume to integrate into
215
+ depths: Depth maps [N, H, W]
216
+ images: RGB images [N, H, W, 3]
217
+ intrinsics: Camera intrinsics [N, 3, 3]
218
+ extrinsics: Camera extrinsics (world-to-camera) [N, 4, 4]
219
+ max_depth: Maximum depth for truncation
220
+
221
+ Returns:
222
+ Extracted triangle mesh
223
+ """
224
+ for i in range(len(depths)):
225
+ depth = depths[i]
226
+ image = images[i]
227
+ ixt = intrinsics[i]
228
+ ext = extrinsics[i]
229
+
230
+ h, w = depth.shape[:2]
231
+
232
+ # Create RGBD image
233
+ depth_o3d = o3d.geometry.Image(depth.astype(np.float32))
234
+ color_o3d = o3d.geometry.Image(image.astype(np.uint8))
235
+ rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
236
+ color_o3d,
237
+ depth_o3d,
238
+ depth_trunc=max_depth,
239
+ convert_rgb_to_intensity=False,
240
+ depth_scale=1.0,
241
+ )
242
+
243
+ # Create camera intrinsics
244
+ ixt_o3d = o3d.camera.PinholeCameraIntrinsic(
245
+ w, h, ixt[0, 0], ixt[1, 1], ixt[0, 2], ixt[1, 2]
246
+ )
247
+
248
+ # Integrate into volume
249
+ volume.integrate(rgbd, ixt_o3d, ext)
250
+
251
+ # Extract mesh
252
+ mesh = volume.extract_triangle_mesh()
253
+ return mesh
254
+
255
+
256
+ def sample_points_from_mesh(
257
+ mesh: o3d.geometry.TriangleMesh,
258
+ num_points: int = 1000000,
259
+ ) -> o3d.geometry.PointCloud:
260
+ """
261
+ Uniformly sample points from a triangle mesh.
262
+
263
+ Args:
264
+ mesh: Input triangle mesh
265
+ num_points: Number of points to sample
266
+
267
+ Returns:
268
+ Sampled point cloud
269
+ """
270
+ try:
271
+ pcd = mesh.sample_points_uniformly(number_of_points=num_points)
272
+ # Clamp colors to valid range [0, 1] for Open3D PLY export
273
+ if pcd.has_colors():
274
+ colors = np.asarray(pcd.colors)
275
+ colors = np.clip(colors, 0.0, 1.0)
276
+ pcd.colors = o3d.utility.Vector3dVector(colors)
277
+ except Exception:
278
+ # Fallback: create random points if mesh is invalid (with fixed seed for reproducibility)
279
+ rng = np.random.default_rng(seed=42)
280
+ points = rng.uniform(-1, 1, size=(num_points, 3))
281
+ pcd = o3d.geometry.PointCloud()
282
+ pcd.points = o3d.utility.Vector3dVector(points)
283
+ return pcd
284
+
285
+
286
+ # =============================================================================
287
+ # Pose Evaluation
288
+ # =============================================================================
289
+
290
+
291
+ def build_pair_index(N: int, B: int = 1):
292
+ """
293
+ Build indices for all possible pairs of frames.
294
+
295
+ Args:
296
+ N: Number of frames
297
+ B: Batch size
298
+
299
+ Returns:
300
+ i1, i2: Indices for all possible pairs
301
+ """
302
+ i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
303
+ i1, i2 = ((i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_])
304
+ return i1, i2
305
+
306
+
307
+ def compute_pose(pred_se3: torch.Tensor, gt_se3: torch.Tensor) -> Dict:
308
+ """
309
+ Compute pose estimation metrics between predicted and ground truth trajectories.
310
+
311
+ Args:
312
+ pred_se3: Predicted SE(3) transformations [N, 4, 4]
313
+ gt_se3: Ground truth SE(3) transformations [N, 4, 4]
314
+
315
+ Returns:
316
+ Dict with AUC metrics at different thresholds (auc30, auc15, auc05, auc03)
317
+ """
318
+ pred_se3 = align_to_first_camera(pred_se3)
319
+ gt_se3 = align_to_first_camera(gt_se3)
320
+
321
+ rel_rangle_deg, rel_tangle_deg = se3_to_relative_pose_error(pred_se3, gt_se3, len(pred_se3))
322
+ rError = rel_rangle_deg.cpu().numpy()
323
+ tError = rel_tangle_deg.cpu().numpy()
324
+
325
+ output = Dict()
326
+ output.auc30, _ = calculate_auc_np(rError, tError, max_threshold=30)
327
+ output.auc15, _ = calculate_auc_np(rError, tError, max_threshold=15)
328
+ output.auc05, _ = calculate_auc_np(rError, tError, max_threshold=5)
329
+ output.auc03, _ = calculate_auc_np(rError, tError, max_threshold=3)
330
+ return output
331
+
332
+
333
+ def align_to_first_camera(camera_poses: torch.Tensor) -> torch.Tensor:
334
+ """
335
+ Align all camera poses to the first camera's coordinate frame.
336
+
337
+ Args:
338
+ camera_poses: Camera poses as SE3 transformations [N, 4, 4]
339
+
340
+ Returns:
341
+ Aligned camera poses [N, 4, 4]
342
+ """
343
+ first_cam_extrinsic_inv = closed_form_inverse_se3(camera_poses[0][None])
344
+ aligned_poses = torch.matmul(camera_poses, first_cam_extrinsic_inv)
345
+ return aligned_poses
346
+
347
+
348
+ def rotation_angle(
349
+ rot_gt: torch.Tensor, rot_pred: torch.Tensor, batch_size: int = None, eps: float = 1e-15
350
+ ) -> torch.Tensor:
351
+ """
352
+ Calculate rotation angle error between ground truth and predicted rotations.
353
+
354
+ Args:
355
+ rot_gt: Ground truth rotation matrices
356
+ rot_pred: Predicted rotation matrices
357
+ batch_size: Batch size for reshaping the result
358
+ eps: Small value to avoid numerical issues
359
+
360
+ Returns:
361
+ Rotation angle error in degrees
362
+ """
363
+ q_pred = mat_to_quat(rot_pred)
364
+ q_gt = mat_to_quat(rot_gt)
365
+
366
+ loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
367
+ err_q = torch.arccos(1 - 2 * loss_q)
368
+
369
+ rel_rangle_deg = err_q * 180 / np.pi
370
+
371
+ if batch_size is not None:
372
+ rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
373
+
374
+ return rel_rangle_deg
375
+
376
+
377
+ def translation_angle(
378
+ tvec_gt: torch.Tensor,
379
+ tvec_pred: torch.Tensor,
380
+ batch_size: int = None,
381
+ ambiguity: bool = True,
382
+ ) -> torch.Tensor:
383
+ """
384
+ Calculate translation angle error between ground truth and predicted translations.
385
+
386
+ Args:
387
+ tvec_gt: Ground truth translation vectors
388
+ tvec_pred: Predicted translation vectors
389
+ batch_size: Batch size for reshaping the result
390
+ ambiguity: Whether to handle direction ambiguity
391
+
392
+ Returns:
393
+ Translation angle error in degrees
394
+ """
395
+ rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
396
+ rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
397
+
398
+ if ambiguity:
399
+ rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
400
+
401
+ if batch_size is not None:
402
+ rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
403
+
404
+ return rel_tangle_deg
405
+
406
+
407
+ def compare_translation_by_angle(
408
+ t_gt: torch.Tensor, t: torch.Tensor, eps: float = 1e-15, default_err: float = 1e6
409
+ ) -> torch.Tensor:
410
+ """
411
+ Normalize the translation vectors and compute the angle between them.
412
+
413
+ Args:
414
+ t_gt: Ground truth translation vectors
415
+ t: Predicted translation vectors
416
+ eps: Small value to avoid division by zero
417
+ default_err: Default error value for invalid cases
418
+
419
+ Returns:
420
+ Angular error between translation vectors in radians
421
+ """
422
+ t_norm = torch.norm(t, dim=1, keepdim=True)
423
+ t = t / (t_norm + eps)
424
+
425
+ t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
426
+ t_gt = t_gt / (t_gt_norm + eps)
427
+
428
+ loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
429
+ err_t = torch.acos(torch.sqrt(1 - loss_t))
430
+
431
+ err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
432
+ return err_t
433
+
434
+
435
+ def calculate_auc_np(
436
+ r_error: np.ndarray, t_error: np.ndarray, max_threshold: int = 30
437
+ ) -> tuple:
438
+ """
439
+ Calculate the Area Under the Curve (AUC) for the given error arrays.
440
+
441
+ Args:
442
+ r_error: Rotation error values in degrees
443
+ t_error: Translation error values in degrees
444
+ max_threshold: Maximum threshold value for binning
445
+
446
+ Returns:
447
+ Tuple of (AUC value, normalized histogram)
448
+ """
449
+ error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
450
+ max_errors = np.max(error_matrix, axis=1)
451
+ bins = np.arange(max_threshold + 1)
452
+ histogram, _ = np.histogram(max_errors, bins=bins)
453
+ num_pairs = float(len(max_errors))
454
+ normalized_histogram = histogram.astype(float) / num_pairs
455
+ return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
456
+
457
+
458
+ def se3_to_relative_pose_error(
459
+ pred_se3: torch.Tensor, gt_se3: torch.Tensor, num_frames: int
460
+ ) -> tuple:
461
+ """
462
+ Compute rotation and translation errors between predicted and ground truth poses.
463
+
464
+ Args:
465
+ pred_se3: Predicted SE(3) transformations
466
+ gt_se3: Ground truth SE(3) transformations
467
+ num_frames: Number of frames
468
+
469
+ Returns:
470
+ Tuple of (rotation angle errors, translation angle errors) in degrees
471
+ """
472
+ pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
473
+
474
+ # Compute relative camera poses between pairs using closed-form inverse
475
+ relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm(gt_se3[pair_idx_i2])
476
+ relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm(pred_se3[pair_idx_i2])
477
+
478
+ # Compute the difference in rotation and translation
479
+ rel_rangle_deg = rotation_angle(relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3])
480
+ rel_tangle_deg = translation_angle(relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3])
481
+
482
+ return rel_rangle_deg, rel_tangle_deg
483
+
484
+
485
+ def closed_form_inverse_se3(
486
+ se3: torch.Tensor, R: torch.Tensor = None, T: torch.Tensor = None
487
+ ) -> torch.Tensor:
488
+ """
489
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
490
+
491
+ Uses closed-form solution instead of torch.inverse() for numerical stability.
492
+
493
+ Args:
494
+ se3: Nx4x4 or Nx3x4 tensor of SE3 matrices
495
+ R: Optional Nx3x3 rotation matrices
496
+ T: Optional Nx3x1 translation vectors
497
+
498
+ Returns:
499
+ Inverted SE3 matrices with same shape as input
500
+ """
501
+ is_numpy = isinstance(se3, np.ndarray)
502
+
503
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
504
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
505
+
506
+ if R is None:
507
+ R = se3[:, :3, :3]
508
+ if T is None:
509
+ T = se3[:, :3, 3:]
510
+
511
+ if is_numpy:
512
+ R_transposed = np.transpose(R, (0, 2, 1))
513
+ top_right = -np.matmul(R_transposed, T)
514
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
515
+ else:
516
+ R_transposed = R.transpose(1, 2)
517
+ top_right = -torch.bmm(R_transposed, T)
518
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
519
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
520
+
521
+ inverted_matrix[:, :3, :3] = R_transposed
522
+ inverted_matrix[:, :3, 3:] = top_right
523
+
524
+ return inverted_matrix
525
+
Depth-Anything-3/src/depth_anything_3/cfg.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Configuration utility functions
17
+ """
18
+
19
+ import importlib
20
+ from pathlib import Path
21
+ from typing import Any, Callable, List, Union
22
+ from omegaconf import DictConfig, ListConfig, OmegaConf
23
+
24
+ try:
25
+ OmegaConf.register_new_resolver("eval", eval)
26
+ except Exception as e:
27
+ # if eval is not available, we can just pass
28
+ print(f"Error registering eval resolver: {e}")
29
+
30
+
31
+ def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
32
+ """
33
+ Load a configuration. Will resolve inheritance.
34
+ Supports both file paths and module paths (e.g., depth_anything_3.configs.giant).
35
+ """
36
+ # Check if path is a module path (contains dots but no slashes and doesn't end with .yaml)
37
+ if "." in path and "/" not in path and not path.endswith(".yaml"):
38
+ # It's a module path, load from package resources
39
+ path_parts = path.split(".")[1:]
40
+ config_path = Path(__file__).resolve().parent
41
+ for part in path_parts:
42
+ config_path = config_path.joinpath(part)
43
+ config_path = config_path.with_suffix(".yaml")
44
+ config = OmegaConf.load(str(config_path))
45
+ else:
46
+ # It's a file path (absolute, relative, or with .yaml extension)
47
+ config = OmegaConf.load(path)
48
+
49
+ if argv is not None:
50
+ config_argv = OmegaConf.from_dotlist(argv)
51
+ config = OmegaConf.merge(config, config_argv)
52
+ config = resolve_recursive(config, resolve_inheritance)
53
+ return config
54
+
55
+
56
+ def resolve_recursive(
57
+ config: Any,
58
+ resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
59
+ ) -> Any:
60
+ config = resolver(config)
61
+ if isinstance(config, DictConfig):
62
+ for k in config.keys():
63
+ v = config.get(k)
64
+ if isinstance(v, (DictConfig, ListConfig)):
65
+ config[k] = resolve_recursive(v, resolver)
66
+ if isinstance(config, ListConfig):
67
+ for i in range(len(config)):
68
+ v = config.get(i)
69
+ if isinstance(v, (DictConfig, ListConfig)):
70
+ config[i] = resolve_recursive(v, resolver)
71
+ return config
72
+
73
+
74
+ def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
75
+ """
76
+ Recursively resolve inheritance if the config contains:
77
+ __inherit__: path/to/parent.yaml or a ListConfig of such paths.
78
+ """
79
+ if isinstance(config, DictConfig):
80
+ inherit = config.pop("__inherit__", None)
81
+
82
+ if inherit:
83
+ inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
84
+
85
+ parent_config = None
86
+ for parent_path in inherit_list:
87
+ assert isinstance(parent_path, str)
88
+ parent_config = (
89
+ load_config(parent_path)
90
+ if parent_config is None
91
+ else OmegaConf.merge(parent_config, load_config(parent_path))
92
+ )
93
+
94
+ if len(config.keys()) > 0:
95
+ config = OmegaConf.merge(parent_config, config)
96
+ else:
97
+ config = parent_config
98
+ return config
99
+
100
+
101
+ def import_item(path: str, name: str) -> Any:
102
+ """
103
+ Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
104
+ """
105
+ return getattr(importlib.import_module(path), name)
106
+
107
+
108
+ def create_object(config: DictConfig) -> Any:
109
+ """
110
+ Create an object from config.
111
+ The config is expected to contains the following:
112
+ __object__:
113
+ path: path.to.module
114
+ name: MyClass
115
+ args: as_config | as_params (default to as_config)
116
+ """
117
+ config = DictConfig(config)
118
+ item = import_item(
119
+ path=config.__object__.path,
120
+ name=config.__object__.name,
121
+ )
122
+ args = config.__object__.get("args", "as_config")
123
+ if args == "as_config":
124
+ return item(config)
125
+ if args == "as_params":
126
+ config = OmegaConf.to_object(config)
127
+ config.pop("__object__")
128
+ return item(**config)
129
+ raise NotImplementedError(f"Unknown args type: {args}")
130
+
131
+
132
+ def create_dataset(path: str, *args, **kwargs) -> Any:
133
+ """
134
+ Create a dataset. Requires the file to contain a "create_dataset" function.
135
+ """
136
+ return import_item(path, "create_dataset")(*args, **kwargs)
137
+
138
+
139
+ def to_dict_recursive(config_obj):
140
+ if isinstance(config_obj, DictConfig):
141
+ return {k: to_dict_recursive(v) for k, v in config_obj.items()}
142
+ elif isinstance(config_obj, ListConfig):
143
+ return [to_dict_recursive(item) for item in config_obj]
144
+ return config_obj
Depth-Anything-3/src/depth_anything_3/cli.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Refactored Depth Anything 3 CLI
17
+ Clean, modular command-line interface
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import typer
24
+
25
+ from depth_anything_3.services import start_server
26
+ from depth_anything_3.services.gallery import gallery as gallery_main
27
+ from depth_anything_3.services.inference_service import run_inference
28
+ from depth_anything_3.services.input_handlers import (
29
+ ColmapHandler,
30
+ ImageHandler,
31
+ ImagesHandler,
32
+ InputHandler,
33
+ VideoHandler,
34
+ parse_export_feat,
35
+ )
36
+ from depth_anything_3.utils.constants import (
37
+ DEFAULT_EXPORT_DIR,
38
+ DEFAULT_GALLERY_DIR,
39
+ DEFAULT_GRADIO_DIR,
40
+ DEFAULT_MODEL,
41
+ )
42
+
43
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
44
+
45
+ app = typer.Typer(help="Depth Anything 3 - Video depth estimation CLI", add_completion=False)
46
+
47
+
48
+ # ============================================================================
49
+ # Input type detection utilities
50
+ # ============================================================================
51
+
52
+ # Supported file extensions
53
+ IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif"}
54
+ VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
55
+
56
+
57
+ def detect_input_type(input_path: str) -> str:
58
+ """
59
+ Detect input type from path.
60
+
61
+ Returns:
62
+ - "image": Single image file
63
+ - "images": Directory containing images
64
+ - "video": Video file
65
+ - "colmap": COLMAP directory structure
66
+ - "unknown": Cannot determine type
67
+ """
68
+ if not os.path.exists(input_path):
69
+ return "unknown"
70
+
71
+ # Check if it's a file
72
+ if os.path.isfile(input_path):
73
+ ext = os.path.splitext(input_path)[1].lower()
74
+ if ext in IMAGE_EXTENSIONS:
75
+ return "image"
76
+ elif ext in VIDEO_EXTENSIONS:
77
+ return "video"
78
+ return "unknown"
79
+
80
+ # Check if it's a directory
81
+ if os.path.isdir(input_path):
82
+ # Check for COLMAP structure
83
+ images_dir = os.path.join(input_path, "images")
84
+ sparse_dir = os.path.join(input_path, "sparse")
85
+
86
+ if os.path.isdir(images_dir) and os.path.isdir(sparse_dir):
87
+ return "colmap"
88
+
89
+ # Check if directory contains image files
90
+ for item in os.listdir(input_path):
91
+ item_path = os.path.join(input_path, item)
92
+ if os.path.isfile(item_path):
93
+ ext = os.path.splitext(item)[1].lower()
94
+ if ext in IMAGE_EXTENSIONS:
95
+ return "images"
96
+
97
+ return "unknown"
98
+
99
+ return "unknown"
100
+
101
+
102
+ # ============================================================================
103
+ # Common parameters and configuration
104
+ # ============================================================================
105
+
106
+ # ============================================================================
107
+ # Inference commands
108
+ # ============================================================================
109
+
110
+
111
+ @app.command()
112
+ def auto(
113
+ input_path: str = typer.Argument(
114
+ ..., help="Path to input (image, directory, video, or COLMAP)"
115
+ ),
116
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
117
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
118
+ export_format: str = typer.Option("glb", help="Export format"),
119
+ device: str = typer.Option("cuda", help="Device to use"),
120
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
121
+ backend_url: str = typer.Option(
122
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
123
+ ),
124
+ process_res: int = typer.Option(504, help="Processing resolution"),
125
+ process_res_method: str = typer.Option(
126
+ "upper_bound_resize", help="Processing resolution method"
127
+ ),
128
+ export_feat: str = typer.Option(
129
+ "",
130
+ help="[FEAT_VIS]Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
131
+ ),
132
+ auto_cleanup: bool = typer.Option(
133
+ False, help="Automatically clean export directory if it exists (no prompt)"
134
+ ),
135
+ # Video-specific options
136
+ fps: float = typer.Option(1.0, help="[Video] Sampling FPS for frame extraction"),
137
+ # COLMAP-specific options
138
+ sparse_subdir: str = typer.Option(
139
+ "", help="[COLMAP] Sparse reconstruction subdirectory (e.g., '0' for sparse/0/)"
140
+ ),
141
+ align_to_input_ext_scale: bool = typer.Option(
142
+ True, help="[COLMAP] Align prediction to input extrinsics scale"
143
+ ),
144
+ # Pose estimation options
145
+ use_ray_pose: bool = typer.Option(
146
+ False, help="Use ray-based pose estimation instead of camera decoder"
147
+ ),
148
+ ref_view_strategy: str = typer.Option(
149
+ "saddle_balanced",
150
+ help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
151
+ ),
152
+ # GLB export options
153
+ conf_thresh_percentile: float = typer.Option(
154
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
155
+ ),
156
+ num_max_points: int = typer.Option(
157
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
158
+ ),
159
+ show_cameras: bool = typer.Option(
160
+ True, help="[GLB] Show camera wireframes in the exported scene"
161
+ ),
162
+ # Feat_vis export options
163
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
164
+ ):
165
+ """
166
+ Automatically detect input type and run appropriate processing.
167
+
168
+ Supports:
169
+ - Single image file (.jpg, .png, etc.)
170
+ - Directory of images
171
+ - Video file (.mp4, .avi, etc.)
172
+ - COLMAP directory (with 'images' and 'sparse' subdirectories)
173
+ """
174
+ # Detect input type
175
+ input_type = detect_input_type(input_path)
176
+
177
+ if input_type == "unknown":
178
+ typer.echo(f"❌ Error: Cannot determine input type for: {input_path}", err=True)
179
+ typer.echo("Supported inputs:", err=True)
180
+ typer.echo(" - Single image file (.jpg, .png, etc.)", err=True)
181
+ typer.echo(" - Directory containing images", err=True)
182
+ typer.echo(" - Video file (.mp4, .avi, etc.)", err=True)
183
+ typer.echo(" - COLMAP directory (with 'images/' and 'sparse/' subdirectories)", err=True)
184
+ raise typer.Exit(1)
185
+
186
+ # Display detected type
187
+ typer.echo(f"🔍 Detected input type: {input_type.upper()}")
188
+ typer.echo(f"📁 Input path: {input_path}")
189
+ typer.echo()
190
+
191
+ # Determine backend URL based on use_backend flag
192
+ final_backend_url = backend_url if use_backend else None
193
+
194
+ # Parse export_feat parameter
195
+ export_feat_layers = parse_export_feat(export_feat)
196
+
197
+ # Route to appropriate handler
198
+ if input_type == "image":
199
+ typer.echo("Processing single image...")
200
+ # Process input
201
+ image_files = ImageHandler.process(input_path)
202
+
203
+ # Handle export directory
204
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
205
+
206
+ # Run inference
207
+ run_inference(
208
+ image_paths=image_files,
209
+ export_dir=export_dir,
210
+ model_dir=model_dir,
211
+ device=device,
212
+ backend_url=final_backend_url,
213
+ export_format=export_format,
214
+ process_res=process_res,
215
+ process_res_method=process_res_method,
216
+ export_feat_layers=export_feat_layers,
217
+ use_ray_pose=use_ray_pose,
218
+ ref_view_strategy=ref_view_strategy,
219
+ conf_thresh_percentile=conf_thresh_percentile,
220
+ num_max_points=num_max_points,
221
+ show_cameras=show_cameras,
222
+ feat_vis_fps=feat_vis_fps,
223
+ )
224
+
225
+ elif input_type == "images":
226
+ typer.echo("Processing directory of images...")
227
+ # Process input - use default extensions
228
+ image_files = ImagesHandler.process(input_path, "png,jpg,jpeg")
229
+
230
+ # Handle export directory
231
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
232
+
233
+ # Run inference
234
+ run_inference(
235
+ image_paths=image_files,
236
+ export_dir=export_dir,
237
+ model_dir=model_dir,
238
+ device=device,
239
+ backend_url=final_backend_url,
240
+ export_format=export_format,
241
+ process_res=process_res,
242
+ process_res_method=process_res_method,
243
+ export_feat_layers=export_feat_layers,
244
+ use_ray_pose=use_ray_pose,
245
+ ref_view_strategy=ref_view_strategy,
246
+ conf_thresh_percentile=conf_thresh_percentile,
247
+ num_max_points=num_max_points,
248
+ show_cameras=show_cameras,
249
+ feat_vis_fps=feat_vis_fps,
250
+ )
251
+
252
+ elif input_type == "video":
253
+ typer.echo(f"Processing video with FPS={fps}...")
254
+ # Handle export directory
255
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
256
+
257
+ # Process input
258
+ image_files = VideoHandler.process(input_path, export_dir, fps)
259
+
260
+ # Run inference
261
+ run_inference(
262
+ image_paths=image_files,
263
+ export_dir=export_dir,
264
+ model_dir=model_dir,
265
+ device=device,
266
+ backend_url=final_backend_url,
267
+ export_format=export_format,
268
+ process_res=process_res,
269
+ process_res_method=process_res_method,
270
+ export_feat_layers=export_feat_layers,
271
+ use_ray_pose=use_ray_pose,
272
+ ref_view_strategy=ref_view_strategy,
273
+ conf_thresh_percentile=conf_thresh_percentile,
274
+ num_max_points=num_max_points,
275
+ show_cameras=show_cameras,
276
+ feat_vis_fps=feat_vis_fps,
277
+ )
278
+
279
+ elif input_type == "colmap":
280
+ typer.echo(
281
+ f"Processing COLMAP directory (sparse subdirectory: '{sparse_subdir or 'default'}')..."
282
+ )
283
+ # Process input
284
+ image_files, extrinsics, intrinsics = ColmapHandler.process(input_path, sparse_subdir)
285
+
286
+ # Handle export directory
287
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
288
+
289
+ # Run inference
290
+ run_inference(
291
+ image_paths=image_files,
292
+ export_dir=export_dir,
293
+ model_dir=model_dir,
294
+ device=device,
295
+ backend_url=final_backend_url,
296
+ export_format=export_format,
297
+ process_res=process_res,
298
+ process_res_method=process_res_method,
299
+ export_feat_layers=export_feat_layers,
300
+ extrinsics=extrinsics,
301
+ intrinsics=intrinsics,
302
+ align_to_input_ext_scale=align_to_input_ext_scale,
303
+ use_ray_pose=use_ray_pose,
304
+ ref_view_strategy=ref_view_strategy,
305
+ conf_thresh_percentile=conf_thresh_percentile,
306
+ num_max_points=num_max_points,
307
+ show_cameras=show_cameras,
308
+ feat_vis_fps=feat_vis_fps,
309
+ )
310
+
311
+ typer.echo()
312
+ typer.echo("✅ Processing completed successfully!")
313
+
314
+
315
+ @app.command()
316
+ def image(
317
+ image_path: str = typer.Argument(..., help="Path to input image file"),
318
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
319
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
320
+ export_format: str = typer.Option("glb", help="Export format"),
321
+ device: str = typer.Option("cuda", help="Device to use"),
322
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
323
+ backend_url: str = typer.Option(
324
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
325
+ ),
326
+ process_res: int = typer.Option(504, help="Processing resolution"),
327
+ process_res_method: str = typer.Option(
328
+ "upper_bound_resize", help="Processing resolution method"
329
+ ),
330
+ export_feat: str = typer.Option(
331
+ "",
332
+ help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
333
+ ),
334
+ auto_cleanup: bool = typer.Option(
335
+ False, help="Automatically clean export directory if it exists (no prompt)"
336
+ ),
337
+ # Pose estimation options
338
+ use_ray_pose: bool = typer.Option(
339
+ False, help="Use ray-based pose estimation instead of camera decoder"
340
+ ),
341
+ ref_view_strategy: str = typer.Option(
342
+ "saddle_balanced",
343
+ help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
344
+ ),
345
+ # GLB export options
346
+ conf_thresh_percentile: float = typer.Option(
347
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
348
+ ),
349
+ num_max_points: int = typer.Option(
350
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
351
+ ),
352
+ show_cameras: bool = typer.Option(
353
+ True, help="[GLB] Show camera wireframes in the exported scene"
354
+ ),
355
+ # Feat_vis export options
356
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
357
+ ):
358
+ """Run camera pose and depth estimation on a single image."""
359
+ # Process input
360
+ image_files = ImageHandler.process(image_path)
361
+
362
+ # Handle export directory
363
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
364
+
365
+ # Parse export_feat parameter
366
+ export_feat_layers = parse_export_feat(export_feat)
367
+
368
+ # Determine backend URL based on use_backend flag
369
+ final_backend_url = backend_url if use_backend else None
370
+
371
+ # Run inference
372
+ run_inference(
373
+ image_paths=image_files,
374
+ export_dir=export_dir,
375
+ model_dir=model_dir,
376
+ device=device,
377
+ backend_url=final_backend_url,
378
+ export_format=export_format,
379
+ process_res=process_res,
380
+ process_res_method=process_res_method,
381
+ export_feat_layers=export_feat_layers,
382
+ use_ray_pose=use_ray_pose,
383
+ reference_view_strategy=reference_view_strategy,
384
+ conf_thresh_percentile=conf_thresh_percentile,
385
+ num_max_points=num_max_points,
386
+ show_cameras=show_cameras,
387
+ feat_vis_fps=feat_vis_fps,
388
+ )
389
+
390
+
391
+ @app.command()
392
+ def images(
393
+ images_dir: str = typer.Argument(..., help="Path to directory containing input images"),
394
+ image_extensions: str = typer.Option(
395
+ "png,jpg,jpeg", help="Comma-separated image file extensions to process"
396
+ ),
397
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
398
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
399
+ export_format: str = typer.Option("glb", help="Export format"),
400
+ device: str = typer.Option("cuda", help="Device to use"),
401
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
402
+ backend_url: str = typer.Option(
403
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
404
+ ),
405
+ process_res: int = typer.Option(504, help="Processing resolution"),
406
+ process_res_method: str = typer.Option(
407
+ "upper_bound_resize", help="Processing resolution method"
408
+ ),
409
+ export_feat: str = typer.Option(
410
+ "",
411
+ help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
412
+ ),
413
+ auto_cleanup: bool = typer.Option(
414
+ False, help="Automatically clean export directory if it exists (no prompt)"
415
+ ),
416
+ # Pose estimation options
417
+ use_ray_pose: bool = typer.Option(
418
+ False, help="Use ray-based pose estimation instead of camera decoder"
419
+ ),
420
+ ref_view_strategy: str = typer.Option(
421
+ "saddle_balanced",
422
+ help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
423
+ ),
424
+ # GLB export options
425
+ conf_thresh_percentile: float = typer.Option(
426
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
427
+ ),
428
+ num_max_points: int = typer.Option(
429
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
430
+ ),
431
+ show_cameras: bool = typer.Option(
432
+ True, help="[GLB] Show camera wireframes in the exported scene"
433
+ ),
434
+ # Feat_vis export options
435
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
436
+ ):
437
+ """Run camera pose and depth estimation on a directory of images."""
438
+ # Process input
439
+ image_files = ImagesHandler.process(images_dir, image_extensions)
440
+
441
+ # Handle export directory
442
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
443
+
444
+ # Parse export_feat parameter
445
+ export_feat_layers = parse_export_feat(export_feat)
446
+
447
+ # Determine backend URL based on use_backend flag
448
+ final_backend_url = backend_url if use_backend else None
449
+
450
+ # Run inference
451
+ run_inference(
452
+ image_paths=image_files,
453
+ export_dir=export_dir,
454
+ model_dir=model_dir,
455
+ device=device,
456
+ backend_url=final_backend_url,
457
+ export_format=export_format,
458
+ process_res=process_res,
459
+ process_res_method=process_res_method,
460
+ export_feat_layers=export_feat_layers,
461
+ use_ray_pose=use_ray_pose,
462
+ reference_view_strategy=reference_view_strategy,
463
+ conf_thresh_percentile=conf_thresh_percentile,
464
+ num_max_points=num_max_points,
465
+ show_cameras=show_cameras,
466
+ feat_vis_fps=feat_vis_fps,
467
+ )
468
+
469
+
470
+ @app.command()
471
+ def colmap(
472
+ colmap_dir: str = typer.Argument(
473
+ ..., help="Path to COLMAP directory containing 'images' and 'sparse' subdirectories"
474
+ ),
475
+ sparse_subdir: str = typer.Option(
476
+ "", help="Sparse reconstruction subdirectory (e.g., '0' for sparse/0/, empty for sparse/)"
477
+ ),
478
+ align_to_input_ext_scale: bool = typer.Option(
479
+ True, help="Align prediction to input extrinsics scale"
480
+ ),
481
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
482
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
483
+ export_format: str = typer.Option("glb", help="Export format"),
484
+ device: str = typer.Option("cuda", help="Device to use"),
485
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
486
+ backend_url: str = typer.Option(
487
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
488
+ ),
489
+ process_res: int = typer.Option(504, help="Processing resolution"),
490
+ process_res_method: str = typer.Option(
491
+ "upper_bound_resize", help="Processing resolution method"
492
+ ),
493
+ export_feat: str = typer.Option(
494
+ "",
495
+ help="Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
496
+ ),
497
+ auto_cleanup: bool = typer.Option(
498
+ False, help="Automatically clean export directory if it exists (no prompt)"
499
+ ),
500
+ # Pose estimation options
501
+ use_ray_pose: bool = typer.Option(
502
+ False, help="Use ray-based pose estimation instead of camera decoder"
503
+ ),
504
+ ref_view_strategy: str = typer.Option(
505
+ "saddle_balanced",
506
+ help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
507
+ ),
508
+ # GLB export options
509
+ conf_thresh_percentile: float = typer.Option(
510
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
511
+ ),
512
+ num_max_points: int = typer.Option(
513
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
514
+ ),
515
+ show_cameras: bool = typer.Option(
516
+ True, help="[GLB] Show camera wireframes in the exported scene"
517
+ ),
518
+ # Feat_vis export options
519
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
520
+ ):
521
+ """Run pose conditioned depth estimation on COLMAP data."""
522
+ # Process input
523
+ image_files, extrinsics, intrinsics = ColmapHandler.process(colmap_dir, sparse_subdir)
524
+
525
+ # Handle export directory
526
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
527
+
528
+ # Parse export_feat parameter
529
+ export_feat_layers = parse_export_feat(export_feat)
530
+
531
+ # Determine backend URL based on use_backend flag
532
+ final_backend_url = backend_url if use_backend else None
533
+
534
+ # Run inference
535
+ run_inference(
536
+ image_paths=image_files,
537
+ export_dir=export_dir,
538
+ model_dir=model_dir,
539
+ device=device,
540
+ backend_url=final_backend_url,
541
+ export_format=export_format,
542
+ process_res=process_res,
543
+ process_res_method=process_res_method,
544
+ export_feat_layers=export_feat_layers,
545
+ extrinsics=extrinsics,
546
+ intrinsics=intrinsics,
547
+ align_to_input_ext_scale=align_to_input_ext_scale,
548
+ use_ray_pose=use_ray_pose,
549
+ reference_view_strategy=reference_view_strategy,
550
+ conf_thresh_percentile=conf_thresh_percentile,
551
+ num_max_points=num_max_points,
552
+ show_cameras=show_cameras,
553
+ feat_vis_fps=feat_vis_fps,
554
+ )
555
+
556
+
557
+ @app.command()
558
+ def video(
559
+ video_path: str = typer.Argument(..., help="Path to input video file"),
560
+ fps: float = typer.Option(1.0, help="Sampling FPS for frame extraction"),
561
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
562
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
563
+ export_format: str = typer.Option("glb", help="Export format"),
564
+ device: str = typer.Option("cuda", help="Device to use"),
565
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
566
+ backend_url: str = typer.Option(
567
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
568
+ ),
569
+ process_res: int = typer.Option(504, help="Processing resolution"),
570
+ process_res_method: str = typer.Option(
571
+ "upper_bound_resize", help="Processing resolution method"
572
+ ),
573
+ export_feat: str = typer.Option(
574
+ "",
575
+ help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
576
+ ),
577
+ auto_cleanup: bool = typer.Option(
578
+ False, help="Automatically clean export directory if it exists (no prompt)"
579
+ ),
580
+ # Pose estimation options
581
+ use_ray_pose: bool = typer.Option(
582
+ False, help="Use ray-based pose estimation instead of camera decoder"
583
+ ),
584
+ ref_view_strategy: str = typer.Option(
585
+ "saddle_balanced",
586
+ help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
587
+ ),
588
+ # GLB export options
589
+ conf_thresh_percentile: float = typer.Option(
590
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
591
+ ),
592
+ num_max_points: int = typer.Option(
593
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
594
+ ),
595
+ show_cameras: bool = typer.Option(
596
+ True, help="[GLB] Show camera wireframes in the exported scene"
597
+ ),
598
+ # Feat_vis export options
599
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
600
+ ):
601
+ """Run depth estimation on video by extracting frames and processing them."""
602
+ # Handle export directory
603
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
604
+
605
+ # Process input
606
+ image_files = VideoHandler.process(video_path, export_dir, fps)
607
+
608
+ # Parse export_feat parameter
609
+ export_feat_layers = parse_export_feat(export_feat)
610
+
611
+ # Determine backend URL based on use_backend flag
612
+ final_backend_url = backend_url if use_backend else None
613
+
614
+ # Run inference
615
+ run_inference(
616
+ image_paths=image_files,
617
+ export_dir=export_dir,
618
+ model_dir=model_dir,
619
+ device=device,
620
+ backend_url=final_backend_url,
621
+ export_format=export_format,
622
+ process_res=process_res,
623
+ process_res_method=process_res_method,
624
+ export_feat_layers=export_feat_layers,
625
+ use_ray_pose=use_ray_pose,
626
+ reference_view_strategy=reference_view_strategy,
627
+ conf_thresh_percentile=conf_thresh_percentile,
628
+ num_max_points=num_max_points,
629
+ show_cameras=show_cameras,
630
+ feat_vis_fps=feat_vis_fps,
631
+ )
632
+
633
+
634
+ # ============================================================================
635
+ # Service management commands
636
+ # ============================================================================
637
+
638
+
639
+ @app.command()
640
+ def backend(
641
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
642
+ device: str = typer.Option("cuda", help="Device to use"),
643
+ host: str = typer.Option("127.0.0.1", help="Host to bind to"),
644
+ port: int = typer.Option(8008, help="Port to bind to"),
645
+ gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path (optional)"),
646
+ ):
647
+ """Start model backend service with integrated gallery."""
648
+ typer.echo("=" * 60)
649
+ typer.echo("🚀 Starting Depth Anything 3 Backend Server")
650
+ typer.echo("=" * 60)
651
+ typer.echo(f"Model directory: {model_dir}")
652
+ typer.echo(f"Device: {device}")
653
+
654
+ # Check if gallery directory exists
655
+ if gallery_dir and os.path.exists(gallery_dir):
656
+ typer.echo(f"Gallery directory: {gallery_dir}")
657
+ else:
658
+ gallery_dir = None # Disable gallery if directory doesn't exist
659
+
660
+ typer.echo()
661
+ typer.echo("📡 Server URLs (Ctrl/CMD+Click to open):")
662
+ typer.echo(f" 🏠 Home: http://{host}:{port}")
663
+ typer.echo(f" 📊 Dashboard: http://{host}:{port}/dashboard")
664
+ typer.echo(f" 📈 API Status: http://{host}:{port}/status")
665
+
666
+ if gallery_dir:
667
+ typer.echo(f" 🎨 Gallery: http://{host}:{port}/gallery/")
668
+
669
+ typer.echo("=" * 60)
670
+
671
+ try:
672
+ start_server(model_dir, device, host, port, gallery_dir)
673
+ except KeyboardInterrupt:
674
+ typer.echo("\n👋 Backend server stopped.")
675
+ except Exception as e:
676
+ typer.echo(f"❌ Failed to start backend: {e}")
677
+ raise typer.Exit(1)
678
+
679
+
680
+ # ============================================================================
681
+ # Application launch commands
682
+ # ============================================================================
683
+
684
+
685
+ @app.command()
686
+ def gradio(
687
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
688
+ workspace_dir: str = typer.Option(DEFAULT_GRADIO_DIR, help="Workspace directory path"),
689
+ gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path"),
690
+ host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
691
+ port: int = typer.Option(7860, help="Port number to bind to"),
692
+ share: bool = typer.Option(False, help="Create a public link for the app"),
693
+ debug: bool = typer.Option(False, help="Enable debug mode"),
694
+ cache_examples: bool = typer.Option(
695
+ False, help="Pre-cache all example scenes at startup for faster loading"
696
+ ),
697
+ cache_gs_tag: str = typer.Option(
698
+ "",
699
+ help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.",
700
+ ),
701
+ ):
702
+ """Launch Depth Anything 3 Gradio interactive web application"""
703
+ from depth_anything_3.app.gradio_app import DepthAnything3App
704
+
705
+ # Create necessary directories
706
+ os.makedirs(workspace_dir, exist_ok=True)
707
+ os.makedirs(gallery_dir, exist_ok=True)
708
+
709
+ typer.echo("Launching Depth Anything 3 Gradio application...")
710
+ typer.echo(f"Model directory: {model_dir}")
711
+ typer.echo(f"Workspace directory: {workspace_dir}")
712
+ typer.echo(f"Gallery directory: {gallery_dir}")
713
+ typer.echo(f"Host: {host}")
714
+ typer.echo(f"Port: {port}")
715
+ typer.echo(f"Share: {share}")
716
+ typer.echo(f"Debug mode: {debug}")
717
+ typer.echo(f"Cache examples: {cache_examples}")
718
+ if cache_examples:
719
+ if cache_gs_tag:
720
+ typer.echo(
721
+ f"Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)"
722
+ )
723
+ else:
724
+ typer.echo(f"Cache GS Tag: None (all scenes will use low-res only)")
725
+
726
+ try:
727
+ # Initialize and launch application
728
+ app = DepthAnything3App(
729
+ model_dir=model_dir, workspace_dir=workspace_dir, gallery_dir=gallery_dir
730
+ )
731
+
732
+ # Pre-cache examples if requested
733
+ if cache_examples:
734
+ typer.echo("\n" + "=" * 60)
735
+ typer.echo("Pre-caching mode enabled")
736
+ if cache_gs_tag:
737
+ typer.echo(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
738
+ typer.echo(f"Other scenes will use LOW-RES only")
739
+ else:
740
+ typer.echo(f"All scenes will use LOW-RES only")
741
+ typer.echo("=" * 60)
742
+ app.cache_examples(
743
+ show_cam=True,
744
+ filter_black_bg=False,
745
+ filter_white_bg=False,
746
+ save_percentage=20.0,
747
+ num_max_points=1000,
748
+ cache_gs_tag=cache_gs_tag,
749
+ gs_trj_mode="smooth",
750
+ gs_video_quality="low",
751
+ )
752
+
753
+ # Prepare launch arguments
754
+ launch_kwargs = {"share": share, "debug": debug}
755
+
756
+ app.launch(host=host, port=port, **launch_kwargs)
757
+
758
+ except KeyboardInterrupt:
759
+ typer.echo("\nGradio application stopped.")
760
+ except Exception as e:
761
+ typer.echo(f"Failed to launch Gradio application: {e}")
762
+ raise typer.Exit(1)
763
+
764
+
765
+ @app.command()
766
+ def gallery(
767
+ gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery root directory"),
768
+ host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
769
+ port: int = typer.Option(8007, help="Port number to bind to"),
770
+ open_browser: bool = typer.Option(False, help="Open browser after launch"),
771
+ ):
772
+ """Launch Depth Anything 3 Gallery server"""
773
+
774
+ # Validate gallery directory
775
+ if not os.path.exists(gallery_dir):
776
+ raise typer.BadParameter(f"Gallery directory not found: {gallery_dir}")
777
+
778
+ typer.echo("Launching Depth Anything 3 Gallery server...")
779
+ typer.echo(f"Gallery directory: {gallery_dir}")
780
+ typer.echo(f"Host: {host}")
781
+ typer.echo(f"Port: {port}")
782
+ typer.echo(f"Auto-open browser: {open_browser}")
783
+
784
+ try:
785
+ # Set command line arguments
786
+ import sys
787
+
788
+ sys.argv = ["gallery", "--dir", gallery_dir, "--host", host, "--port", str(port)]
789
+ if open_browser:
790
+ sys.argv.append("--open")
791
+
792
+ # Launch gallery server
793
+ gallery_main()
794
+
795
+ except KeyboardInterrupt:
796
+ typer.echo("\nGallery server stopped.")
797
+ except Exception as e:
798
+ typer.echo(f"Failed to launch Gallery server: {e}")
799
+ raise typer.Exit(1)
800
+
801
+
802
+ if __name__ == "__main__":
803
+ app()
Depth-Anything-3/src/depth_anything_3/configs/da3-base.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitb
13
+ out_layers: [5, 7, 9, 11]
14
+ alt_start: 4
15
+ qknorm_start: 4
16
+ rope_start: 4
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 1536
26
+ output_dim: 2
27
+ features: &head_features 128
28
+ out_channels: &head_out_channels [96, 192, 384, 768]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 768
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 1536
Depth-Anything-3/src/depth_anything_3/configs/da3-giant.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitg
13
+ out_layers: [19, 27, 33, 39]
14
+ alt_start: 13
15
+ qknorm_start: 13
16
+ rope_start: 13
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 3072
26
+ output_dim: 2
27
+ features: &head_features 256
28
+ out_channels: &head_out_channels [256, 512, 1024, 1024]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 1536
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 3072
46
+
47
+
48
+ gs_head:
49
+ __object__:
50
+ path: depth_anything_3.model.gsdpt
51
+ name: GSDPT
52
+ args: as_params
53
+
54
+ dim_in: *head_dim_in
55
+ output_dim: 38 # should align with gs_adapter's setting, for gs params
56
+ features: *head_features
57
+ out_channels: *head_out_channels
58
+
59
+
60
+ gs_adapter:
61
+ __object__:
62
+ path: depth_anything_3.model.gs_adapter
63
+ name: GaussianAdapter
64
+ args: as_params
65
+
66
+ sh_degree: 2
67
+ pred_color: false # predict SH coefficient if false
68
+ pred_offset_depth: true
69
+ pred_offset_xy: true
70
+ gaussian_scale_min: 1e-5
71
+ gaussian_scale_max: 30.0
Depth-Anything-3/src/depth_anything_3/configs/da3-large.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitl
13
+ out_layers: [11, 15, 19, 23]
14
+ alt_start: 8
15
+ qknorm_start: 8
16
+ rope_start: 8
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 2048
26
+ output_dim: 2
27
+ features: &head_features 256
28
+ out_channels: &head_out_channels [256, 512, 1024, 1024]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 1024
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 2048
Depth-Anything-3/src/depth_anything_3/configs/da3-small.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vits
13
+ out_layers: [5, 7, 9, 11]
14
+ alt_start: 4
15
+ qknorm_start: 4
16
+ rope_start: 4
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 768
26
+ output_dim: 2
27
+ features: &head_features 64
28
+ out_channels: &head_out_channels [48, 96, 192, 384]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 384
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 768
Depth-Anything-3/src/depth_anything_3/configs/da3metric-large.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitl
13
+ out_layers: [4, 11, 17, 23]
14
+ alt_start: -1 # -1 means disable
15
+ qknorm_start: -1
16
+ rope_start: -1
17
+ cat_token: False
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dpt
22
+ name: DPT
23
+ args: as_params
24
+
25
+ dim_in: 1024
26
+ output_dim: 1
27
+ features: 256
28
+ out_channels: [256, 512, 1024, 1024]
Depth-Anything-3/src/depth_anything_3/configs/da3mono-large.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitl
13
+ out_layers: [4, 11, 17, 23]
14
+ alt_start: -1 # -1 means disable
15
+ qknorm_start: -1
16
+ rope_start: -1
17
+ cat_token: False
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dpt
22
+ name: DPT
23
+ args: as_params
24
+
25
+ dim_in: 1024
26
+ output_dim: 1
27
+ features: 256
28
+ out_channels: [256, 512, 1024, 1024]
Depth-Anything-3/src/depth_anything_3/configs/da3nested-giant-large.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: NestedDepthAnything3Net
4
+ args: as_params
5
+
6
+ anyview:
7
+ __inherit__: depth_anything_3.configs.da3-giant
8
+
9
+ metric:
10
+ __inherit__: depth_anything_3.configs.da3metric-large
Depth-Anything-3/src/depth_anything_3/model/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net
16
+
17
+ __export__ = [
18
+ NestedDepthAnything3Net,
19
+ DepthAnything3Net,
20
+ ]
Depth-Anything-3/src/depth_anything_3/model/cam_dec.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class CameraDec(nn.Module):
20
+ def __init__(self, dim_in=1536):
21
+ super().__init__()
22
+ output_dim = dim_in
23
+ self.backbone = nn.Sequential(
24
+ nn.Linear(output_dim, output_dim),
25
+ nn.ReLU(),
26
+ nn.Linear(output_dim, output_dim),
27
+ nn.ReLU(),
28
+ )
29
+ self.fc_t = nn.Linear(output_dim, 3)
30
+ self.fc_qvec = nn.Linear(output_dim, 4)
31
+ self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU())
32
+
33
+ def forward(self, feat, camera_encoding=None, *args, **kwargs):
34
+ B, N = feat.shape[:2]
35
+ feat = feat.reshape(B * N, -1)
36
+ feat = self.backbone(feat)
37
+ out_t = self.fc_t(feat.float()).reshape(B, N, 3)
38
+ if camera_encoding is None:
39
+ out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
40
+ out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
41
+ else:
42
+ out_qvec = camera_encoding[..., 3:7]
43
+ out_fov = camera_encoding[..., -2:]
44
+ pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1)
45
+ return pose_enc