Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Depth-Anything-3/da3_streaming/loop_utils/__init__.py +15 -0
- Depth-Anything-3/da3_streaming/loop_utils/alignment_torch.py +395 -0
- Depth-Anything-3/da3_streaming/loop_utils/alignment_triton.py +543 -0
- Depth-Anything-3/da3_streaming/loop_utils/config_utils.py +66 -0
- Depth-Anything-3/da3_streaming/loop_utils/logging_utils.py +32 -0
- Depth-Anything-3/da3_streaming/loop_utils/loop_detector.py +391 -0
- Depth-Anything-3/da3_streaming/loop_utils/loop_refinement.py +268 -0
- Depth-Anything-3/da3_streaming/loop_utils/sim3loop.py +399 -0
- Depth-Anything-3/da3_streaming/loop_utils/sim3utils.py +1261 -0
- Depth-Anything-3/da3_streaming/scripts/download_weights.sh +20 -0
- Depth-Anything-3/docs/API.md +465 -0
- Depth-Anything-3/docs/BENCHMARK.md +484 -0
- Depth-Anything-3/docs/CLI.md +654 -0
- Depth-Anything-3/docs/funcs/ref_view_strategy.md +183 -0
- Depth-Anything-3/notebooks/da3.ipynb +0 -0
- Depth-Anything-3/src/depth_anything_3/api.py +446 -0
- Depth-Anything-3/src/depth_anything_3/app/css_and_html.py +594 -0
- Depth-Anything-3/src/depth_anything_3/app/gradio_app.py +724 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/__init__.py +43 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/event_handlers.py +619 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/file_handlers.py +304 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/model_inference.py +260 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/ui_components.py +477 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/utils.py +207 -0
- Depth-Anything-3/src/depth_anything_3/app/modules/visualization.py +434 -0
- Depth-Anything-3/src/depth_anything_3/bench/__init__.py +45 -0
- Depth-Anything-3/src/depth_anything_3/bench/configs/eval_bench.yaml +98 -0
- Depth-Anything-3/src/depth_anything_3/bench/dataset.py +136 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/__init__.py +21 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu.py +681 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu64.py +182 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/eth3d.py +594 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/hiroom.py +440 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/scannetpp.py +591 -0
- Depth-Anything-3/src/depth_anything_3/bench/datasets/sevenscenes.py +449 -0
- Depth-Anything-3/src/depth_anything_3/bench/evaluator.py +752 -0
- Depth-Anything-3/src/depth_anything_3/bench/print_metrics.py +618 -0
- Depth-Anything-3/src/depth_anything_3/bench/registries.py +85 -0
- Depth-Anything-3/src/depth_anything_3/bench/utils.py +525 -0
- Depth-Anything-3/src/depth_anything_3/cfg.py +144 -0
- Depth-Anything-3/src/depth_anything_3/cli.py +803 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3-base.yaml +45 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3-giant.yaml +71 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3-large.yaml +45 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3-small.yaml +45 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3metric-large.yaml +28 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3mono-large.yaml +28 -0
- Depth-Anything-3/src/depth_anything_3/configs/da3nested-giant-large.yaml +10 -0
- Depth-Anything-3/src/depth_anything_3/model/__init__.py +20 -0
- Depth-Anything-3/src/depth_anything_3/model/cam_dec.py +45 -0
Depth-Anything-3/da3_streaming/loop_utils/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
Depth-Anything-3/da3_streaming/loop_utils/alignment_torch.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def weighted_estimate_se3_torch(source_points, target_points, weights):
|
| 22 |
+
source_points = torch.from_numpy(source_points).cuda().float()
|
| 23 |
+
target_points = torch.from_numpy(target_points).cuda().float()
|
| 24 |
+
weights = torch.from_numpy(weights).cuda().float()
|
| 25 |
+
|
| 26 |
+
total_weight = torch.sum(weights)
|
| 27 |
+
if total_weight < 1e-6:
|
| 28 |
+
return (
|
| 29 |
+
1.0,
|
| 30 |
+
np.zeros(3, dtype=np.float32),
|
| 31 |
+
np.zeros(3, dtype=np.float32),
|
| 32 |
+
np.zeros((3, 3), dtype=np.float32),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
normalized_weights = weights / total_weight
|
| 36 |
+
|
| 37 |
+
mu_src = torch.sum(normalized_weights[:, None] * source_points, dim=0)
|
| 38 |
+
mu_tgt = torch.sum(normalized_weights[:, None] * target_points, dim=0)
|
| 39 |
+
|
| 40 |
+
src_centered = source_points - mu_src
|
| 41 |
+
tgt_centered = target_points - mu_tgt
|
| 42 |
+
|
| 43 |
+
weighted_src = src_centered * torch.sqrt(normalized_weights)[:, None]
|
| 44 |
+
weighted_tgt = tgt_centered * torch.sqrt(normalized_weights)[:, None]
|
| 45 |
+
|
| 46 |
+
H = weighted_src.T @ weighted_tgt
|
| 47 |
+
|
| 48 |
+
return 1.0, mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def weighted_estimate_sim3_torch(source_points, target_points, weights):
|
| 52 |
+
|
| 53 |
+
source_points = torch.from_numpy(source_points).cuda().float()
|
| 54 |
+
target_points = torch.from_numpy(target_points).cuda().float()
|
| 55 |
+
weights = torch.from_numpy(weights).cuda().float()
|
| 56 |
+
|
| 57 |
+
total_weight = torch.sum(weights)
|
| 58 |
+
if total_weight < 1e-6:
|
| 59 |
+
return (
|
| 60 |
+
-1.0,
|
| 61 |
+
np.zeros(3, dtype=np.float32),
|
| 62 |
+
np.zeros(3, dtype=np.float32),
|
| 63 |
+
np.zeros((3, 3), dtype=np.float32),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
normalized_weights = weights / total_weight
|
| 67 |
+
|
| 68 |
+
mu_src = torch.sum(normalized_weights[:, None] * source_points, dim=0)
|
| 69 |
+
mu_tgt = torch.sum(normalized_weights[:, None] * target_points, dim=0)
|
| 70 |
+
|
| 71 |
+
src_centered = source_points - mu_src
|
| 72 |
+
tgt_centered = target_points - mu_tgt
|
| 73 |
+
|
| 74 |
+
scale_src = torch.sqrt(torch.sum(normalized_weights * torch.sum(src_centered**2, dim=1)))
|
| 75 |
+
scale_tgt = torch.sqrt(torch.sum(normalized_weights * torch.sum(tgt_centered**2, dim=1)))
|
| 76 |
+
s = scale_tgt / scale_src
|
| 77 |
+
|
| 78 |
+
weighted_src = (s * src_centered) * torch.sqrt(normalized_weights)[:, None]
|
| 79 |
+
weighted_tgt = tgt_centered * torch.sqrt(normalized_weights)[:, None]
|
| 80 |
+
|
| 81 |
+
H = weighted_src.T @ weighted_tgt
|
| 82 |
+
|
| 83 |
+
return s.cpu().numpy(), mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def weighted_estimate_sim3_numba_torch(source_points, target_points, weights, align_method="sim3"):
|
| 87 |
+
|
| 88 |
+
if align_method == "sim3":
|
| 89 |
+
s, mu_src, mu_tgt, H = weighted_estimate_sim3_torch(source_points, target_points, weights)
|
| 90 |
+
elif align_method == "se3" or align_method == "scale+se3":
|
| 91 |
+
s, mu_src, mu_tgt, H = weighted_estimate_se3_torch(source_points, target_points, weights)
|
| 92 |
+
|
| 93 |
+
if s < 0:
|
| 94 |
+
raise ValueError("Total weight too small for meaningful estimation")
|
| 95 |
+
|
| 96 |
+
H_torch = torch.from_numpy(H).cuda().float()
|
| 97 |
+
U, _, Vt = torch.linalg.svd(H_torch)
|
| 98 |
+
|
| 99 |
+
U = U.cpu().numpy()
|
| 100 |
+
Vt = Vt.cpu().numpy()
|
| 101 |
+
|
| 102 |
+
R = Vt.T @ U.T
|
| 103 |
+
if np.linalg.det(R) < 0:
|
| 104 |
+
Vt[2, :] *= -1
|
| 105 |
+
R = Vt.T @ U.T
|
| 106 |
+
|
| 107 |
+
mu_src = mu_src.astype(np.float32)
|
| 108 |
+
mu_tgt = mu_tgt.astype(np.float32)
|
| 109 |
+
R = R.astype(np.float32)
|
| 110 |
+
|
| 111 |
+
if align_method == "se3" or align_method == "scale+se3":
|
| 112 |
+
t = mu_tgt - R @ mu_src
|
| 113 |
+
else:
|
| 114 |
+
t = mu_tgt - s * R @ mu_src
|
| 115 |
+
|
| 116 |
+
return s, R, t.astype(np.float32)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def huber_loss_torch(r, delta):
|
| 120 |
+
|
| 121 |
+
r_torch = torch.from_numpy(r).cuda().float()
|
| 122 |
+
delta_torch = torch.tensor(delta, device="cuda", dtype=torch.float32)
|
| 123 |
+
|
| 124 |
+
abs_r = torch.abs(r_torch)
|
| 125 |
+
result = torch.where(
|
| 126 |
+
abs_r <= delta_torch, 0.5 * r_torch**2, delta_torch * (abs_r - 0.5 * delta_torch)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return result.cpu().numpy()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_residuals_torch(tgt, transformed):
|
| 133 |
+
|
| 134 |
+
tgt_torch = torch.from_numpy(tgt).cuda().float()
|
| 135 |
+
transformed_torch = torch.from_numpy(transformed).cuda().float()
|
| 136 |
+
|
| 137 |
+
residuals = torch.sqrt(torch.sum((tgt_torch - transformed_torch) ** 2, dim=1))
|
| 138 |
+
return residuals.cpu().numpy()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def compute_huber_weights_torch(residuals, delta):
|
| 142 |
+
|
| 143 |
+
residuals_torch = torch.from_numpy(residuals).cuda().float()
|
| 144 |
+
delta_torch = torch.tensor(delta, device="cuda", dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
weights = torch.ones_like(residuals_torch)
|
| 147 |
+
mask = residuals_torch > delta_torch
|
| 148 |
+
weights[mask] = delta_torch / residuals_torch[mask]
|
| 149 |
+
|
| 150 |
+
return weights.cpu().numpy()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def apply_transformation_torch(src, s, R, t):
|
| 154 |
+
|
| 155 |
+
src_torch = torch.from_numpy(src).cuda().float()
|
| 156 |
+
R_torch = torch.from_numpy(R).cuda().float()
|
| 157 |
+
t_torch = torch.from_numpy(t).cuda().float()
|
| 158 |
+
s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
|
| 159 |
+
|
| 160 |
+
transformed = s_torch * (src_torch @ R_torch.T) + t_torch
|
| 161 |
+
return transformed.cpu().numpy()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def robust_weighted_estimate_sim3_torch(
|
| 165 |
+
src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
|
| 166 |
+
):
|
| 167 |
+
|
| 168 |
+
src = src.astype(np.float32)
|
| 169 |
+
tgt = tgt.astype(np.float32)
|
| 170 |
+
init_weights = init_weights.astype(np.float32)
|
| 171 |
+
|
| 172 |
+
s, R, t = weighted_estimate_sim3_numba_torch(src, tgt, init_weights, align_method=align_method)
|
| 173 |
+
|
| 174 |
+
prev_error = float("inf")
|
| 175 |
+
|
| 176 |
+
for iter in range(max_iters):
|
| 177 |
+
transformed = apply_transformation_torch(src, s, R, t)
|
| 178 |
+
residuals = compute_residuals_torch(tgt, transformed)
|
| 179 |
+
|
| 180 |
+
print(f"Iter {iter}: Mean residual = {np.mean(residuals):.6f}")
|
| 181 |
+
|
| 182 |
+
huber_weights = compute_huber_weights_torch(residuals, delta)
|
| 183 |
+
combined_weights = init_weights * huber_weights
|
| 184 |
+
combined_weights /= np.sum(combined_weights) + 1e-12
|
| 185 |
+
|
| 186 |
+
s_new, R_new, t_new = weighted_estimate_sim3_numba_torch(
|
| 187 |
+
src, tgt, combined_weights, align_method=align_method
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
|
| 191 |
+
rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
|
| 192 |
+
|
| 193 |
+
current_error = np.sum(huber_loss_torch(residuals, delta) * init_weights)
|
| 194 |
+
|
| 195 |
+
if (param_change < tol and rot_angle < np.radians(0.1)) or (
|
| 196 |
+
abs(prev_error - current_error) < tol * prev_error
|
| 197 |
+
):
|
| 198 |
+
print(f"Converged at iteration {iter}")
|
| 199 |
+
break
|
| 200 |
+
|
| 201 |
+
s, R, t = s_new, R_new, t_new
|
| 202 |
+
prev_error = current_error
|
| 203 |
+
|
| 204 |
+
return s, R, t
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def apply_sim3_direct_torch(point_maps, s, R, t, device=None):
|
| 208 |
+
"""
|
| 209 |
+
PyTorch SIM3
|
| 210 |
+
point_maps: (b, h, w, 3) numpy array
|
| 211 |
+
s: scalar or (b,) array
|
| 212 |
+
R: (3, 3) or (b, 3, 3) numpy array
|
| 213 |
+
t: (3,) or (b, 3) numpy array
|
| 214 |
+
"""
|
| 215 |
+
if isinstance(point_maps, np.ndarray):
|
| 216 |
+
point_maps_torch = torch.from_numpy(point_maps).float()
|
| 217 |
+
R_torch = torch.from_numpy(R).float()
|
| 218 |
+
t_torch = torch.from_numpy(t).float()
|
| 219 |
+
s_torch = torch.tensor(s).float() if np.isscalar(s) else torch.from_numpy(s).float()
|
| 220 |
+
else:
|
| 221 |
+
point_maps_torch = point_maps
|
| 222 |
+
R_torch = R
|
| 223 |
+
t_torch = t
|
| 224 |
+
s_torch = s
|
| 225 |
+
|
| 226 |
+
if device is not None:
|
| 227 |
+
point_maps_torch = point_maps_torch.to(device)
|
| 228 |
+
R_torch = R_torch.to(device)
|
| 229 |
+
t_torch = t_torch.to(device)
|
| 230 |
+
s_torch = s_torch.to(device)
|
| 231 |
+
|
| 232 |
+
b, h, w, c = point_maps_torch.shape
|
| 233 |
+
|
| 234 |
+
points_flat = point_maps_torch.reshape(b, -1, 3) # (b, h*w, 3)
|
| 235 |
+
|
| 236 |
+
if R_torch.dim() == 2:
|
| 237 |
+
R_torch = R_torch.unsqueeze(0).expand(b, 3, 3) # (b, 3, 3)
|
| 238 |
+
|
| 239 |
+
if t_torch.dim() == 1:
|
| 240 |
+
t_torch = t_torch.unsqueeze(0).expand(b, 3) # (b, 3)
|
| 241 |
+
|
| 242 |
+
if s_torch.dim() == 0:
|
| 243 |
+
s_torch = s_torch.unsqueeze(0).expand(b) # (b,)
|
| 244 |
+
|
| 245 |
+
rotated_flat = torch.bmm(points_flat, R_torch.transpose(1, 2)) # (b, h*w, 3)
|
| 246 |
+
|
| 247 |
+
transformed_flat = s_torch[:, None, None] * rotated_flat + t_torch[:, None, :]
|
| 248 |
+
|
| 249 |
+
transformed = transformed_flat.reshape(b, h, w, 3)
|
| 250 |
+
|
| 251 |
+
if isinstance(point_maps, np.ndarray):
|
| 252 |
+
return transformed.cpu().numpy()
|
| 253 |
+
return transformed
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def depth_to_point_cloud_optimized_torch(depth, intrinsics, extrinsics, device=None):
|
| 257 |
+
|
| 258 |
+
input_is_numpy = isinstance(depth, np.ndarray)
|
| 259 |
+
|
| 260 |
+
if input_is_numpy:
|
| 261 |
+
depth_tensor = torch.from_numpy(depth).float()
|
| 262 |
+
intrinsics_tensor = torch.from_numpy(intrinsics).float()
|
| 263 |
+
extrinsics_tensor = torch.from_numpy(extrinsics).float()
|
| 264 |
+
else:
|
| 265 |
+
depth_tensor = depth
|
| 266 |
+
intrinsics_tensor = intrinsics
|
| 267 |
+
extrinsics_tensor = extrinsics
|
| 268 |
+
|
| 269 |
+
if device is not None:
|
| 270 |
+
depth_tensor = depth_tensor.to(device)
|
| 271 |
+
intrinsics_tensor = intrinsics_tensor.to(device)
|
| 272 |
+
extrinsics_tensor = extrinsics_tensor.to(device)
|
| 273 |
+
|
| 274 |
+
N, H, W = depth_tensor.shape
|
| 275 |
+
device = depth_tensor.device
|
| 276 |
+
|
| 277 |
+
u = torch.arange(W, device=device, dtype=torch.float32).view(1, 1, W)
|
| 278 |
+
v = torch.arange(H, device=device, dtype=torch.float32).view(1, H, 1)
|
| 279 |
+
|
| 280 |
+
u_expanded = u.expand(N, H, W)
|
| 281 |
+
v_expanded = v.expand(N, H, W)
|
| 282 |
+
|
| 283 |
+
ones = torch.ones((N, H, W), device=device)
|
| 284 |
+
pixel_coords = torch.stack([u_expanded, v_expanded, ones], dim=-1) # [N, H, W, 3]
|
| 285 |
+
|
| 286 |
+
intrinsics_inv = torch.inverse(intrinsics_tensor) # [N, 3, 3]
|
| 287 |
+
|
| 288 |
+
camera_coords = torch.einsum("nij,nhwj->nhwi", intrinsics_inv, pixel_coords)
|
| 289 |
+
|
| 290 |
+
camera_coords = camera_coords * depth_tensor.unsqueeze(-1) # [N, H, W, 3]
|
| 291 |
+
|
| 292 |
+
camera_coords_homo = torch.cat(
|
| 293 |
+
[camera_coords, torch.ones((N, H, W, 1), device=device)], dim=-1
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
extrinsics_4x4 = torch.zeros(N, 4, 4, device=device)
|
| 297 |
+
extrinsics_4x4[:, :3, :4] = extrinsics_tensor
|
| 298 |
+
extrinsics_4x4[:, 3, 3] = 1.0
|
| 299 |
+
|
| 300 |
+
c2w = torch.inverse(extrinsics_4x4) # [N, 4, 4]
|
| 301 |
+
|
| 302 |
+
world_coords_homo = torch.einsum("nij,nhwj->nhwi", c2w, camera_coords_homo)
|
| 303 |
+
point_cloud_world = world_coords_homo[..., :3] # [N, H, W, 3]
|
| 304 |
+
|
| 305 |
+
if input_is_numpy:
|
| 306 |
+
return point_cloud_world.cpu().numpy()
|
| 307 |
+
return point_cloud_world
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def warmup_torch():
|
| 311 |
+
|
| 312 |
+
print("\nWarming up PyTorch alignment...")
|
| 313 |
+
|
| 314 |
+
src = np.random.randn(100000, 3).astype(np.float32)
|
| 315 |
+
tgt = np.random.randn(100000, 3).astype(np.float32)
|
| 316 |
+
weights = np.ones(100000, dtype=np.float32)
|
| 317 |
+
residuals = np.abs(np.random.randn(100000).astype(np.float32))
|
| 318 |
+
R = np.eye(3, dtype=np.float32)
|
| 319 |
+
t = np.zeros(3, dtype=np.float32)
|
| 320 |
+
s = np.float32(1.0)
|
| 321 |
+
delta = np.float32(1.0)
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
_ = weighted_estimate_sim3_torch(src, tgt, weights)
|
| 325 |
+
print(" - weighted_estimate_sim3_torch warmed up.")
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(" ! Failed to warm up weighted_estimate_sim3_torch:", e)
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
_ = weighted_estimate_se3_torch(src, tgt, weights)
|
| 331 |
+
print(" - weighted_estimate_se3_torch warmed up.")
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(" ! Failed to warm up weighted_estimate_se3_torch:", e)
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
_ = huber_loss_torch(residuals, delta)
|
| 337 |
+
print(" - huber_loss_torch warmed up.")
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(" ! Failed to warm up huber_loss_torch:", e)
|
| 340 |
+
|
| 341 |
+
try:
|
| 342 |
+
_ = compute_huber_weights_torch(residuals, delta)
|
| 343 |
+
print(" - compute_huber_weights_torch warmed up.")
|
| 344 |
+
except Exception as e:
|
| 345 |
+
print(" ! Failed to warm up compute_huber_weights_torch:", e)
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
_ = compute_residuals_torch(tgt, src)
|
| 349 |
+
print(" - compute_residuals_torch warmed up.")
|
| 350 |
+
except Exception as e:
|
| 351 |
+
print(" ! Failed to warm up compute_residuals_torch:", e)
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
_ = apply_transformation_torch(src, s, R, t)
|
| 355 |
+
print(" - apply_transformation_torch warmed up.")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
print(" ! Failed to warm up apply_transformation_torch:", e)
|
| 358 |
+
|
| 359 |
+
print("PyTorch warm-up complete.\n")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def print_gpu_memory():
|
| 363 |
+
if torch.cuda.is_available():
|
| 364 |
+
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
| 365 |
+
cached = torch.cuda.memory_reserved() / 1024**3 # GB
|
| 366 |
+
print(f"GPU Memory Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
|
| 371 |
+
warmup_torch()
|
| 372 |
+
|
| 373 |
+
n_points = 7_500_000
|
| 374 |
+
src = np.random.randn(n_points, 3).astype(np.float32)
|
| 375 |
+
|
| 376 |
+
true_R = np.array([[0.866, -0.5, 0], [0.5, 0.866, 0], [0, 0, 1]], dtype=np.float32)
|
| 377 |
+
true_t = np.array([1.0, 2.0, 0.5], dtype=np.float32)
|
| 378 |
+
true_s = 1.2
|
| 379 |
+
|
| 380 |
+
tgt = true_s * (src @ true_R.T) + true_t
|
| 381 |
+
tgt += 0.01 * np.random.randn(*tgt.shape).astype(np.float32)
|
| 382 |
+
|
| 383 |
+
weights = np.ones(n_points, dtype=np.float32)
|
| 384 |
+
|
| 385 |
+
print_gpu_memory()
|
| 386 |
+
|
| 387 |
+
s, R, t = robust_weighted_estimate_sim3_torch(
|
| 388 |
+
src, tgt, weights, delta=0.1, max_iters=5, align_method="sim3"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
print(f"\nEstimated scale: {s:.6f}")
|
| 392 |
+
print(f"Estimated rotation:\n{R}")
|
| 393 |
+
print(f"Estimated translation: {t}")
|
| 394 |
+
|
| 395 |
+
print_gpu_memory()
|
Depth-Anything-3/da3_streaming/loop_utils/alignment_triton.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import triton
|
| 20 |
+
import triton.language as tl
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@triton.jit
|
| 24 |
+
def apply_transformation_residual_kernel(
|
| 25 |
+
src_ptr, # [n, 3]
|
| 26 |
+
tgt_ptr, # [n, 3]
|
| 27 |
+
transformed_ptr, # [n, 3]
|
| 28 |
+
residuals_ptr, # [n]
|
| 29 |
+
s,
|
| 30 |
+
R00,
|
| 31 |
+
R01,
|
| 32 |
+
R02,
|
| 33 |
+
R10,
|
| 34 |
+
R11,
|
| 35 |
+
R12,
|
| 36 |
+
R20,
|
| 37 |
+
R21,
|
| 38 |
+
R22,
|
| 39 |
+
t0,
|
| 40 |
+
t1,
|
| 41 |
+
t2,
|
| 42 |
+
n_points,
|
| 43 |
+
BLOCK_SIZE: tl.constexpr,
|
| 44 |
+
):
|
| 45 |
+
pid = tl.program_id(0)
|
| 46 |
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 47 |
+
mask = offsets < n_points
|
| 48 |
+
|
| 49 |
+
src_x = tl.load(src_ptr + offsets * 3 + 0, mask=mask)
|
| 50 |
+
src_y = tl.load(src_ptr + offsets * 3 + 1, mask=mask)
|
| 51 |
+
src_z = tl.load(src_ptr + offsets * 3 + 2, mask=mask)
|
| 52 |
+
|
| 53 |
+
tgt_x = tl.load(tgt_ptr + offsets * 3 + 0, mask=mask)
|
| 54 |
+
tgt_y = tl.load(tgt_ptr + offsets * 3 + 1, mask=mask)
|
| 55 |
+
tgt_z = tl.load(tgt_ptr + offsets * 3 + 2, mask=mask)
|
| 56 |
+
|
| 57 |
+
# transformed = s * (R @ p) + t
|
| 58 |
+
transformed_x = s * (R00 * src_x + R01 * src_y + R02 * src_z) + t0
|
| 59 |
+
transformed_y = s * (R10 * src_x + R11 * src_y + R12 * src_z) + t1
|
| 60 |
+
transformed_z = s * (R20 * src_x + R21 * src_y + R22 * src_z) + t2
|
| 61 |
+
|
| 62 |
+
tl.store(transformed_ptr + offsets * 3 + 0, transformed_x, mask=mask)
|
| 63 |
+
tl.store(transformed_ptr + offsets * 3 + 1, transformed_y, mask=mask)
|
| 64 |
+
tl.store(transformed_ptr + offsets * 3 + 2, transformed_z, mask=mask)
|
| 65 |
+
|
| 66 |
+
dx = tgt_x - transformed_x
|
| 67 |
+
dy = tgt_y - transformed_y
|
| 68 |
+
dz = tgt_z - transformed_z
|
| 69 |
+
residual = tl.sqrt(dx * dx + dy * dy + dz * dz)
|
| 70 |
+
tl.store(residuals_ptr + offsets, residual, mask=mask)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@triton.jit
|
| 74 |
+
def weighted_covariance_kernel(
|
| 75 |
+
src_ptr, # [n, 3]
|
| 76 |
+
tgt_ptr, # [n, 3]
|
| 77 |
+
weights_ptr, # [n]
|
| 78 |
+
mu_src0,
|
| 79 |
+
mu_src1,
|
| 80 |
+
mu_src2,
|
| 81 |
+
mu_tgt0,
|
| 82 |
+
mu_tgt1,
|
| 83 |
+
mu_tgt2,
|
| 84 |
+
H_ptr, # [3, 3]
|
| 85 |
+
n_points,
|
| 86 |
+
BLOCK_SIZE: tl.constexpr,
|
| 87 |
+
):
|
| 88 |
+
pid = tl.program_id(0)
|
| 89 |
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 90 |
+
mask = offsets < n_points
|
| 91 |
+
|
| 92 |
+
w = tl.load(weights_ptr + offsets, mask=mask)
|
| 93 |
+
src_x = tl.load(src_ptr + offsets * 3 + 0, mask=mask)
|
| 94 |
+
src_y = tl.load(src_ptr + offsets * 3 + 1, mask=mask)
|
| 95 |
+
src_z = tl.load(src_ptr + offsets * 3 + 2, mask=mask)
|
| 96 |
+
tgt_x = tl.load(tgt_ptr + offsets * 3 + 0, mask=mask)
|
| 97 |
+
tgt_y = tl.load(tgt_ptr + offsets * 3 + 1, mask=mask)
|
| 98 |
+
tgt_z = tl.load(tgt_ptr + offsets * 3 + 2, mask=mask)
|
| 99 |
+
|
| 100 |
+
src_centered_x = src_x - mu_src0
|
| 101 |
+
src_centered_y = src_y - mu_src1
|
| 102 |
+
src_centered_z = src_z - mu_src2
|
| 103 |
+
|
| 104 |
+
tgt_centered_x = tgt_x - mu_tgt0
|
| 105 |
+
tgt_centered_y = tgt_y - mu_tgt1
|
| 106 |
+
tgt_centered_z = tgt_z - mu_tgt2
|
| 107 |
+
|
| 108 |
+
sqrt_w = tl.sqrt(w)
|
| 109 |
+
weighted_src_x = src_centered_x * sqrt_w
|
| 110 |
+
weighted_src_y = src_centered_y * sqrt_w
|
| 111 |
+
weighted_src_z = src_centered_z * sqrt_w
|
| 112 |
+
|
| 113 |
+
weighted_tgt_x = tgt_centered_x * sqrt_w
|
| 114 |
+
weighted_tgt_y = tgt_centered_y * sqrt_w
|
| 115 |
+
weighted_tgt_z = tgt_centered_z * sqrt_w
|
| 116 |
+
|
| 117 |
+
h00 = weighted_src_x * weighted_tgt_x
|
| 118 |
+
h01 = weighted_src_x * weighted_tgt_y
|
| 119 |
+
h02 = weighted_src_x * weighted_tgt_z
|
| 120 |
+
|
| 121 |
+
h10 = weighted_src_y * weighted_tgt_x
|
| 122 |
+
h11 = weighted_src_y * weighted_tgt_y
|
| 123 |
+
h12 = weighted_src_y * weighted_tgt_z
|
| 124 |
+
|
| 125 |
+
h20 = weighted_src_z * weighted_tgt_x
|
| 126 |
+
h21 = weighted_src_z * weighted_tgt_y
|
| 127 |
+
h22 = weighted_src_z * weighted_tgt_z
|
| 128 |
+
|
| 129 |
+
tl.atomic_add(H_ptr + 0, tl.sum(h00, axis=0))
|
| 130 |
+
tl.atomic_add(H_ptr + 1, tl.sum(h01, axis=0))
|
| 131 |
+
tl.atomic_add(H_ptr + 2, tl.sum(h02, axis=0))
|
| 132 |
+
|
| 133 |
+
tl.atomic_add(H_ptr + 3, tl.sum(h10, axis=0))
|
| 134 |
+
tl.atomic_add(H_ptr + 4, tl.sum(h11, axis=0))
|
| 135 |
+
tl.atomic_add(H_ptr + 5, tl.sum(h12, axis=0))
|
| 136 |
+
|
| 137 |
+
tl.atomic_add(H_ptr + 6, tl.sum(h20, axis=0))
|
| 138 |
+
tl.atomic_add(H_ptr + 7, tl.sum(h21, axis=0))
|
| 139 |
+
tl.atomic_add(H_ptr + 8, tl.sum(h22, axis=0))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@triton.jit
|
| 143 |
+
def compute_huber_weights_kernel(
|
| 144 |
+
residuals_ptr,
|
| 145 |
+
weights_ptr,
|
| 146 |
+
delta,
|
| 147 |
+
n_points,
|
| 148 |
+
BLOCK_SIZE: tl.constexpr,
|
| 149 |
+
):
|
| 150 |
+
pid = tl.program_id(0)
|
| 151 |
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 152 |
+
mask = offsets < n_points
|
| 153 |
+
|
| 154 |
+
r = tl.load(residuals_ptr + offsets, mask=mask)
|
| 155 |
+
|
| 156 |
+
weight = tl.where(r > delta, delta / r, 1.0)
|
| 157 |
+
|
| 158 |
+
tl.store(weights_ptr + offsets, weight, mask=mask)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@triton.jit
|
| 162 |
+
def weighted_mean_kernel(
|
| 163 |
+
points_ptr, # [n, 3]
|
| 164 |
+
weights_ptr, # [n]
|
| 165 |
+
mean_ptr, # [sum(w*x), sum(w*y), sum(w*z), sum(w)]
|
| 166 |
+
n_points,
|
| 167 |
+
BLOCK_SIZE: tl.constexpr,
|
| 168 |
+
):
|
| 169 |
+
pid = tl.program_id(0)
|
| 170 |
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 171 |
+
mask = offsets < n_points
|
| 172 |
+
|
| 173 |
+
w = tl.load(weights_ptr + offsets, mask=mask)
|
| 174 |
+
x = tl.load(points_ptr + offsets * 3 + 0, mask=mask)
|
| 175 |
+
y = tl.load(points_ptr + offsets * 3 + 1, mask=mask)
|
| 176 |
+
z = tl.load(points_ptr + offsets * 3 + 2, mask=mask)
|
| 177 |
+
|
| 178 |
+
wx = w * x
|
| 179 |
+
wy = w * y
|
| 180 |
+
wz = w * z
|
| 181 |
+
|
| 182 |
+
tl.atomic_add(mean_ptr + 0, tl.sum(wx, axis=0))
|
| 183 |
+
tl.atomic_add(mean_ptr + 1, tl.sum(wy, axis=0))
|
| 184 |
+
tl.atomic_add(mean_ptr + 2, tl.sum(wz, axis=0))
|
| 185 |
+
tl.atomic_add(mean_ptr + 3, tl.sum(w, axis=0))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def apply_transformation_residual_triton(src, tgt, s, R, t):
|
| 189 |
+
n_points = src.shape[0]
|
| 190 |
+
|
| 191 |
+
transformed = torch.empty_like(src)
|
| 192 |
+
residuals = torch.empty(n_points, device=src.device, dtype=src.dtype)
|
| 193 |
+
|
| 194 |
+
BLOCK_SIZE = 256
|
| 195 |
+
grid = (triton.cdiv(n_points, BLOCK_SIZE),)
|
| 196 |
+
|
| 197 |
+
R_flat = R.contiguous().view(-1)
|
| 198 |
+
t_flat = t.contiguous().view(-1)
|
| 199 |
+
|
| 200 |
+
apply_transformation_residual_kernel[grid](
|
| 201 |
+
src,
|
| 202 |
+
tgt,
|
| 203 |
+
transformed,
|
| 204 |
+
residuals,
|
| 205 |
+
float(s),
|
| 206 |
+
float(R_flat[0]),
|
| 207 |
+
float(R_flat[1]),
|
| 208 |
+
float(R_flat[2]),
|
| 209 |
+
float(R_flat[3]),
|
| 210 |
+
float(R_flat[4]),
|
| 211 |
+
float(R_flat[5]),
|
| 212 |
+
float(R_flat[6]),
|
| 213 |
+
float(R_flat[7]),
|
| 214 |
+
float(R_flat[8]),
|
| 215 |
+
float(t_flat[0]),
|
| 216 |
+
float(t_flat[1]),
|
| 217 |
+
float(t_flat[2]),
|
| 218 |
+
n_points,
|
| 219 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return transformed, residuals
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def compute_weighted_mean_triton(points, weights):
|
| 226 |
+
n_points = points.shape[0]
|
| 227 |
+
|
| 228 |
+
# [sum(w*x), sum(w*y), sum(w*z), sum(w)]
|
| 229 |
+
mean_buffer = torch.zeros(4, device=points.device, dtype=points.dtype)
|
| 230 |
+
|
| 231 |
+
BLOCK_SIZE = 256
|
| 232 |
+
grid = (triton.cdiv(n_points, BLOCK_SIZE),)
|
| 233 |
+
|
| 234 |
+
weighted_mean_kernel[grid](points, weights, mean_buffer, n_points, BLOCK_SIZE=BLOCK_SIZE)
|
| 235 |
+
|
| 236 |
+
total_weight = mean_buffer[3]
|
| 237 |
+
if total_weight > 1e-12:
|
| 238 |
+
mean = mean_buffer[:3] / total_weight
|
| 239 |
+
else:
|
| 240 |
+
mean = torch.zeros(3, device=points.device, dtype=points.dtype)
|
| 241 |
+
|
| 242 |
+
return mean, total_weight
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def compute_weighted_covariance_triton(src, tgt, weights, mu_src, mu_tgt):
|
| 246 |
+
n_points = src.shape[0]
|
| 247 |
+
|
| 248 |
+
H = torch.zeros(9, device=src.device, dtype=src.dtype)
|
| 249 |
+
|
| 250 |
+
BLOCK_SIZE = 256
|
| 251 |
+
grid = (triton.cdiv(n_points, BLOCK_SIZE),)
|
| 252 |
+
|
| 253 |
+
mu_src_flat = mu_src.contiguous().view(-1)
|
| 254 |
+
mu_tgt_flat = mu_tgt.contiguous().view(-1)
|
| 255 |
+
|
| 256 |
+
weighted_covariance_kernel[grid](
|
| 257 |
+
src,
|
| 258 |
+
tgt,
|
| 259 |
+
weights,
|
| 260 |
+
float(mu_src_flat[0]),
|
| 261 |
+
float(mu_src_flat[1]),
|
| 262 |
+
float(mu_src_flat[2]),
|
| 263 |
+
float(mu_tgt_flat[0]),
|
| 264 |
+
float(mu_tgt_flat[1]),
|
| 265 |
+
float(mu_tgt_flat[2]),
|
| 266 |
+
H,
|
| 267 |
+
n_points,
|
| 268 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return H.reshape(3, 3)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def compute_huber_weights_triton(residuals, delta):
|
| 275 |
+
n_points = residuals.shape[0]
|
| 276 |
+
weights = torch.empty_like(residuals)
|
| 277 |
+
|
| 278 |
+
BLOCK_SIZE = 256
|
| 279 |
+
grid = (triton.cdiv(n_points, BLOCK_SIZE),)
|
| 280 |
+
|
| 281 |
+
compute_huber_weights_kernel[grid](
|
| 282 |
+
residuals, weights, float(delta), n_points, BLOCK_SIZE=BLOCK_SIZE
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return weights
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def weighted_estimate_se3_triton(source_points, target_points, weights):
|
| 289 |
+
|
| 290 |
+
source_points = torch.from_numpy(source_points).cuda().float()
|
| 291 |
+
target_points = torch.from_numpy(target_points).cuda().float()
|
| 292 |
+
weights = torch.from_numpy(weights).cuda().float()
|
| 293 |
+
|
| 294 |
+
total_weight = torch.sum(weights)
|
| 295 |
+
if total_weight < 1e-6:
|
| 296 |
+
return (
|
| 297 |
+
1.0,
|
| 298 |
+
np.zeros(3, dtype=np.float32),
|
| 299 |
+
np.zeros(3, dtype=np.float32),
|
| 300 |
+
np.zeros((3, 3), dtype=np.float32),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
normalized_weights = weights / total_weight
|
| 304 |
+
|
| 305 |
+
mu_src, _ = compute_weighted_mean_triton(source_points, normalized_weights)
|
| 306 |
+
mu_tgt, _ = compute_weighted_mean_triton(target_points, normalized_weights)
|
| 307 |
+
|
| 308 |
+
H = compute_weighted_covariance_triton(
|
| 309 |
+
source_points, target_points, normalized_weights, mu_src, mu_tgt
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return 1.0, mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def weighted_estimate_sim3_triton(source_points, target_points, weights):
|
| 316 |
+
|
| 317 |
+
source_points = torch.from_numpy(source_points).cuda().float()
|
| 318 |
+
target_points = torch.from_numpy(target_points).cuda().float()
|
| 319 |
+
weights = torch.from_numpy(weights).cuda().float()
|
| 320 |
+
|
| 321 |
+
total_weight = torch.sum(weights)
|
| 322 |
+
if total_weight < 1e-6:
|
| 323 |
+
return (
|
| 324 |
+
-1.0,
|
| 325 |
+
np.zeros(3, dtype=np.float32),
|
| 326 |
+
np.zeros(3, dtype=np.float32),
|
| 327 |
+
np.zeros((3, 3), dtype=np.float32),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
normalized_weights = weights / total_weight
|
| 331 |
+
|
| 332 |
+
mu_src, _ = compute_weighted_mean_triton(source_points, normalized_weights)
|
| 333 |
+
mu_tgt, _ = compute_weighted_mean_triton(target_points, normalized_weights)
|
| 334 |
+
|
| 335 |
+
src_centered = source_points - mu_src
|
| 336 |
+
tgt_centered = target_points - mu_tgt
|
| 337 |
+
|
| 338 |
+
scale_src = torch.sqrt(torch.sum(normalized_weights * torch.sum(src_centered**2, dim=1)))
|
| 339 |
+
scale_tgt = torch.sqrt(torch.sum(normalized_weights * torch.sum(tgt_centered**2, dim=1)))
|
| 340 |
+
s = scale_tgt / scale_src
|
| 341 |
+
|
| 342 |
+
weighted_src = s * src_centered
|
| 343 |
+
H = compute_weighted_covariance_triton(
|
| 344 |
+
weighted_src,
|
| 345 |
+
tgt_centered,
|
| 346 |
+
normalized_weights,
|
| 347 |
+
torch.zeros_like(mu_src),
|
| 348 |
+
torch.zeros_like(mu_tgt),
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return s.cpu().numpy(), mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def weighted_estimate_sim3_numba_triton(
|
| 355 |
+
source_points, target_points, weights, align_method="sim3"
|
| 356 |
+
):
|
| 357 |
+
|
| 358 |
+
if align_method == "sim3":
|
| 359 |
+
s, mu_src, mu_tgt, H = weighted_estimate_sim3_triton(source_points, target_points, weights)
|
| 360 |
+
elif align_method == "se3" or align_method == "scale+se3":
|
| 361 |
+
s, mu_src, mu_tgt, H = weighted_estimate_se3_triton(source_points, target_points, weights)
|
| 362 |
+
|
| 363 |
+
if s < 0:
|
| 364 |
+
raise ValueError("Total weight too small for meaningful estimation")
|
| 365 |
+
|
| 366 |
+
H_torch = torch.from_numpy(H).cuda().float()
|
| 367 |
+
U, _, Vt = torch.linalg.svd(H_torch)
|
| 368 |
+
|
| 369 |
+
U = U.cpu().numpy()
|
| 370 |
+
Vt = Vt.cpu().numpy()
|
| 371 |
+
|
| 372 |
+
R = Vt.T @ U.T
|
| 373 |
+
if np.linalg.det(R) < 0:
|
| 374 |
+
Vt[2, :] *= -1
|
| 375 |
+
R = Vt.T @ U.T
|
| 376 |
+
|
| 377 |
+
mu_src = mu_src.astype(np.float32)
|
| 378 |
+
mu_tgt = mu_tgt.astype(np.float32)
|
| 379 |
+
R = R.astype(np.float32)
|
| 380 |
+
|
| 381 |
+
if align_method == "se3" or align_method == "scale+se3":
|
| 382 |
+
t = mu_tgt - R @ mu_src
|
| 383 |
+
else:
|
| 384 |
+
t = mu_tgt - s * R @ mu_src
|
| 385 |
+
|
| 386 |
+
return s, R, t.astype(np.float32)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def robust_weighted_estimate_sim3_triton(
|
| 390 |
+
src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
|
| 391 |
+
):
|
| 392 |
+
|
| 393 |
+
src = src.astype(np.float32)
|
| 394 |
+
tgt = tgt.astype(np.float32)
|
| 395 |
+
init_weights = init_weights.astype(np.float32)
|
| 396 |
+
|
| 397 |
+
src_torch = torch.from_numpy(src).cuda().float()
|
| 398 |
+
tgt_torch = torch.from_numpy(tgt).cuda().float()
|
| 399 |
+
init_weights_torch = torch.from_numpy(init_weights).cuda().float()
|
| 400 |
+
|
| 401 |
+
s, R, t = weighted_estimate_sim3_numba_triton(
|
| 402 |
+
src, tgt, init_weights, align_method=align_method
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
R_torch = torch.from_numpy(R).cuda().float()
|
| 406 |
+
t_torch = torch.from_numpy(t).cuda().float()
|
| 407 |
+
s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
|
| 408 |
+
|
| 409 |
+
prev_error = float("inf")
|
| 410 |
+
|
| 411 |
+
for iter in range(max_iters):
|
| 412 |
+
transformed, residuals = apply_transformation_residual_triton(
|
| 413 |
+
src_torch, tgt_torch, s_torch, R_torch, t_torch
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
mean_residual = torch.mean(residuals).cpu().numpy()
|
| 417 |
+
print(f"Iter {iter}: Mean residual = {mean_residual:.6f}")
|
| 418 |
+
|
| 419 |
+
huber_weights = compute_huber_weights_triton(residuals, delta)
|
| 420 |
+
|
| 421 |
+
combined_weights = init_weights_torch * huber_weights
|
| 422 |
+
combined_weights_sum = torch.sum(combined_weights)
|
| 423 |
+
if combined_weights_sum > 1e-12:
|
| 424 |
+
combined_weights /= combined_weights_sum
|
| 425 |
+
else:
|
| 426 |
+
combined_weights = init_weights_torch / torch.sum(init_weights_torch)
|
| 427 |
+
|
| 428 |
+
combined_weights_np = combined_weights.cpu().numpy()
|
| 429 |
+
s_new, R_new, t_new = weighted_estimate_sim3_numba_triton(
|
| 430 |
+
src, tgt, combined_weights_np, align_method=align_method
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
|
| 434 |
+
rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
|
| 435 |
+
|
| 436 |
+
residuals_np = residuals.cpu().numpy()
|
| 437 |
+
huber_loss_values = np.where(
|
| 438 |
+
residuals_np <= delta, 0.5 * residuals_np**2, delta * (residuals_np - 0.5 * delta)
|
| 439 |
+
)
|
| 440 |
+
current_error = np.sum(huber_loss_values * init_weights)
|
| 441 |
+
|
| 442 |
+
if (param_change < tol and rot_angle < np.radians(0.1)) or (
|
| 443 |
+
abs(prev_error - current_error) < tol * prev_error
|
| 444 |
+
):
|
| 445 |
+
print(f"Converged at iteration {iter}")
|
| 446 |
+
break
|
| 447 |
+
|
| 448 |
+
s, R, t = s_new, R_new, t_new
|
| 449 |
+
s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
|
| 450 |
+
R_torch = torch.from_numpy(R).cuda().float()
|
| 451 |
+
t_torch = torch.from_numpy(t).cuda().float()
|
| 452 |
+
prev_error = current_error
|
| 453 |
+
|
| 454 |
+
return s, R, t
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def warmup_triton():
|
| 458 |
+
print("\nWarming up Triton functions...")
|
| 459 |
+
|
| 460 |
+
n_points = 10000
|
| 461 |
+
src = np.random.randn(n_points, 3).astype(np.float32)
|
| 462 |
+
tgt = np.random.randn(n_points, 3).astype(np.float32)
|
| 463 |
+
weights = np.ones(n_points, dtype=np.float32)
|
| 464 |
+
|
| 465 |
+
src_torch = torch.from_numpy(src).cuda().float()
|
| 466 |
+
tgt_torch = torch.from_numpy(tgt).cuda().float()
|
| 467 |
+
weights_torch = torch.from_numpy(weights).cuda().float()
|
| 468 |
+
|
| 469 |
+
R = np.eye(3, dtype=np.float32)
|
| 470 |
+
t = np.zeros(3, dtype=np.float32)
|
| 471 |
+
s = np.float32(1.0)
|
| 472 |
+
delta = np.float32(0.1)
|
| 473 |
+
|
| 474 |
+
R_torch = torch.from_numpy(R).cuda().float()
|
| 475 |
+
t_torch = torch.from_numpy(t).cuda().float()
|
| 476 |
+
s_torch = torch.tensor(s, device="cuda", dtype=torch.float32)
|
| 477 |
+
|
| 478 |
+
try:
|
| 479 |
+
_, _ = apply_transformation_residual_triton(
|
| 480 |
+
src_torch, tgt_torch, s_torch, R_torch, t_torch
|
| 481 |
+
)
|
| 482 |
+
print(" - apply_transformation_residual_triton warmed up.")
|
| 483 |
+
except Exception as e:
|
| 484 |
+
print(f" ! Failed to warm up apply_transformation_residual_triton: {e}")
|
| 485 |
+
|
| 486 |
+
try:
|
| 487 |
+
_, _ = compute_weighted_mean_triton(src_torch, weights_torch)
|
| 488 |
+
print(" - compute_weighted_mean_triton warmed up.")
|
| 489 |
+
except Exception as e:
|
| 490 |
+
print(f" ! Failed to warm up compute_weighted_mean_triton: {e}")
|
| 491 |
+
|
| 492 |
+
try:
|
| 493 |
+
mu_src, _ = compute_weighted_mean_triton(src_torch, weights_torch)
|
| 494 |
+
mu_tgt, _ = compute_weighted_mean_triton(tgt_torch, weights_torch)
|
| 495 |
+
_ = compute_weighted_covariance_triton(src_torch, tgt_torch, weights_torch, mu_src, mu_tgt)
|
| 496 |
+
print(" - compute_weighted_covariance_triton warmed up.")
|
| 497 |
+
except Exception as e:
|
| 498 |
+
print(f" ! Failed to warm up compute_weighted_covariance_triton: {e}")
|
| 499 |
+
|
| 500 |
+
try:
|
| 501 |
+
residuals = torch.abs(torch.randn(n_points, device="cuda", dtype=torch.float32))
|
| 502 |
+
_ = compute_huber_weights_triton(residuals, delta)
|
| 503 |
+
print(" - compute_huber_weights_triton warmed up.")
|
| 504 |
+
except Exception as e:
|
| 505 |
+
print(f" ! Failed to warm up compute_huber_weights_triton: {e}")
|
| 506 |
+
|
| 507 |
+
print("Triton warm-up complete.\n")
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def print_gpu_memory():
|
| 511 |
+
if torch.cuda.is_available():
|
| 512 |
+
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
| 513 |
+
cached = torch.cuda.memory_reserved() / 1024**3 # GB
|
| 514 |
+
print(f"GPU Memory Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
if __name__ == "__main__":
|
| 518 |
+
|
| 519 |
+
warmup_triton()
|
| 520 |
+
|
| 521 |
+
n_points = 7_500_000
|
| 522 |
+
src = np.random.randn(n_points, 3).astype(np.float32)
|
| 523 |
+
|
| 524 |
+
true_R = np.array([[0.866, -0.5, 0], [0.5, 0.866, 0], [0, 0, 1]], dtype=np.float32)
|
| 525 |
+
true_t = np.array([1.0, 2.0, 0.5], dtype=np.float32)
|
| 526 |
+
true_s = 1.2
|
| 527 |
+
|
| 528 |
+
tgt = true_s * (src @ true_R.T) + true_t
|
| 529 |
+
tgt += 0.01 * np.random.randn(*tgt.shape).astype(np.float32)
|
| 530 |
+
|
| 531 |
+
weights = np.ones(n_points, dtype=np.float32)
|
| 532 |
+
|
| 533 |
+
print_gpu_memory()
|
| 534 |
+
|
| 535 |
+
s, R, t = robust_weighted_estimate_sim3_triton(
|
| 536 |
+
src, tgt, weights, delta=0.1, max_iters=5, align_method="sim3"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
print(f"\nEstimated scale: {s:.6f}")
|
| 540 |
+
print(f"Estimated rotation:\n{R}")
|
| 541 |
+
print(f"Estimated translation: {t}")
|
| 542 |
+
|
| 543 |
+
print_gpu_memory()
|
Depth-Anything-3/da3_streaming/loop_utils/config_utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import yaml
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_config(path, default_path=None):
|
| 21 |
+
"""
|
| 22 |
+
Loads config file.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
path (str): path to config file.
|
| 26 |
+
default_path (str, optional): whether to use default path. Defaults to None.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
cfg (dict): config dict.
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
# load configuration from per scene/dataset cfg.
|
| 33 |
+
with open(path) as f:
|
| 34 |
+
cfg_special = yaml.full_load(f)
|
| 35 |
+
|
| 36 |
+
inherit_from = cfg_special.get("inherit_from")
|
| 37 |
+
|
| 38 |
+
if inherit_from is not None:
|
| 39 |
+
cfg = load_config(inherit_from, default_path)
|
| 40 |
+
elif default_path is not None:
|
| 41 |
+
with open(default_path) as f:
|
| 42 |
+
cfg = yaml.full_load(f)
|
| 43 |
+
else:
|
| 44 |
+
cfg = dict()
|
| 45 |
+
|
| 46 |
+
# merge per dataset cfg. and main cfg.
|
| 47 |
+
update_recursive(cfg, cfg_special)
|
| 48 |
+
|
| 49 |
+
return cfg
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def update_recursive(dict1, dict2):
|
| 53 |
+
"""
|
| 54 |
+
Update two config dictionaries recursively. dict1 get masked by dict2, and we retuen dict1.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
dict1 (dict): first dictionary to be updated.
|
| 58 |
+
dict2 (dict): second dictionary which entries should be used.
|
| 59 |
+
"""
|
| 60 |
+
for k, v in dict2.items():
|
| 61 |
+
if k not in dict1:
|
| 62 |
+
dict1[k] = dict()
|
| 63 |
+
if isinstance(v, dict):
|
| 64 |
+
update_recursive(dict1[k], v)
|
| 65 |
+
else:
|
| 66 |
+
dict1[k] = v
|
Depth-Anything-3/da3_streaming/loop_utils/logging_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import rich
|
| 18 |
+
|
| 19 |
+
_log_styles = {
|
| 20 |
+
"DA3-Streaming": "bold green",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_style(tag):
|
| 25 |
+
if tag in _log_styles.keys():
|
| 26 |
+
return _log_styles[tag]
|
| 27 |
+
return "bold blue"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def Log(*args, tag="DA3-Streaming"):
|
| 31 |
+
style = get_style(tag)
|
| 32 |
+
rich.print(f"[{style}]{tag}:[/{style}]", *args)
|
Depth-Anything-3/da3_streaming/loop_utils/loop_detector.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import faiss
|
| 22 |
+
import torch
|
| 23 |
+
import torchvision.transforms as T
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from torch import nn
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
+
SALAD_ROOT = os.path.join(CURRENT_DIR, "salad")
|
| 30 |
+
if SALAD_ROOT not in sys.path:
|
| 31 |
+
sys.path.insert(0, SALAD_ROOT)
|
| 32 |
+
from loop_utils.salad.models import helper
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class VPRModel(nn.Module):
|
| 36 |
+
"""This is the main model for Visual Place Recognition
|
| 37 |
+
we use Pytorch Lightning for modularity purposes.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
pl (_type_): _description_
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
# ---- Backbone
|
| 46 |
+
backbone_arch="resnet50",
|
| 47 |
+
backbone_config={},
|
| 48 |
+
# ---- Aggregator
|
| 49 |
+
agg_arch="ConvAP",
|
| 50 |
+
agg_config={},
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
# Backbone
|
| 55 |
+
self.encoder_arch = backbone_arch
|
| 56 |
+
self.backbone_config = backbone_config
|
| 57 |
+
|
| 58 |
+
# Aggregator
|
| 59 |
+
self.agg_arch = agg_arch
|
| 60 |
+
self.agg_config = agg_config
|
| 61 |
+
|
| 62 |
+
# ----------------------------------
|
| 63 |
+
# get the backbone and the aggregator
|
| 64 |
+
self.backbone = helper.get_backbone(backbone_arch, backbone_config)
|
| 65 |
+
self.aggregator = helper.get_aggregator(agg_arch, agg_config)
|
| 66 |
+
|
| 67 |
+
# the forward pass of the lightning model
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = self.backbone(x)
|
| 70 |
+
x = self.aggregator(x)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LoopDetector:
|
| 75 |
+
"""Loop detector class for detecting loop closures in image sequences"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, image_dir, output="loop_closures.txt", config=None):
|
| 78 |
+
"""Initialize the loop detector
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image_dir: Directory path containing images
|
| 82 |
+
ckpt_path: Model checkpoint path
|
| 83 |
+
image_size: Image resize dimensions [height width]
|
| 84 |
+
batch_size: Batch size for processing
|
| 85 |
+
similarity_threshold: Similarity threshold for loop closure
|
| 86 |
+
top_k: Number of nearest neighbors to check for each image
|
| 87 |
+
use_nms: Whether to use Non-Maximum Suppression (NMS) filtering
|
| 88 |
+
nms_threshold: NMS threshold for minimum frame difference between loop pairs
|
| 89 |
+
output: Output file path
|
| 90 |
+
"""
|
| 91 |
+
self.config = config
|
| 92 |
+
self.image_dir = image_dir
|
| 93 |
+
self.ckpt_path = self.config["Weights"]["SALAD"]
|
| 94 |
+
self.image_size = self.config["Loop"]["SALAD"]["image_size"]
|
| 95 |
+
self.batch_size = self.config["Loop"]["SALAD"]["batch_size"]
|
| 96 |
+
self.similarity_threshold = self.config["Loop"]["SALAD"]["similarity_threshold"]
|
| 97 |
+
self.top_k = self.config["Loop"]["SALAD"]["top_k"]
|
| 98 |
+
self.use_nms = self.config["Loop"]["SALAD"]["use_nms"]
|
| 99 |
+
self.nms_threshold = self.config["Loop"]["SALAD"]["nms_threshold"]
|
| 100 |
+
self.output = output
|
| 101 |
+
|
| 102 |
+
self.model = None
|
| 103 |
+
self.device = None
|
| 104 |
+
self.image_paths = None
|
| 105 |
+
self.descriptors = None
|
| 106 |
+
self.loop_closures = None
|
| 107 |
+
|
| 108 |
+
def _input_transform(self, image_size=None):
|
| 109 |
+
"""Create image transformation function"""
|
| 110 |
+
MEAN = [0.485, 0.456, 0.406]
|
| 111 |
+
STD = [0.229, 0.224, 0.225]
|
| 112 |
+
if image_size:
|
| 113 |
+
return T.Compose(
|
| 114 |
+
[
|
| 115 |
+
T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
|
| 116 |
+
T.ToTensor(),
|
| 117 |
+
T.Normalize(mean=MEAN, std=STD),
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
return T.Compose([T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])
|
| 122 |
+
|
| 123 |
+
def load_model(self):
|
| 124 |
+
"""Load model"""
|
| 125 |
+
model = VPRModel(
|
| 126 |
+
backbone_arch="dinov2_vitb14",
|
| 127 |
+
backbone_config={
|
| 128 |
+
"num_trainable_blocks": 4,
|
| 129 |
+
"return_token": True,
|
| 130 |
+
"norm_layer": True,
|
| 131 |
+
},
|
| 132 |
+
agg_arch="SALAD",
|
| 133 |
+
agg_config={
|
| 134 |
+
"num_channels": 768,
|
| 135 |
+
"num_clusters": 64,
|
| 136 |
+
"cluster_dim": 128,
|
| 137 |
+
"token_dim": 256,
|
| 138 |
+
},
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
model.load_state_dict(torch.load(self.ckpt_path))
|
| 142 |
+
model = model.eval()
|
| 143 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 144 |
+
model = model.to(device)
|
| 145 |
+
print(f"Model loaded: {self.ckpt_path}")
|
| 146 |
+
|
| 147 |
+
self.model = model
|
| 148 |
+
self.device = device
|
| 149 |
+
return model, device
|
| 150 |
+
|
| 151 |
+
def get_image_paths(self):
|
| 152 |
+
"""Get paths of all image files in directory"""
|
| 153 |
+
image_extensions = [".jpg", ".jpeg", ".png"]
|
| 154 |
+
image_paths = []
|
| 155 |
+
|
| 156 |
+
for ext in image_extensions:
|
| 157 |
+
image_paths.extend(list(Path(self.image_dir).glob(f"*{ext}")))
|
| 158 |
+
image_paths.extend(list(Path(self.image_dir).glob(f"*{ext.upper()}")))
|
| 159 |
+
|
| 160 |
+
image_paths = sorted(image_paths)
|
| 161 |
+
self.image_paths = image_paths
|
| 162 |
+
return image_paths
|
| 163 |
+
|
| 164 |
+
def extract_descriptors(self):
|
| 165 |
+
"""Extract image feature descriptors"""
|
| 166 |
+
if self.model is None or self.device is None:
|
| 167 |
+
self.load_model()
|
| 168 |
+
|
| 169 |
+
if self.image_paths is None:
|
| 170 |
+
self.get_image_paths()
|
| 171 |
+
|
| 172 |
+
transform = self._input_transform(self.image_size)
|
| 173 |
+
descriptors = []
|
| 174 |
+
|
| 175 |
+
for i in tqdm(
|
| 176 |
+
range(0, len(self.image_paths), self.batch_size), desc="Extracting features"
|
| 177 |
+
):
|
| 178 |
+
batch_paths = self.image_paths[i : i + self.batch_size]
|
| 179 |
+
batch_imgs = []
|
| 180 |
+
|
| 181 |
+
for path in batch_paths:
|
| 182 |
+
try:
|
| 183 |
+
img = Image.open(path).convert("RGB")
|
| 184 |
+
img = transform(img)
|
| 185 |
+
batch_imgs.append(img)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Error processing image {path}: {e}")
|
| 188 |
+
img = (
|
| 189 |
+
torch.zeros(3, 224, 224)
|
| 190 |
+
if self.image_size is None
|
| 191 |
+
else torch.zeros(3, self.image_size[0], self.image_size[1])
|
| 192 |
+
)
|
| 193 |
+
batch_imgs.append(img)
|
| 194 |
+
|
| 195 |
+
batch_tensor = torch.stack(batch_imgs).to(self.device)
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
with torch.autocast(
|
| 199 |
+
device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16
|
| 200 |
+
):
|
| 201 |
+
batch_descriptors = self.model(batch_tensor).cpu()
|
| 202 |
+
|
| 203 |
+
descriptors.append(batch_descriptors)
|
| 204 |
+
|
| 205 |
+
self.descriptors = torch.cat(descriptors)
|
| 206 |
+
return self.descriptors
|
| 207 |
+
|
| 208 |
+
def _apply_nms_filter(self, loop_closures, nms_threshold):
|
| 209 |
+
"""Apply Non-Maximum Suppression (NMS) filtering to loop pairs"""
|
| 210 |
+
if not loop_closures or nms_threshold <= 0:
|
| 211 |
+
return loop_closures
|
| 212 |
+
|
| 213 |
+
sorted_loops = sorted(loop_closures, key=lambda x: x[2], reverse=True)
|
| 214 |
+
filtered_loops = []
|
| 215 |
+
suppressed = set()
|
| 216 |
+
|
| 217 |
+
max_frame = max(max(idx1, idx2) for idx1, idx2, _ in loop_closures)
|
| 218 |
+
|
| 219 |
+
for idx1, idx2, sim in sorted_loops:
|
| 220 |
+
if idx1 in suppressed or idx2 in suppressed:
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
filtered_loops.append((idx1, idx2, sim))
|
| 224 |
+
|
| 225 |
+
suppress_range = set()
|
| 226 |
+
|
| 227 |
+
start1 = max(0, idx1 - nms_threshold)
|
| 228 |
+
end1 = min(idx1 + nms_threshold + 1, idx2)
|
| 229 |
+
suppress_range.update(range(start1, end1))
|
| 230 |
+
|
| 231 |
+
start2 = max(idx1 + 1, idx2 - nms_threshold)
|
| 232 |
+
end2 = min(idx2 + nms_threshold + 1, max_frame + 1)
|
| 233 |
+
suppress_range.update(range(start2, end2))
|
| 234 |
+
|
| 235 |
+
suppressed.update(suppress_range)
|
| 236 |
+
|
| 237 |
+
return filtered_loops
|
| 238 |
+
|
| 239 |
+
def _ensure_decending_order(self, tuples_list):
|
| 240 |
+
return [(max(a, b), min(a, b), score) for a, b, score in tuples_list]
|
| 241 |
+
|
| 242 |
+
def find_loop_closures(self):
|
| 243 |
+
"""Find loop closures"""
|
| 244 |
+
if self.descriptors is None:
|
| 245 |
+
self.extract_descriptors()
|
| 246 |
+
|
| 247 |
+
embed_size = self.descriptors.shape[1]
|
| 248 |
+
faiss_index = faiss.IndexFlatIP(embed_size)
|
| 249 |
+
|
| 250 |
+
normalized_descriptors = self.descriptors.numpy()
|
| 251 |
+
faiss_index.add(normalized_descriptors)
|
| 252 |
+
|
| 253 |
+
similarities, indices = faiss_index.search(
|
| 254 |
+
normalized_descriptors, self.top_k + 1
|
| 255 |
+
) # +1 because self is most similar
|
| 256 |
+
|
| 257 |
+
loop_closures = []
|
| 258 |
+
for i in range(len(self.descriptors)):
|
| 259 |
+
# Skip first result (self)
|
| 260 |
+
for j in range(1, self.top_k + 1):
|
| 261 |
+
neighbor_idx = indices[i, j]
|
| 262 |
+
similarity = similarities[i, j]
|
| 263 |
+
|
| 264 |
+
if similarity > self.similarity_threshold and abs(i - neighbor_idx) > 10:
|
| 265 |
+
if i < neighbor_idx:
|
| 266 |
+
loop_closures.append((i, neighbor_idx, similarity))
|
| 267 |
+
else:
|
| 268 |
+
loop_closures.append((neighbor_idx, i, similarity))
|
| 269 |
+
|
| 270 |
+
loop_closures = list(set(loop_closures))
|
| 271 |
+
loop_closures.sort(key=lambda x: x[2], reverse=True)
|
| 272 |
+
|
| 273 |
+
if self.use_nms and self.nms_threshold > 0:
|
| 274 |
+
loop_closures = self._apply_nms_filter(loop_closures, self.nms_threshold)
|
| 275 |
+
|
| 276 |
+
self.loop_closures = self._ensure_decending_order(loop_closures)
|
| 277 |
+
return self.loop_closures
|
| 278 |
+
|
| 279 |
+
def save_results(self):
|
| 280 |
+
"""Save loop detection results to file"""
|
| 281 |
+
if self.loop_closures is None:
|
| 282 |
+
self.find_loop_closures()
|
| 283 |
+
|
| 284 |
+
with open(self.output, "w") as f:
|
| 285 |
+
f.write("# Loop Detection Results (index1, index2, similarity)\n")
|
| 286 |
+
if self.use_nms:
|
| 287 |
+
f.write(f"# NMS filtering applied, threshold: {self.nms_threshold}\n")
|
| 288 |
+
f.write("\n# Loop pairs:\n")
|
| 289 |
+
for i, j, sim in self.loop_closures:
|
| 290 |
+
f.write(f"{i}, {j}, {sim:.4f}\n")
|
| 291 |
+
f.write("\n# Image path list:\n")
|
| 292 |
+
for i, path in enumerate(self.image_paths):
|
| 293 |
+
f.write(f"# {i}: {path}\n")
|
| 294 |
+
|
| 295 |
+
print(f"Found {len(self.loop_closures)} loop pairs, results saved to {self.output}")
|
| 296 |
+
if self.use_nms:
|
| 297 |
+
print(f"NMS filtering applied, threshold: {self.nms_threshold}")
|
| 298 |
+
|
| 299 |
+
if self.loop_closures:
|
| 300 |
+
print("\nTop 10 loop pairs:")
|
| 301 |
+
for i, (idx1, idx2, sim) in enumerate(self.loop_closures[:10]):
|
| 302 |
+
print(f"{idx1}, {idx2}, similarity: {sim:.4f}")
|
| 303 |
+
if i >= 9:
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
def get_loop_list(self):
|
| 307 |
+
return [(idx1, idx2) for idx1, idx2, _ in self.loop_closures]
|
| 308 |
+
|
| 309 |
+
def run(self):
|
| 310 |
+
"""Run complete loop detection pipeline"""
|
| 311 |
+
print("Loading model...")
|
| 312 |
+
if self.model is None:
|
| 313 |
+
self.load_model()
|
| 314 |
+
|
| 315 |
+
self.get_image_paths()
|
| 316 |
+
if not self.image_paths:
|
| 317 |
+
print(f"No image files found in {self.image_dir}")
|
| 318 |
+
return
|
| 319 |
+
|
| 320 |
+
print(f"Found {len(self.image_paths)} image files")
|
| 321 |
+
|
| 322 |
+
self.extract_descriptors()
|
| 323 |
+
|
| 324 |
+
self.find_loop_closures()
|
| 325 |
+
|
| 326 |
+
self.save_results()
|
| 327 |
+
|
| 328 |
+
return self.loop_closures
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def main():
|
| 332 |
+
parser = argparse.ArgumentParser(description="Loop detection using SALAD model")
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--image_dir",
|
| 335 |
+
type=str,
|
| 336 |
+
default="/media/deng/Data/KITTIdataset/data_odometry_color/dataset/sequences/00/image_2",
|
| 337 |
+
help="Directory path containing images",
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--ckpt_path", type=str, default="./weights/dino_salad.ckpt", help="Model checkpoint path"
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--image_size",
|
| 344 |
+
nargs=2,
|
| 345 |
+
type=int,
|
| 346 |
+
default=[336, 336],
|
| 347 |
+
help="Image resize dimensions [height width]",
|
| 348 |
+
)
|
| 349 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing")
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--similarity_threshold",
|
| 352 |
+
type=float,
|
| 353 |
+
default=0.7,
|
| 354 |
+
help="Similarity threshold for loop closure",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--top_k", type=int, default=5, help="Number of nearest neighbors to check for each image"
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument("--output", type=str, default="loop_closures.txt", help="Output file path")
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--use_nms",
|
| 362 |
+
action="store_true",
|
| 363 |
+
default=True,
|
| 364 |
+
help="Whether to use Non-Maximum Suppression (NMS) filtering",
|
| 365 |
+
)
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--nms_threshold",
|
| 368 |
+
type=int,
|
| 369 |
+
default=25,
|
| 370 |
+
help="NMS threshold for minimum frame difference between loop pairs",
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
args = parser.parse_args()
|
| 374 |
+
|
| 375 |
+
detector = LoopDetector(
|
| 376 |
+
image_dir=args.image_dir,
|
| 377 |
+
ckpt_path=args.ckpt_path,
|
| 378 |
+
image_size=args.image_size,
|
| 379 |
+
batch_size=args.batch_size,
|
| 380 |
+
similarity_threshold=args.similarity_threshold,
|
| 381 |
+
top_k=args.top_k,
|
| 382 |
+
use_nms=args.use_nms,
|
| 383 |
+
nms_threshold=args.nms_threshold,
|
| 384 |
+
output=args.output,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
detector.run()
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
if __name__ == "__main__":
|
| 391 |
+
main()
|
Depth-Anything-3/da3_streaming/loop_utils/loop_refinement.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import numba as nb
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pypose as pp
|
| 20 |
+
import sim3solve
|
| 21 |
+
import torch
|
| 22 |
+
from einops import parse_shape, rearrange
|
| 23 |
+
from scipy.spatial.transform import Rotation as R
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_pypose_Sim3(rot, t, s):
|
| 27 |
+
q = R.from_matrix(rot).as_quat()
|
| 28 |
+
data = np.concatenate([t, q, np.array(s).reshape((1,))])
|
| 29 |
+
return pp.Sim3(data)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def SE3_to_Sim3(x: pp.SE3):
|
| 33 |
+
out = torch.cat((x.data, torch.ones_like(x.data[..., :1])), dim=-1)
|
| 34 |
+
return pp.Sim3(out)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@nb.njit(cache=True)
|
| 38 |
+
def _format(es):
|
| 39 |
+
return np.asarray(es, dtype=np.int64).reshape((-1, 2))[1:]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@nb.njit(cache=True)
|
| 43 |
+
def reduce_edges(flow_mag, ii, jj, max_num_edges, nms):
|
| 44 |
+
es = [(-1, -1)]
|
| 45 |
+
|
| 46 |
+
if ii.size == 0:
|
| 47 |
+
return _format(es)
|
| 48 |
+
|
| 49 |
+
Ni, Nj = (ii.max() + 1), (jj.max() + 1)
|
| 50 |
+
ignore_lookup = np.zeros((Ni, Nj), dtype=nb.bool_)
|
| 51 |
+
|
| 52 |
+
idxs = np.argsort(flow_mag)
|
| 53 |
+
for idx in idxs: # edge index
|
| 54 |
+
|
| 55 |
+
if len(es) > max_num_edges:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
i = ii[idx]
|
| 59 |
+
j = jj[idx]
|
| 60 |
+
mag = flow_mag[idx]
|
| 61 |
+
|
| 62 |
+
if (j - i) < 30:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
if mag >= 1000: # i.e., inf
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
if ignore_lookup[i, j]:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
es.append((i, j))
|
| 72 |
+
|
| 73 |
+
for di in range(-nms, nms + 1):
|
| 74 |
+
i1 = i + di
|
| 75 |
+
|
| 76 |
+
if 0 <= i1 < Ni:
|
| 77 |
+
ignore_lookup[i1, j] = True
|
| 78 |
+
|
| 79 |
+
return _format(es)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@nb.njit(cache=True)
|
| 83 |
+
def umeyama_alignment(x: np.ndarray, y: np.ndarray):
|
| 84 |
+
"""
|
| 85 |
+
The following function was copied from:
|
| 86 |
+
https://github.com/MichaelGrupp/evo/blob/3067541b350528fe46375423e5bc3a7c42c06c63/evo/core/geometry.py#L35
|
| 87 |
+
|
| 88 |
+
Computes the least squares solution parameters of an Sim(m) matrix
|
| 89 |
+
that minimizes the distance between a set of registered points.
|
| 90 |
+
Umeyama, Shinji: Least-squares estimation of transformation parameters
|
| 91 |
+
between two point patterns. IEEE PAMI, 1991
|
| 92 |
+
:param x: mxn matrix of points, m = dimension, n = nr. of data points
|
| 93 |
+
:param y: mxn matrix of points, m = dimension, n = nr. of data points
|
| 94 |
+
:param with_scale: set to True to align also the scale (default: 1.0 scale)
|
| 95 |
+
:return: r, t, c - rotation matrix, translation vector and scale factor
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# m = dimension, n = nr. of data points
|
| 99 |
+
m, n = x.shape
|
| 100 |
+
|
| 101 |
+
# means, eq. 34 and 35
|
| 102 |
+
mean_x = x.sum(axis=1) / n
|
| 103 |
+
mean_y = y.sum(axis=1) / n
|
| 104 |
+
|
| 105 |
+
# variance, eq. 36
|
| 106 |
+
# "transpose" for column subtraction
|
| 107 |
+
sigma_x = 1.0 / n * (np.linalg.norm(x - mean_x[:, np.newaxis]) ** 2)
|
| 108 |
+
|
| 109 |
+
# covariance matrix, eq. 38
|
| 110 |
+
outer_sum = np.zeros((m, m))
|
| 111 |
+
for i in range(n):
|
| 112 |
+
outer_sum += np.outer((y[:, i] - mean_y), (x[:, i] - mean_x))
|
| 113 |
+
cov_xy = np.multiply(1.0 / n, outer_sum)
|
| 114 |
+
|
| 115 |
+
# SVD (text betw. eq. 38 and 39)
|
| 116 |
+
u, d, v = np.linalg.svd(cov_xy)
|
| 117 |
+
if np.count_nonzero(d > np.finfo(d.dtype).eps) < m - 1:
|
| 118 |
+
return None, None, None # Degenerate covariance rank, Umeyama alignment is not possible
|
| 119 |
+
|
| 120 |
+
# S matrix, eq. 43
|
| 121 |
+
s = np.eye(m)
|
| 122 |
+
if np.linalg.det(u) * np.linalg.det(v) < 0.0:
|
| 123 |
+
# Ensure a RHS coordinate system (Kabsch algorithm).
|
| 124 |
+
s[m - 1, m - 1] = -1
|
| 125 |
+
|
| 126 |
+
# rotation, eq. 40
|
| 127 |
+
r = u.dot(s).dot(v)
|
| 128 |
+
|
| 129 |
+
# scale & translation, eq. 42 and 41
|
| 130 |
+
c = 1 / sigma_x * np.trace(np.diag(d).dot(s))
|
| 131 |
+
t = mean_y - np.multiply(c, r.dot(mean_x))
|
| 132 |
+
|
| 133 |
+
return r, t, c
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@nb.njit(cache=True)
|
| 137 |
+
def ransac_umeyama(src_points, dst_points, iterations=1, threshold=0.1):
|
| 138 |
+
best_inliers = 0
|
| 139 |
+
best_R = None
|
| 140 |
+
best_t = None
|
| 141 |
+
best_s = None
|
| 142 |
+
for _ in range(iterations):
|
| 143 |
+
# Randomly select three points
|
| 144 |
+
indices = np.random.choice(src_points.shape[0], 3, replace=False)
|
| 145 |
+
src_sample = src_points[indices]
|
| 146 |
+
dst_sample = dst_points[indices]
|
| 147 |
+
|
| 148 |
+
# Estimate transformation
|
| 149 |
+
R, t, s = umeyama_alignment(src_sample.T, dst_sample.T)
|
| 150 |
+
if t is None:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
# Apply transformation
|
| 154 |
+
transformed = (src_points @ (R * s).T) + t
|
| 155 |
+
|
| 156 |
+
# Count inliers (not ideal because depends on scene scale)
|
| 157 |
+
distances = np.sum((transformed - dst_points) ** 2, axis=1) ** 0.5
|
| 158 |
+
inlier_mask = distances < threshold
|
| 159 |
+
inliers = np.sum(inlier_mask)
|
| 160 |
+
|
| 161 |
+
# Update best transformation
|
| 162 |
+
if inliers > best_inliers:
|
| 163 |
+
best_inliers = inliers
|
| 164 |
+
best_R, best_t, best_s = umeyama_alignment(
|
| 165 |
+
src_points[inlier_mask].T, dst_points[inlier_mask].T
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return best_R, best_t, best_s, best_inliers
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def batch_jacobian(func, x):
|
| 172 |
+
def _func_sum(*x):
|
| 173 |
+
return func(*x).sum(dim=0)
|
| 174 |
+
|
| 175 |
+
_, b, c = torch.autograd.functional.jacobian(_func_sum, x, vectorize=True)
|
| 176 |
+
return rearrange(torch.stack((b, c)), "N O B I -> N B O I", N=2)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _residual(C, Gi, Gj):
|
| 180 |
+
assert parse_shape(C, "N _") == parse_shape(Gi, "N _") == parse_shape(Gj, "N _")
|
| 181 |
+
out = C @ pp.Exp(Gi) @ pp.Exp(Gj).Inv()
|
| 182 |
+
return out.Log().tensor()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def residual(Ginv, input_poses, dSloop, ii, jj, jacobian=False):
|
| 186 |
+
|
| 187 |
+
# prep
|
| 188 |
+
device = Ginv.device
|
| 189 |
+
assert parse_shape(input_poses, "_ d") == dict(d=7)
|
| 190 |
+
pred_inv_poses = SE3_to_Sim3(input_poses).Inv()
|
| 191 |
+
|
| 192 |
+
# free variables
|
| 193 |
+
n, _ = pred_inv_poses.shape
|
| 194 |
+
kk = torch.arange(1, n, device=device)
|
| 195 |
+
ll = kk - 1
|
| 196 |
+
|
| 197 |
+
# constants
|
| 198 |
+
Ti = pred_inv_poses[kk]
|
| 199 |
+
Tj = pred_inv_poses[ll]
|
| 200 |
+
dSij = Tj @ Ti.Inv()
|
| 201 |
+
|
| 202 |
+
constants = torch.cat((dSij, dSloop), dim=0)
|
| 203 |
+
iii = torch.cat((kk, ii))
|
| 204 |
+
jjj = torch.cat((ll, jj))
|
| 205 |
+
resid = _residual(constants, Ginv[iii], Ginv[jjj])
|
| 206 |
+
|
| 207 |
+
if not jacobian:
|
| 208 |
+
return resid
|
| 209 |
+
|
| 210 |
+
J_Ginv_i, J_Ginv_j = batch_jacobian(_residual, (constants, Ginv[iii], Ginv[jjj]))
|
| 211 |
+
return resid, (J_Ginv_i, J_Ginv_j, iii, jjj)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def perform_updates(
|
| 215 |
+
input_poses, dSloop, ii_loop, jj_loop, iters=30, ep=0.0, lmbda=1e-6, fix_opt_window=False
|
| 216 |
+
):
|
| 217 |
+
"""Run the Levenberg Marquardt algorithm"""
|
| 218 |
+
|
| 219 |
+
input_poses = input_poses.clone()
|
| 220 |
+
|
| 221 |
+
if fix_opt_window:
|
| 222 |
+
freen = torch.cat((ii_loop, jj_loop)).max().item() + 1
|
| 223 |
+
else:
|
| 224 |
+
freen = -1
|
| 225 |
+
|
| 226 |
+
Ginv = SE3_to_Sim3(input_poses).Inv().Log()
|
| 227 |
+
|
| 228 |
+
residual_history = []
|
| 229 |
+
|
| 230 |
+
for itr in range(iters):
|
| 231 |
+
resid, (J_Ginv_i, J_Ginv_j, iii, jjj) = residual(
|
| 232 |
+
Ginv, input_poses, dSloop, ii_loop, jj_loop, jacobian=True
|
| 233 |
+
)
|
| 234 |
+
residual_history.append(resid.square().mean().item())
|
| 235 |
+
print(f"resid: {resid.square().mean().item()}")
|
| 236 |
+
(delta_pose,) = sim3solve.solve_system(
|
| 237 |
+
J_Ginv_i, J_Ginv_j, iii, jjj, resid, ep, lmbda, freen
|
| 238 |
+
)
|
| 239 |
+
assert Ginv.shape == delta_pose.shape
|
| 240 |
+
Ginv_tmp = Ginv + delta_pose
|
| 241 |
+
|
| 242 |
+
new_resid = residual(Ginv_tmp, input_poses, dSloop, ii_loop, jj_loop)
|
| 243 |
+
if new_resid.square().mean() < residual_history[-1]:
|
| 244 |
+
Ginv = Ginv_tmp
|
| 245 |
+
lmbda /= 2
|
| 246 |
+
else:
|
| 247 |
+
lmbda *= 2
|
| 248 |
+
|
| 249 |
+
if (
|
| 250 |
+
(residual_history[-1] < 1e-5)
|
| 251 |
+
and (itr >= 4)
|
| 252 |
+
and ((residual_history[-5] / residual_history[-1]) < 1.5)
|
| 253 |
+
):
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
return pp.Exp(Ginv).Inv()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def pose_refinement(pred_poses, loop_poses, loop_ii, loop_jj):
|
| 260 |
+
|
| 261 |
+
final_est = perform_updates(pred_poses, loop_poses, loop_ii, loop_jj, iters=30)
|
| 262 |
+
|
| 263 |
+
safe_i = loop_ii.max().item() + 1
|
| 264 |
+
aa = SE3_to_Sim3(pred_poses.cpu())
|
| 265 |
+
final_est = (aa[[safe_i]] * final_est[[safe_i]].Inv()) * final_est
|
| 266 |
+
output = final_est[:safe_i]
|
| 267 |
+
|
| 268 |
+
return output
|
Depth-Anything-3/da3_streaming/loop_utils/sim3loop.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import time
|
| 18 |
+
from typing import List, Tuple
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pypose as pp
|
| 21 |
+
import torch
|
| 22 |
+
from fastloop.solve_python import solve_system_py
|
| 23 |
+
from scipy.spatial.transform import Rotation as R
|
| 24 |
+
|
| 25 |
+
cpp_version = False
|
| 26 |
+
try:
|
| 27 |
+
import sim3solve
|
| 28 |
+
|
| 29 |
+
cpp_version = True
|
| 30 |
+
except Exception:
|
| 31 |
+
print("Sim3solve of C++ Version failed, Will using Python Version.")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Sim3LoopOptimizer:
|
| 35 |
+
"""
|
| 36 |
+
Loop closure optimizer for sequences of Sim3 transformations
|
| 37 |
+
|
| 38 |
+
Input:
|
| 39 |
+
- sequential_transforms: List[Tuple[float, np.ndarray, np.ndarray]]
|
| 40 |
+
Each element is (s, R, t), where s is scalar scale, R is [3,3] rotation matrix,
|
| 41 |
+
t is [3,] translation vector
|
| 42 |
+
- loop_constraints: List[Tuple[int, int, Tuple[float, np.ndarray, np.ndarray]]]
|
| 43 |
+
Each element is (i, j, (s, R, t)), representing a loop closure constraint
|
| 44 |
+
from frame i to frame j
|
| 45 |
+
|
| 46 |
+
Output:
|
| 47 |
+
- Optimized sequential_transforms
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, config, device="cpu"):
|
| 51 |
+
self.device = device
|
| 52 |
+
self.config = config
|
| 53 |
+
self.solve_system_version = self.config["Loop"]["SIM3_Optimizer"][
|
| 54 |
+
"lang_version"
|
| 55 |
+
] # choose between 'python' and 'cpp'
|
| 56 |
+
|
| 57 |
+
if not cpp_version:
|
| 58 |
+
self.solve_system_version = "python"
|
| 59 |
+
|
| 60 |
+
def numpy_to_pypose_sim3(self, s: float, R_mat: np.ndarray, t_vec: np.ndarray) -> pp.Sim3:
|
| 61 |
+
"""Convert numpy s,R,t to pypose Sim3"""
|
| 62 |
+
q = R.from_matrix(R_mat).as_quat() # [x,y,z,w]
|
| 63 |
+
# pypose requires [t, q, s] format
|
| 64 |
+
data = np.concatenate([t_vec, q, np.array([s])])
|
| 65 |
+
return pp.Sim3(torch.from_numpy(data).float().to(self.device))
|
| 66 |
+
|
| 67 |
+
def pypose_sim3_to_numpy(self, sim3: pp.Sim3) -> Tuple[float, np.ndarray, np.ndarray]:
|
| 68 |
+
"""Convert pypose Sim3 to numpy s,R,t"""
|
| 69 |
+
data = sim3.data.cpu().numpy()
|
| 70 |
+
t = data[:3]
|
| 71 |
+
q = data[3:7] # [x,y,z,w]
|
| 72 |
+
s = data[7]
|
| 73 |
+
R_mat = R.from_quat(q).as_matrix()
|
| 74 |
+
return s, R_mat, t
|
| 75 |
+
|
| 76 |
+
def sequential_to_absolute_poses(
|
| 77 |
+
self, sequential_transforms: List[Tuple[float, np.ndarray, np.ndarray]]
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
"""
|
| 80 |
+
Convert sequential relative transforms to absolute pose sequence
|
| 81 |
+
S_01, S_12, S_23, ... -> T_0, T_1, T_2, T_3, ...
|
| 82 |
+
Where T_i is the transform from world coordinate to frame i
|
| 83 |
+
"""
|
| 84 |
+
len(sequential_transforms) + 1
|
| 85 |
+
poses = []
|
| 86 |
+
|
| 87 |
+
identity = pp.Sim3(
|
| 88 |
+
torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], device=self.device)
|
| 89 |
+
)
|
| 90 |
+
poses.append(identity)
|
| 91 |
+
|
| 92 |
+
current_pose = identity
|
| 93 |
+
for s, R_mat, t_vec in sequential_transforms:
|
| 94 |
+
rel_transform = self.numpy_to_pypose_sim3(s, R_mat, t_vec)
|
| 95 |
+
current_pose = current_pose @ rel_transform
|
| 96 |
+
poses.append(current_pose)
|
| 97 |
+
|
| 98 |
+
return torch.stack(poses)
|
| 99 |
+
|
| 100 |
+
def absolute_to_sequential_transforms(
|
| 101 |
+
self, absolute_poses: pp.Sim3
|
| 102 |
+
) -> List[Tuple[float, np.ndarray, np.ndarray]]:
|
| 103 |
+
"""
|
| 104 |
+
Convert absolute pose sequence back to sequential relative transforms
|
| 105 |
+
T_0, T_1, T_2, ... -> S_01, S_12, S_23, ...
|
| 106 |
+
"""
|
| 107 |
+
sequential_transforms = []
|
| 108 |
+
n = absolute_poses.shape[0]
|
| 109 |
+
|
| 110 |
+
for i in range(n - 1):
|
| 111 |
+
rel_transform = absolute_poses[i].Inv() @ absolute_poses[i + 1]
|
| 112 |
+
s, R_mat, t_vec = self.pypose_sim3_to_numpy(rel_transform)
|
| 113 |
+
sequential_transforms.append((s, R_mat, t_vec))
|
| 114 |
+
|
| 115 |
+
return sequential_transforms
|
| 116 |
+
|
| 117 |
+
def SE3_to_Sim3(self, x: torch.Tensor) -> pp.Sim3:
|
| 118 |
+
"""Convert SE3 to Sim3 (add unit scale)"""
|
| 119 |
+
ones = torch.ones_like(x[..., :1])
|
| 120 |
+
out = torch.cat((x, ones), dim=-1)
|
| 121 |
+
return pp.Sim3(out)
|
| 122 |
+
|
| 123 |
+
def build_loop_constraints(
|
| 124 |
+
self, loop_constraints: List[Tuple[int, int, Tuple[float, np.ndarray, np.ndarray]]]
|
| 125 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 126 |
+
"""Build loop closure constraints"""
|
| 127 |
+
if not loop_constraints:
|
| 128 |
+
return (
|
| 129 |
+
torch.empty(0, 8, device=self.device),
|
| 130 |
+
torch.empty(0, dtype=torch.long),
|
| 131 |
+
torch.empty(0, dtype=torch.long),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
loop_transforms = []
|
| 135 |
+
ii_loop = []
|
| 136 |
+
jj_loop = []
|
| 137 |
+
|
| 138 |
+
for i, j, (s, R_mat, t_vec) in loop_constraints:
|
| 139 |
+
loop_sim3 = self.numpy_to_pypose_sim3(s, R_mat, t_vec)
|
| 140 |
+
loop_transforms.append(loop_sim3.data)
|
| 141 |
+
ii_loop.append(i)
|
| 142 |
+
jj_loop.append(j)
|
| 143 |
+
|
| 144 |
+
dSloop = pp.Sim3(torch.stack(loop_transforms))
|
| 145 |
+
ii_loop = torch.tensor(ii_loop, dtype=torch.long, device=self.device)
|
| 146 |
+
jj_loop = torch.tensor(jj_loop, dtype=torch.long, device=self.device)
|
| 147 |
+
|
| 148 |
+
return dSloop, ii_loop, jj_loop
|
| 149 |
+
|
| 150 |
+
def residual(self, Ginv, input_poses, dSloop, ii, jj, jacobian=False):
|
| 151 |
+
"""Compute residuals (modified from original code)"""
|
| 152 |
+
|
| 153 |
+
def _residual(C, Gi, Gj):
|
| 154 |
+
out = C @ pp.Exp(Gi) @ pp.Exp(Gj).Inv()
|
| 155 |
+
return out.Log().tensor()
|
| 156 |
+
|
| 157 |
+
pred_inv_poses = pp.Sim3(input_poses).Inv()
|
| 158 |
+
|
| 159 |
+
n, _ = pred_inv_poses.shape
|
| 160 |
+
if n > 1:
|
| 161 |
+
kk = torch.arange(1, n, device=self.device)
|
| 162 |
+
ll = kk - 1
|
| 163 |
+
Ti = pred_inv_poses[kk]
|
| 164 |
+
Tj = pred_inv_poses[ll]
|
| 165 |
+
dSij = Tj @ Ti.Inv()
|
| 166 |
+
else:
|
| 167 |
+
kk = torch.empty(0, dtype=torch.long, device=self.device)
|
| 168 |
+
ll = torch.empty(0, dtype=torch.long, device=self.device)
|
| 169 |
+
dSij = pp.Sim3(torch.empty(0, 8, device=self.device))
|
| 170 |
+
|
| 171 |
+
constants = (
|
| 172 |
+
torch.cat((dSij.data, dSloop.data), dim=0) if dSloop.shape[0] > 0 else dSij.data
|
| 173 |
+
)
|
| 174 |
+
if constants.shape[0] > 0:
|
| 175 |
+
constants = pp.Sim3(constants)
|
| 176 |
+
iii = torch.cat((kk, ii))
|
| 177 |
+
jjj = torch.cat((ll, jj))
|
| 178 |
+
resid = _residual(constants, Ginv[iii], Ginv[jjj])
|
| 179 |
+
else:
|
| 180 |
+
iii = torch.empty(0, dtype=torch.long, device=self.device)
|
| 181 |
+
jjj = torch.empty(0, dtype=torch.long, device=self.device)
|
| 182 |
+
resid = torch.empty(0, device=self.device)
|
| 183 |
+
|
| 184 |
+
if not jacobian:
|
| 185 |
+
return resid
|
| 186 |
+
|
| 187 |
+
if constants.shape[0] > 0:
|
| 188 |
+
|
| 189 |
+
def batch_jacobian(func, x):
|
| 190 |
+
def _func_sum(*x):
|
| 191 |
+
return func(*x).sum(dim=0)
|
| 192 |
+
|
| 193 |
+
_, b, c = torch.autograd.functional.jacobian(_func_sum, x, vectorize=True)
|
| 194 |
+
from einops import rearrange
|
| 195 |
+
|
| 196 |
+
return rearrange(torch.stack((b, c)), "N O B I -> N B O I", N=2)
|
| 197 |
+
|
| 198 |
+
J_Ginv_i, J_Ginv_j = batch_jacobian(_residual, (constants, Ginv[iii], Ginv[jjj]))
|
| 199 |
+
else:
|
| 200 |
+
J_Ginv_i = torch.empty(0, device=self.device)
|
| 201 |
+
J_Ginv_j = torch.empty(0, device=self.device)
|
| 202 |
+
|
| 203 |
+
return resid, (J_Ginv_i, J_Ginv_j, iii, jjj)
|
| 204 |
+
|
| 205 |
+
def optimize(
|
| 206 |
+
self,
|
| 207 |
+
sequential_transforms: List[Tuple[float, np.ndarray, np.ndarray]],
|
| 208 |
+
loop_constraints: List[Tuple[int, int, Tuple[float, np.ndarray, np.ndarray]]],
|
| 209 |
+
max_iterations: int = None,
|
| 210 |
+
lambda_init: float = None,
|
| 211 |
+
) -> List[Tuple[float, np.ndarray, np.ndarray]]:
|
| 212 |
+
"""
|
| 213 |
+
Main optimization function
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
sequential_transforms: Input sequence of transforms
|
| 217 |
+
loop_constraints: List of loop closure constraints
|
| 218 |
+
max_iterations: Maximum iterations
|
| 219 |
+
lambda_init: Initial lambda for L-M algorithm
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Optimized sequence of transforms
|
| 223 |
+
"""
|
| 224 |
+
if max_iterations is None:
|
| 225 |
+
max_iterations = self.config["Loop"]["SIM3_Optimizer"]["max_iterations"]
|
| 226 |
+
if lambda_init is None:
|
| 227 |
+
lambda_init = eval(self.config["Loop"]["SIM3_Optimizer"]["lambda_init"])
|
| 228 |
+
|
| 229 |
+
input_poses = self.sequential_to_absolute_poses(sequential_transforms)
|
| 230 |
+
|
| 231 |
+
dSloop, ii_loop, jj_loop = self.build_loop_constraints(loop_constraints)
|
| 232 |
+
|
| 233 |
+
if len(loop_constraints) == 0:
|
| 234 |
+
print("Warning: No loop constraints provided, returning original transforms")
|
| 235 |
+
return sequential_transforms
|
| 236 |
+
|
| 237 |
+
Ginv = pp.Sim3(input_poses).Inv().Log()
|
| 238 |
+
lmbda = lambda_init
|
| 239 |
+
residual_history = []
|
| 240 |
+
|
| 241 |
+
print(
|
| 242 |
+
f"Starting optimization with {len(sequential_transforms)} poses \
|
| 243 |
+
and {len(loop_constraints)} loop constraints"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# L-M loop
|
| 247 |
+
for itr in range(max_iterations):
|
| 248 |
+
resid, (J_Ginv_i, J_Ginv_j, iii, jjj) = self.residual(
|
| 249 |
+
Ginv, input_poses, dSloop, ii_loop, jj_loop, jacobian=True
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if resid.numel() == 0:
|
| 253 |
+
print("No residuals to optimize")
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
current_cost = resid.square().mean().item()
|
| 257 |
+
residual_history.append(current_cost)
|
| 258 |
+
|
| 259 |
+
try: # Solve linear system
|
| 260 |
+
begin_time = time.time()
|
| 261 |
+
if self.solve_system_version == "cpp":
|
| 262 |
+
(delta_pose,) = sim3solve.solve_system(
|
| 263 |
+
J_Ginv_i, J_Ginv_j, iii, jjj, resid, 0.0, lmbda, -1
|
| 264 |
+
)
|
| 265 |
+
elif self.solve_system_version == "python":
|
| 266 |
+
delta_pose = solve_system_py(
|
| 267 |
+
J_Ginv_i, J_Ginv_j, iii, jjj, resid, 0.0, lmbda, -1
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
print("Solver version has not been chosen! ('python' or 'cpp')")
|
| 271 |
+
end_time = time.time()
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"Solver failed at iteration {itr}: {e}")
|
| 274 |
+
break
|
| 275 |
+
|
| 276 |
+
Ginv_tmp = Ginv + delta_pose
|
| 277 |
+
|
| 278 |
+
new_resid = self.residual(Ginv_tmp, input_poses, dSloop, ii_loop, jj_loop)
|
| 279 |
+
new_cost = new_resid.square().mean().item() if new_resid.numel() > 0 else float("inf")
|
| 280 |
+
|
| 281 |
+
# L-M
|
| 282 |
+
if new_cost < current_cost:
|
| 283 |
+
Ginv = Ginv_tmp
|
| 284 |
+
lmbda /= 2
|
| 285 |
+
print(
|
| 286 |
+
f"Iteration {itr}: cost {current_cost:.14f} -> {new_cost:.14f} (accepted)",
|
| 287 |
+
end=" | ",
|
| 288 |
+
)
|
| 289 |
+
else:
|
| 290 |
+
lmbda *= 2
|
| 291 |
+
print(
|
| 292 |
+
f"Iteration {itr}: cost {current_cost:.14f} -> {new_cost:.14f} (rej) ",
|
| 293 |
+
end=" | ",
|
| 294 |
+
) # more readible to accepted
|
| 295 |
+
|
| 296 |
+
print(
|
| 297 |
+
f"Time of solver ({self.solve_system_version}): \
|
| 298 |
+
{(end_time - begin_time)*1000:.4f} ms"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if (current_cost < 1e-5) and (itr >= 4):
|
| 302 |
+
if len(residual_history) >= 5:
|
| 303 |
+
improvement_ratio = residual_history[-5] / residual_history[-1]
|
| 304 |
+
if improvement_ratio < 1.5:
|
| 305 |
+
print(f"Converged at iteration {itr}")
|
| 306 |
+
break
|
| 307 |
+
|
| 308 |
+
optimized_absolute_poses = pp.Exp(Ginv).Inv()
|
| 309 |
+
|
| 310 |
+
optimized_sequential = self.absolute_to_sequential_transforms(optimized_absolute_poses)
|
| 311 |
+
|
| 312 |
+
print(
|
| 313 |
+
f"Optimization completed. Final cost: \
|
| 314 |
+
{residual_history[-1] if residual_history else 'N/A'}"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return optimized_sequential
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ======== TEST CODE ========
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def create_ring_transforms(num_poses=6, radius=5.0, rot_noise_deg=2.0):
|
| 324 |
+
"""Generate a ring of Sim3 transforms with rotation, adding slight rotational noise"""
|
| 325 |
+
transforms = []
|
| 326 |
+
angle_step = 2 * np.pi / num_poses
|
| 327 |
+
|
| 328 |
+
for i in range(num_poses):
|
| 329 |
+
angle = angle_step
|
| 330 |
+
|
| 331 |
+
# Main rotation (around Z-axis)
|
| 332 |
+
R_z = R.from_euler("z", angle, degrees=False)
|
| 333 |
+
|
| 334 |
+
# Add slight rotational noise (Gaussian noise in degrees)
|
| 335 |
+
noise_angles_deg = np.random.normal(loc=0.0, scale=rot_noise_deg, size=3)
|
| 336 |
+
R_noise = R.from_euler("xyz", noise_angles_deg, degrees=True)
|
| 337 |
+
|
| 338 |
+
# Combine rotations
|
| 339 |
+
R_mat = (R_noise * R_z).as_matrix()
|
| 340 |
+
|
| 341 |
+
# Translation: simulate a circular trajectory
|
| 342 |
+
t = np.array([radius * np.sin(angle), radius * (1 - np.cos(angle)), 0.0])
|
| 343 |
+
|
| 344 |
+
s = np.random.uniform(0.8, 1.2)
|
| 345 |
+
|
| 346 |
+
transforms.append((s, R_mat, t))
|
| 347 |
+
|
| 348 |
+
return transforms
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def example_usage():
|
| 352 |
+
optimizer = Sim3LoopOptimizer(solve_system_version="cpp")
|
| 353 |
+
|
| 354 |
+
# Build rotating ring
|
| 355 |
+
sequential_transforms = create_ring_transforms(num_poses=20, radius=3.0)
|
| 356 |
+
|
| 357 |
+
# Add loop closure constraint: from frame 5 back to frame 0
|
| 358 |
+
loop_constraints = [
|
| 359 |
+
(20, 0, (1.0, np.eye(3), np.zeros(3))) # Temporary unit loop for simulation
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
# Trajectory before/after optimization
|
| 363 |
+
input_abs_poses = optimizer.sequential_to_absolute_poses(sequential_transforms)
|
| 364 |
+
optimized_transforms = optimizer.optimize(sequential_transforms, loop_constraints)
|
| 365 |
+
optimized_abs_poses = optimizer.sequential_to_absolute_poses(optimized_transforms)
|
| 366 |
+
|
| 367 |
+
def extract_xyz(pose_tensor):
|
| 368 |
+
poses = pose_tensor.cpu().numpy()
|
| 369 |
+
return poses[:, 0], poses[:, 1], poses[:, 2]
|
| 370 |
+
|
| 371 |
+
x0, y0, z0 = extract_xyz(input_abs_poses)
|
| 372 |
+
x1, y1, z1 = extract_xyz(optimized_abs_poses)
|
| 373 |
+
|
| 374 |
+
# Visualize trajectory
|
| 375 |
+
import matplotlib
|
| 376 |
+
import matplotlib.pyplot as plt
|
| 377 |
+
|
| 378 |
+
matplotlib.use("Agg")
|
| 379 |
+
|
| 380 |
+
plt.figure(figsize=(8, 6))
|
| 381 |
+
plt.plot(x0, y0, "o--", label="Before Optimization")
|
| 382 |
+
plt.plot(x1, y1, "o-", label="After Optimization")
|
| 383 |
+
for i, j, _ in loop_constraints:
|
| 384 |
+
plt.plot([x0[i], x0[j]], [y0[i], y0[j]], "r--", label="Loop (Before)" if i == 5 else "")
|
| 385 |
+
plt.plot([x1[i], x1[j]], [y1[i], y1[j]], "g-", label="Loop (After)" if i == 5 else "")
|
| 386 |
+
plt.gca().set_aspect("equal")
|
| 387 |
+
plt.title("Sim3 Loop Closure Optimization (Rotating Ring)")
|
| 388 |
+
plt.xlabel("x")
|
| 389 |
+
plt.ylabel("y")
|
| 390 |
+
plt.legend()
|
| 391 |
+
plt.grid(True)
|
| 392 |
+
plt.axis("equal")
|
| 393 |
+
plt.show()
|
| 394 |
+
|
| 395 |
+
return optimized_transforms
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == "__main__":
|
| 399 |
+
example_usage()
|
Depth-Anything-3/da3_streaming/loop_utils/sim3utils.py
ADDED
|
@@ -0,0 +1,1261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
|
| 16 |
+
|
| 17 |
+
import bisect
|
| 18 |
+
import glob
|
| 19 |
+
import os
|
| 20 |
+
import numpy as np
|
| 21 |
+
import trimesh
|
| 22 |
+
from loop_utils.alignment_torch import robust_weighted_estimate_sim3_torch
|
| 23 |
+
from loop_utils.alignment_triton import robust_weighted_estimate_sim3_triton
|
| 24 |
+
from numba import njit
|
| 25 |
+
from sklearn.linear_model import LinearRegression, RANSACRegressor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def accumulate_sim3_transforms(transforms):
|
| 29 |
+
"""
|
| 30 |
+
Accumulate adjacent SIM(3) transforms into transforms
|
| 31 |
+
from the initial frame to each subsequent frame.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
transforms: list, each element is a tuple (R, s, t)
|
| 35 |
+
R: 3x3 rotation matrix (np.array)
|
| 36 |
+
s: scale factor (scalar)
|
| 37 |
+
t: 3x1 translation vector (np.array)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Cumulative transforms list, each element is (R_cum, s_cum, t_cum)
|
| 41 |
+
representing the transform from frame 0 to frame k
|
| 42 |
+
"""
|
| 43 |
+
if not transforms:
|
| 44 |
+
return []
|
| 45 |
+
|
| 46 |
+
cumulative_transforms = [transforms[0]]
|
| 47 |
+
|
| 48 |
+
for i in range(1, len(transforms)):
|
| 49 |
+
s_cum_prev, R_cum_prev, t_cum_prev = cumulative_transforms[i - 1]
|
| 50 |
+
s_next, R_next, t_next = transforms[i]
|
| 51 |
+
R_cum_new = R_cum_prev @ R_next
|
| 52 |
+
s_cum_new = s_cum_prev * s_next
|
| 53 |
+
t_cum_new = s_cum_prev * (R_cum_prev @ t_next) + t_cum_prev
|
| 54 |
+
cumulative_transforms.append((s_cum_new, R_cum_new, t_cum_new))
|
| 55 |
+
|
| 56 |
+
return cumulative_transforms
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def estimate_sim3(source_points, target_points):
|
| 60 |
+
mu_src = np.mean(source_points, axis=0)
|
| 61 |
+
mu_tgt = np.mean(target_points, axis=0)
|
| 62 |
+
|
| 63 |
+
src_centered = source_points - mu_src
|
| 64 |
+
tgt_centered = target_points - mu_tgt
|
| 65 |
+
|
| 66 |
+
scale_src = np.sqrt((src_centered**2).sum(axis=1).mean())
|
| 67 |
+
scale_tgt = np.sqrt((tgt_centered**2).sum(axis=1).mean())
|
| 68 |
+
s = scale_tgt / scale_src
|
| 69 |
+
|
| 70 |
+
src_scaled = src_centered * s
|
| 71 |
+
|
| 72 |
+
H = src_scaled.T @ tgt_centered
|
| 73 |
+
U, _, Vt = np.linalg.svd(H)
|
| 74 |
+
R = Vt.T @ U.T
|
| 75 |
+
if np.linalg.det(R) < 0:
|
| 76 |
+
Vt[2, :] *= -1
|
| 77 |
+
R = Vt.T @ U.T
|
| 78 |
+
|
| 79 |
+
t = mu_tgt - s * R @ mu_src
|
| 80 |
+
return s, R, t
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def align_point_maps(point_map1, conf1, point_map2, conf2, conf_threshold):
|
| 84 |
+
"""point_map2 -> point_map1"""
|
| 85 |
+
b1, _, _, _ = point_map1.shape
|
| 86 |
+
b2, _, _, _ = point_map2.shape
|
| 87 |
+
b = min(b1, b2)
|
| 88 |
+
|
| 89 |
+
aligned_points1 = []
|
| 90 |
+
aligned_points2 = []
|
| 91 |
+
|
| 92 |
+
for i in range(b):
|
| 93 |
+
mask1 = conf1[i] > conf_threshold
|
| 94 |
+
mask2 = conf2[i] > conf_threshold
|
| 95 |
+
valid_mask = mask1 & mask2
|
| 96 |
+
|
| 97 |
+
idx = np.where(valid_mask)
|
| 98 |
+
if len(idx[0]) == 0:
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
pts1 = point_map1[i][idx]
|
| 102 |
+
pts2 = point_map2[i][idx]
|
| 103 |
+
|
| 104 |
+
aligned_points1.append(pts1)
|
| 105 |
+
aligned_points2.append(pts2)
|
| 106 |
+
|
| 107 |
+
if len(aligned_points1) == 0:
|
| 108 |
+
raise ValueError("No matching point pairs were found!")
|
| 109 |
+
|
| 110 |
+
all_pts1 = np.concatenate(aligned_points1, axis=0)
|
| 111 |
+
all_pts2 = np.concatenate(aligned_points2, axis=0)
|
| 112 |
+
|
| 113 |
+
print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
|
| 114 |
+
s, R, t = estimate_sim3(all_pts2, all_pts1)
|
| 115 |
+
|
| 116 |
+
mean_error = compute_alignment_error(
|
| 117 |
+
point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
|
| 118 |
+
)
|
| 119 |
+
print(f"Mean error: {mean_error}")
|
| 120 |
+
|
| 121 |
+
return s, R, t
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def apply_sim3(points, s, R, t):
|
| 125 |
+
return (s * (R @ points.T)).T + t
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def apply_sim3_direct(point_maps, s, R, t):
|
| 129 |
+
# point_maps: (b, h, w, 3) -> (b, h, w, 3, 1)
|
| 130 |
+
point_maps_expanded = point_maps[..., np.newaxis] # (b, h, w, 3, 1)
|
| 131 |
+
|
| 132 |
+
# R: (3, 3) -> (b, h, w, 3, 1) = (3, 3) @ (3, 1)
|
| 133 |
+
rotated = np.matmul(R, point_maps_expanded) # (b, h, w, 3, 1)
|
| 134 |
+
rotated = rotated.squeeze(-1) # (b, h, w, 3)
|
| 135 |
+
transformed = s * rotated + t # (b, h, w, 3)
|
| 136 |
+
|
| 137 |
+
return transformed
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def compute_alignment_error(point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t):
|
| 141 |
+
"""
|
| 142 |
+
Compute the average point alignment error (using only original inputs)
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
point_map1: target point map (b, h, w, 3)
|
| 146 |
+
conf1: target confidence map (b, h, w)
|
| 147 |
+
point_map2: source point map (b, h, w, 3)
|
| 148 |
+
conf2: source confidence map (b, h, w)
|
| 149 |
+
conf_threshold: confidence threshold
|
| 150 |
+
s, R, t: transformation parameters
|
| 151 |
+
"""
|
| 152 |
+
b1, h1, w1, _ = point_map1.shape
|
| 153 |
+
b2, h2, w2, _ = point_map2.shape
|
| 154 |
+
b = min(b1, b2)
|
| 155 |
+
h = min(h1, h2)
|
| 156 |
+
w = min(w1, w2)
|
| 157 |
+
|
| 158 |
+
target_points = []
|
| 159 |
+
source_points = []
|
| 160 |
+
|
| 161 |
+
for i in range(b):
|
| 162 |
+
mask1 = conf1[i, :h, :w] > conf_threshold
|
| 163 |
+
mask2 = conf2[i, :h, :w] > conf_threshold
|
| 164 |
+
valid_mask = mask1 & mask2
|
| 165 |
+
|
| 166 |
+
idx = np.where(valid_mask)
|
| 167 |
+
if len(idx[0]) == 0:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
t_pts = point_map1[i, :h, :w][idx]
|
| 171 |
+
s_pts = point_map2[i, :h, :w][idx]
|
| 172 |
+
|
| 173 |
+
target_points.append(t_pts)
|
| 174 |
+
source_points.append(s_pts)
|
| 175 |
+
|
| 176 |
+
if len(target_points) == 0:
|
| 177 |
+
print("Warning: No matching point pairs found for error calculation")
|
| 178 |
+
return np.nan
|
| 179 |
+
|
| 180 |
+
all_target = np.concatenate(target_points, axis=0)
|
| 181 |
+
all_source = np.concatenate(source_points, axis=0)
|
| 182 |
+
|
| 183 |
+
transformed = (s * (R @ all_source.T)).T + t
|
| 184 |
+
|
| 185 |
+
errors = np.linalg.norm(transformed - all_target, axis=1)
|
| 186 |
+
|
| 187 |
+
mean_error = np.mean(errors)
|
| 188 |
+
std_error = np.std(errors)
|
| 189 |
+
median_error = np.median(errors)
|
| 190 |
+
max_error = np.max(errors)
|
| 191 |
+
|
| 192 |
+
print(
|
| 193 |
+
f"Alignment error statistics [using {len(errors)} points]: "
|
| 194 |
+
f"mean={mean_error:.4f}, std={std_error:.4f}, "
|
| 195 |
+
f"median={median_error:.4f}, max={max_error:.4f}"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return mean_error
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def save_confident_pointcloud(
|
| 202 |
+
points, colors, confs, output_path, conf_threshold, sample_ratio=1.0
|
| 203 |
+
):
|
| 204 |
+
"""
|
| 205 |
+
Filter points based on confidence threshold
|
| 206 |
+
and save as PLY file, with optional random sampling ratio.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
- points: np.ndarray, shape (H, W, 3) or (N, 3)
|
| 210 |
+
- colors: np.ndarray, shape (H, W, 3) or (N, 3)
|
| 211 |
+
- confs: np.ndarray, shape (H, W) or (N,)
|
| 212 |
+
- output_path: str, output PLY file path
|
| 213 |
+
- conf_threshold: float, confidence threshold for point filtering
|
| 214 |
+
- sample_ratio: float, sampling ratio (0 < sample_ratio <= 1.0)
|
| 215 |
+
"""
|
| 216 |
+
points = points.reshape(-1, 3).astype(np.float32, copy=False)
|
| 217 |
+
colors = colors.reshape(-1, 3).astype(np.uint8, copy=False)
|
| 218 |
+
confs = confs.reshape(-1).astype(np.float32, copy=False)
|
| 219 |
+
|
| 220 |
+
conf_mask = (confs >= conf_threshold) & (confs > 1e-5)
|
| 221 |
+
points = points[conf_mask]
|
| 222 |
+
colors = colors[conf_mask]
|
| 223 |
+
|
| 224 |
+
if 0 < sample_ratio < 1.0 and len(points) > 0:
|
| 225 |
+
num_samples = int(len(points) * sample_ratio)
|
| 226 |
+
indices = np.random.choice(len(points), num_samples, replace=False)
|
| 227 |
+
points = points[indices]
|
| 228 |
+
colors = colors[indices]
|
| 229 |
+
|
| 230 |
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
| 231 |
+
|
| 232 |
+
print(f"shape of sampled point: {points.shape}")
|
| 233 |
+
trimesh.PointCloud(points, colors=colors).export(output_path)
|
| 234 |
+
print(f"Saved point cloud with {len(points)} points to {output_path}")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def save_confident_pointcloud_batch(
|
| 238 |
+
points, colors, confs, output_path, conf_threshold, sample_ratio=1.0, batch_size=1000000
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
- points: np.ndarray, (b, H, W, 3) / (N, 3)
|
| 242 |
+
- colors: np.ndarray, (b, H, W, 3) / (N, 3)
|
| 243 |
+
- confs: np.ndarray, (b, H, W) / (N,)
|
| 244 |
+
- output_path: str
|
| 245 |
+
- conf_threshold: float,
|
| 246 |
+
- sample_ratio: float (0 < sample_ratio <= 1.0)
|
| 247 |
+
- batch_size: int
|
| 248 |
+
"""
|
| 249 |
+
if points.ndim == 2:
|
| 250 |
+
b = 1
|
| 251 |
+
points = points[np.newaxis, ...]
|
| 252 |
+
colors = colors[np.newaxis, ...]
|
| 253 |
+
confs = confs[np.newaxis, ...]
|
| 254 |
+
elif points.ndim == 4:
|
| 255 |
+
b = points.shape[0]
|
| 256 |
+
else:
|
| 257 |
+
raise ValueError("Unsupported points dimension. Must be 2 (N,3) or 4 (b,H,W,3)")
|
| 258 |
+
|
| 259 |
+
total_valid = 0
|
| 260 |
+
for i in range(b):
|
| 261 |
+
cfs = confs[i].reshape(-1)
|
| 262 |
+
total_valid += np.count_nonzero((cfs >= conf_threshold) & (cfs > 1e-5))
|
| 263 |
+
|
| 264 |
+
num_samples = int(total_valid * sample_ratio) if sample_ratio < 1.0 else total_valid
|
| 265 |
+
|
| 266 |
+
if num_samples == 0:
|
| 267 |
+
save_ply(np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8), output_path)
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
if sample_ratio == 1.0:
|
| 271 |
+
with open(output_path, "wb") as f:
|
| 272 |
+
write_ply_header(f, num_samples)
|
| 273 |
+
|
| 274 |
+
for i in range(b):
|
| 275 |
+
pts = points[i].reshape(-1, 3).astype(np.float32)
|
| 276 |
+
cls = colors[i].reshape(-1, 3).astype(np.uint8)
|
| 277 |
+
cfs = confs[i].reshape(-1).astype(np.float32)
|
| 278 |
+
|
| 279 |
+
mask = (cfs >= conf_threshold) & (cfs > 1e-5)
|
| 280 |
+
valid_pts = pts[mask]
|
| 281 |
+
valid_cls = cls[mask]
|
| 282 |
+
|
| 283 |
+
for j in range(0, len(valid_pts), batch_size):
|
| 284 |
+
batch_pts = valid_pts[j : j + batch_size]
|
| 285 |
+
batch_cls = valid_cls[j : j + batch_size]
|
| 286 |
+
write_ply_batch(f, batch_pts, batch_cls)
|
| 287 |
+
|
| 288 |
+
else:
|
| 289 |
+
reservoir_pts = np.zeros((num_samples, 3), dtype=np.float32)
|
| 290 |
+
reservoir_clr = np.zeros((num_samples, 3), dtype=np.uint8)
|
| 291 |
+
count = 0
|
| 292 |
+
|
| 293 |
+
for i in range(b):
|
| 294 |
+
pts = points[i].reshape(-1, 3).astype(np.float32)
|
| 295 |
+
cls = colors[i].reshape(-1, 3).astype(np.uint8)
|
| 296 |
+
cfs = confs[i].reshape(-1).astype(np.float32)
|
| 297 |
+
|
| 298 |
+
mask = (cfs >= conf_threshold) & (cfs > 1e-5)
|
| 299 |
+
valid_pts = pts[mask]
|
| 300 |
+
valid_cls = cls[mask]
|
| 301 |
+
n_valid = len(valid_pts)
|
| 302 |
+
|
| 303 |
+
if count < num_samples:
|
| 304 |
+
fill_count = min(num_samples - count, n_valid)
|
| 305 |
+
|
| 306 |
+
reservoir_pts[count : count + fill_count] = valid_pts[:fill_count]
|
| 307 |
+
reservoir_clr[count : count + fill_count] = valid_cls[:fill_count]
|
| 308 |
+
count += fill_count
|
| 309 |
+
|
| 310 |
+
if fill_count < n_valid:
|
| 311 |
+
remaining_pts = valid_pts[fill_count:]
|
| 312 |
+
remaining_cls = valid_cls[fill_count:]
|
| 313 |
+
|
| 314 |
+
count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
|
| 315 |
+
remaining_pts, remaining_cls, count, reservoir_pts, reservoir_clr
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
|
| 319 |
+
valid_pts, valid_cls, count, reservoir_pts, reservoir_clr
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
save_ply(reservoir_pts, reservoir_clr, output_path)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
""" The following function is deprecated"""
|
| 326 |
+
|
| 327 |
+
# def vectorized_reservoir_sampling(new_pts, new_cls, current_count, reservoir_pts, reservoir_clr):
|
| 328 |
+
# """
|
| 329 |
+
# - new_pts: (M, 3)
|
| 330 |
+
# - new_cls: (M, 3)
|
| 331 |
+
# - current_count
|
| 332 |
+
# - reservoir_pts: (K, 3)
|
| 333 |
+
# - reservoir_clr: (K, 3)
|
| 334 |
+
|
| 335 |
+
# """
|
| 336 |
+
# k = len(reservoir_pts)
|
| 337 |
+
# n_new = len(new_pts)
|
| 338 |
+
|
| 339 |
+
# rand_indices = np.random.randint(0, current_count + n_new, size=n_new)
|
| 340 |
+
|
| 341 |
+
# replace_mask = rand_indices < k
|
| 342 |
+
# replace_indices = rand_indices[replace_mask]
|
| 343 |
+
# replace_pts = new_pts[replace_mask]
|
| 344 |
+
# replace_cls = new_cls[replace_mask]
|
| 345 |
+
|
| 346 |
+
# reservoir_pts[replace_indices] = replace_pts
|
| 347 |
+
# reservoir_clr[replace_indices] = replace_cls
|
| 348 |
+
|
| 349 |
+
# return current_count + n_new, reservoir_pts, reservoir_clr
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
"""
|
| 353 |
+
Function `vectorized_reservoir_sampling` is not mathematically accurate in sampling.
|
| 354 |
+
This leads to inconsistent density in the downsampled point clouds.
|
| 355 |
+
The `optimized_vectorized_reservoir_sampling` function has fixed this bug.
|
| 356 |
+
|
| 357 |
+
Special thanks to @Horace89 for the detailed analysis and code assistance.
|
| 358 |
+
|
| 359 |
+
See https://github.com/DengKaiCQ/VGGT-Long/issues/28 for details
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def optimized_vectorized_reservoir_sampling(
|
| 364 |
+
new_points: np.ndarray,
|
| 365 |
+
new_colors: np.ndarray,
|
| 366 |
+
current_count: int,
|
| 367 |
+
reservoir_points: np.ndarray,
|
| 368 |
+
reservoir_colors: np.ndarray,
|
| 369 |
+
) -> tuple[int, np.ndarray, np.ndarray]:
|
| 370 |
+
"""
|
| 371 |
+
Optimized vectorized reservoir sampling with batch probability calculations.
|
| 372 |
+
|
| 373 |
+
This maintains mathematical correctness while improving performance through
|
| 374 |
+
vectorized operations where possible.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
new_points: New point coordinates to consider, shape (M, 3)
|
| 378 |
+
new_colors: New point colors to consider, shape (M, 3)
|
| 379 |
+
current_count: Number of elements seen so far
|
| 380 |
+
reservoir_points: Current reservoir of sampled points, shape (K, 3)
|
| 381 |
+
reservoir_colors: Current reservoir of sampled colors, shape (K, 3)
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Tuple of (updated_count, updated_reservoir_points, updated_reservoir_colors)
|
| 385 |
+
"""
|
| 386 |
+
random_gen = np.random
|
| 387 |
+
|
| 388 |
+
reservoir_size = len(reservoir_points)
|
| 389 |
+
num_new_points = len(new_points)
|
| 390 |
+
|
| 391 |
+
if num_new_points == 0:
|
| 392 |
+
return current_count, reservoir_points, reservoir_colors
|
| 393 |
+
|
| 394 |
+
# Calculate sequential indices for each new point
|
| 395 |
+
point_indices = np.arange(current_count + 1, current_count + num_new_points + 1)
|
| 396 |
+
|
| 397 |
+
# Generate random numbers for each point
|
| 398 |
+
random_values = random_gen.randint(0, point_indices, size=num_new_points)
|
| 399 |
+
|
| 400 |
+
# Determine which points should replace reservoir elements
|
| 401 |
+
replacement_mask = random_values < reservoir_size
|
| 402 |
+
replacement_positions = random_values[replacement_mask]
|
| 403 |
+
|
| 404 |
+
# Apply replacements
|
| 405 |
+
if np.any(replacement_mask):
|
| 406 |
+
points_to_replace = new_points[replacement_mask]
|
| 407 |
+
colors_to_replace = new_colors[replacement_mask]
|
| 408 |
+
|
| 409 |
+
reservoir_points[replacement_positions] = points_to_replace
|
| 410 |
+
reservoir_colors[replacement_positions] = colors_to_replace
|
| 411 |
+
|
| 412 |
+
return current_count + num_new_points, reservoir_points, reservoir_colors
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def write_ply_header(f, num_vertices):
|
| 416 |
+
header = [
|
| 417 |
+
"ply",
|
| 418 |
+
"format binary_little_endian 1.0",
|
| 419 |
+
f"element vertex {num_vertices}",
|
| 420 |
+
"property float x",
|
| 421 |
+
"property float y",
|
| 422 |
+
"property float z",
|
| 423 |
+
"property uchar red",
|
| 424 |
+
"property uchar green",
|
| 425 |
+
"property uchar blue",
|
| 426 |
+
"end_header",
|
| 427 |
+
]
|
| 428 |
+
f.write("\n".join(header).encode() + b"\n")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def write_ply_batch(f, points, colors):
|
| 432 |
+
structured = np.zeros(
|
| 433 |
+
len(points),
|
| 434 |
+
dtype=[
|
| 435 |
+
("x", np.float32),
|
| 436 |
+
("y", np.float32),
|
| 437 |
+
("z", np.float32),
|
| 438 |
+
("red", np.uint8),
|
| 439 |
+
("green", np.uint8),
|
| 440 |
+
("blue", np.uint8),
|
| 441 |
+
],
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
structured["x"] = points[:, 0]
|
| 445 |
+
structured["y"] = points[:, 1]
|
| 446 |
+
structured["z"] = points[:, 2]
|
| 447 |
+
structured["red"] = colors[:, 0]
|
| 448 |
+
structured["green"] = colors[:, 1]
|
| 449 |
+
structured["blue"] = colors[:, 2]
|
| 450 |
+
|
| 451 |
+
f.write(structured.tobytes())
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def save_ply(points, colors, filename):
|
| 455 |
+
with open(filename, "wb") as f:
|
| 456 |
+
write_ply_header(f, len(points))
|
| 457 |
+
write_ply_batch(f, points, colors)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def find_chunk_index(chunks, idx):
|
| 461 |
+
"""
|
| 462 |
+
Find the 0-based chunk index that contains the given index idx.
|
| 463 |
+
chunks: List of (begin_idx, end_idx).
|
| 464 |
+
idx: The index to search for.
|
| 465 |
+
Returns the 0-based chunk index.
|
| 466 |
+
"""
|
| 467 |
+
starts = [chunk[0] for chunk in chunks]
|
| 468 |
+
pos = bisect.bisect_right(starts, idx) - 1 # Find position of idx in starts
|
| 469 |
+
if pos < 0 or pos >= len(chunks):
|
| 470 |
+
raise ValueError(f"Index {idx} not found in any chunk")
|
| 471 |
+
chunk_begin, chunk_end = chunks[pos]
|
| 472 |
+
if idx < chunk_begin or idx > chunk_end:
|
| 473 |
+
raise ValueError(f"Index {idx} not found in any chunk")
|
| 474 |
+
return pos
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def get_frame_range(chunk, idx, half_window=10):
|
| 478 |
+
"""
|
| 479 |
+
Calculate the frame range centered at idx with half_window
|
| 480 |
+
frames on each side within chunk boundaries.
|
| 481 |
+
If near boundaries, take 2 * half_window frames starting from the boundary.
|
| 482 |
+
chunk: (begin_idx, end_idx).
|
| 483 |
+
idx: Center index.
|
| 484 |
+
half_window: Number of frames to take on each side of center index.
|
| 485 |
+
Returns (start, end).
|
| 486 |
+
"""
|
| 487 |
+
begin, end = chunk
|
| 488 |
+
window_size = 2 * half_window
|
| 489 |
+
|
| 490 |
+
if idx - half_window < begin:
|
| 491 |
+
start = begin
|
| 492 |
+
end_candidate = begin + window_size
|
| 493 |
+
end = min(end, end_candidate)
|
| 494 |
+
|
| 495 |
+
elif idx + half_window > end:
|
| 496 |
+
end_candidate = end
|
| 497 |
+
start_candidate = end - window_size
|
| 498 |
+
start = max(begin, start_candidate)
|
| 499 |
+
|
| 500 |
+
else:
|
| 501 |
+
start = idx - half_window
|
| 502 |
+
end = idx + half_window
|
| 503 |
+
return (start, end)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def process_loop_list(chunk_index, loop_list, half_window=10):
|
| 507 |
+
"""
|
| 508 |
+
Process loop_list and return chunk indices and frame ranges for each (idx1, idx2) pair.
|
| 509 |
+
chunk_index: List of (begin_idx, end_idx) tuples.
|
| 510 |
+
loop_list: List of (idx1, idx2) tuples.
|
| 511 |
+
half_window: Number of frames to take on each side of center index (default 10).
|
| 512 |
+
Returns list of (chunk_idx1, range1, chunk_idx2, range2) tuples where:
|
| 513 |
+
- chunk_idx1, chunk_idx2: Chunk indices (1-based).
|
| 514 |
+
- range1, range2: Frame range tuples (start, end).
|
| 515 |
+
"""
|
| 516 |
+
results = []
|
| 517 |
+
for idx1, idx2 in loop_list:
|
| 518 |
+
try:
|
| 519 |
+
chunk_idx1_0based = find_chunk_index(chunk_index, idx1)
|
| 520 |
+
chunk1 = chunk_index[chunk_idx1_0based]
|
| 521 |
+
range1 = get_frame_range(chunk1, idx1, half_window)
|
| 522 |
+
|
| 523 |
+
chunk_idx2_0based = find_chunk_index(chunk_index, idx2)
|
| 524 |
+
chunk2 = chunk_index[chunk_idx2_0based]
|
| 525 |
+
range2 = get_frame_range(chunk2, idx2, half_window)
|
| 526 |
+
|
| 527 |
+
result = (chunk_idx1_0based, range1, chunk_idx2_0based, range2)
|
| 528 |
+
results.append(result)
|
| 529 |
+
except ValueError as e:
|
| 530 |
+
print(f"Skipping pair ({idx1}, {idx2}): {e}")
|
| 531 |
+
return results
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def compute_sim3_ab(S_a, S_b):
|
| 535 |
+
|
| 536 |
+
s_a, R_a, T_a = S_a
|
| 537 |
+
s_b, R_b, T_b = S_b
|
| 538 |
+
|
| 539 |
+
s_ab = s_b / s_a
|
| 540 |
+
R_ab = R_b @ R_a.T
|
| 541 |
+
T_ab = T_b - s_ab * (R_ab @ T_a)
|
| 542 |
+
|
| 543 |
+
return (s_ab, R_ab, T_ab)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def merge_ply_files(input_dir, output_path):
|
| 547 |
+
"""
|
| 548 |
+
Merge all PLY files in a directory into one file (without loading into memory)
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
- input_dir: Input directory containing multiple '{idx}_pcd.ply' files
|
| 552 |
+
- output_path: Output file path (e.g., 'combined.ply')
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
print("Merging PLY files...")
|
| 556 |
+
|
| 557 |
+
input_files = sorted(glob.glob(os.path.join(input_dir, "*_pcd.ply")))
|
| 558 |
+
|
| 559 |
+
if not input_files:
|
| 560 |
+
print("No PLY files found")
|
| 561 |
+
return
|
| 562 |
+
|
| 563 |
+
idx_file = 0
|
| 564 |
+
len(input_files)
|
| 565 |
+
|
| 566 |
+
total_vertices = 0
|
| 567 |
+
for file in input_files: # Count total vertices
|
| 568 |
+
with open(file, "rb") as f:
|
| 569 |
+
for line in f:
|
| 570 |
+
if line.startswith(b"element vertex"):
|
| 571 |
+
vertex_count = int(line.split()[-1])
|
| 572 |
+
total_vertices += vertex_count
|
| 573 |
+
elif line.startswith(b"end_header"):
|
| 574 |
+
break
|
| 575 |
+
|
| 576 |
+
with open(output_path, "wb") as out_f:
|
| 577 |
+
# Write new header
|
| 578 |
+
out_f.write(b"ply\n")
|
| 579 |
+
out_f.write(b"format binary_little_endian 1.0\n")
|
| 580 |
+
out_f.write(f"element vertex {total_vertices}\n".encode())
|
| 581 |
+
out_f.write(b"property float x\n")
|
| 582 |
+
out_f.write(b"property float y\n")
|
| 583 |
+
out_f.write(b"property float z\n")
|
| 584 |
+
out_f.write(b"property uchar red\n")
|
| 585 |
+
out_f.write(b"property uchar green\n")
|
| 586 |
+
out_f.write(b"property uchar blue\n")
|
| 587 |
+
out_f.write(b"end_header\n")
|
| 588 |
+
|
| 589 |
+
for file in input_files:
|
| 590 |
+
print(f"Processing {idx_file}/{len(input_files)}: {file}")
|
| 591 |
+
idx_file += 1
|
| 592 |
+
with open(file, "rb") as in_f:
|
| 593 |
+
# Skip the head
|
| 594 |
+
in_header = True
|
| 595 |
+
while in_header:
|
| 596 |
+
line = in_f.readline()
|
| 597 |
+
if line.startswith(b"end_header"):
|
| 598 |
+
in_header = False
|
| 599 |
+
data = in_f.read()
|
| 600 |
+
out_f.write(data)
|
| 601 |
+
|
| 602 |
+
print(f"Merge completed! Total points: {total_vertices}")
|
| 603 |
+
print(f"Output file: {output_path}")
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def weighted_estimate_se3(source_points, target_points, weights):
|
| 607 |
+
"""
|
| 608 |
+
source_points: (Nx3)
|
| 609 |
+
target_points: (Nx3)
|
| 610 |
+
:weights: (N,) [0,1]
|
| 611 |
+
"""
|
| 612 |
+
total_weight = np.sum(weights)
|
| 613 |
+
if total_weight < 1e-6:
|
| 614 |
+
raise ValueError("Total weight too small for meaningful estimation")
|
| 615 |
+
|
| 616 |
+
normalized_weights = weights / total_weight
|
| 617 |
+
|
| 618 |
+
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
|
| 619 |
+
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
|
| 620 |
+
|
| 621 |
+
src_centered = source_points - mu_src
|
| 622 |
+
tgt_centered = target_points - mu_tgt
|
| 623 |
+
|
| 624 |
+
weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
|
| 625 |
+
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
|
| 626 |
+
|
| 627 |
+
H = weighted_src.T @ weighted_tgt
|
| 628 |
+
|
| 629 |
+
U, _, Vt = np.linalg.svd(H)
|
| 630 |
+
R = Vt.T @ U.T
|
| 631 |
+
|
| 632 |
+
if np.linalg.det(R) < 0:
|
| 633 |
+
Vt[2, :] *= -1
|
| 634 |
+
R = Vt.T @ U.T
|
| 635 |
+
|
| 636 |
+
t = mu_tgt - R @ mu_src
|
| 637 |
+
|
| 638 |
+
return 1.0, R, t
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def weighted_estimate_sim3(source_points, target_points, weights):
|
| 642 |
+
"""
|
| 643 |
+
source_points: (Nx3)
|
| 644 |
+
target_points: (Nx3)
|
| 645 |
+
:weights: (N,) [0,1]
|
| 646 |
+
"""
|
| 647 |
+
total_weight = np.sum(weights)
|
| 648 |
+
if total_weight < 1e-6:
|
| 649 |
+
raise ValueError("Total weight too small for meaningful estimation")
|
| 650 |
+
|
| 651 |
+
normalized_weights = weights / total_weight
|
| 652 |
+
|
| 653 |
+
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
|
| 654 |
+
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
|
| 655 |
+
|
| 656 |
+
src_centered = source_points - mu_src
|
| 657 |
+
tgt_centered = target_points - mu_tgt
|
| 658 |
+
|
| 659 |
+
scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
|
| 660 |
+
scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
|
| 661 |
+
s = scale_tgt / scale_src
|
| 662 |
+
|
| 663 |
+
weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
|
| 664 |
+
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
|
| 665 |
+
|
| 666 |
+
H = weighted_src.T @ weighted_tgt
|
| 667 |
+
|
| 668 |
+
U, _, Vt = np.linalg.svd(H)
|
| 669 |
+
R = Vt.T @ U.T
|
| 670 |
+
|
| 671 |
+
if np.linalg.det(R) < 0:
|
| 672 |
+
Vt[2, :] *= -1
|
| 673 |
+
R = Vt.T @ U.T
|
| 674 |
+
|
| 675 |
+
t = mu_tgt - s * R @ mu_src
|
| 676 |
+
return s, R, t
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def huber_loss(r, delta):
|
| 680 |
+
abs_r = np.abs(r)
|
| 681 |
+
return np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def robust_weighted_estimate_sim3(
|
| 685 |
+
src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
|
| 686 |
+
):
|
| 687 |
+
"""
|
| 688 |
+
src: (Nx3)
|
| 689 |
+
tgt: (Nx3)
|
| 690 |
+
init_weights: (N,)
|
| 691 |
+
"""
|
| 692 |
+
if align_method == "sim3":
|
| 693 |
+
s, R, t = weighted_estimate_sim3(src, tgt, init_weights)
|
| 694 |
+
elif align_method == "se3" or align_method == "scale+se3":
|
| 695 |
+
s, R, t = weighted_estimate_se3(src, tgt, init_weights)
|
| 696 |
+
|
| 697 |
+
prev_error = float("inf")
|
| 698 |
+
|
| 699 |
+
for iter in range(max_iters):
|
| 700 |
+
|
| 701 |
+
transformed = s * (src @ R.T) + t
|
| 702 |
+
residuals = np.linalg.norm(tgt - transformed, axis=1) # (N,)
|
| 703 |
+
print(f"Residuals: {np.mean(residuals)}")
|
| 704 |
+
|
| 705 |
+
abs_res = np.abs(residuals)
|
| 706 |
+
huber_weights = np.ones_like(residuals)
|
| 707 |
+
large_res_mask = abs_res > delta
|
| 708 |
+
huber_weights[large_res_mask] = delta / abs_res[large_res_mask]
|
| 709 |
+
|
| 710 |
+
combined_weights = init_weights * huber_weights
|
| 711 |
+
|
| 712 |
+
combined_weights /= np.sum(combined_weights) + 1e-12
|
| 713 |
+
|
| 714 |
+
if align_method == "se3":
|
| 715 |
+
s_new, R_new, t_new = weighted_estimate_se3(src, tgt, combined_weights)
|
| 716 |
+
elif align_method == "sim3" or align_method == "scale+se3":
|
| 717 |
+
s_new, R_new, t_new = weighted_estimate_sim3(src, tgt, combined_weights)
|
| 718 |
+
|
| 719 |
+
param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
|
| 720 |
+
rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
|
| 721 |
+
current_error = np.sum(huber_loss(residuals, delta) * init_weights)
|
| 722 |
+
|
| 723 |
+
if (param_change < tol and rot_angle < np.radians(0.1)) or (
|
| 724 |
+
abs(prev_error - current_error) < tol * prev_error
|
| 725 |
+
):
|
| 726 |
+
break
|
| 727 |
+
|
| 728 |
+
s, R, t = s_new, R_new, t_new
|
| 729 |
+
prev_error = current_error
|
| 730 |
+
|
| 731 |
+
return s, R, t
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
# ===== Speed Up Begin =====
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
@njit(cache=True)
|
| 738 |
+
def _weighted_estimate_se3_numba(source_points, target_points, weights):
|
| 739 |
+
# Ensure float32
|
| 740 |
+
source_points = source_points.astype(np.float32)
|
| 741 |
+
target_points = target_points.astype(np.float32)
|
| 742 |
+
weights = weights.astype(np.float32)
|
| 743 |
+
|
| 744 |
+
total_weight = np.sum(weights)
|
| 745 |
+
if total_weight < 1e-6:
|
| 746 |
+
return (
|
| 747 |
+
1.0,
|
| 748 |
+
np.zeros(3, dtype=np.float32),
|
| 749 |
+
np.zeros(3, dtype=np.float32),
|
| 750 |
+
np.zeros((3, 3), dtype=np.float32),
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
normalized_weights = weights / total_weight
|
| 754 |
+
|
| 755 |
+
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
|
| 756 |
+
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
|
| 757 |
+
|
| 758 |
+
src_centered = source_points - mu_src
|
| 759 |
+
tgt_centered = target_points - mu_tgt
|
| 760 |
+
|
| 761 |
+
weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
|
| 762 |
+
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
|
| 763 |
+
|
| 764 |
+
H = weighted_src.T @ weighted_tgt
|
| 765 |
+
|
| 766 |
+
return 1.0, mu_src, mu_tgt, H
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
@njit(cache=True)
|
| 770 |
+
def _weighted_estimate_sim3_numba(source_points, target_points, weights):
|
| 771 |
+
# Ensure float32
|
| 772 |
+
source_points = source_points.astype(np.float32)
|
| 773 |
+
target_points = target_points.astype(np.float32)
|
| 774 |
+
weights = weights.astype(np.float32)
|
| 775 |
+
|
| 776 |
+
total_weight = np.sum(weights)
|
| 777 |
+
if total_weight < 1e-6:
|
| 778 |
+
return (
|
| 779 |
+
-1.0,
|
| 780 |
+
np.zeros(3, dtype=np.float32),
|
| 781 |
+
np.zeros(3, dtype=np.float32),
|
| 782 |
+
np.zeros((3, 3), dtype=np.float32),
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
normalized_weights = weights / total_weight
|
| 786 |
+
|
| 787 |
+
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
|
| 788 |
+
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
|
| 789 |
+
|
| 790 |
+
src_centered = source_points - mu_src
|
| 791 |
+
tgt_centered = target_points - mu_tgt
|
| 792 |
+
|
| 793 |
+
scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
|
| 794 |
+
scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
|
| 795 |
+
s = scale_tgt / scale_src
|
| 796 |
+
|
| 797 |
+
weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
|
| 798 |
+
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
|
| 799 |
+
|
| 800 |
+
H = weighted_src.T @ weighted_tgt
|
| 801 |
+
|
| 802 |
+
return s, mu_src, mu_tgt, H
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def weighted_estimate_sim3_numba(source_points, target_points, weights, align_method="sim3"):
|
| 806 |
+
if align_method == "sim3":
|
| 807 |
+
s, mu_src, mu_tgt, H = _weighted_estimate_sim3_numba(source_points, target_points, weights)
|
| 808 |
+
elif align_method == "se3" or align_method == "scale+se3":
|
| 809 |
+
s, mu_src, mu_tgt, H = _weighted_estimate_se3_numba(source_points, target_points, weights)
|
| 810 |
+
|
| 811 |
+
if s < 0:
|
| 812 |
+
raise ValueError("Total weight too small for meaningful estimation")
|
| 813 |
+
|
| 814 |
+
# Ensure float32
|
| 815 |
+
H = H.astype(np.float32)
|
| 816 |
+
U, _, Vt = np.linalg.svd(H.astype(np.float32)) # float32 SVD
|
| 817 |
+
|
| 818 |
+
R = Vt.T @ U.T
|
| 819 |
+
if np.linalg.det(R) < 0:
|
| 820 |
+
Vt[2, :] *= -1
|
| 821 |
+
R = Vt.T @ U.T
|
| 822 |
+
|
| 823 |
+
if align_method == "se3" or align_method == "scale+se3":
|
| 824 |
+
t = mu_tgt - R @ mu_src
|
| 825 |
+
else:
|
| 826 |
+
t = mu_tgt - s * R @ mu_src
|
| 827 |
+
|
| 828 |
+
return s, R, t
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
@njit(cache=True)
|
| 832 |
+
def huber_loss_numba(r, delta):
|
| 833 |
+
r = r.astype(np.float32)
|
| 834 |
+
delta = np.float32(delta)
|
| 835 |
+
abs_r = np.abs(r)
|
| 836 |
+
result = np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
|
| 837 |
+
return result.astype(np.float32)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
@njit(cache=True)
|
| 841 |
+
def compute_residuals_numba(tgt, transformed):
|
| 842 |
+
residuals = np.empty(tgt.shape[0], dtype=np.float32)
|
| 843 |
+
for i in range(tgt.shape[0]):
|
| 844 |
+
diff = tgt[i] - transformed[i]
|
| 845 |
+
residuals[i] = np.sqrt(np.sum(diff**2))
|
| 846 |
+
return residuals
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
@njit(cache=True)
|
| 850 |
+
def compute_huber_weights_numba(residuals, delta):
|
| 851 |
+
weights = np.ones(residuals.shape, dtype=np.float32)
|
| 852 |
+
for i in range(residuals.shape[0]):
|
| 853 |
+
r = residuals[i]
|
| 854 |
+
if r > delta:
|
| 855 |
+
weights[i] = delta / r
|
| 856 |
+
return weights
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
@njit(cache=True)
|
| 860 |
+
def apply_transformation_numba(src, s, R, t):
|
| 861 |
+
transformed = np.empty_like(src)
|
| 862 |
+
for i in range(src.shape[0]):
|
| 863 |
+
p = src[i]
|
| 864 |
+
transformed[i] = s * (R @ p) + t
|
| 865 |
+
return transformed
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def robust_weighted_estimate_sim3_numba(
|
| 869 |
+
src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
|
| 870 |
+
):
|
| 871 |
+
src = src.astype(np.float32)
|
| 872 |
+
tgt = tgt.astype(np.float32)
|
| 873 |
+
init_weights = init_weights.astype(np.float32)
|
| 874 |
+
|
| 875 |
+
s, R, t = weighted_estimate_sim3_numba(src, tgt, init_weights, align_method=align_method)
|
| 876 |
+
|
| 877 |
+
prev_error = float("inf")
|
| 878 |
+
|
| 879 |
+
for iter in range(max_iters):
|
| 880 |
+
transformed = apply_transformation_numba(src, s, R, t)
|
| 881 |
+
residuals = compute_residuals_numba(tgt, transformed)
|
| 882 |
+
|
| 883 |
+
print(f"Residuals: {np.mean(residuals)}")
|
| 884 |
+
|
| 885 |
+
huber_weights = compute_huber_weights_numba(residuals, delta)
|
| 886 |
+
combined_weights = init_weights * huber_weights
|
| 887 |
+
combined_weights /= np.sum(combined_weights) + 1e-12
|
| 888 |
+
|
| 889 |
+
s_new, R_new, t_new = weighted_estimate_sim3_numba(
|
| 890 |
+
src, tgt, combined_weights, align_method=align_method
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
|
| 894 |
+
rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
|
| 895 |
+
|
| 896 |
+
current_error = np.sum(huber_loss_numba(residuals, delta) * init_weights)
|
| 897 |
+
|
| 898 |
+
if (param_change < tol and rot_angle < np.radians(0.1)) or (
|
| 899 |
+
abs(prev_error - current_error) < tol * prev_error
|
| 900 |
+
):
|
| 901 |
+
break
|
| 902 |
+
|
| 903 |
+
s, R, t = s_new, R_new, t_new
|
| 904 |
+
prev_error = current_error
|
| 905 |
+
|
| 906 |
+
return s, R, t
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def warmup_numba():
|
| 910 |
+
|
| 911 |
+
print("\nWarming up Numba JIT-compiled functions...")
|
| 912 |
+
|
| 913 |
+
src = np.random.randn(50000, 3).astype(np.float32)
|
| 914 |
+
tgt = np.random.randn(50000, 3).astype(np.float32)
|
| 915 |
+
weights = np.ones(50000, dtype=np.float32)
|
| 916 |
+
residuals = np.abs(np.random.randn(50000).astype(np.float32))
|
| 917 |
+
R = np.eye(3, dtype=np.float32)
|
| 918 |
+
t = np.zeros(3, dtype=np.float32)
|
| 919 |
+
s = np.float32(1.0)
|
| 920 |
+
delta = np.float32(1.0)
|
| 921 |
+
|
| 922 |
+
try:
|
| 923 |
+
_ = _weighted_estimate_sim3_numba(src, tgt, weights)
|
| 924 |
+
print(" - _weighted_estimate_sim3_numba warmed up.")
|
| 925 |
+
except Exception as e:
|
| 926 |
+
print(" ! Failed to warm up _weighted_estimate_sim3_numba:", e)
|
| 927 |
+
|
| 928 |
+
try:
|
| 929 |
+
_ = _weighted_estimate_se3_numba(src, tgt, weights)
|
| 930 |
+
print(" - _weighted_estimate_se3_numba warmed up.")
|
| 931 |
+
except Exception as e:
|
| 932 |
+
print(" ! Failed to warm up _weighted_estimate_se3_numba:", e)
|
| 933 |
+
|
| 934 |
+
try:
|
| 935 |
+
_ = huber_loss_numba(residuals, delta)
|
| 936 |
+
print(" - huber_loss_numba warmed up.")
|
| 937 |
+
except Exception as e:
|
| 938 |
+
print(" ! Failed to warm up huber_loss_numba:", e)
|
| 939 |
+
|
| 940 |
+
try:
|
| 941 |
+
_ = compute_huber_weights_numba(residuals, delta)
|
| 942 |
+
print(" - compute_huber_weights_numba warmed up.")
|
| 943 |
+
except Exception as e:
|
| 944 |
+
print(" ! Failed to warm up compute_huber_weights_numba:", e)
|
| 945 |
+
|
| 946 |
+
try:
|
| 947 |
+
_ = compute_residuals_numba(tgt, src)
|
| 948 |
+
print(" - compute_residuals_numba warmed up.")
|
| 949 |
+
except Exception as e:
|
| 950 |
+
print(" ! Failed to warm up compute_residuals_numba:", e)
|
| 951 |
+
|
| 952 |
+
try:
|
| 953 |
+
_ = apply_transformation_numba(src, s, R, t)
|
| 954 |
+
print(" - apply_transformation_numba warmed up.")
|
| 955 |
+
except Exception as e:
|
| 956 |
+
print(" ! Failed to warm up apply_transformation_numba:", e)
|
| 957 |
+
|
| 958 |
+
print("Numba warm-up complete.\n")
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
# ===== Speed Up End =====
|
| 962 |
+
|
| 963 |
+
# ===== Scale precompute begin =====
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def compute_scale_ransac(
|
| 967 |
+
depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, max_samples=10000
|
| 968 |
+
):
|
| 969 |
+
"""
|
| 970 |
+
Args:
|
| 971 |
+
depth1: (n1, h, w)
|
| 972 |
+
depth2: (n2, h, w)
|
| 973 |
+
conf1: (n1, h, w)
|
| 974 |
+
conf2: (n2, h, w)
|
| 975 |
+
|
| 976 |
+
"""
|
| 977 |
+
|
| 978 |
+
depth1_flat = depth1.reshape(-1)
|
| 979 |
+
depth2_flat = depth2.reshape(-1)
|
| 980 |
+
conf1_flat = conf1.reshape(-1)
|
| 981 |
+
conf2_flat = conf2.reshape(-1)
|
| 982 |
+
|
| 983 |
+
conf_threshold = max(
|
| 984 |
+
np.median(conf1_flat) * conf_threshold_ratio,
|
| 985 |
+
np.median(conf2_flat) * conf_threshold_ratio,
|
| 986 |
+
1e-6,
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
valid_mask = (
|
| 990 |
+
(conf1_flat > conf_threshold)
|
| 991 |
+
& (conf2_flat > conf_threshold)
|
| 992 |
+
& (depth1_flat > 1e-3)
|
| 993 |
+
& (depth2_flat > 1e-3)
|
| 994 |
+
& (depth1_flat < 100)
|
| 995 |
+
& (depth2_flat < 100)
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
if np.sum(valid_mask) < 100:
|
| 999 |
+
print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
|
| 1000 |
+
return 1.0, 0.0
|
| 1001 |
+
|
| 1002 |
+
valid_depth1 = depth1_flat[valid_mask]
|
| 1003 |
+
valid_depth2 = depth2_flat[valid_mask]
|
| 1004 |
+
|
| 1005 |
+
if len(valid_depth1) > max_samples:
|
| 1006 |
+
indices = np.random.choice(len(valid_depth1), max_samples, replace=False)
|
| 1007 |
+
valid_depth1 = valid_depth1[indices]
|
| 1008 |
+
valid_depth2 = valid_depth2[indices]
|
| 1009 |
+
|
| 1010 |
+
X = valid_depth2.reshape(-1, 1)
|
| 1011 |
+
y = valid_depth1
|
| 1012 |
+
|
| 1013 |
+
base_estimator = LinearRegression(fit_intercept=False)
|
| 1014 |
+
ransac = RANSACRegressor(
|
| 1015 |
+
estimator=base_estimator,
|
| 1016 |
+
max_trials=1000,
|
| 1017 |
+
min_samples=max(10, len(X) // 100),
|
| 1018 |
+
residual_threshold=0.1,
|
| 1019 |
+
random_state=42,
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
ransac.fit(X, y)
|
| 1023 |
+
scale_factor = ransac.estimator_.coef_[0]
|
| 1024 |
+
inlier_mask = ransac.inlier_mask_
|
| 1025 |
+
inlier_ratio = np.sum(inlier_mask) / len(inlier_mask)
|
| 1026 |
+
|
| 1027 |
+
print(f"RANSAC scale: {scale_factor:.6f}, inlier ratio: {inlier_ratio:.4f}")
|
| 1028 |
+
|
| 1029 |
+
if 0.1 < scale_factor < 10.0:
|
| 1030 |
+
return scale_factor, inlier_ratio
|
| 1031 |
+
else:
|
| 1032 |
+
print(f"Warning: Unreasonable scale {scale_factor}, using 1.0")
|
| 1033 |
+
return 1.0, inlier_ratio
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def compute_scale_weighted(
|
| 1037 |
+
depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, weight_power=2.0, robust_quantile=0.9
|
| 1038 |
+
):
|
| 1039 |
+
"""
|
| 1040 |
+
Args:
|
| 1041 |
+
depth1: (n1, h, w)
|
| 1042 |
+
depth2: (n2, h, w)
|
| 1043 |
+
conf1: (n1, h, w)
|
| 1044 |
+
conf2: (n2, h, w)
|
| 1045 |
+
"""
|
| 1046 |
+
depth1_flat = depth1.reshape(-1)
|
| 1047 |
+
depth2_flat = depth2.reshape(-1)
|
| 1048 |
+
conf1_flat = conf1.reshape(-1)
|
| 1049 |
+
conf2_flat = conf2.reshape(-1)
|
| 1050 |
+
|
| 1051 |
+
conf_threshold = max(
|
| 1052 |
+
np.median(conf1_flat) * conf_threshold_ratio,
|
| 1053 |
+
np.median(conf2_flat) * conf_threshold_ratio,
|
| 1054 |
+
1e-6,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
valid_mask = (
|
| 1058 |
+
(conf1_flat > conf_threshold)
|
| 1059 |
+
& (conf2_flat > conf_threshold)
|
| 1060 |
+
& (depth1_flat > 1e-3)
|
| 1061 |
+
& (depth2_flat > 1e-3)
|
| 1062 |
+
& (depth1_flat < 100)
|
| 1063 |
+
& (depth2_flat < 100)
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
if np.sum(valid_mask) < 100:
|
| 1067 |
+
print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
|
| 1068 |
+
return 1.0, 0.0
|
| 1069 |
+
|
| 1070 |
+
valid_depth1 = depth1_flat[valid_mask]
|
| 1071 |
+
valid_depth2 = depth2_flat[valid_mask]
|
| 1072 |
+
valid_conf1 = conf1_flat[valid_mask]
|
| 1073 |
+
valid_conf2 = conf2_flat[valid_mask]
|
| 1074 |
+
|
| 1075 |
+
combined_weights = (valid_conf1 * valid_conf2) ** weight_power
|
| 1076 |
+
|
| 1077 |
+
combined_weights = combined_weights / (np.sum(combined_weights) + 1e-8)
|
| 1078 |
+
|
| 1079 |
+
ratios = valid_depth1 / (valid_depth2 + 1e-8)
|
| 1080 |
+
|
| 1081 |
+
sorted_indices = np.argsort(ratios)
|
| 1082 |
+
sorted_ratios = ratios[sorted_indices]
|
| 1083 |
+
sorted_weights = combined_weights[sorted_indices]
|
| 1084 |
+
|
| 1085 |
+
cumulative_weights = np.cumsum(sorted_weights)
|
| 1086 |
+
median_idx = np.searchsorted(cumulative_weights, 0.5)
|
| 1087 |
+
scale_median = sorted_ratios[median_idx] if median_idx < len(sorted_ratios) else 1.0
|
| 1088 |
+
|
| 1089 |
+
quantile_idx = np.searchsorted(cumulative_weights, robust_quantile)
|
| 1090 |
+
scale_quantile = (
|
| 1091 |
+
sorted_ratios[quantile_idx] if quantile_idx < len(sorted_ratios) else scale_median
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
weight_entropy = -np.sum(combined_weights * np.log(combined_weights + 1e-8))
|
| 1095 |
+
max_entropy = np.log(len(combined_weights))
|
| 1096 |
+
confidence_score = 1.0 - (weight_entropy / max_entropy) if max_entropy > 0 else 0.0
|
| 1097 |
+
|
| 1098 |
+
print(f"Weighted scale: {scale_quantile:.6f}, confidence: {confidence_score:.4f}")
|
| 1099 |
+
|
| 1100 |
+
if 0.1 < scale_quantile < 10.0:
|
| 1101 |
+
return scale_quantile, confidence_score
|
| 1102 |
+
else:
|
| 1103 |
+
print(f"Warning: Unreasonable scale {scale_quantile}, using 1.0")
|
| 1104 |
+
return 1.0, confidence_score
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
def compute_chunk_scale_advanced(depth1, depth2, conf1, conf2, method="auto"):
|
| 1108 |
+
"""
|
| 1109 |
+
method: 'auto', 'ransac', 'weighted'
|
| 1110 |
+
"""
|
| 1111 |
+
if method == "ransac":
|
| 1112 |
+
scale, score = compute_scale_ransac(depth1, depth2, conf1, conf2)
|
| 1113 |
+
return scale, score, "ransac"
|
| 1114 |
+
|
| 1115 |
+
elif method == "weighted":
|
| 1116 |
+
scale, score = compute_scale_weighted(depth1, depth2, conf1, conf2)
|
| 1117 |
+
return scale, score, "weighted"
|
| 1118 |
+
|
| 1119 |
+
elif method == "auto":
|
| 1120 |
+
scale_ransac, inlier_ratio = compute_scale_ransac(depth1, depth2, conf1, conf2)
|
| 1121 |
+
scale_weighted, conf_score = compute_scale_weighted(depth1, depth2, conf1, conf2)
|
| 1122 |
+
|
| 1123 |
+
ransac_quality = inlier_ratio
|
| 1124 |
+
weighted_quality = conf_score
|
| 1125 |
+
|
| 1126 |
+
print(f"RANSAC quality: {ransac_quality:.4f}, Weighted quality: {weighted_quality:.4f}")
|
| 1127 |
+
|
| 1128 |
+
if ransac_quality > 0.7 and weighted_quality > 0.7:
|
| 1129 |
+
# both method are good, we take both of them by average
|
| 1130 |
+
final_scale = (scale_ransac + scale_weighted) / 2
|
| 1131 |
+
final_method = "average"
|
| 1132 |
+
elif ransac_quality > weighted_quality:
|
| 1133 |
+
final_scale = scale_ransac
|
| 1134 |
+
final_method = "ransac"
|
| 1135 |
+
else:
|
| 1136 |
+
final_scale = scale_weighted
|
| 1137 |
+
final_method = "weighted"
|
| 1138 |
+
|
| 1139 |
+
final_quality = max(ransac_quality, weighted_quality)
|
| 1140 |
+
return final_scale, final_quality, final_method
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
def precompute_scale_chunks_with_depth(
|
| 1144 |
+
chunk1_depth, chunk1_conf, chunk2_depth, chunk2_conf, method="auto"
|
| 1145 |
+
):
|
| 1146 |
+
"""
|
| 1147 |
+
Args:
|
| 1148 |
+
chunk1_depth: (n1, h, w)
|
| 1149 |
+
chunk1_conf: (n1, h, w)
|
| 1150 |
+
chunk2_depth: (n2, h, w)
|
| 1151 |
+
chunk2_conf: (n2, h, w)
|
| 1152 |
+
method: 'auto', 'ransac', 'weighted'
|
| 1153 |
+
"""
|
| 1154 |
+
|
| 1155 |
+
scale_factor, quality_score, method_used = compute_chunk_scale_advanced(
|
| 1156 |
+
chunk1_depth, chunk2_depth, chunk1_conf, chunk2_conf, method
|
| 1157 |
+
)
|
| 1158 |
+
|
| 1159 |
+
print(f"Final scale: {scale_factor:.6f}, quality: {quality_score:.4f}, method: {method_used}")
|
| 1160 |
+
|
| 1161 |
+
return scale_factor, quality_score, method_used
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
# ===== Scale precompute end =====
|
| 1165 |
+
|
| 1166 |
+
|
| 1167 |
+
def weighted_align_point_maps(
|
| 1168 |
+
point_map1, conf1, point_map2, conf2, conf_threshold, config, precompute_scale=None
|
| 1169 |
+
):
|
| 1170 |
+
"""point_map2 -> point_map1"""
|
| 1171 |
+
b1, _, _, _ = point_map1.shape
|
| 1172 |
+
b2, _, _, _ = point_map2.shape
|
| 1173 |
+
b = min(b1, b2)
|
| 1174 |
+
|
| 1175 |
+
if precompute_scale is not None: # meaning we are using align method 'scale+se3'
|
| 1176 |
+
point_map2 *= precompute_scale
|
| 1177 |
+
|
| 1178 |
+
aligned_points1 = []
|
| 1179 |
+
aligned_points2 = []
|
| 1180 |
+
confidence_weights = []
|
| 1181 |
+
|
| 1182 |
+
for i in range(b):
|
| 1183 |
+
mask1 = conf1[i] > conf_threshold
|
| 1184 |
+
mask2 = conf2[i] > conf_threshold
|
| 1185 |
+
valid_mask = mask1 & mask2
|
| 1186 |
+
|
| 1187 |
+
idx = np.where(valid_mask)
|
| 1188 |
+
if len(idx[0]) == 0:
|
| 1189 |
+
continue
|
| 1190 |
+
|
| 1191 |
+
pts1 = point_map1[i][idx]
|
| 1192 |
+
pts2 = point_map2[i][idx]
|
| 1193 |
+
|
| 1194 |
+
combined_conf = np.sqrt(conf1[i][idx] * conf2[i][idx])
|
| 1195 |
+
|
| 1196 |
+
aligned_points1.append(pts1)
|
| 1197 |
+
aligned_points2.append(pts2)
|
| 1198 |
+
confidence_weights.append(combined_conf)
|
| 1199 |
+
|
| 1200 |
+
if len(aligned_points1) == 0:
|
| 1201 |
+
raise ValueError("No matching point pairs were found!")
|
| 1202 |
+
|
| 1203 |
+
all_pts1 = np.concatenate(aligned_points1, axis=0)
|
| 1204 |
+
all_pts2 = np.concatenate(aligned_points2, axis=0)
|
| 1205 |
+
all_weights = np.concatenate(confidence_weights, axis=0)
|
| 1206 |
+
|
| 1207 |
+
print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
|
| 1208 |
+
|
| 1209 |
+
if config["Model"]["align_lib"] == "numba":
|
| 1210 |
+
s, R, t = robust_weighted_estimate_sim3_numba(
|
| 1211 |
+
all_pts2,
|
| 1212 |
+
all_pts1,
|
| 1213 |
+
all_weights,
|
| 1214 |
+
delta=config["Model"]["IRLS"]["delta"],
|
| 1215 |
+
max_iters=config["Model"]["IRLS"]["max_iters"],
|
| 1216 |
+
tol=eval(config["Model"]["IRLS"]["tol"]),
|
| 1217 |
+
align_method=config["Model"]["align_method"],
|
| 1218 |
+
)
|
| 1219 |
+
elif config["Model"]["align_lib"] == "numpy": # numpy
|
| 1220 |
+
s, R, t = robust_weighted_estimate_sim3(
|
| 1221 |
+
all_pts2,
|
| 1222 |
+
all_pts1,
|
| 1223 |
+
all_weights,
|
| 1224 |
+
delta=config["Model"]["IRLS"]["delta"],
|
| 1225 |
+
max_iters=config["Model"]["IRLS"]["max_iters"],
|
| 1226 |
+
tol=eval(config["Model"]["IRLS"]["tol"]),
|
| 1227 |
+
align_method=config["Model"]["align_method"],
|
| 1228 |
+
)
|
| 1229 |
+
elif config["Model"]["align_lib"] == "torch": # torch
|
| 1230 |
+
s, R, t = robust_weighted_estimate_sim3_torch(
|
| 1231 |
+
all_pts2,
|
| 1232 |
+
all_pts1,
|
| 1233 |
+
all_weights,
|
| 1234 |
+
delta=config["Model"]["IRLS"]["delta"],
|
| 1235 |
+
max_iters=config["Model"]["IRLS"]["max_iters"],
|
| 1236 |
+
tol=eval(config["Model"]["IRLS"]["tol"]),
|
| 1237 |
+
align_method=config["Model"]["align_method"],
|
| 1238 |
+
)
|
| 1239 |
+
elif config["Model"]["align_lib"] == "triton": # triton
|
| 1240 |
+
s, R, t = robust_weighted_estimate_sim3_triton(
|
| 1241 |
+
all_pts2,
|
| 1242 |
+
all_pts1,
|
| 1243 |
+
all_weights,
|
| 1244 |
+
delta=config["Model"]["IRLS"]["delta"],
|
| 1245 |
+
max_iters=config["Model"]["IRLS"]["max_iters"],
|
| 1246 |
+
tol=eval(config["Model"]["IRLS"]["tol"]),
|
| 1247 |
+
align_method=config["Model"]["align_method"],
|
| 1248 |
+
)
|
| 1249 |
+
else:
|
| 1250 |
+
raise ValueError(f"Unknown align_lib: {config['Model']['align_lib']}")
|
| 1251 |
+
|
| 1252 |
+
if precompute_scale is not None: # meaning we are using align method 'scale+se3'
|
| 1253 |
+
# we need this precompute_scale for loop align
|
| 1254 |
+
s = precompute_scale
|
| 1255 |
+
|
| 1256 |
+
mean_error = compute_alignment_error(
|
| 1257 |
+
point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
|
| 1258 |
+
)
|
| 1259 |
+
print(f"Mean error: {mean_error}")
|
| 1260 |
+
|
| 1261 |
+
return s, R, t
|
Depth-Anything-3/da3_streaming/scripts/download_weights.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
mkdir weights
|
| 4 |
+
cd ./weights
|
| 5 |
+
|
| 6 |
+
# SALAD (~ 340 MiB)
|
| 7 |
+
echo "Downloading SALAD weights (~ 340 MiB) ..."
|
| 8 |
+
SALAD_URL="https://github.com/serizba/salad/releases/download/v1.0.0/dino_salad.ckpt"
|
| 9 |
+
curl -L "$SALAD_URL" -o "./dino_salad.ckpt"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# DA3NESTED-GIANT-LARGE-1.1
|
| 13 |
+
echo "Downloading DA3NESTED-GIANT-LARGE-1.1 weights and config (~ 6.76 GiB)..."
|
| 14 |
+
BASE_URL="https://huggingface.co/depth-anything/DA3NESTED-GIANT-LARGE-1.1/resolve/main"
|
| 15 |
+
|
| 16 |
+
# download config.json (~ 3.1 KiB)
|
| 17 |
+
curl -L "$BASE_URL/config.json" -o "./config.json"
|
| 18 |
+
|
| 19 |
+
# download model.safetensors (~ 6.76 GiB)
|
| 20 |
+
curl -L "$BASE_URL/model.safetensors" -o "./model.safetensors"
|
Depth-Anything-3/docs/API.md
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📚 DepthAnything3 API Documentation
|
| 2 |
+
|
| 3 |
+
## 📑 Table of Contents
|
| 4 |
+
|
| 5 |
+
1. [📖 Overview](#overview)
|
| 6 |
+
2. [💡 Usage Examples](#usage-examples)
|
| 7 |
+
3. [🔧 Core API](#core-api)
|
| 8 |
+
- [DepthAnything3 Class](#depthanything3-class)
|
| 9 |
+
- [inference() Method](#inference-method)
|
| 10 |
+
4. [⚙️ Parameters](#parameters)
|
| 11 |
+
- [Input Parameters](#input-parameters)
|
| 12 |
+
- [Pose Alignment Parameters](#pose-alignment-parameters)
|
| 13 |
+
- [Feature Export Parameters](#feature-export-parameters)
|
| 14 |
+
- [Rendering Parameters](#rendering-parameters)
|
| 15 |
+
- [Processing Parameters](#processing-parameters)
|
| 16 |
+
- [Export Parameters](#export-parameters)
|
| 17 |
+
5. [📤 Export Formats](#export-formats)
|
| 18 |
+
6. [↩️ Return Value](#return-value)
|
| 19 |
+
|
| 20 |
+
## 📖 Overview
|
| 21 |
+
|
| 22 |
+
This documentation provides comprehensive API reference for DepthAnything3, including usage examples, parameter specifications, export formats, and advanced features. It covers both basic pose and depth estimation workflows and advanced pose-conditioned processing with multiple export capabilities.
|
| 23 |
+
|
| 24 |
+
## 💡 Usage Examples
|
| 25 |
+
|
| 26 |
+
Here are quick examples to get you started:
|
| 27 |
+
|
| 28 |
+
### 🚀 Basic Depth Estimation
|
| 29 |
+
```python
|
| 30 |
+
from depth_anything_3.api import DepthAnything3
|
| 31 |
+
|
| 32 |
+
# Initialize and run inference
|
| 33 |
+
model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE").to("cuda")
|
| 34 |
+
prediction = model.inference(["image1.jpg", "image2.jpg"])
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### 📷 Pose-Conditioned Depth Estimation
|
| 38 |
+
```python
|
| 39 |
+
import numpy as np
|
| 40 |
+
|
| 41 |
+
# With camera parameters for better consistency
|
| 42 |
+
prediction = model.inference(
|
| 43 |
+
image=["image1.jpg", "image2.jpg"],
|
| 44 |
+
extrinsics=extrinsics_array, # (N, 4, 4)
|
| 45 |
+
intrinsics=intrinsics_array # (N, 3, 3)
|
| 46 |
+
)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### 📤 Export Results
|
| 50 |
+
```python
|
| 51 |
+
# Export depth data and 3D visualization
|
| 52 |
+
prediction = model.inference(
|
| 53 |
+
image=image_paths,
|
| 54 |
+
export_dir="./output",
|
| 55 |
+
export_format="mini_npz-glb"
|
| 56 |
+
)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### 🔍 Feature Extraction
|
| 60 |
+
```python
|
| 61 |
+
# Export intermediate features from specific layers
|
| 62 |
+
prediction = model.inference(
|
| 63 |
+
image=image_paths,
|
| 64 |
+
export_dir="./output",
|
| 65 |
+
export_format="feat_vis",
|
| 66 |
+
export_feat_layers=[0, 1, 2] # Export features from layers 0, 1, 2
|
| 67 |
+
)
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### ✨ Advanced Export with Gaussian Splatting
|
| 71 |
+
```python
|
| 72 |
+
# Export multiple formats including Gaussian Splatting
|
| 73 |
+
# Note: infer_gs=True requires da3-giant or da3nested-giant-large model
|
| 74 |
+
model = DepthAnything3(model_name="da3-giant").to("cuda")
|
| 75 |
+
|
| 76 |
+
prediction = model.inference(
|
| 77 |
+
image=image_paths,
|
| 78 |
+
extrinsics=extrinsics_array,
|
| 79 |
+
intrinsics=intrinsics_array,
|
| 80 |
+
export_dir="./output",
|
| 81 |
+
export_format="npz-glb-gs_ply-gs_video",
|
| 82 |
+
align_to_input_ext_scale=True,
|
| 83 |
+
infer_gs=True, # Required for gs_ply and gs_video exports
|
| 84 |
+
)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### 🎨 Advanced Export with Feature Visualization
|
| 88 |
+
```python
|
| 89 |
+
# Export with intermediate feature visualization
|
| 90 |
+
prediction = model.inference(
|
| 91 |
+
image=image_paths,
|
| 92 |
+
export_dir="./output",
|
| 93 |
+
export_format="mini_npz-glb-depth_vis-feat_vis",
|
| 94 |
+
export_feat_layers=[0, 5, 10, 15, 20],
|
| 95 |
+
feat_vis_fps=30,
|
| 96 |
+
)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### 📐 Using Ray-Based Pose Estimation
|
| 100 |
+
```python
|
| 101 |
+
# Use ray-based pose estimation instead of camera decoder
|
| 102 |
+
prediction = model.inference(
|
| 103 |
+
image=image_paths,
|
| 104 |
+
export_dir="./output",
|
| 105 |
+
export_format="glb",
|
| 106 |
+
use_ray_pose=True, # Enable ray-based pose estimation
|
| 107 |
+
)
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### 🎯 Reference View Selection
|
| 111 |
+
```python
|
| 112 |
+
# For multi-view inputs, automatically select the best reference view
|
| 113 |
+
prediction = model.inference(
|
| 114 |
+
image=image_paths,
|
| 115 |
+
ref_view_strategy="saddle_balanced", # Default: balanced selection
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# For video sequences, use middle frame as reference
|
| 119 |
+
prediction = model.inference(
|
| 120 |
+
image=video_frames,
|
| 121 |
+
ref_view_strategy="middle", # Good for temporally ordered inputs
|
| 122 |
+
)
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## 🔧 Core API
|
| 126 |
+
|
| 127 |
+
### 🔨 DepthAnything3 Class
|
| 128 |
+
|
| 129 |
+
The main API class that provides depth estimation capabilities with optional pose conditioning.
|
| 130 |
+
|
| 131 |
+
#### 🎯 Initialization
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
from depth_anything_3 import DepthAnything3
|
| 135 |
+
|
| 136 |
+
# Initialize the model with a model name
|
| 137 |
+
model = DepthAnything3(model_name="da3-large")
|
| 138 |
+
model = model.to("cuda") # Move to GPU
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
**Parameters:**
|
| 142 |
+
- `model_name` (str, default: "da3-large"): The name of the model preset to use.
|
| 143 |
+
- **Available models:**
|
| 144 |
+
- 🦾 `"da3-giant"` - 1.15B params, any-view model with GS support
|
| 145 |
+
- ⭐ `"da3-large"` - 0.35B params, any-view model (recommended for most use cases)
|
| 146 |
+
- 📦 `"da3-base"` - 0.12B params, any-view model
|
| 147 |
+
- 🪶 `"da3-small"` - 0.08B params, any-view model
|
| 148 |
+
- 👁️ `"da3mono-large"` - 0.35B params, monocular depth only
|
| 149 |
+
- 📏 `"da3metric-large"` - 0.35B params, metric depth with sky segmentation
|
| 150 |
+
- 🎯 `"da3nested-giant-large"` - 1.40B params, nested model with all features
|
| 151 |
+
|
| 152 |
+
### 🚀 inference() Method
|
| 153 |
+
|
| 154 |
+
The primary inference method that processes images and returns depth predictions.
|
| 155 |
+
|
| 156 |
+
```python
|
| 157 |
+
prediction = model.inference(
|
| 158 |
+
image=image_list,
|
| 159 |
+
extrinsics=extrinsics_array, # Optional
|
| 160 |
+
intrinsics=intrinsics_array, # Optional
|
| 161 |
+
align_to_input_ext_scale=True, # Whether to align predicted poses to input scale
|
| 162 |
+
infer_gs=True, # Enable Gaussian branch for gs exports
|
| 163 |
+
use_ray_pose=False, # Use ray-based pose estimation instead of camera decoder
|
| 164 |
+
ref_view_strategy="saddle_balanced", # Reference view selection strategy
|
| 165 |
+
render_exts=render_extrinsics, # Optional renders for gs_video
|
| 166 |
+
render_ixts=render_intrinsics, # Optional renders for gs_video
|
| 167 |
+
render_hw=(height, width), # Optional renders for gs_video
|
| 168 |
+
process_res=504,
|
| 169 |
+
process_res_method="upper_bound_resize",
|
| 170 |
+
export_dir="output_directory", # Optional
|
| 171 |
+
export_format="mini_npz",
|
| 172 |
+
export_feat_layers=[], # List of layer indices to export features from
|
| 173 |
+
conf_thresh_percentile=40.0, # Confidence threshold percentile for depth map in GLB export
|
| 174 |
+
num_max_points=1_000_000, # Maximum number of points to export in GLB export
|
| 175 |
+
show_cameras=True, # Whether to show cameras in GLB export
|
| 176 |
+
feat_vis_fps=15, # Frames per second for feature visualization in feat_vis export
|
| 177 |
+
export_kwargs={} # Optional, additional arguments to export functions. export_format:key:val, see 'Parameters/Export Parameters' for details
|
| 178 |
+
)
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## ⚙️ Parameters
|
| 182 |
+
|
| 183 |
+
### 📸 Input Parameters
|
| 184 |
+
|
| 185 |
+
#### `image` (required)
|
| 186 |
+
- **Type**: `List[Union[np.ndarray, Image.Image, str]]`
|
| 187 |
+
- **Description**: List of input images. Can be numpy arrays, PIL Images, or file paths.
|
| 188 |
+
- **Example**:
|
| 189 |
+
```python
|
| 190 |
+
# From file paths
|
| 191 |
+
image = ["image1.jpg", "image2.jpg", "image3.jpg"]
|
| 192 |
+
|
| 193 |
+
# From numpy arrays
|
| 194 |
+
image = [np.array(img1), np.array(img2)]
|
| 195 |
+
|
| 196 |
+
# From PIL Images
|
| 197 |
+
image = [Image.open("image1.jpg"), Image.open("image2.jpg")]
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
#### `extrinsics` (optional)
|
| 201 |
+
- **Type**: `Optional[np.ndarray]`
|
| 202 |
+
- **Shape**: `(N, 4, 4)` where N is the number of input images
|
| 203 |
+
- **Description**: Camera extrinsic matrices (world-to-camera transformation). When provided, enables pose-conditioned depth estimation mode.
|
| 204 |
+
- **Note**: If not provided, the model operates in standard depth estimation mode.
|
| 205 |
+
|
| 206 |
+
#### `intrinsics` (optional)
|
| 207 |
+
- **Type**: `Optional[np.ndarray]`
|
| 208 |
+
- **Shape**: `(N, 3, 3)` where N is the number of input images
|
| 209 |
+
- **Description**: Camera intrinsic matrices containing focal length and principal point information. When provided, enables pose-conditioned depth estimation mode.
|
| 210 |
+
|
| 211 |
+
### 🎯 Pose Alignment Parameters
|
| 212 |
+
|
| 213 |
+
#### `align_to_input_ext_scale` (default: True)
|
| 214 |
+
- **Type**: `bool`
|
| 215 |
+
- **Description**: When True the predicted extrinsics are replaced with the input
|
| 216 |
+
ones and the depth maps are rescaled to match their metric scale. When False the
|
| 217 |
+
function returns the internally aligned poses computed via Umeyama alignment.
|
| 218 |
+
|
| 219 |
+
#### `infer_gs` (default: False)
|
| 220 |
+
- **Type**: `bool`
|
| 221 |
+
- **Description**: Enable Gaussian Splatting branch for gaussian splatting exports. Required when using `gs_ply` or `gs_video` export formats.
|
| 222 |
+
|
| 223 |
+
#### `use_ray_pose` (default: False)
|
| 224 |
+
- **Type**: `bool`
|
| 225 |
+
- **Description**: Use ray-based pose estimation instead of camera decoder for pose prediction. When True, the model uses ray prediction heads to estimate camera poses; when False, it uses the camera decoder approach.
|
| 226 |
+
|
| 227 |
+
#### `ref_view_strategy` (default: "saddle_balanced")
|
| 228 |
+
- **Type**: `str`
|
| 229 |
+
- **Description**: Strategy for selecting the reference view from multiple input views. Options: `"first"`, `"middle"`, `"saddle_balanced"`, `"saddle_sim_range"`. Only applied when number of views ≥ 3. See [detailed documentation](funcs/ref_view_strategy.md) for strategy comparisons.
|
| 230 |
+
- **Available strategies**:
|
| 231 |
+
- `"saddle_balanced"`: Selects view with balanced features across multiple metrics (recommended default)
|
| 232 |
+
- `"saddle_sim_range"`: Selects view with largest similarity range
|
| 233 |
+
- `"first"`: Always uses first view (not recommended, equivalent to no reordering for views < 3)
|
| 234 |
+
- `"middle"`: Uses middle view (recommended for video sequences)
|
| 235 |
+
|
| 236 |
+
### 🔍 Feature Export Parameters
|
| 237 |
+
|
| 238 |
+
#### `export_feat_layers` (default: [])
|
| 239 |
+
- **Type**: `List[int]`
|
| 240 |
+
- **Description**: List of layer indices to export intermediate features from. Features are stored in the `aux` dictionary of the Prediction object with keys like `feat_layer_0`, `feat_layer_1`, etc.
|
| 241 |
+
|
| 242 |
+
### 🎥 Rendering Parameters
|
| 243 |
+
|
| 244 |
+
These arguments are only used when exporting Gaussian-splatting videos (include
|
| 245 |
+
`"gs_video"` in `export_format`). They describe an auxiliary camera trajectory
|
| 246 |
+
with ``M`` views.
|
| 247 |
+
|
| 248 |
+
#### `render_exts` (optional)
|
| 249 |
+
- **Type**: `Optional[np.ndarray]`
|
| 250 |
+
- **Shape**: `(M, 4, 4)`
|
| 251 |
+
- **Description**: Camera extrinsics for the synthesized trajectory. If omitted,
|
| 252 |
+
the exporter falls back to the predicted poses.
|
| 253 |
+
|
| 254 |
+
#### `render_ixts` (optional)
|
| 255 |
+
- **Type**: `Optional[np.ndarray]`
|
| 256 |
+
- **Shape**: `(M, 3, 3)`
|
| 257 |
+
- **Description**: Camera intrinsics for each rendered frame. Leave `None` to
|
| 258 |
+
reuse the input intrinsics.
|
| 259 |
+
|
| 260 |
+
#### `render_hw` (optional)
|
| 261 |
+
- **Type**: `Optional[Tuple[int, int]]`
|
| 262 |
+
- **Description**: Explicit output resolution `(height, width)` for the rendered
|
| 263 |
+
frames. Defaults to the input resolution when not provided.
|
| 264 |
+
|
| 265 |
+
### ⚡ Processing Parameters
|
| 266 |
+
|
| 267 |
+
#### `process_res` (default: 504)
|
| 268 |
+
- **Type**: `int`
|
| 269 |
+
- **Description**: Base resolution for processing. The model will resize images to this resolution for inference.
|
| 270 |
+
|
| 271 |
+
#### `process_res_method` (default: "upper_bound_resize")
|
| 272 |
+
- **Type**: `str`
|
| 273 |
+
- **Description**: Method for resizing images to the target resolution.
|
| 274 |
+
- **Options**:
|
| 275 |
+
- `"upper_bound_resize"`: Resize so that the specified dimension (504) becomes the longer side
|
| 276 |
+
- `"lower_bound_resize"`: Resize so that the specified dimension (504) becomes the shorter side
|
| 277 |
+
- **Example**:
|
| 278 |
+
- Input: 1200×1600 → Output: 378×504 (with `process_res=504`, `process_res_method="upper_bound_resize"`)
|
| 279 |
+
- Input: 504×672 → Output: 504×672 (no change needed)
|
| 280 |
+
|
| 281 |
+
### 📦 Export Parameters
|
| 282 |
+
|
| 283 |
+
#### `export_dir` (optional)
|
| 284 |
+
- **Type**: `Optional[str]`
|
| 285 |
+
- **Description**: Directory path where exported files will be saved. If not provided, no files will be exported.
|
| 286 |
+
|
| 287 |
+
#### `export_format` (default: "mini_npz")
|
| 288 |
+
- **Type**: `str`
|
| 289 |
+
- **Description**: Format for exporting results. Supports multiple formats separated by `-`.
|
| 290 |
+
- **Example**: `"mini_npz-glb"` exports both mini_npz and glb formats.
|
| 291 |
+
|
| 292 |
+
#### 🌐 GLB Export Parameters
|
| 293 |
+
|
| 294 |
+
These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"glb"`.
|
| 295 |
+
|
| 296 |
+
##### `conf_thresh_percentile` (default: 40.0)
|
| 297 |
+
- **Type**: `float`
|
| 298 |
+
- **Description**: Lower percentile for adaptive confidence threshold. Points below this confidence percentile will be filtered out from the point cloud.
|
| 299 |
+
|
| 300 |
+
##### `num_max_points` (default: 1,000,000)
|
| 301 |
+
- **Type**: `int`
|
| 302 |
+
- **Description**: Maximum number of points in the exported point cloud. If the point cloud exceeds this limit, it will be downsampled.
|
| 303 |
+
|
| 304 |
+
##### `show_cameras` (default: True)
|
| 305 |
+
- **Type**: `bool`
|
| 306 |
+
- **Description**: Whether to include camera wireframes in the exported GLB file for visualization.
|
| 307 |
+
|
| 308 |
+
#### 🎨 Feature Visualization Parameters
|
| 309 |
+
|
| 310 |
+
These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"feat_vis"`.
|
| 311 |
+
|
| 312 |
+
##### `feat_vis_fps` (default: 15)
|
| 313 |
+
- **Type**: `int`
|
| 314 |
+
- **Description**: Frame rate for the output video when visualizing features across multiple images.
|
| 315 |
+
|
| 316 |
+
#### ✨🎥 3DGS and 3DGS Video Parameters
|
| 317 |
+
|
| 318 |
+
These parameters are passed directly to the `inference()` method and only apply when `export_format` includes `"gs_ply"` or `"gs_video"`.
|
| 319 |
+
|
| 320 |
+
##### `export_kwargs` (default: `{}`)
|
| 321 |
+
- Type: `dict[str, dict[str, Any]]`
|
| 322 |
+
- Description: Per-format extra arguments passed to export functions, mainly for `"gs_ply"` and `"gs_video"`.
|
| 323 |
+
- Access pattern: `export_kwargs[export_format][key] = value`
|
| 324 |
+
- Example:
|
| 325 |
+
```python
|
| 326 |
+
{
|
| 327 |
+
"gs_ply": {
|
| 328 |
+
"gs_views_interval": 1,
|
| 329 |
+
},
|
| 330 |
+
"gs_video": {
|
| 331 |
+
"trj_mode": "interpolate_smooth",
|
| 332 |
+
"chunk_size": 1,
|
| 333 |
+
"vis_depth": None,
|
| 334 |
+
},
|
| 335 |
+
}
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
## 📤 Export Formats
|
| 339 |
+
|
| 340 |
+
The API supports multiple export formats for different use cases:
|
| 341 |
+
|
| 342 |
+
### 📊 `mini_npz`
|
| 343 |
+
- **Description**: Minimal NPZ format containing essential data
|
| 344 |
+
- **Contents**: `depth`, `conf`, `exts`, `ixts`
|
| 345 |
+
- **Use case**: Lightweight storage for depth data with camera parameters
|
| 346 |
+
|
| 347 |
+
### 📦 `npz`
|
| 348 |
+
- **Description**: Full NPZ format with comprehensive data
|
| 349 |
+
- **Contents**: `depth`, `conf`, `exts`, `ixts`, `image`, etc.
|
| 350 |
+
- **Use case**: Complete data export for advanced processing
|
| 351 |
+
|
| 352 |
+
### 🌐 `glb`
|
| 353 |
+
- **Description**: 3D visualization format with point cloud and camera poses
|
| 354 |
+
- **Contents**:
|
| 355 |
+
- Point cloud with colors from original images
|
| 356 |
+
- Camera wireframes for visualization
|
| 357 |
+
- Confidence-based filtering and downsampling
|
| 358 |
+
- **Use case**: 3D visualization, inspection, and analysis
|
| 359 |
+
- **Features**:
|
| 360 |
+
- Automatic sky depth handling
|
| 361 |
+
- Confidence threshold filtering
|
| 362 |
+
- Background filtering (black/white)
|
| 363 |
+
- Scene scale normalization
|
| 364 |
+
- **Parameters** (passed via `inference()` method directly):
|
| 365 |
+
- `conf_thresh_percentile` (float, default: 40.0): Lower percentile for adaptive confidence threshold. Points below this confidence percentile will be filtered out.
|
| 366 |
+
- `num_max_points` (int, default: 1,000,000): Maximum number of points in the exported point cloud. If exceeded, points will be downsampled.
|
| 367 |
+
- `show_cameras` (bool, default: True): Whether to include camera wireframes in the exported GLB file for visualization.
|
| 368 |
+
|
| 369 |
+
### ✨ `gs_ply`
|
| 370 |
+
- **Description**: Gaussian Splatting point cloud format
|
| 371 |
+
- **Contents**: 3DGS data in PLY format. Compatible with standard 3DGS viewers such as [SuperSplat](https://superspl.at/editor) (recommended), [SPARK](https://sparkjs.dev/viewer/).
|
| 372 |
+
- **Use case**: Gaussian Splatting reconstruction
|
| 373 |
+
- **Requirements**: Must set `infer_gs=True` when calling `inference()`. Only supported by `da3-giant` and `da3nested-giant-large` models.
|
| 374 |
+
- **Additional configs**, provided via `export_kwargs` (see [Export Parameters](#export-parameters)):
|
| 375 |
+
- `gs_views_interval`: Export to 3DGS every N views, default: `1`.
|
| 376 |
+
|
| 377 |
+
### 🎥 `gs_video`
|
| 378 |
+
- **Description**: Rasterized 3DGS to obtain videos
|
| 379 |
+
- **Contents**: A video of 3DGS-rasterized views using either provided viewpoints or a predefined camera trajectory.
|
| 380 |
+
- **Use case**: Video rendering for Gaussian Splatting
|
| 381 |
+
- **Requirements**: Must set `infer_gs=True` when calling `inference()`. Only supported by `da3-giant` and `da3nested-giant-large` models.
|
| 382 |
+
- **Note**: Can optionally use `render_exts`, `render_ixts`, and `render_hw` parameters in `inference()` method to specify novel viewpoints.
|
| 383 |
+
- **Additional configs**, provided via `export_kwargs` (see [Export Parameters](#export-parameters)):
|
| 384 |
+
- `extrinsics`: Optional world-to-camera poses for novel views. Falls back to the predicted poses of input views if not provided. (Alternatively, use `render_exts` parameter in `inference()`)
|
| 385 |
+
- `intrinsics`: Optional camera intrinsics for novel views. Falls back to the predicted intrinsics of input views if not provided. (Alternatively, use `render_ixts` parameter in `inference()`)
|
| 386 |
+
- `out_image_hw`: Optional output resolution `H x W`. Falls back to input resolution if not provided. (Alternatively, use `render_hw` parameter in `inference()`)
|
| 387 |
+
- `chunk_size`: Number of views rasterized per batch. Default: `8`.
|
| 388 |
+
- `trj_mode`: Predefined camera trajectory for novel-view rendering.
|
| 389 |
+
- `color_mode`: Same as `render_mode` in [gsplat](https://docs.gsplat.studio/main/apis/rasterization.html#gsplat.rasterization).
|
| 390 |
+
- `vis_depth`: How depth is combined with RGB. Default: `hcat` (horizontal concatenation).
|
| 391 |
+
- `enable_tqdm`: Whether to display a tqdm progress bar during rendering.
|
| 392 |
+
- `output_name`: File name of the rendered video.
|
| 393 |
+
- `video_quality`: Video quality to save. Default: `high`.
|
| 394 |
+
- `high`: High quality video (default)
|
| 395 |
+
- `medium`: Medium quality video (balance of storage space and quality)
|
| 396 |
+
- `low`: Low quality video (fewer storage space)
|
| 397 |
+
|
| 398 |
+
### 🔍 `feat_vis`
|
| 399 |
+
- **Description**: Feature visualization format
|
| 400 |
+
- **Contents**: PCA-visualized intermediate features from specified layers
|
| 401 |
+
- **Use case**: Model interpretability and feature analysis
|
| 402 |
+
- **Note**: Requires `export_feat_layers` to be specified
|
| 403 |
+
- **Parameters** (passed via `inference()` method directly):
|
| 404 |
+
- `feat_vis_fps` (int, default: 15): Frame rate for the output video when visualizing features across multiple images.
|
| 405 |
+
|
| 406 |
+
### 🎨 `depth_vis`
|
| 407 |
+
- **Description**: Depth visualization format
|
| 408 |
+
- **Contents**: Color-coded depth maps alongside original images
|
| 409 |
+
- **Use case**: Visual inspection of depth estimation quality
|
| 410 |
+
|
| 411 |
+
### 🔗 Multiple Format Export
|
| 412 |
+
You can export multiple formats simultaneously by separating them with `-`:
|
| 413 |
+
|
| 414 |
+
```python
|
| 415 |
+
# Export both mini_npz and glb formats
|
| 416 |
+
export_format = "mini_npz-glb"
|
| 417 |
+
|
| 418 |
+
# Export multiple formats
|
| 419 |
+
export_format = "npz-glb-gs_ply"
|
| 420 |
+
```
|
| 421 |
+
|
| 422 |
+
## ↩️ Return Value
|
| 423 |
+
|
| 424 |
+
The `inference()` method returns a `Prediction` object with the following attributes:
|
| 425 |
+
|
| 426 |
+
### 📊 Core Outputs
|
| 427 |
+
|
| 428 |
+
- **depth**: `np.ndarray` - Estimated depth maps with shape `(N, H, W)` where N is the number of images, H is height, and W is width.
|
| 429 |
+
- **conf**: `np.ndarray` - Confidence maps with shape `(N, H, W)` indicating prediction reliability (optional, depends on model).
|
| 430 |
+
|
| 431 |
+
### 📷 Camera Parameters
|
| 432 |
+
|
| 433 |
+
- **extrinsics**: `np.ndarray` - Camera extrinsic matrices with shape `(N, 3, 4)` representing world-to-camera transformations. Only present if camera poses were estimated or provided as input.
|
| 434 |
+
- **intrinsics**: `np.ndarray` - Camera intrinsic matrices with shape `(N, 3, 3)` containing focal length and principal point information. Only present if poses were estimated or provided as input.
|
| 435 |
+
|
| 436 |
+
### 🎁 Additional Outputs
|
| 437 |
+
|
| 438 |
+
- **processed_images**: `np.ndarray` - Preprocessed input images with shape `(N, H, W, 3)` in RGB format (0-255 uint8).
|
| 439 |
+
- **aux**: `dict` - Auxiliary outputs including:
|
| 440 |
+
- `feat_layer_X`: Intermediate features from layer X (if `export_feat_layers` was specified)
|
| 441 |
+
- `gaussians`: 3D Gaussian Splats data (if `infer_gs=True`)
|
| 442 |
+
|
| 443 |
+
### 💻 Usage Example
|
| 444 |
+
|
| 445 |
+
```python
|
| 446 |
+
prediction = model.inference(image=["img1.jpg", "img2.jpg"])
|
| 447 |
+
|
| 448 |
+
# Access depth maps
|
| 449 |
+
depth_maps = prediction.depth # shape: (2, H, W)
|
| 450 |
+
|
| 451 |
+
# Access confidence
|
| 452 |
+
if hasattr(prediction, 'conf'):
|
| 453 |
+
confidence = prediction.conf
|
| 454 |
+
|
| 455 |
+
# Access camera parameters (if available)
|
| 456 |
+
if hasattr(prediction, 'extrinsics'):
|
| 457 |
+
camera_poses = prediction.extrinsics # shape: (2, 4, 4)
|
| 458 |
+
|
| 459 |
+
if hasattr(prediction, 'intrinsics'):
|
| 460 |
+
camera_intrinsics = prediction.intrinsics # shape: (2, 3, 3)
|
| 461 |
+
|
| 462 |
+
# Access intermediate features (if export_feat_layers was set)
|
| 463 |
+
if hasattr(prediction, 'aux') and 'feat_layer_0' in prediction.aux:
|
| 464 |
+
features = prediction.aux['feat_layer_0']
|
| 465 |
+
```
|
Depth-Anything-3/docs/BENCHMARK.md
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📏 Visual Geometry Benchmark
|
| 2 |
+
|
| 3 |
+
This document provides comprehensive instructions for running benchmark evaluation on Depth Anything 3.
|
| 4 |
+
|
| 5 |
+
## ✨ Highlights
|
| 6 |
+
|
| 7 |
+
- 🗂️ **Diverse and Challenging Datasets**: 5 datasets (ETH3D, 7Scenes, ScanNet++, HiRoom, DTU) covering from objects to indoor and outdoor scenes. Part of datasets are recalibrated for high accuracy (see [ScanNet++](#scannet) details). All preprocessed datasets are uploaded to [depth-anything/DA3-BENCH](https://huggingface.co/datasets/depth-anything/DA3-BENCH).
|
| 8 |
+
- 🔧 **Robust Evaluation Pipeline**: Standardized pipeline featuring RANSAC-based pose alignment for better coordinate system alignment, TSDF fusion for directly reflecting depth 3D consistency.
|
| 9 |
+
- 📊 **Standardized Metrics**: Performance measured using established metrics: AUC for pose accuracy, F1-score and Chamfer Distance for reconstruction.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 📑 Table of Contents
|
| 14 |
+
|
| 15 |
+
- [🚀 Quick Start](#quick-start)
|
| 16 |
+
- [📥 Dataset Download](#dataset-download)
|
| 17 |
+
- [⚙️ Evaluation Pipeline](#evaluation-pipeline)
|
| 18 |
+
- [🔧 Configuration](#configuration)
|
| 19 |
+
- [📊 Metrics](#metrics)
|
| 20 |
+
- [🗂️ Dataset Details](#dataset-details)
|
| 21 |
+
- [💻 Command Reference](#command-reference)
|
| 22 |
+
- [🔍 Troubleshooting](#troubleshooting)
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## 🚀 Quick Start
|
| 27 |
+
|
| 28 |
+
### 1. Download Benchmark Data
|
| 29 |
+
|
| 30 |
+
> 💡 **Note:** Install HuggingFace CLI first: `pip install -U huggingface_hub[cli]`
|
| 31 |
+
>
|
| 32 |
+
> 🌐 **Mirror:** If download is slow, try: `export HF_ENDPOINT=https://hf-mirror.com`
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
cd da3_release
|
| 36 |
+
|
| 37 |
+
# Create directory and download from HuggingFace
|
| 38 |
+
mkdir -p workspace/benchmark_dataset
|
| 39 |
+
hf download depth-anything/DA3-BENCH \
|
| 40 |
+
--local-dir workspace/benchmark_dataset \
|
| 41 |
+
--repo-type dataset
|
| 42 |
+
|
| 43 |
+
# Extract all datasets
|
| 44 |
+
cd workspace/benchmark_dataset
|
| 45 |
+
for f in *.zip; do unzip -q "$f"; done
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### 2. Run Evaluation
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
# Set model (default: depth-anything/DA3-GIANT)
|
| 52 |
+
MODEL=depth-anything/DA3-GIANT
|
| 53 |
+
|
| 54 |
+
# Full evaluation (all datasets, all modes)
|
| 55 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL
|
| 56 |
+
|
| 57 |
+
# View results
|
| 58 |
+
python -m depth_anything_3.bench.evaluator eval.print_only=true
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## 📥 Dataset Download
|
| 64 |
+
|
| 65 |
+
All benchmark datasets are hosted on HuggingFace: **[depth-anything/DA3-BENCH](https://huggingface.co/datasets/depth-anything/DA3-BENCH)**
|
| 66 |
+
|
| 67 |
+
| Dataset | File | Size | Description |
|
| 68 |
+
|---------|------|------|-------------|
|
| 69 |
+
| ETH3D | `eth3d.zip` | ~14.1 GB | High-resolution multi-view stereo (indoor/outdoor) |
|
| 70 |
+
| ScanNet++ | `scannetpp.zip` | ~10.1 GB | High-quality RGB-D indoor scenes |
|
| 71 |
+
| DTU-49 | `dtu.zip` | ~8.3 GB | Multi-view stereo benchmark (22 scenes × 49 views) |
|
| 72 |
+
| 7Scenes | `7scenes.zip` | ~3.3 GB | RGB-D indoor localization |
|
| 73 |
+
| DTU-64 | `dtu64.zip` | ~1.7 GB | DTU subset for pose evaluation (13 scenes × 64 views) |
|
| 74 |
+
| HiRoom | `hiroom.zip` | ~0.7 GB | High-resolution indoor rooms |
|
| 75 |
+
|
| 76 |
+
### Download Options
|
| 77 |
+
|
| 78 |
+
**Option 1: Download All (Recommended)**
|
| 79 |
+
```bash
|
| 80 |
+
hf download depth-anything/DA3-BENCH \
|
| 81 |
+
--local-dir workspace/benchmark_dataset \
|
| 82 |
+
--repo-type dataset
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
**Option 2: Download Specific Dataset**
|
| 86 |
+
```bash
|
| 87 |
+
# Download only HiRoom
|
| 88 |
+
hf download depth-anything/DA3-BENCH hiroom.zip \
|
| 89 |
+
--local-dir workspace/benchmark_dataset \
|
| 90 |
+
--repo-type dataset
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
**Option 3: Manual Download**
|
| 94 |
+
|
| 95 |
+
Visit [https://huggingface.co/datasets/depth-anything/DA3-BENCH](https://huggingface.co/datasets/depth-anything/DA3-BENCH) and download the zip files manually.
|
| 96 |
+
|
| 97 |
+
### Extract Datasets
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
cd workspace/benchmark_dataset
|
| 101 |
+
|
| 102 |
+
# Extract all
|
| 103 |
+
for f in *.zip; do unzip -q "$f"; done
|
| 104 |
+
|
| 105 |
+
# Or extract specific dataset
|
| 106 |
+
unzip hiroom.zip
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### Expected Directory Structure
|
| 110 |
+
|
| 111 |
+
After extraction, your directory should look like:
|
| 112 |
+
```
|
| 113 |
+
workspace/benchmark_dataset/
|
| 114 |
+
├── eth3d/
|
| 115 |
+
│ ├── courtyard/
|
| 116 |
+
│ ├── electro/
|
| 117 |
+
│ └── ...
|
| 118 |
+
├── 7scenes/
|
| 119 |
+
│ └── 7Scenes/
|
| 120 |
+
│ ├── chess/
|
| 121 |
+
│ └── ...
|
| 122 |
+
├── scannetpp/
|
| 123 |
+
│ ├── 09c1414f1b/
|
| 124 |
+
│ └── ...
|
| 125 |
+
├── hiroom/
|
| 126 |
+
│ ├── data/
|
| 127 |
+
│ ├── fused_pcd/
|
| 128 |
+
│ └── selected_scene_list_val.txt
|
| 129 |
+
├── dtu/
|
| 130 |
+
│ ├── Rectified/
|
| 131 |
+
│ ├── Cameras/
|
| 132 |
+
│ ├── Points/
|
| 133 |
+
│ ├── SampleSet/
|
| 134 |
+
│ └── depth_raw/
|
| 135 |
+
└── dtu64/
|
| 136 |
+
├── Cameras/
|
| 137 |
+
├── scan105/
|
| 138 |
+
└── ...
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## ⚙️ Evaluation Pipeline
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
### Evaluation Modes
|
| 148 |
+
|
| 149 |
+
| Mode | Description | Metrics |
|
| 150 |
+
|------|-------------|---------|
|
| 151 |
+
| `pose` | Camera pose estimation | AUC@3°, AUC@30° |
|
| 152 |
+
| `recon_unposed` | 3D reconstruction with **predicted** poses | F-score, Overall |
|
| 153 |
+
| `recon_posed` | 3D reconstruction with **GT** poses | F-score, Overall |
|
| 154 |
+
|
| 155 |
+
### Basic Usage
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
cd da3_release
|
| 159 |
+
MODEL=depth-anything/DA3-GIANT
|
| 160 |
+
|
| 161 |
+
# Full evaluation (inference + evaluation + print results)
|
| 162 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL
|
| 163 |
+
|
| 164 |
+
# Skip inference, only evaluate existing predictions
|
| 165 |
+
python -m depth_anything_3.bench.evaluator eval.eval_only=true
|
| 166 |
+
|
| 167 |
+
# Only print saved metrics
|
| 168 |
+
python -m depth_anything_3.bench.evaluator eval.print_only=true
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### Selective Evaluation
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
# Evaluate specific datasets
|
| 175 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL eval.datasets=[hiroom]
|
| 176 |
+
|
| 177 |
+
# Evaluate specific modes
|
| 178 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL eval.modes=[pose,recon_unposed]
|
| 179 |
+
|
| 180 |
+
# Combine dataset and mode selection
|
| 181 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL \
|
| 182 |
+
eval.datasets=[hiroom] \
|
| 183 |
+
eval.modes=[pose]
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### 🖥️ Multi-GPU Inference
|
| 187 |
+
|
| 188 |
+
The evaluator automatically distributes inference across available GPUs:
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
# Use 4 GPUs
|
| 192 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m depth_anything_3.bench.evaluator model.path=$MODEL
|
| 193 |
+
|
| 194 |
+
# Use all available GPUs (default)
|
| 195 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL
|
| 196 |
+
|
| 197 |
+
# Single GPU
|
| 198 |
+
CUDA_VISIBLE_DEVICES=0 python -m depth_anything_3.bench.evaluator model.path=$MODEL
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## 🔧 Configuration
|
| 204 |
+
|
| 205 |
+
### Config File
|
| 206 |
+
|
| 207 |
+
Default config: `src/depth_anything_3/bench/configs/eval_bench.yaml`
|
| 208 |
+
|
| 209 |
+
```yaml
|
| 210 |
+
# Model path
|
| 211 |
+
model:
|
| 212 |
+
path: depth-anything/DA3-GIANT
|
| 213 |
+
|
| 214 |
+
# Workspace directory
|
| 215 |
+
workspace:
|
| 216 |
+
work_dir: ./workspace/evaluation
|
| 217 |
+
|
| 218 |
+
# Evaluation settings
|
| 219 |
+
eval:
|
| 220 |
+
datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu, dtu64]
|
| 221 |
+
modes: [pose, recon_unposed, recon_posed]
|
| 222 |
+
max_frames: 100 # Max frames per scene (-1 = no limit)
|
| 223 |
+
scenes: null # Specific scenes (null = all)
|
| 224 |
+
|
| 225 |
+
# Inference settings
|
| 226 |
+
inference:
|
| 227 |
+
num_fusion_workers: 4
|
| 228 |
+
debug: false
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
### Output Structure
|
| 232 |
+
|
| 233 |
+
```
|
| 234 |
+
workspace/evaluation/
|
| 235 |
+
├── model_results/ # Inference outputs
|
| 236 |
+
│ ├── eth3d/
|
| 237 |
+
│ │ └── {scene}/
|
| 238 |
+
│ │ ├── unposed/ # Predictions for recon_unposed
|
| 239 |
+
│ │ └── posed/ # Predictions for recon_posed
|
| 240 |
+
│ ├── 7scenes/
|
| 241 |
+
│ ├── scannetpp/
|
| 242 |
+
│ ├── hiroom/
|
| 243 |
+
│ ├── dtu/
|
| 244 |
+
│ └── dtu64/
|
| 245 |
+
└── metric_results/ # Evaluation metrics (JSON)
|
| 246 |
+
├── eth3d_pose.json
|
| 247 |
+
├── eth3d_recon_unposed.json
|
| 248 |
+
├── eth3d_recon_posed.json
|
| 249 |
+
└── ...
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
---
|
| 253 |
+
|
| 254 |
+
## 📊 Metrics
|
| 255 |
+
|
| 256 |
+
### 🎯 Pose Estimation
|
| 257 |
+
|
| 258 |
+
| Metric | Description |
|
| 259 |
+
|--------|-------------|
|
| 260 |
+
| **Auc3** | Area Under Curve at 3° angular error threshold |
|
| 261 |
+
| **Auc30** | Area Under Curve at 30° angular error threshold |
|
| 262 |
+
|
| 263 |
+
### 🏗️ 3D Reconstruction
|
| 264 |
+
|
| 265 |
+
| Metric | Description | Note |
|
| 266 |
+
|--------|-------------|------|
|
| 267 |
+
| **F-score** | Harmonic mean of Precision and Recall | Higher is better |
|
| 268 |
+
| **Overall** | (Accuracy + Completeness) / 2 | Lower is better (error in meters/mm) |
|
| 269 |
+
|
| 270 |
+
> **Note:** DTU reports Overall in millimeters; other datasets report in meters.
|
| 271 |
+
|
| 272 |
+
### Expected Results for DA3-GIANT
|
| 273 |
+
|
| 274 |
+
If your setup is correct, you should get the following results when evaluating the **DA3-GIANT** model:
|
| 275 |
+
|
| 276 |
+
```
|
| 277 |
+
========================================================
|
| 278 |
+
📊 SUMMARY
|
| 279 |
+
========================================================
|
| 280 |
+
|
| 281 |
+
🎯 POSE ESTIMATION
|
| 282 |
+
---------------------------------------------------------------------------------------
|
| 283 |
+
Metric Avg HiRoom ETH3D DTU-64 7Scenes ScanNet++
|
| 284 |
+
---------------------------------------------------------------------------------------
|
| 285 |
+
Auc3 0.6705 0.8030 0.4872 0.9408 0.2744 0.8470
|
| 286 |
+
Auc30 0.9436 0.9592 0.9153 0.9939 0.8668 0.9827
|
| 287 |
+
|
| 288 |
+
🏗️ RECON_UNPOSED (Pred Pose)
|
| 289 |
+
---------------------------------------------------------------------------------------
|
| 290 |
+
Metric Avg* HiRoom ETH3D DTU 7Scenes ScanNet++
|
| 291 |
+
---------------------------------------------------------------------------------------
|
| 292 |
+
F-score 0.7345 0.8629 0.7876 N/A 0.5043 0.7831
|
| 293 |
+
Overall 0.1682 0.0457 0.4366 1.7927 0.1230 0.0676
|
| 294 |
+
|
| 295 |
+
🏗️ RECON_POSED (GT Pose)
|
| 296 |
+
---------------------------------------------------------------------------------------
|
| 297 |
+
Metric Avg* HiRoom ETH3D DTU 7Scenes ScanNet++
|
| 298 |
+
---------------------------------------------------------------------------------------
|
| 299 |
+
F-score 0.7978 0.9546 0.8685 N/A 0.5635 0.8045
|
| 300 |
+
Overall 0.1408 0.0213 0.3679 1.7488 0.1092 0.0649
|
| 301 |
+
|
| 302 |
+
* Avg F-score / Overall = average over HiRoom, ETH3D, 7Scenes, ScanNet++ (4 datasets)
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
---
|
| 306 |
+
|
| 307 |
+
## 🗂️ Dataset Details
|
| 308 |
+
|
| 309 |
+
### ETH3D
|
| 310 |
+
|
| 311 |
+
High-resolution multi-view stereo benchmark with laser-scanned ground truth.
|
| 312 |
+
|
| 313 |
+
- **Scenes:** 11 (courtyard, electro, kicker, pipes, relief, delivery_area, facade, office, playground, relief_2, terrains)
|
| 314 |
+
- **Resolution:** Variable (high-res DSLR images)
|
| 315 |
+
- **GT:** Laser-scanned meshes + depth maps
|
| 316 |
+
|
| 317 |
+
> **⚠️ Image Filtering:** Some images with unusual camera rotations are filtered out for stable evaluation. See `ETH3D_FILTER_KEYS` in `constants.py`.
|
| 318 |
+
|
| 319 |
+
### 7Scenes
|
| 320 |
+
|
| 321 |
+
RGB-D dataset for camera relocalization.
|
| 322 |
+
|
| 323 |
+
- **Scenes:** 7 (chess, fire, heads, office, pumpkin, redkitchen, stairs)
|
| 324 |
+
- **Resolution:** 640×480
|
| 325 |
+
- **GT:** Poses from KinectFusion, meshes from TSDF fusion
|
| 326 |
+
|
| 327 |
+
### ScanNet++
|
| 328 |
+
|
| 329 |
+
High-quality indoor RGB-D dataset with dense annotations.
|
| 330 |
+
|
| 331 |
+
- **Scenes:** 20 validation scenes
|
| 332 |
+
- **Resolution:** 768×1024 (after undistortion)
|
| 333 |
+
- **GT:** High-quality meshes from FARO scanner
|
| 334 |
+
|
| 335 |
+
> **⚠️ Camera Pose Re-calibration:** The default ScanNet++ poses are often inaccurate due to motion blur and textureless frames from iPhone captures. We re-ran COLMAP with the following improvements:
|
| 336 |
+
> - **Frame filtering:** Removed blurry images during frame extraction
|
| 337 |
+
> - **Fisheye calibration:** Jointly calibrated fisheye camera for wider FOV and better accuracy
|
| 338 |
+
> - **Exhaustive matching:** Used COLMAP's exhaustive matcher and mapper for reliable poses (takes several days per scene but necessary for quality)
|
| 339 |
+
> - All processed scenes are available at [haotongl/scannetpp_zipnerf](https://huggingface.co/datasets/haotongl/scannetpp_zipnerf)
|
| 340 |
+
|
| 341 |
+
### HiRoom
|
| 342 |
+
|
| 343 |
+
Indoor room scenes with high-resolution RGB-D data.
|
| 344 |
+
|
| 345 |
+
- **Scenes:** 24 validation scenes
|
| 346 |
+
- **GT:** Fused point clouds
|
| 347 |
+
|
| 348 |
+
### DTU-49 (Reconstruction Only)
|
| 349 |
+
|
| 350 |
+
Multi-view stereo benchmark following MVSNet evaluation protocol.
|
| 351 |
+
|
| 352 |
+
- **Scenes:** 22 evaluation scenes
|
| 353 |
+
- **Views:** 49 images per scene
|
| 354 |
+
- **GT:** Laser-scanned point clouds with observation masks
|
| 355 |
+
- **Metrics:** Overall only (accuracy + completeness in mm)
|
| 356 |
+
|
| 357 |
+
### DTU-64 (Pose Only)
|
| 358 |
+
|
| 359 |
+
DTU subset for pose estimation evaluation.
|
| 360 |
+
|
| 361 |
+
- **Scenes:** 13 scenes
|
| 362 |
+
- **Views:** 64 images per scene
|
| 363 |
+
- **Metrics:** AUC@3°, AUC@30°
|
| 364 |
+
|
| 365 |
+
> **Why two DTU settings?**
|
| 366 |
+
> - **DTU-64** (pose): More views = more challenging pose estimation
|
| 367 |
+
> - **DTU-49** (recon): Standard MVSNet protocol for fair comparison with MVS methods
|
| 368 |
+
|
| 369 |
+
---
|
| 370 |
+
|
| 371 |
+
## 💻 Command Reference
|
| 372 |
+
|
| 373 |
+
```
|
| 374 |
+
python -m depth_anything_3.bench.evaluator [OPTIONS] [KEY=VALUE ...]
|
| 375 |
+
|
| 376 |
+
Configuration:
|
| 377 |
+
--config PATH Config YAML file (default: bench/configs/eval_bench.yaml)
|
| 378 |
+
|
| 379 |
+
Config Overrides (using dotlist notation):
|
| 380 |
+
model.path=VALUE Model path or HuggingFace ID
|
| 381 |
+
workspace.work_dir=VALUE Working directory for outputs
|
| 382 |
+
eval.datasets=[dataset1,dataset2] Datasets to evaluate (eth3d,7scenes,scannetpp,hiroom,dtu,dtu64)
|
| 383 |
+
eval.modes=[mode1,mode2] Evaluation modes (pose,recon_unposed,recon_posed)
|
| 384 |
+
eval.scenes=[scene1,scene2] Specific scenes to evaluate (null=all)
|
| 385 |
+
eval.max_frames=VALUE Max frames per scene (-1=no limit, default: 100)
|
| 386 |
+
eval.ref_view_strategy=VALUE Reference view strategy (default: first)
|
| 387 |
+
eval.eval_only=VALUE Only run evaluation (skip inference) (true/false)
|
| 388 |
+
eval.print_only=VALUE Only print saved metrics (true/false)
|
| 389 |
+
inference.num_fusion_workers=VALUE Number of parallel workers (default: 4)
|
| 390 |
+
inference.debug=VALUE Enable debug mode (true/false)
|
| 391 |
+
|
| 392 |
+
Special Flags:
|
| 393 |
+
--help, -h Show this help message
|
| 394 |
+
|
| 395 |
+
Multi-GPU:
|
| 396 |
+
Use CUDA_VISIBLE_DEVICES to specify GPUs (auto-detected and distributed)
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
### Examples
|
| 400 |
+
|
| 401 |
+
```bash
|
| 402 |
+
MODEL=depth-anything/DA3-GIANT
|
| 403 |
+
|
| 404 |
+
# Full evaluation
|
| 405 |
+
python -m depth_anything_3.bench.evaluator model.path=$MODEL
|
| 406 |
+
|
| 407 |
+
# Quick test on HiRoom only
|
| 408 |
+
python -m depth_anything_3.bench.evaluator \
|
| 409 |
+
model.path=$MODEL \
|
| 410 |
+
eval.datasets=[hiroom] \
|
| 411 |
+
eval.modes=[pose]
|
| 412 |
+
|
| 413 |
+
# Pose-only evaluation (all 5 pose datasets)
|
| 414 |
+
python -m depth_anything_3.bench.evaluator \
|
| 415 |
+
model.path=$MODEL \
|
| 416 |
+
eval.datasets=[eth3d,7scenes,scannetpp,hiroom,dtu64] \
|
| 417 |
+
eval.modes=[pose]
|
| 418 |
+
|
| 419 |
+
# Recon-only evaluation (all 5 recon datasets)
|
| 420 |
+
python -m depth_anything_3.bench.evaluator \
|
| 421 |
+
model.path=$MODEL \
|
| 422 |
+
eval.datasets=[eth3d,7scenes,scannetpp,hiroom,dtu] \
|
| 423 |
+
eval.modes=[recon_unposed,recon_posed]
|
| 424 |
+
|
| 425 |
+
# Debug specific scenes
|
| 426 |
+
python -m depth_anything_3.bench.evaluator \
|
| 427 |
+
model.path=$MODEL \
|
| 428 |
+
eval.datasets=[eth3d] \
|
| 429 |
+
eval.scenes=[courtyard] \
|
| 430 |
+
inference.debug=true
|
| 431 |
+
|
| 432 |
+
# Re-evaluate without re-running inference
|
| 433 |
+
python -m depth_anything_3.bench.evaluator eval.eval_only=true
|
| 434 |
+
|
| 435 |
+
# Just view results
|
| 436 |
+
python -m depth_anything_3.bench.evaluator eval.print_only=true
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
---
|
| 440 |
+
|
| 441 |
+
## 🔍 Troubleshooting
|
| 442 |
+
|
| 443 |
+
### Data Path Issues
|
| 444 |
+
|
| 445 |
+
Ensure dataset paths in `src/depth_anything_3/utils/constants.py` are correct:
|
| 446 |
+
|
| 447 |
+
```python
|
| 448 |
+
# Default paths (relative to project root)
|
| 449 |
+
ETH3D_EVAL_DATA_ROOT = "workspace/benchmark_dataset/eth3d"
|
| 450 |
+
SEVENSCENES_EVAL_DATA_ROOT = "workspace/benchmark_dataset/7scenes"
|
| 451 |
+
SCANNETPP_EVAL_DATA_ROOT = "workspace/benchmark_dataset/scannetpp"
|
| 452 |
+
HIROOM_EVAL_DATA_ROOT = "workspace/benchmark_dataset/hiroom/data"
|
| 453 |
+
DTU_EVAL_DATA_ROOT = "workspace/benchmark_dataset/dtu"
|
| 454 |
+
DTU64_EVAL_DATA_ROOT = "workspace/benchmark_dataset/dtu64"
|
| 455 |
+
```
|
| 456 |
+
|
| 457 |
+
---
|
| 458 |
+
|
| 459 |
+
## 📝 Citation
|
| 460 |
+
|
| 461 |
+
If you find this benchmark useful, please cite:
|
| 462 |
+
|
| 463 |
+
```
|
| 464 |
+
@article{depthanything3,
|
| 465 |
+
title={Depth Anything 3: Recovering the visual space from any views},
|
| 466 |
+
author={Haotong Lin and Sili Chen and Jun Hao Liew and Donny Y. Chen and Zhenyu Li and Guang Shi and Jiashi Feng and Bingyi Kang},
|
| 467 |
+
journal={arXiv preprint arXiv:2511.10647},
|
| 468 |
+
year={2025}
|
| 469 |
+
}
|
| 470 |
+
```
|
| 471 |
+
|
| 472 |
+
Please also cite the original dataset papers for each benchmark you use.
|
| 473 |
+
|
| 474 |
+
---
|
| 475 |
+
|
| 476 |
+
## 📄 License
|
| 477 |
+
|
| 478 |
+
The benchmark datasets are provided for research purposes only. Users must follow the original licenses of each dataset:
|
| 479 |
+
|
| 480 |
+
- **ETH3D:** [https://www.eth3d.net/](https://www.eth3d.net/)
|
| 481 |
+
- **7Scenes:** [Microsoft Research](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/)
|
| 482 |
+
- **ScanNet++:** [http://www.scan-net.org/](http://www.scan-net.org/)
|
| 483 |
+
- **DTU:** [https://roboimagedata.compute.dtu.dk/](https://roboimagedata.compute.dtu.dk/)
|
| 484 |
+
- **HiRoom:** [SVLightVerse](https://jerrypiglet.github.io/SVLightVerse/)
|
Depth-Anything-3/docs/CLI.md
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Depth Anything 3 Command Line Interface
|
| 2 |
+
|
| 3 |
+
## 📋 Table of Contents
|
| 4 |
+
|
| 5 |
+
- [📖 Overview](#overview)
|
| 6 |
+
- [⚡ Quick Start](#quick-start)
|
| 7 |
+
- [📚 Command Reference](#command-reference)
|
| 8 |
+
- [🤖 auto - Auto Mode](#auto---auto-mode)
|
| 9 |
+
- [🖼️ image - Single Image Processing](#image---single-image-processing)
|
| 10 |
+
- [🗂️ images - Image Directory Processing](#images---image-directory-processing)
|
| 11 |
+
- [🎬 video - Video Processing](#video---video-processing)
|
| 12 |
+
- [📐 colmap - COLMAP Dataset Processing](#colmap---colmap-dataset-processing)
|
| 13 |
+
- [🔧 backend - Backend Service](#backend---backend-service)
|
| 14 |
+
- [🎨 gradio - Gradio Application](#gradio---gradio-application)
|
| 15 |
+
- [🖼️ gallery - Gallery Server](#gallery---gallery-server)
|
| 16 |
+
- [⚙️ Parameter Details](#parameter-details)
|
| 17 |
+
- [💡 Usage Examples](#usage-examples)
|
| 18 |
+
|
| 19 |
+
## 📖 Overview
|
| 20 |
+
|
| 21 |
+
The Depth Anything 3 CLI provides a comprehensive command-line toolkit supporting image depth estimation, video processing, COLMAP dataset handling, and web applications.
|
| 22 |
+
|
| 23 |
+
The backend service enables cache model to GPU so that we do not need to reload model for each command.
|
| 24 |
+
|
| 25 |
+
## ⚡ Quick Start
|
| 26 |
+
|
| 27 |
+
The CLI can run fully offline or connect to the backend for cached weights and task scheduling:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# 🔧 Start backend service (optional, keeps model resident in GPU memory)
|
| 31 |
+
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE
|
| 32 |
+
|
| 33 |
+
# 🚀 Use auto mode to process input
|
| 34 |
+
da3 auto path/to/input --export-dir ./workspace/scene001
|
| 35 |
+
|
| 36 |
+
# ♻️ Reuse backend for next job
|
| 37 |
+
da3 auto path/to/video.mp4 \
|
| 38 |
+
--export-dir ./workspace/scene002 \
|
| 39 |
+
--use-backend \
|
| 40 |
+
--backend-url http://localhost:8008
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Each export directory contains `scene.glb`, `scene.jpg`, and optional extras such as `depth_vis/` or `gs_video/` depending on the requested format.
|
| 44 |
+
|
| 45 |
+
## 📚 Command Reference
|
| 46 |
+
|
| 47 |
+
### 🤖 auto - Auto Mode
|
| 48 |
+
|
| 49 |
+
Automatically detect input type and dispatch to the appropriate handler.
|
| 50 |
+
|
| 51 |
+
**Usage:**
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
da3 auto INPUT_PATH [OPTIONS]
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
**Input Type Detection:**
|
| 58 |
+
- 🖼️ Single image file (.jpg, .png, .jpeg, .webp, .bmp, .tiff, .tif)
|
| 59 |
+
- 📁 Image directory
|
| 60 |
+
- 🎬 Video file (.mp4, .avi, .mov, .mkv, .flv, .wmv, .webm, .m4v)
|
| 61 |
+
- 📐 COLMAP directory (containing `images/` and `sparse/` subdirectories)
|
| 62 |
+
|
| 63 |
+
**Parameters:**
|
| 64 |
+
|
| 65 |
+
| Parameter | Type | Default | Description |
|
| 66 |
+
|-----------|------|---------|-------------|
|
| 67 |
+
| `INPUT_PATH` | str | Required | Input path (image, directory, video, or COLMAP) |
|
| 68 |
+
| `--model-dir` | str | Default model | Model directory path |
|
| 69 |
+
| `--export-dir` | str | `debug` | Export directory |
|
| 70 |
+
| `--export-format` | str | `glb` | Export format (supports `mini_npz`, `glb`, `feat_vis`, etc., can be combined with hyphens) |
|
| 71 |
+
| `--device` | str | `cuda` | Device to use |
|
| 72 |
+
| `--use-backend` | bool | `False` | Use backend service for inference |
|
| 73 |
+
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
|
| 74 |
+
| `--process-res` | int | `504` | Processing resolution |
|
| 75 |
+
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
|
| 76 |
+
| `--export-feat` | str | `""` | Export features from specified layers, comma-separated (e.g., `"0,1,2"`) |
|
| 77 |
+
| `--auto-cleanup` | bool | `False` | Automatically clean export directory without confirmation |
|
| 78 |
+
| `--fps` | float | `1.0` | [Video] Frame sampling FPS |
|
| 79 |
+
| `--sparse-subdir` | str | `""` | [COLMAP] Sparse reconstruction subdirectory (e.g., `"0"` for `sparse/0/`) |
|
| 80 |
+
| `--align-to-input-ext-scale` | bool | `True` | [COLMAP] Align prediction to input extrinsics scale |
|
| 81 |
+
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
|
| 82 |
+
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy: `first`, `middle`, `saddle_balanced`, `saddle_sim_range`. See [docs](funcs/ref_view_strategy.md) |
|
| 83 |
+
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Lower percentile for adaptive confidence threshold |
|
| 84 |
+
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points in the point cloud |
|
| 85 |
+
| `--show-cameras` | bool | `True` | [GLB] Show camera wireframes in the exported scene |
|
| 86 |
+
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Frame rate for output video |
|
| 87 |
+
|
| 88 |
+
**Examples:**
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
# 🖼️ Auto-process an image
|
| 92 |
+
da3 auto path/to/image.jpg --export-dir ./output
|
| 93 |
+
|
| 94 |
+
# 🎬 Auto-process a video
|
| 95 |
+
da3 auto path/to/video.mp4 --fps 2.0 --export-dir ./output
|
| 96 |
+
|
| 97 |
+
# 🔧 Use backend service
|
| 98 |
+
da3 auto path/to/input \
|
| 99 |
+
--export-format mini_npz-glb \
|
| 100 |
+
--use-backend \
|
| 101 |
+
--backend-url http://localhost:8008 \
|
| 102 |
+
--export-dir ./output
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
### 🖼️ image - Single Image Processing
|
| 108 |
+
|
| 109 |
+
Process a single image for camera pose and depth estimation.
|
| 110 |
+
|
| 111 |
+
**Usage:**
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
da3 image IMAGE_PATH [OPTIONS]
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
**Parameters:**
|
| 118 |
+
|
| 119 |
+
| Parameter | Type | Default | Description |
|
| 120 |
+
|-----------|------|---------|-------------|
|
| 121 |
+
| `IMAGE_PATH` | str | Required | Input image file path |
|
| 122 |
+
| `--model-dir` | str | Default model | Model directory path |
|
| 123 |
+
| `--export-dir` | str | `debug` | Export directory |
|
| 124 |
+
| `--export-format` | str | `glb` | Export format |
|
| 125 |
+
| `--device` | str | `cuda` | Device to use |
|
| 126 |
+
| `--use-backend` | bool | `False` | Use backend service for inference |
|
| 127 |
+
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
|
| 128 |
+
| `--process-res` | int | `504` | Processing resolution |
|
| 129 |
+
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
|
| 130 |
+
| `--export-feat` | str | `""` | Export feature layer indices (comma-separated) |
|
| 131 |
+
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
|
| 132 |
+
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
|
| 133 |
+
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
|
| 134 |
+
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
|
| 135 |
+
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
|
| 136 |
+
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
|
| 137 |
+
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
|
| 138 |
+
|
| 139 |
+
**Examples:**
|
| 140 |
+
|
| 141 |
+
```bash
|
| 142 |
+
# ✨ Basic usage
|
| 143 |
+
da3 image path/to/image.png --export-dir ./output
|
| 144 |
+
|
| 145 |
+
# ⚡ With backend acceleration
|
| 146 |
+
da3 image path/to/image.png \
|
| 147 |
+
--use-backend \
|
| 148 |
+
--backend-url http://localhost:8008 \
|
| 149 |
+
--export-dir ./output
|
| 150 |
+
|
| 151 |
+
# 🔍 Export feature visualization
|
| 152 |
+
da3 image image.jpg \
|
| 153 |
+
--export-format feat_vis \
|
| 154 |
+
--export-feat "9,19,29,39" \
|
| 155 |
+
--export-dir ./results
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
---
|
| 159 |
+
|
| 160 |
+
### 🗂️ images - Image Directory Processing
|
| 161 |
+
|
| 162 |
+
Process a directory of images for batch depth estimation.
|
| 163 |
+
|
| 164 |
+
**Usage:**
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
da3 images IMAGES_DIR [OPTIONS]
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
**Parameters:**
|
| 171 |
+
|
| 172 |
+
| Parameter | Type | Default | Description |
|
| 173 |
+
|-----------|------|---------|-------------|
|
| 174 |
+
| `IMAGES_DIR` | str | Required | Directory path containing images |
|
| 175 |
+
| `--image-extensions` | str | `png,jpg,jpeg` | Image file extensions to process (comma-separated) |
|
| 176 |
+
| `--model-dir` | str | Default model | Model directory path |
|
| 177 |
+
| `--export-dir` | str | `debug` | Export directory |
|
| 178 |
+
| `--export-format` | str | `glb` | Export format |
|
| 179 |
+
| `--device` | str | `cuda` | Device to use |
|
| 180 |
+
| `--use-backend` | bool | `False` | Use backend service for inference |
|
| 181 |
+
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
|
| 182 |
+
| `--process-res` | int | `504` | Processing resolution |
|
| 183 |
+
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
|
| 184 |
+
| `--export-feat` | str | `""` | Export feature layer indices |
|
| 185 |
+
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
|
| 186 |
+
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
|
| 187 |
+
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
|
| 188 |
+
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
|
| 189 |
+
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
|
| 190 |
+
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
|
| 191 |
+
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
|
| 192 |
+
|
| 193 |
+
**Examples:**
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
# 📁 Process directory (defaults to png/jpg/jpeg)
|
| 197 |
+
da3 images ./image_folder --export-dir ./output
|
| 198 |
+
|
| 199 |
+
# 🎯 Custom extensions
|
| 200 |
+
da3 images ./dataset --image-extensions "png,jpg,webp" --export-dir ./output
|
| 201 |
+
|
| 202 |
+
# 🔧 Use backend service
|
| 203 |
+
da3 images ./dataset \
|
| 204 |
+
--use-backend \
|
| 205 |
+
--backend-url http://localhost:8008 \
|
| 206 |
+
--export-dir ./output
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
### 🎬 video - Video Processing
|
| 212 |
+
|
| 213 |
+
Process video by extracting frames for depth estimation.
|
| 214 |
+
|
| 215 |
+
**Usage:**
|
| 216 |
+
|
| 217 |
+
```bash
|
| 218 |
+
da3 video VIDEO_PATH [OPTIONS]
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
**Parameters:**
|
| 222 |
+
|
| 223 |
+
| Parameter | Type | Default | Description |
|
| 224 |
+
|-----------|------|---------|-------------|
|
| 225 |
+
| `VIDEO_PATH` | str | Required | Input video file path |
|
| 226 |
+
| `--fps` | float | `1.0` | Frame extraction sampling FPS |
|
| 227 |
+
| `--model-dir` | str | Default model | Model directory path |
|
| 228 |
+
| `--export-dir` | str | `debug` | Export directory |
|
| 229 |
+
| `--export-format` | str | `glb` | Export format |
|
| 230 |
+
| `--device` | str | `cuda` | Device to use |
|
| 231 |
+
| `--use-backend` | bool | `False` | Use backend service for inference |
|
| 232 |
+
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
|
| 233 |
+
| `--process-res` | int | `504` | Processing resolution |
|
| 234 |
+
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
|
| 235 |
+
| `--export-feat` | str | `""` | Export feature layer indices |
|
| 236 |
+
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
|
| 237 |
+
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
|
| 238 |
+
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
|
| 239 |
+
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
|
| 240 |
+
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
|
| 241 |
+
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
|
| 242 |
+
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
|
| 243 |
+
|
| 244 |
+
**Examples:**
|
| 245 |
+
|
| 246 |
+
```bash
|
| 247 |
+
# ��� Basic video processing
|
| 248 |
+
da3 video path/to/video.mp4 --export-dir ./output
|
| 249 |
+
|
| 250 |
+
# ⚙️ Control frame sampling and resolution
|
| 251 |
+
da3 video path/to/video.mp4 \
|
| 252 |
+
--fps 2.0 \
|
| 253 |
+
--process-res 1024 \
|
| 254 |
+
--export-dir ./output
|
| 255 |
+
|
| 256 |
+
# 🔧 Use backend service
|
| 257 |
+
da3 video path/to/video.mp4 \
|
| 258 |
+
--use-backend \
|
| 259 |
+
--backend-url http://localhost:8008 \
|
| 260 |
+
--export-dir ./output
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
### 📐 colmap - COLMAP Dataset Processing
|
| 266 |
+
|
| 267 |
+
Run pose-conditioned depth estimation on COLMAP data.
|
| 268 |
+
|
| 269 |
+
**Usage:**
|
| 270 |
+
|
| 271 |
+
```bash
|
| 272 |
+
da3 colmap COLMAP_DIR [OPTIONS]
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
**Parameters:**
|
| 276 |
+
|
| 277 |
+
| Parameter | Type | Default | Description |
|
| 278 |
+
|-----------|------|---------|-------------|
|
| 279 |
+
| `COLMAP_DIR` | str | Required | COLMAP directory containing `images/` and `sparse/` subdirectories |
|
| 280 |
+
| `--sparse-subdir` | str | `""` | Sparse reconstruction subdirectory (e.g., `"0"` for `sparse/0/`) |
|
| 281 |
+
| `--align-to-input-ext-scale` | bool | `True` | Align prediction to input extrinsics scale |
|
| 282 |
+
| `--model-dir` | str | Default model | Model directory path |
|
| 283 |
+
| `--export-dir` | str | `debug` | Export directory |
|
| 284 |
+
| `--export-format` | str | `glb` | Export format |
|
| 285 |
+
| `--device` | str | `cuda` | Device to use |
|
| 286 |
+
| `--use-backend` | bool | `False` | Use backend service for inference |
|
| 287 |
+
| `--backend-url` | str | `http://localhost:8008` | Backend service URL |
|
| 288 |
+
| `--process-res` | int | `504` | Processing resolution |
|
| 289 |
+
| `--process-res-method` | str | `upper_bound_resize` | Processing resolution method |
|
| 290 |
+
| `--export-feat` | str | `""` | Export feature layer indices |
|
| 291 |
+
| `--auto-cleanup` | bool | `False` | Automatically clean export directory |
|
| 292 |
+
| `--use-ray-pose` | bool | `False` | Use ray-based pose estimation instead of camera decoder |
|
| 293 |
+
| `--ref-view-strategy` | str | `saddle_balanced` | Reference view selection strategy. See [docs](funcs/ref_view_strategy.md) |
|
| 294 |
+
| `--conf-thresh-percentile` | float | `40.0` | [GLB] Confidence threshold percentile |
|
| 295 |
+
| `--num-max-points` | int | `1000000` | [GLB] Maximum number of points |
|
| 296 |
+
| `--show-cameras` | bool | `True` | [GLB] Show cameras |
|
| 297 |
+
| `--feat-vis-fps` | int | `15` | [FEAT_VIS] Video frame rate |
|
| 298 |
+
|
| 299 |
+
**Examples:**
|
| 300 |
+
|
| 301 |
+
```bash
|
| 302 |
+
# 📐 Process COLMAP dataset
|
| 303 |
+
da3 colmap ./colmap_dataset --export-dir ./output
|
| 304 |
+
|
| 305 |
+
# 🎯 Use specific sparse subdirectory and align scale
|
| 306 |
+
da3 colmap ./colmap_dataset \
|
| 307 |
+
--sparse-subdir 0 \
|
| 308 |
+
--align-to-input-ext-scale \
|
| 309 |
+
--export-dir ./output
|
| 310 |
+
|
| 311 |
+
# 🔧 Use backend service
|
| 312 |
+
da3 colmap ./colmap_dataset \
|
| 313 |
+
--use-backend \
|
| 314 |
+
--backend-url http://localhost:8008 \
|
| 315 |
+
--export-dir ./output
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
---
|
| 319 |
+
|
| 320 |
+
### 🔧 backend - Backend Service
|
| 321 |
+
|
| 322 |
+
Start model backend service with integrated gallery.
|
| 323 |
+
|
| 324 |
+
**Usage:**
|
| 325 |
+
|
| 326 |
+
```bash
|
| 327 |
+
da3 backend [OPTIONS]
|
| 328 |
+
```
|
| 329 |
+
|
| 330 |
+
**Parameters:**
|
| 331 |
+
|
| 332 |
+
| Parameter | Type | Default | Description |
|
| 333 |
+
|-----------|------|---------|-------------|
|
| 334 |
+
| `--model-dir` | str | Default model | Model directory path |
|
| 335 |
+
| `--device` | str | `cuda` | Device to use |
|
| 336 |
+
| `--host` | str | `127.0.0.1` | Host address to bind to |
|
| 337 |
+
| `--port` | int | `8008` | Port number to bind to |
|
| 338 |
+
| `--gallery-dir` | str | Default gallery dir | Gallery directory path (optional) |
|
| 339 |
+
|
| 340 |
+
**Features:**
|
| 341 |
+
- 🎯 Keeps model resident in GPU memory
|
| 342 |
+
- 🔌 Provides REST inference API
|
| 343 |
+
- 📊 Integrated dashboard and status monitoring
|
| 344 |
+
- 🖼️ Optional gallery browser (if `--gallery-dir` is provided)
|
| 345 |
+
|
| 346 |
+
**Available Endpoints:**
|
| 347 |
+
- 🏠 `/` - Home page
|
| 348 |
+
- 📊 `/dashboard` - Dashboard
|
| 349 |
+
- ✅ `/status` - API status
|
| 350 |
+
- 🖼️ `/gallery/` - Gallery browser (if enabled)
|
| 351 |
+
|
| 352 |
+
**Examples:**
|
| 353 |
+
|
| 354 |
+
```bash
|
| 355 |
+
# 🚀 Basic backend service
|
| 356 |
+
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE
|
| 357 |
+
|
| 358 |
+
# 🖼️ Backend with gallery
|
| 359 |
+
da3 backend \
|
| 360 |
+
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
|
| 361 |
+
--device cuda \
|
| 362 |
+
--host 0.0.0.0 \
|
| 363 |
+
--port 8008 \
|
| 364 |
+
--gallery-dir ./workspace
|
| 365 |
+
|
| 366 |
+
# 💻 Use CPU
|
| 367 |
+
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE --device cpu
|
| 368 |
+
```
|
| 369 |
+
|
| 370 |
+
---
|
| 371 |
+
|
| 372 |
+
### 🎨 gradio - Gradio Application
|
| 373 |
+
|
| 374 |
+
Launch Depth Anything 3 Gradio interactive web application.
|
| 375 |
+
|
| 376 |
+
**Usage:**
|
| 377 |
+
|
| 378 |
+
```bash
|
| 379 |
+
da3 gradio [OPTIONS]
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
**Parameters:**
|
| 383 |
+
|
| 384 |
+
| Parameter | Type | Default | Description |
|
| 385 |
+
|-----------|------|---------|-------------|
|
| 386 |
+
| `--model-dir` | str | Required | Model directory path |
|
| 387 |
+
| `--workspace-dir` | str | Required | Workspace directory path |
|
| 388 |
+
| `--gallery-dir` | str | Required | Gallery directory path |
|
| 389 |
+
| `--host` | str | `127.0.0.1` | Host address to bind to |
|
| 390 |
+
| `--port` | int | `7860` | Port number to bind to |
|
| 391 |
+
| `--share` | bool | `False` | Create a public link |
|
| 392 |
+
| `--debug` | bool | `False` | Enable debug mode |
|
| 393 |
+
| `--cache-examples` | bool | `False` | Pre-cache all example scenes at startup |
|
| 394 |
+
| `--cache-gs-tag` | str | `""` | Tag to match scene names for high-res+3DGS caching |
|
| 395 |
+
|
| 396 |
+
**Examples:**
|
| 397 |
+
|
| 398 |
+
```bash
|
| 399 |
+
# 🎨 Basic Gradio application
|
| 400 |
+
da3 gradio \
|
| 401 |
+
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
|
| 402 |
+
--workspace-dir ./workspace \
|
| 403 |
+
--gallery-dir ./gallery
|
| 404 |
+
|
| 405 |
+
# 🌐 Enable sharing and debug
|
| 406 |
+
da3 gradio \
|
| 407 |
+
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
|
| 408 |
+
--workspace-dir ./workspace \
|
| 409 |
+
--gallery-dir ./gallery \
|
| 410 |
+
--share \
|
| 411 |
+
--debug
|
| 412 |
+
|
| 413 |
+
# ⚡ Pre-cache examples
|
| 414 |
+
da3 gradio \
|
| 415 |
+
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
|
| 416 |
+
--workspace-dir ./workspace \
|
| 417 |
+
--gallery-dir ./gallery \
|
| 418 |
+
--cache-examples \
|
| 419 |
+
--cache-gs-tag "dl3dv"
|
| 420 |
+
```
|
| 421 |
+
|
| 422 |
+
---
|
| 423 |
+
|
| 424 |
+
### 🖼️ gallery - Gallery Server
|
| 425 |
+
|
| 426 |
+
Launch standalone Depth Anything 3 Gallery server.
|
| 427 |
+
|
| 428 |
+
**Usage:**
|
| 429 |
+
|
| 430 |
+
```bash
|
| 431 |
+
da3 gallery [OPTIONS]
|
| 432 |
+
```
|
| 433 |
+
|
| 434 |
+
**Parameters:**
|
| 435 |
+
|
| 436 |
+
| Parameter | Type | Default | Description |
|
| 437 |
+
|-----------|------|---------|-------------|
|
| 438 |
+
| `--gallery-dir` | str | Default gallery dir | Gallery root directory |
|
| 439 |
+
| `--host` | str | `127.0.0.1` | Host address to bind to |
|
| 440 |
+
| `--port` | int | `8007` | Port number to bind to |
|
| 441 |
+
| `--open-browser` | bool | `False` | Open browser after launch |
|
| 442 |
+
|
| 443 |
+
**Note:**
|
| 444 |
+
The gallery expects each scene folder to contain at least `scene.glb` and `scene.jpg`, with optional subfolders such as `depth_vis/` or `gs_video/`.
|
| 445 |
+
|
| 446 |
+
**Examples:**
|
| 447 |
+
|
| 448 |
+
```bash
|
| 449 |
+
# 🖼️ Basic gallery server
|
| 450 |
+
da3 gallery --gallery-dir ./workspace
|
| 451 |
+
|
| 452 |
+
# 🌐 Custom host and port
|
| 453 |
+
da3 gallery \
|
| 454 |
+
--gallery-dir ./workspace \
|
| 455 |
+
--host 0.0.0.0 \
|
| 456 |
+
--port 8007
|
| 457 |
+
|
| 458 |
+
# 🚀 Auto-open browser
|
| 459 |
+
da3 gallery --gallery-dir ./workspace --open-browser
|
| 460 |
+
```
|
| 461 |
+
|
| 462 |
+
---
|
| 463 |
+
|
| 464 |
+
## ⚙️ Parameter Details
|
| 465 |
+
|
| 466 |
+
### 🔧 Common Parameters
|
| 467 |
+
|
| 468 |
+
- **`--export-dir`**: Output directory, defaults to `debug`
|
| 469 |
+
- **`--export-format`**: Export format, supports combining multiple formats with hyphens:
|
| 470 |
+
- 📦 `mini_npz`: Compressed NumPy format
|
| 471 |
+
- 🎨 `glb`: glTF binary format (3D scene)
|
| 472 |
+
- 🔍 `feat_vis`: Feature visualization
|
| 473 |
+
- Example: `mini_npz-glb` exports both formats
|
| 474 |
+
|
| 475 |
+
- **`--process-res`** / **`--process-res-method`**: Control preprocessing resolution strategy
|
| 476 |
+
- `process-res`: Target resolution (default 504)
|
| 477 |
+
- `process-res-method`: Resize method (default `upper_bound_resize`)
|
| 478 |
+
|
| 479 |
+
- **`--auto-cleanup`**: Remove existing export directory without confirmation
|
| 480 |
+
|
| 481 |
+
- **`--use-backend`** / **`--backend-url`**: Reuse running backend service
|
| 482 |
+
- ⚡ Reduces model loading time
|
| 483 |
+
- 🌐 Supports distributed processing
|
| 484 |
+
|
| 485 |
+
- **`--export-feat`**: Layer indices for exporting intermediate features (comma-separated)
|
| 486 |
+
- Example: `"9,19,29,39"`
|
| 487 |
+
|
| 488 |
+
### 🎨 GLB Export Parameters
|
| 489 |
+
|
| 490 |
+
- **`--conf-thresh-percentile`**: Lower percentile for adaptive confidence threshold (default 40.0)
|
| 491 |
+
- Used to filter low-confidence points
|
| 492 |
+
|
| 493 |
+
- **`--num-max-points`**: Maximum number of points in point cloud (default 1,000,000)
|
| 494 |
+
- Controls output file size and performance
|
| 495 |
+
|
| 496 |
+
- **`--show-cameras`**: Show camera wireframes in exported scene (default True)
|
| 497 |
+
|
| 498 |
+
### 🔍 Feature Visualization Parameters
|
| 499 |
+
|
| 500 |
+
- **`--feat-vis-fps`**: Frame rate for feature visualization output video (default 15)
|
| 501 |
+
|
| 502 |
+
### 🎬 Video-Specific Parameters
|
| 503 |
+
|
| 504 |
+
- **`--fps`**: Video frame extraction sampling rate (default 1.0 FPS)
|
| 505 |
+
- Higher values extract more frames
|
| 506 |
+
|
| 507 |
+
### 📐 COLMAP-Specific Parameters
|
| 508 |
+
|
| 509 |
+
- **`--sparse-subdir`**: Sparse reconstruction subdirectory
|
| 510 |
+
- Empty string uses `sparse/` directory
|
| 511 |
+
- `"0"` uses `sparse/0/` directory
|
| 512 |
+
|
| 513 |
+
- **`--align-to-input-ext-scale`**: Align prediction to input extrinsics scale (default True)
|
| 514 |
+
- Ensures depth estimation is consistent with COLMAP scale
|
| 515 |
+
|
| 516 |
+
---
|
| 517 |
+
|
| 518 |
+
## 💡 Usage Examples
|
| 519 |
+
|
| 520 |
+
### 1️⃣ Basic Workflow
|
| 521 |
+
|
| 522 |
+
```bash
|
| 523 |
+
# 🔧 Start backend service
|
| 524 |
+
da3 backend --model-dir depth-anything/DA3NESTED-GIANT-LARGE --host 0.0.0.0 --port 8008
|
| 525 |
+
|
| 526 |
+
# 🖼️ Process single image
|
| 527 |
+
da3 image image.jpg --export-dir ./output1 --use-backend
|
| 528 |
+
|
| 529 |
+
# 🎬 Process video
|
| 530 |
+
da3 video video.mp4 --fps 2.0 --export-dir ./output2 --use-backend
|
| 531 |
+
|
| 532 |
+
# 📐 Process COLMAP dataset
|
| 533 |
+
da3 colmap ./colmap_data --export-dir ./output3 --use-backend
|
| 534 |
+
```
|
| 535 |
+
|
| 536 |
+
### 2️⃣ Using Auto Mode
|
| 537 |
+
|
| 538 |
+
```bash
|
| 539 |
+
# 🤖 Auto-detect and process
|
| 540 |
+
da3 auto ./unknown_input --export-dir ./output
|
| 541 |
+
|
| 542 |
+
# ⚡ With backend acceleration
|
| 543 |
+
da3 auto ./unknown_input \
|
| 544 |
+
--use-backend \
|
| 545 |
+
--backend-url http://localhost:8008 \
|
| 546 |
+
--export-dir ./output
|
| 547 |
+
```
|
| 548 |
+
|
| 549 |
+
### 3️⃣ Multi-Format Export
|
| 550 |
+
|
| 551 |
+
```bash
|
| 552 |
+
# 📦 Export both NPZ and GLB formats
|
| 553 |
+
da3 auto assets/examples/SOH \
|
| 554 |
+
--export-format mini_npz-glb \
|
| 555 |
+
--export-dir ./workspace/soh
|
| 556 |
+
|
| 557 |
+
# 🔍 Export feature visualization
|
| 558 |
+
da3 image image.jpg \
|
| 559 |
+
--export-format feat_vis \
|
| 560 |
+
--export-feat "9,19,29,39" \
|
| 561 |
+
--export-dir ./results
|
| 562 |
+
```
|
| 563 |
+
|
| 564 |
+
### 4️⃣ Advanced Configuration
|
| 565 |
+
|
| 566 |
+
```bash
|
| 567 |
+
# ⚙️ Custom resolution and point cloud density
|
| 568 |
+
da3 image image.jpg \
|
| 569 |
+
--process-res 1024 \
|
| 570 |
+
--num-max-points 2000000 \
|
| 571 |
+
--conf-thresh-percentile 30.0 \
|
| 572 |
+
--export-dir ./output
|
| 573 |
+
|
| 574 |
+
# 📐 COLMAP advanced options
|
| 575 |
+
da3 colmap ./colmap_data \
|
| 576 |
+
--sparse-subdir 0 \
|
| 577 |
+
--align-to-input-ext-scale \
|
| 578 |
+
--process-res 756 \
|
| 579 |
+
--export-dir ./output
|
| 580 |
+
```
|
| 581 |
+
|
| 582 |
+
### 5️⃣ Batch Processing Workflow
|
| 583 |
+
|
| 584 |
+
```bash
|
| 585 |
+
# 🔧 Start backend
|
| 586 |
+
da3 backend \
|
| 587 |
+
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
|
| 588 |
+
--device cuda \
|
| 589 |
+
--host 0.0.0.0 \
|
| 590 |
+
--port 8008 \
|
| 591 |
+
--gallery-dir ./workspace
|
| 592 |
+
|
| 593 |
+
# 🔄 Batch process multiple scenes
|
| 594 |
+
for scene in scene1 scene2 scene3; do
|
| 595 |
+
da3 auto ./data/$scene \
|
| 596 |
+
--export-dir ./workspace/$scene \
|
| 597 |
+
--use-backend \
|
| 598 |
+
--auto-cleanup
|
| 599 |
+
done
|
| 600 |
+
|
| 601 |
+
# 🖼️ Launch gallery to view results
|
| 602 |
+
da3 gallery --gallery-dir ./workspace --open-browser
|
| 603 |
+
```
|
| 604 |
+
|
| 605 |
+
### 6️⃣ Web Applications
|
| 606 |
+
|
| 607 |
+
```bash
|
| 608 |
+
# 🎨 Launch Gradio application
|
| 609 |
+
da3 gradio \
|
| 610 |
+
--model-dir depth-anything/DA3NESTED-GIANT-LARGE \
|
| 611 |
+
--workspace-dir workspace/gradio \
|
| 612 |
+
--gallery-dir ./gallery \
|
| 613 |
+
--host 0.0.0.0 \
|
| 614 |
+
--port 7860 \
|
| 615 |
+
--share
|
| 616 |
+
```
|
| 617 |
+
|
| 618 |
+
### 7️⃣ Transformer Feature Visualization
|
| 619 |
+
|
| 620 |
+
```bash
|
| 621 |
+
# 🔍 Export Transformer features
|
| 622 |
+
# 📦 Combined with numerical output
|
| 623 |
+
da3 auto video.mp4 \
|
| 624 |
+
--export-format glb-feat_vis \
|
| 625 |
+
--export-feat "11,21,31" \
|
| 626 |
+
--export-dir ./debug \
|
| 627 |
+
--use-backend
|
| 628 |
+
```
|
| 629 |
+
|
| 630 |
+
---
|
| 631 |
+
|
| 632 |
+
## 📝 Notes
|
| 633 |
+
|
| 634 |
+
1. **🔧 Backend Service**: Recommended for processing multiple tasks to improve efficiency
|
| 635 |
+
2. **💾 GPU Memory**: Be mindful of GPU memory usage when processing high-resolution inputs
|
| 636 |
+
3. **📁 Export Directory**: Use `--auto-cleanup` to avoid manual confirmation for deletion
|
| 637 |
+
4. **🔀 Format Combination**: Multiple export formats can be combined with hyphens (e.g., `mini_npz-glb-feat_vis`)
|
| 638 |
+
5. **📐 COLMAP Data**: Ensure COLMAP directory structure is correct (contains `images/` and `sparse/` subdirectories)
|
| 639 |
+
|
| 640 |
+
---
|
| 641 |
+
|
| 642 |
+
## ❓ Getting Help
|
| 643 |
+
|
| 644 |
+
View detailed help for any command:
|
| 645 |
+
|
| 646 |
+
```bash
|
| 647 |
+
# 📖 View main help
|
| 648 |
+
da3 --help
|
| 649 |
+
|
| 650 |
+
# 🔍 View specific command help
|
| 651 |
+
da3 auto --help
|
| 652 |
+
da3 image --help
|
| 653 |
+
da3 backend --help
|
| 654 |
+
```
|
Depth-Anything-3/docs/funcs/ref_view_strategy.md
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📐 Reference View Selection Strategy
|
| 2 |
+
|
| 3 |
+
## 📖 Overview
|
| 4 |
+
|
| 5 |
+
Reference view selection is a component in multi-view depth estimation. When processing multiple input views, the model needs to determine which view should serve as the primary reference frame for depth prediction, defining the world coordinate system.
|
| 6 |
+
|
| 7 |
+
Different reference view will leads to different reconstruction results. This is a known consideration in multi-view geometry and was analyzed in [PI3](https://arxiv.org/abs/2507.13347). The choice of reference view can affect the quality and consistency of depth predictions across the scene.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## 🚀 Our Simple Solution: Automatic Reference View Selection
|
| 11 |
+
|
| 12 |
+
DA3 provides a simple approach to address this through **automatic reference view selection** based on **class tokens**. Instead of relying on heuristics or manual selection, the model analyzes the class token features from all input views and intelligently selects the most suitable reference frame.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## 🎨 Available Strategies
|
| 17 |
+
|
| 18 |
+
### 1. ⚖️ `saddle_balanced` (Recommended, Default)
|
| 19 |
+
|
| 20 |
+
**Philosophy:**
|
| 21 |
+
Select a view that achieves balance across multiple feature metrics. This strategy looks for a "middle ground" view that is neither too similar nor too different from other views, making it a stable reference point.
|
| 22 |
+
|
| 23 |
+
**How it works:**
|
| 24 |
+
1. Extracts and normalizes class tokens from all views
|
| 25 |
+
2. Computes three complementary metrics for each view:
|
| 26 |
+
- **Similarity score**: Average cosine similarity with other views
|
| 27 |
+
- **Feature norm**: L2 norm of the original features
|
| 28 |
+
- **Feature variance**: Variance across feature dimensions
|
| 29 |
+
3. Normalizes each metric to [0, 1] range
|
| 30 |
+
4. Selects the view closest to 0.5 (median) across all three metrics
|
| 31 |
+
|
| 32 |
+
### 2. 🎢 `saddle_sim_range`
|
| 33 |
+
|
| 34 |
+
**Philosophy:**
|
| 35 |
+
Select a view with the largest similarity range to other views. This identifies "saddle point" views that are highly similar to some views but dissimilar to others, making them information-rich anchor points.
|
| 36 |
+
|
| 37 |
+
**How it works:**
|
| 38 |
+
1. Computes pairwise cosine similarity between all views
|
| 39 |
+
2. For each view, calculates the range (max - min) of similarities to other views
|
| 40 |
+
3. Selects the view with the maximum similarity range
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
### 3. 1️⃣ `first` (Not Recommended)
|
| 45 |
+
|
| 46 |
+
**Philosophy:**
|
| 47 |
+
Always use the first view in the input sequence as the reference.
|
| 48 |
+
|
| 49 |
+
**How it works:**
|
| 50 |
+
Simply returns index 0.
|
| 51 |
+
|
| 52 |
+
**When to use:**
|
| 53 |
+
- ⛔ **Not recommended** in general
|
| 54 |
+
- 🔧 Only use when you have manually pre-sorted your views and know the first view is optimal
|
| 55 |
+
- 🐛 Debugging or baseline comparisons
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
### 4. ⏸️ `middle`
|
| 60 |
+
|
| 61 |
+
**Philosophy:**
|
| 62 |
+
Select the view in the middle of the input sequence.
|
| 63 |
+
|
| 64 |
+
**How it works:**
|
| 65 |
+
Returns the view at index `S // 2` where S is the number of views.
|
| 66 |
+
|
| 67 |
+
**When to use:**
|
| 68 |
+
- ⏱️ **Only recommended when input images are temporally ordered**
|
| 69 |
+
- 🎬 Video sequences (e.g., **DA3-LONG** setting)
|
| 70 |
+
- 📹 Sequential captures where the middle frame likely has the most stable viewpoint
|
| 71 |
+
|
| 72 |
+
**Specific use case: DA3-LONG** 🎬
|
| 73 |
+
In video-based depth estimation scenarios (like DA3-LONG), where inputs are consecutive frames, `middle` is often the **optimal choice** because that it has maximum overlap with all other frames.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## 💻 Usage
|
| 77 |
+
|
| 78 |
+
### 🐍 Python API
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
from depth_anything_3 import DepthAnything3
|
| 82 |
+
|
| 83 |
+
model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE")
|
| 84 |
+
|
| 85 |
+
# Use default (saddle_balanced)
|
| 86 |
+
prediction = model.inference(
|
| 87 |
+
images,
|
| 88 |
+
ref_view_strategy="saddle_balanced"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# For video sequences, consider using middle
|
| 92 |
+
prediction = model.inference(
|
| 93 |
+
video_frames,
|
| 94 |
+
ref_view_strategy="middle" # Good for temporal sequences
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# For complex scenes with wide baselines
|
| 98 |
+
prediction = model.inference(
|
| 99 |
+
images,
|
| 100 |
+
ref_view_strategy="saddle_sim_range"
|
| 101 |
+
)
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### 🖥️ Command Line Interface
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
# Default (saddle_balanced)
|
| 108 |
+
da3 auto input/ --export-dir output/
|
| 109 |
+
|
| 110 |
+
# Explicitly specify strategy
|
| 111 |
+
da3 auto input/ --ref-view-strategy saddle_balanced
|
| 112 |
+
|
| 113 |
+
# For video processing
|
| 114 |
+
da3 video input.mp4 --ref-view-strategy middle
|
| 115 |
+
|
| 116 |
+
# For wide-baseline multi-view
|
| 117 |
+
da3 images captures/ --ref-view-strategy saddle_sim_range
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
### 🎯 When Selection Is Applied
|
| 123 |
+
|
| 124 |
+
Reference view selection is applied when:
|
| 125 |
+
- 3️⃣ Number of views S ≥ 3
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## 💡 Recommendations
|
| 130 |
+
|
| 131 |
+
### 📋 Quick Guide
|
| 132 |
+
|
| 133 |
+
| Scenario | Recommended Strategy | Rationale |
|
| 134 |
+
|----------|---------------------|-----------|
|
| 135 |
+
| **Default / Unknown** | `saddle_balanced` | Robust, balanced, works well across diverse scenarios |
|
| 136 |
+
| **Video frames** | `middle` | Temporal coherence, stable middle frame |
|
| 137 |
+
| **Wide-baseline multi-view** | `saddle_sim_range` | Maximizes information coverage |
|
| 138 |
+
| **Pre-sorted inputs** | `first` | Use only if you've manually optimized ordering |
|
| 139 |
+
| **Single image** | `first` | Automatically used (no reordering needed for S ≤ 2) |
|
| 140 |
+
|
| 141 |
+
### ✨ Best Practices
|
| 142 |
+
|
| 143 |
+
1. 🎯 **Start with defaults**: `saddle_balanced` works well in most cases
|
| 144 |
+
2. 🎬 **Consider your input type**: Use `middle` for videos, `saddle_balanced` for photos
|
| 145 |
+
3. 🔬 **Experiment if needed**: Try different strategies if results are suboptimal
|
| 146 |
+
4. 📊 **Monitor performance**: Check `glb` quality and consistency across views.
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## 🔧 Technical Details
|
| 151 |
+
|
| 152 |
+
### 🎚️ Selection Threshold
|
| 153 |
+
|
| 154 |
+
The reference view selection is only triggered when:
|
| 155 |
+
```python
|
| 156 |
+
num_views >= 3 # At least 3 views required
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
For 1-2 views, no reordering is performed (equivalent to using `first`).
|
| 160 |
+
|
| 161 |
+
### ⚙️ Implementation
|
| 162 |
+
|
| 163 |
+
The selection happens at layer `alt_start - 1` in the vision transformer, before the first global attention layer. This ensures the selected reference view influences the entire depth prediction pipeline.
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
|
| 167 |
+
## ❓ FAQ
|
| 168 |
+
|
| 169 |
+
**Q: 🤔 Why is this feature provided?**
|
| 170 |
+
A: The model can handle any view order, but this feature provides automatic optimization for reference view selection, which can help improve depth prediction quality in multi-view scenarios.
|
| 171 |
+
|
| 172 |
+
**Q: ⏱️ Does this add computational cost?**
|
| 173 |
+
A: The overhead is totally negligible.
|
| 174 |
+
|
| 175 |
+
**Q: 🎮 Can I manually specify which view to use as reference?**
|
| 176 |
+
A: Not directly through this parameter. You can pre-sort your input images to place your preferred reference view first and use `ref_view_strategy="first"`.
|
| 177 |
+
|
| 178 |
+
**Q: ⚙️ What happens if I don't specify this parameter?**
|
| 179 |
+
A: The default `saddle_balanced` strategy is used automatically.
|
| 180 |
+
|
| 181 |
+
**Q: 📊 Is this feature used in the DA3 paper benchmarks?**
|
| 182 |
+
A: No, the paper used `first` as the default strategy for all multi-view experiments. The current default has been updated to `saddle_balanced` for better robustness.
|
| 183 |
+
|
Depth-Anything-3/notebooks/da3.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Depth-Anything-3/src/depth_anything_3/api.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Depth Anything 3 API module.
|
| 16 |
+
|
| 17 |
+
This module provides the main API for Depth Anything 3, including model loading,
|
| 18 |
+
inference, and export capabilities. It supports both single and nested model architectures.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import time
|
| 24 |
+
from typing import Optional, Sequence
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
from depth_anything_3.cfg import create_object, load_config
|
| 32 |
+
from depth_anything_3.registry import MODEL_REGISTRY
|
| 33 |
+
from depth_anything_3.specs import Prediction
|
| 34 |
+
from depth_anything_3.utils.export import export
|
| 35 |
+
from depth_anything_3.utils.geometry import affine_inverse
|
| 36 |
+
from depth_anything_3.utils.io.input_processor import InputProcessor
|
| 37 |
+
from depth_anything_3.utils.io.output_processor import OutputProcessor
|
| 38 |
+
from depth_anything_3.utils.logger import logger
|
| 39 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 40 |
+
|
| 41 |
+
torch.backends.cudnn.benchmark = False
|
| 42 |
+
# logger.info("CUDNN Benchmark Disabled")
|
| 43 |
+
|
| 44 |
+
SAFETENSORS_NAME = "model.safetensors"
|
| 45 |
+
CONFIG_NAME = "config.json"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DepthAnything3(nn.Module, PyTorchModelHubMixin):
|
| 49 |
+
"""
|
| 50 |
+
Depth Anything 3 main API class.
|
| 51 |
+
|
| 52 |
+
This class provides a high-level interface for depth estimation using Depth Anything 3.
|
| 53 |
+
It supports both single and nested model architectures with metric scaling capabilities.
|
| 54 |
+
|
| 55 |
+
Features:
|
| 56 |
+
- Hugging Face Hub integration via PyTorchModelHubMixin
|
| 57 |
+
- Support for multiple model presets (vitb, vitg, nested variants)
|
| 58 |
+
- Automatic mixed precision inference
|
| 59 |
+
- Export capabilities for various formats (GLB, PLY, NPZ, etc.)
|
| 60 |
+
- Camera pose estimation and metric depth scaling
|
| 61 |
+
|
| 62 |
+
Usage:
|
| 63 |
+
# Load from Hugging Face Hub
|
| 64 |
+
model = DepthAnything3.from_pretrained("huggingface/model-name")
|
| 65 |
+
|
| 66 |
+
# Or create with specific preset
|
| 67 |
+
model = DepthAnything3(preset="vitg")
|
| 68 |
+
|
| 69 |
+
# Run inference
|
| 70 |
+
prediction = model.inference(images, export_dir="output", export_format="glb")
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
_commit_hash: str | None = None # Set by mixin when loading from Hub
|
| 74 |
+
|
| 75 |
+
def __init__(self, model_name: str = "da3-large", **kwargs):
|
| 76 |
+
"""
|
| 77 |
+
Initialize DepthAnything3 with specified preset.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_name: The name of the model preset to use.
|
| 81 |
+
Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
|
| 82 |
+
**kwargs: Additional keyword arguments (currently unused).
|
| 83 |
+
"""
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.model_name = model_name
|
| 86 |
+
|
| 87 |
+
# Build the underlying network
|
| 88 |
+
self.config = load_config(MODEL_REGISTRY[self.model_name])
|
| 89 |
+
self.model = create_object(self.config)
|
| 90 |
+
self.model.eval()
|
| 91 |
+
|
| 92 |
+
# Initialize processors
|
| 93 |
+
self.input_processor = InputProcessor()
|
| 94 |
+
self.output_processor = OutputProcessor()
|
| 95 |
+
|
| 96 |
+
# Device management (set by user)
|
| 97 |
+
self.device = None
|
| 98 |
+
|
| 99 |
+
@torch.inference_mode()
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
image: torch.Tensor,
|
| 103 |
+
extrinsics: torch.Tensor | None = None,
|
| 104 |
+
intrinsics: torch.Tensor | None = None,
|
| 105 |
+
export_feat_layers: list[int] | None = None,
|
| 106 |
+
infer_gs: bool = False,
|
| 107 |
+
use_ray_pose: bool = False,
|
| 108 |
+
ref_view_strategy: str = "saddle_balanced",
|
| 109 |
+
) -> dict[str, torch.Tensor]:
|
| 110 |
+
"""
|
| 111 |
+
Forward pass through the model.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
image: Input batch with shape ``(B, N, 3, H, W)`` on the model device.
|
| 115 |
+
extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``.
|
| 116 |
+
intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``.
|
| 117 |
+
export_feat_layers: Layer indices to return intermediate features for.
|
| 118 |
+
infer_gs: Enable Gaussian Splatting branch.
|
| 119 |
+
use_ray_pose: Use ray-based pose estimation instead of camera decoder.
|
| 120 |
+
ref_view_strategy: Strategy for selecting reference view from multiple views.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Dictionary containing model predictions
|
| 124 |
+
"""
|
| 125 |
+
# Determine optimal autocast dtype
|
| 126 |
+
autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
with torch.autocast(device_type=image.device.type, dtype=autocast_dtype):
|
| 129 |
+
return self.model(
|
| 130 |
+
image, extrinsics, intrinsics, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def inference(
|
| 134 |
+
self,
|
| 135 |
+
image: list[np.ndarray | Image.Image | str],
|
| 136 |
+
extrinsics: np.ndarray | None = None,
|
| 137 |
+
intrinsics: np.ndarray | None = None,
|
| 138 |
+
align_to_input_ext_scale: bool = True,
|
| 139 |
+
infer_gs: bool = False,
|
| 140 |
+
use_ray_pose: bool = False,
|
| 141 |
+
ref_view_strategy: str = "saddle_balanced",
|
| 142 |
+
render_exts: np.ndarray | None = None,
|
| 143 |
+
render_ixts: np.ndarray | None = None,
|
| 144 |
+
render_hw: tuple[int, int] | None = None,
|
| 145 |
+
process_res: int = 504,
|
| 146 |
+
process_res_method: str = "upper_bound_resize",
|
| 147 |
+
export_dir: str | None = None,
|
| 148 |
+
export_format: str = "mini_npz",
|
| 149 |
+
export_feat_layers: Sequence[int] | None = None,
|
| 150 |
+
# GLB export parameters
|
| 151 |
+
conf_thresh_percentile: float = 40.0,
|
| 152 |
+
num_max_points: int = 1_000_000,
|
| 153 |
+
show_cameras: bool = True,
|
| 154 |
+
# Feat_vis export parameters
|
| 155 |
+
feat_vis_fps: int = 15,
|
| 156 |
+
# Other export parameters, e.g., gs_ply, gs_video
|
| 157 |
+
export_kwargs: Optional[dict] = {},
|
| 158 |
+
) -> Prediction:
|
| 159 |
+
"""
|
| 160 |
+
Run inference on input images.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
image: List of input images (numpy arrays, PIL Images, or file paths)
|
| 164 |
+
extrinsics: Camera extrinsics (N, 4, 4)
|
| 165 |
+
intrinsics: Camera intrinsics (N, 3, 3)
|
| 166 |
+
align_to_input_ext_scale: whether to align the input pose scale to the prediction
|
| 167 |
+
infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports)
|
| 168 |
+
use_ray_pose: Use ray-based pose estimation instead of camera decoder (default: False)
|
| 169 |
+
ref_view_strategy: Strategy for selecting reference view from multiple views.
|
| 170 |
+
Options: "first", "middle", "saddle_balanced", "saddle_sim_range".
|
| 171 |
+
Default: "saddle_balanced". For single view input (S ≤ 2), no reordering is performed.
|
| 172 |
+
render_exts: Optional render extrinsics for Gaussian video export
|
| 173 |
+
render_ixts: Optional render intrinsics for Gaussian video export
|
| 174 |
+
render_hw: Optional render resolution for Gaussian video export
|
| 175 |
+
process_res: Processing resolution
|
| 176 |
+
process_res_method: Resize method for processing
|
| 177 |
+
export_dir: Directory to export results
|
| 178 |
+
export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video)
|
| 179 |
+
export_feat_layers: Layer indices to export intermediate features from
|
| 180 |
+
conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501
|
| 181 |
+
num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000)
|
| 182 |
+
show_cameras: [GLB] Show camera wireframes in the exported scene (default: True)
|
| 183 |
+
feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15)
|
| 184 |
+
export_kwargs: additional arguments to export functions.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Prediction object containing depth maps and camera parameters
|
| 188 |
+
"""
|
| 189 |
+
if "gs" in export_format:
|
| 190 |
+
assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
|
| 191 |
+
|
| 192 |
+
if "colmap" in export_format:
|
| 193 |
+
assert isinstance(image[0], str), "`image` must be image paths for COLMAP export."
|
| 194 |
+
|
| 195 |
+
# Preprocess images
|
| 196 |
+
imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
|
| 197 |
+
image, extrinsics, intrinsics, process_res, process_res_method
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Prepare tensors for model
|
| 201 |
+
imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
|
| 202 |
+
|
| 203 |
+
# Normalize extrinsics
|
| 204 |
+
ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
|
| 205 |
+
|
| 206 |
+
# Run model forward pass
|
| 207 |
+
export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
|
| 208 |
+
|
| 209 |
+
raw_output = self._run_model_forward(
|
| 210 |
+
imgs, ex_t_norm, in_t, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Convert raw output to prediction
|
| 214 |
+
prediction = self._convert_to_prediction(raw_output)
|
| 215 |
+
|
| 216 |
+
# Align prediction to extrinsincs
|
| 217 |
+
prediction = self._align_to_input_extrinsics_intrinsics(
|
| 218 |
+
extrinsics, intrinsics, prediction, align_to_input_ext_scale
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Add processed images for visualization
|
| 222 |
+
prediction = self._add_processed_images(prediction, imgs_cpu)
|
| 223 |
+
|
| 224 |
+
# Export if requested
|
| 225 |
+
if export_dir is not None:
|
| 226 |
+
|
| 227 |
+
if "gs" in export_format:
|
| 228 |
+
if infer_gs and "gs_video" not in export_format:
|
| 229 |
+
export_format = f"{export_format}-gs_video"
|
| 230 |
+
if "gs_video" in export_format:
|
| 231 |
+
if "gs_video" not in export_kwargs:
|
| 232 |
+
export_kwargs["gs_video"] = {}
|
| 233 |
+
export_kwargs["gs_video"].update(
|
| 234 |
+
{
|
| 235 |
+
"extrinsics": render_exts,
|
| 236 |
+
"intrinsics": render_ixts,
|
| 237 |
+
"out_image_hw": render_hw,
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
+
# Add GLB export parameters
|
| 241 |
+
if "glb" in export_format:
|
| 242 |
+
if "glb" not in export_kwargs:
|
| 243 |
+
export_kwargs["glb"] = {}
|
| 244 |
+
export_kwargs["glb"].update(
|
| 245 |
+
{
|
| 246 |
+
"conf_thresh_percentile": conf_thresh_percentile,
|
| 247 |
+
"num_max_points": num_max_points,
|
| 248 |
+
"show_cameras": show_cameras,
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
# Add Feat_vis export parameters
|
| 252 |
+
if "feat_vis" in export_format:
|
| 253 |
+
if "feat_vis" not in export_kwargs:
|
| 254 |
+
export_kwargs["feat_vis"] = {}
|
| 255 |
+
export_kwargs["feat_vis"].update(
|
| 256 |
+
{
|
| 257 |
+
"fps": feat_vis_fps,
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
# Add COLMAP export parameters
|
| 261 |
+
if "colmap" in export_format:
|
| 262 |
+
if "colmap" not in export_kwargs:
|
| 263 |
+
export_kwargs["colmap"] = {}
|
| 264 |
+
export_kwargs["colmap"].update(
|
| 265 |
+
{
|
| 266 |
+
"image_paths": image,
|
| 267 |
+
"conf_thresh_percentile": conf_thresh_percentile,
|
| 268 |
+
"process_res_method": process_res_method,
|
| 269 |
+
}
|
| 270 |
+
)
|
| 271 |
+
self._export_results(prediction, export_format, export_dir, **export_kwargs)
|
| 272 |
+
|
| 273 |
+
return prediction
|
| 274 |
+
|
| 275 |
+
def _preprocess_inputs(
|
| 276 |
+
self,
|
| 277 |
+
image: list[np.ndarray | Image.Image | str],
|
| 278 |
+
extrinsics: np.ndarray | None = None,
|
| 279 |
+
intrinsics: np.ndarray | None = None,
|
| 280 |
+
process_res: int = 504,
|
| 281 |
+
process_res_method: str = "upper_bound_resize",
|
| 282 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
| 283 |
+
"""Preprocess input images using input processor."""
|
| 284 |
+
start_time = time.time()
|
| 285 |
+
imgs_cpu, extrinsics, intrinsics = self.input_processor(
|
| 286 |
+
image,
|
| 287 |
+
extrinsics.copy() if extrinsics is not None else None,
|
| 288 |
+
intrinsics.copy() if intrinsics is not None else None,
|
| 289 |
+
process_res,
|
| 290 |
+
process_res_method,
|
| 291 |
+
)
|
| 292 |
+
end_time = time.time()
|
| 293 |
+
logger.info(
|
| 294 |
+
"Processed Images Done taking",
|
| 295 |
+
end_time - start_time,
|
| 296 |
+
"seconds. Shape: ",
|
| 297 |
+
imgs_cpu.shape,
|
| 298 |
+
)
|
| 299 |
+
return imgs_cpu, extrinsics, intrinsics
|
| 300 |
+
|
| 301 |
+
def _prepare_model_inputs(
|
| 302 |
+
self,
|
| 303 |
+
imgs_cpu: torch.Tensor,
|
| 304 |
+
extrinsics: torch.Tensor | None,
|
| 305 |
+
intrinsics: torch.Tensor | None,
|
| 306 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
| 307 |
+
"""Prepare tensors for model input."""
|
| 308 |
+
device = self._get_model_device()
|
| 309 |
+
|
| 310 |
+
# Move images to model device
|
| 311 |
+
imgs = imgs_cpu.to(device, non_blocking=True)[None].float()
|
| 312 |
+
|
| 313 |
+
# Convert camera parameters to tensors
|
| 314 |
+
ex_t = (
|
| 315 |
+
extrinsics.to(device, non_blocking=True)[None].float()
|
| 316 |
+
if extrinsics is not None
|
| 317 |
+
else None
|
| 318 |
+
)
|
| 319 |
+
in_t = (
|
| 320 |
+
intrinsics.to(device, non_blocking=True)[None].float()
|
| 321 |
+
if intrinsics is not None
|
| 322 |
+
else None
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return imgs, ex_t, in_t
|
| 326 |
+
|
| 327 |
+
def _normalize_extrinsics(self, ex_t: torch.Tensor | None) -> torch.Tensor | None:
|
| 328 |
+
"""Normalize extrinsics"""
|
| 329 |
+
if ex_t is None:
|
| 330 |
+
return None
|
| 331 |
+
transform = affine_inverse(ex_t[:, :1])
|
| 332 |
+
ex_t_norm = ex_t @ transform
|
| 333 |
+
c2ws = affine_inverse(ex_t_norm)
|
| 334 |
+
translations = c2ws[..., :3, 3]
|
| 335 |
+
dists = translations.norm(dim=-1)
|
| 336 |
+
median_dist = torch.median(dists)
|
| 337 |
+
median_dist = torch.clamp(median_dist, min=1e-1)
|
| 338 |
+
ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
|
| 339 |
+
return ex_t_norm
|
| 340 |
+
|
| 341 |
+
def _align_to_input_extrinsics_intrinsics(
|
| 342 |
+
self,
|
| 343 |
+
extrinsics: torch.Tensor | None,
|
| 344 |
+
intrinsics: torch.Tensor | None,
|
| 345 |
+
prediction: Prediction,
|
| 346 |
+
align_to_input_ext_scale: bool = True,
|
| 347 |
+
ransac_view_thresh: int = 10,
|
| 348 |
+
) -> Prediction:
|
| 349 |
+
"""Align depth map to input extrinsics"""
|
| 350 |
+
if extrinsics is None:
|
| 351 |
+
return prediction
|
| 352 |
+
prediction.intrinsics = intrinsics.numpy()
|
| 353 |
+
_, _, scale, aligned_extrinsics = align_poses_umeyama(
|
| 354 |
+
prediction.extrinsics,
|
| 355 |
+
extrinsics.numpy(),
|
| 356 |
+
ransac=len(extrinsics) >= ransac_view_thresh,
|
| 357 |
+
return_aligned=True,
|
| 358 |
+
random_state=42,
|
| 359 |
+
)
|
| 360 |
+
if align_to_input_ext_scale:
|
| 361 |
+
prediction.extrinsics = extrinsics[..., :3, :].numpy()
|
| 362 |
+
prediction.depth /= scale
|
| 363 |
+
else:
|
| 364 |
+
prediction.extrinsics = aligned_extrinsics
|
| 365 |
+
return prediction
|
| 366 |
+
|
| 367 |
+
def _run_model_forward(
|
| 368 |
+
self,
|
| 369 |
+
imgs: torch.Tensor,
|
| 370 |
+
ex_t: torch.Tensor | None,
|
| 371 |
+
in_t: torch.Tensor | None,
|
| 372 |
+
export_feat_layers: Sequence[int] | None = None,
|
| 373 |
+
infer_gs: bool = False,
|
| 374 |
+
use_ray_pose: bool = False,
|
| 375 |
+
ref_view_strategy: str = "saddle_balanced",
|
| 376 |
+
) -> dict[str, torch.Tensor]:
|
| 377 |
+
"""Run model forward pass."""
|
| 378 |
+
device = imgs.device
|
| 379 |
+
need_sync = device.type == "cuda"
|
| 380 |
+
if need_sync:
|
| 381 |
+
torch.cuda.synchronize(device)
|
| 382 |
+
start_time = time.time()
|
| 383 |
+
feat_layers = list(export_feat_layers) if export_feat_layers is not None else None
|
| 384 |
+
output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs, use_ray_pose, ref_view_strategy)
|
| 385 |
+
if need_sync:
|
| 386 |
+
torch.cuda.synchronize(device)
|
| 387 |
+
end_time = time.time()
|
| 388 |
+
logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds")
|
| 389 |
+
return output
|
| 390 |
+
|
| 391 |
+
def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction:
|
| 392 |
+
"""Convert raw model output to Prediction object."""
|
| 393 |
+
start_time = time.time()
|
| 394 |
+
output = self.output_processor(raw_output)
|
| 395 |
+
end_time = time.time()
|
| 396 |
+
logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds")
|
| 397 |
+
return output
|
| 398 |
+
|
| 399 |
+
def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction:
|
| 400 |
+
"""Add processed images to prediction for visualization."""
|
| 401 |
+
# Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize
|
| 402 |
+
processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3)
|
| 403 |
+
|
| 404 |
+
# Denormalize from ImageNet normalization
|
| 405 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 406 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 407 |
+
processed_imgs = processed_imgs * std + mean
|
| 408 |
+
processed_imgs = np.clip(processed_imgs, 0, 1)
|
| 409 |
+
processed_imgs = (processed_imgs * 255).astype(np.uint8)
|
| 410 |
+
|
| 411 |
+
prediction.processed_images = processed_imgs
|
| 412 |
+
return prediction
|
| 413 |
+
|
| 414 |
+
def _export_results(
|
| 415 |
+
self, prediction: Prediction, export_format: str, export_dir: str, **kwargs
|
| 416 |
+
) -> None:
|
| 417 |
+
"""Export results to specified format and directory."""
|
| 418 |
+
start_time = time.time()
|
| 419 |
+
export(prediction, export_format, export_dir, **kwargs)
|
| 420 |
+
end_time = time.time()
|
| 421 |
+
logger.info(f"Export Results Done. Time: {end_time - start_time} seconds")
|
| 422 |
+
|
| 423 |
+
def _get_model_device(self) -> torch.device:
|
| 424 |
+
"""
|
| 425 |
+
Get the device where the model is located.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Device where the model parameters are located
|
| 429 |
+
|
| 430 |
+
Raises:
|
| 431 |
+
ValueError: If no tensors are found in the model
|
| 432 |
+
"""
|
| 433 |
+
if self.device is not None:
|
| 434 |
+
return self.device
|
| 435 |
+
|
| 436 |
+
# Find device from parameters
|
| 437 |
+
for param in self.parameters():
|
| 438 |
+
self.device = param.device
|
| 439 |
+
return param.device
|
| 440 |
+
|
| 441 |
+
# Find device from buffers
|
| 442 |
+
for buffer in self.buffers():
|
| 443 |
+
self.device = buffer.device
|
| 444 |
+
return buffer.device
|
| 445 |
+
|
| 446 |
+
raise ValueError("No tensor found in model")
|
Depth-Anything-3/src/depth_anything_3/app/css_and_html.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: E501
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
CSS and HTML content for the Depth Anything 3 Gradio application.
|
| 19 |
+
This module contains all the CSS styles and HTML content blocks
|
| 20 |
+
used in the Gradio interface.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# CSS Styles for the Gradio interface
|
| 24 |
+
GRADIO_CSS = """
|
| 25 |
+
/* Add Font Awesome CDN with all styles including brands and colors */
|
| 26 |
+
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css');
|
| 27 |
+
|
| 28 |
+
/* Add custom styles for colored icons */
|
| 29 |
+
.fa-color-blue {
|
| 30 |
+
color: #3b82f6;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.fa-color-purple {
|
| 34 |
+
color: #8b5cf6;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.fa-color-cyan {
|
| 38 |
+
color: #06b6d4;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.fa-color-green {
|
| 42 |
+
color: #10b981;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.fa-color-yellow {
|
| 46 |
+
color: #f59e0b;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.fa-color-red {
|
| 50 |
+
color: #ef4444;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.link-btn {
|
| 54 |
+
display: inline-flex;
|
| 55 |
+
align-items: center;
|
| 56 |
+
gap: 8px;
|
| 57 |
+
text-decoration: none;
|
| 58 |
+
padding: 12px 24px;
|
| 59 |
+
border-radius: 50px;
|
| 60 |
+
font-weight: 500;
|
| 61 |
+
transition: all 0.3s ease;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/* Dark mode tech theme */
|
| 65 |
+
@media (prefers-color-scheme: dark) {
|
| 66 |
+
html, body {
|
| 67 |
+
background: #1e293b;
|
| 68 |
+
color: #ffffff;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.gradio-container {
|
| 72 |
+
background: #1e293b;
|
| 73 |
+
color: #ffffff;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.link-btn {
|
| 77 |
+
background: rgba(255, 255, 255, 0.2);
|
| 78 |
+
color: white;
|
| 79 |
+
backdrop-filter: blur(10px);
|
| 80 |
+
border: 1px solid rgba(255, 255, 255, 0.3);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.link-btn:hover {
|
| 84 |
+
background: rgba(255, 255, 255, 0.3);
|
| 85 |
+
transform: translateY(-2px);
|
| 86 |
+
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.tech-bg {
|
| 90 |
+
background: linear-gradient(135deg, #0f172a, #1e293b); /* Darker colors */
|
| 91 |
+
position: relative;
|
| 92 |
+
overflow: hidden;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.tech-bg::before {
|
| 96 |
+
content: '';
|
| 97 |
+
position: absolute;
|
| 98 |
+
top: 0;
|
| 99 |
+
left: 0;
|
| 100 |
+
right: 0;
|
| 101 |
+
bottom: 0;
|
| 102 |
+
background:
|
| 103 |
+
radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
|
| 104 |
+
radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
|
| 105 |
+
radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.1) 0%, transparent 50%); /* Reduced opacity */
|
| 106 |
+
animation: techPulse 8s ease-in-out infinite;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.gradio-container .panel,
|
| 110 |
+
.gradio-container .block,
|
| 111 |
+
.gradio-container .form {
|
| 112 |
+
background: rgba(0, 0, 0, 0.3);
|
| 113 |
+
border: 1px solid rgba(59, 130, 246, 0.2);
|
| 114 |
+
border-radius: 10px;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
.gradio-container * {
|
| 118 |
+
color: #ffffff;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.gradio-container label {
|
| 122 |
+
color: #e0e0e0;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.gradio-container .markdown {
|
| 126 |
+
color: #e0e0e0;
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/* Light mode tech theme */
|
| 131 |
+
@media (prefers-color-scheme: light) {
|
| 132 |
+
html, body {
|
| 133 |
+
background: #ffffff;
|
| 134 |
+
color: #1e293b;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.gradio-container {
|
| 138 |
+
background: #ffffff;
|
| 139 |
+
color: #1e293b;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.tech-bg {
|
| 143 |
+
background: linear-gradient(135deg, #ffffff, #f1f5f9);
|
| 144 |
+
position: relative;
|
| 145 |
+
overflow: hidden;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.link-btn {
|
| 149 |
+
background: rgba(59, 130, 246, 0.15);
|
| 150 |
+
color: var(--body-text-color);
|
| 151 |
+
border: 1px solid rgba(59, 130, 246, 0.3);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.link-btn:hover {
|
| 155 |
+
background: rgba(59, 130, 246, 0.25);
|
| 156 |
+
transform: translateY(-2px);
|
| 157 |
+
box-shadow: 0 8px 25px rgba(59, 130, 246, 0.2);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
.tech-bg::before {
|
| 161 |
+
content: '';
|
| 162 |
+
position: absolute;
|
| 163 |
+
top: 0;
|
| 164 |
+
left: 0;
|
| 165 |
+
right: 0;
|
| 166 |
+
bottom: 0;
|
| 167 |
+
background:
|
| 168 |
+
radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.1) 0%, transparent 50%),
|
| 169 |
+
radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.1) 0%, transparent 50%),
|
| 170 |
+
radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.08) 0%, transparent 50%);
|
| 171 |
+
animation: techPulse 8s ease-in-out infinite;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
.gradio-container .panel,
|
| 175 |
+
.gradio-container .block,
|
| 176 |
+
.gradio-container .form {
|
| 177 |
+
background: rgba(255, 255, 255, 0.8);
|
| 178 |
+
border: 1px solid rgba(59, 130, 246, 0.3);
|
| 179 |
+
border-radius: 10px;
|
| 180 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
.gradio-container * {
|
| 184 |
+
color: #1e293b;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.gradio-container label {
|
| 188 |
+
color: #334155;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.gradio-container .markdown {
|
| 192 |
+
color: #334155;
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@keyframes techPulse {
|
| 200 |
+
0%, 100% { opacity: 0.5; }
|
| 201 |
+
50% { opacity: 0.8; }
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
/* Custom log with tech gradient */
|
| 205 |
+
.custom-log * {
|
| 206 |
+
font-style: italic;
|
| 207 |
+
font-size: 22px !important;
|
| 208 |
+
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
|
| 209 |
+
background-size: 400% 400%;
|
| 210 |
+
-webkit-background-clip: text;
|
| 211 |
+
background-clip: text;
|
| 212 |
+
font-weight: bold !important;
|
| 213 |
+
color: transparent !important;
|
| 214 |
+
text-align: center !important;
|
| 215 |
+
animation: techGradient 3s ease infinite;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
@keyframes techGradient {
|
| 219 |
+
0% { background-position: 0% 50%; }
|
| 220 |
+
50% { background-position: 100% 50%; }
|
| 221 |
+
100% { background-position: 0% 50%; }
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
@keyframes metricPulse {
|
| 225 |
+
0%, 100% { background-position: 0% 50%; }
|
| 226 |
+
50% { background-position: 100% 50%; }
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
@keyframes pointcloudPulse {
|
| 230 |
+
0%, 100% { background-position: 0% 50%; }
|
| 231 |
+
50% { background-position: 100% 50%; }
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
@keyframes camerasPulse {
|
| 235 |
+
0%, 100% { background-position: 0% 50%; }
|
| 236 |
+
50% { background-position: 100% 50%; }
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@keyframes gaussiansPulse {
|
| 240 |
+
0%, 100% { background-position: 0% 50%; }
|
| 241 |
+
50% { background-position: 100% 50%; }
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/* Special colors for key terms - Global styles */
|
| 245 |
+
.metric-text {
|
| 246 |
+
background: linear-gradient(45deg, #ff6b6b, #ff8e53, #ff6b6b);
|
| 247 |
+
background-size: 200% 200%;
|
| 248 |
+
-webkit-background-clip: text;
|
| 249 |
+
background-clip: text;
|
| 250 |
+
color: transparent !important;
|
| 251 |
+
animation: metricPulse 2s ease-in-out infinite;
|
| 252 |
+
font-weight: 700;
|
| 253 |
+
text-shadow: 0 0 10px rgba(255, 107, 107, 0.5);
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
.pointcloud-text {
|
| 257 |
+
background: linear-gradient(45deg, #4ecdc4, #44a08d, #4ecdc4);
|
| 258 |
+
background-size: 200% 200%;
|
| 259 |
+
-webkit-background-clip: text;
|
| 260 |
+
background-clip: text;
|
| 261 |
+
color: transparent !important;
|
| 262 |
+
animation: pointcloudPulse 2.5s ease-in-out infinite;
|
| 263 |
+
font-weight: 700;
|
| 264 |
+
text-shadow: 0 0 10px rgba(78, 205, 196, 0.5);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
.cameras-text {
|
| 268 |
+
background: linear-gradient(45deg, #667eea, #764ba2, #667eea);
|
| 269 |
+
background-size: 200% 200%;
|
| 270 |
+
-webkit-background-clip: text;
|
| 271 |
+
background-clip: text;
|
| 272 |
+
color: transparent !important;
|
| 273 |
+
animation: camerasPulse 3s ease-in-out infinite;
|
| 274 |
+
font-weight: 700;
|
| 275 |
+
text-shadow: 0 0 10px rgba(102, 126, 234, 0.5);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
.gaussians-text {
|
| 279 |
+
background: linear-gradient(45deg, #f093fb, #f5576c, #f093fb);
|
| 280 |
+
background-size: 200% 200%;
|
| 281 |
+
-webkit-background-clip: text;
|
| 282 |
+
background-clip: text;
|
| 283 |
+
color: transparent !important;
|
| 284 |
+
animation: gaussiansPulse 2.2s ease-in-out infinite;
|
| 285 |
+
font-weight: 700;
|
| 286 |
+
text-shadow: 0 0 10px rgba(240, 147, 251, 0.5);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
.example-log * {
|
| 290 |
+
font-style: italic;
|
| 291 |
+
font-size: 16px !important;
|
| 292 |
+
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
|
| 293 |
+
-webkit-background-clip: text;
|
| 294 |
+
background-clip: text;
|
| 295 |
+
color: transparent !important;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
#my_radio .wrap {
|
| 299 |
+
display: flex;
|
| 300 |
+
flex-wrap: nowrap;
|
| 301 |
+
justify-content: center;
|
| 302 |
+
align-items: center;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
#my_radio .wrap label {
|
| 306 |
+
display: flex;
|
| 307 |
+
width: 50%;
|
| 308 |
+
justify-content: center;
|
| 309 |
+
align-items: center;
|
| 310 |
+
margin: 0;
|
| 311 |
+
padding: 10px 0;
|
| 312 |
+
box-sizing: border-box;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
/* Align navigation buttons with dropdown bottom */
|
| 316 |
+
.navigation-row {
|
| 317 |
+
display: flex !important;
|
| 318 |
+
align-items: flex-end !important;
|
| 319 |
+
gap: 8px !important;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
.navigation-row > div:nth-child(1),
|
| 323 |
+
.navigation-row > div:nth-child(3) {
|
| 324 |
+
align-self: flex-end !important;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
.navigation-row > div:nth-child(2) {
|
| 328 |
+
flex: 1 !important;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/* Make thumbnails clickable with pointer cursor */
|
| 332 |
+
.clickable-thumbnail img {
|
| 333 |
+
cursor: pointer !important;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
.clickable-thumbnail:hover img {
|
| 337 |
+
cursor: pointer !important;
|
| 338 |
+
opacity: 0.8;
|
| 339 |
+
transition: opacity 0.3s ease;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
/* Make thumbnail containers narrower horizontally */
|
| 343 |
+
.clickable-thumbnail {
|
| 344 |
+
padding: 5px 2px !important;
|
| 345 |
+
margin: 0 2px !important;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
.clickable-thumbnail .image-container {
|
| 349 |
+
margin: 0 !important;
|
| 350 |
+
padding: 0 !important;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
.scene-info {
|
| 354 |
+
text-align: center !important;
|
| 355 |
+
padding: 5px 2px !important;
|
| 356 |
+
margin: 0 !important;
|
| 357 |
+
}
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def get_header_html(logo_base64=None):
|
| 362 |
+
"""
|
| 363 |
+
Generate the main header HTML with logo and title.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
logo_base64 (str, optional): Base64 encoded logo image
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
str: HTML string for the header
|
| 370 |
+
"""
|
| 371 |
+
return """
|
| 372 |
+
<div class="tech-bg" style="text-align: center; margin-bottom: 5px; padding: 40px 20px; border-radius: 15px; position: relative; overflow: hidden;">
|
| 373 |
+
<div style="position: relative; z-index: 2;">
|
| 374 |
+
<h1 style="margin: 0; font-size: 3.5em; font-weight: 700;
|
| 375 |
+
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
|
| 376 |
+
background-size: 400% 400%;
|
| 377 |
+
-webkit-background-clip: text;
|
| 378 |
+
background-clip: text;
|
| 379 |
+
color: transparent;
|
| 380 |
+
animation: techGradient 3s ease infinite;
|
| 381 |
+
text-shadow: 0 0 30px rgba(59, 130, 246, 0.5);
|
| 382 |
+
letter-spacing: 2px;">
|
| 383 |
+
Depth Anything 3
|
| 384 |
+
</h1>
|
| 385 |
+
<p style="margin: 15px 0 0 0; font-size: 2.16em; font-weight: 300;" class="header-subtitle">
|
| 386 |
+
Recovering the Visual Space from Any Views
|
| 387 |
+
</p>
|
| 388 |
+
<div style="margin-top: 20px;">
|
| 389 |
+
<!-- Revert buttons to original inline styles -->
|
| 390 |
+
<a href="https://depth-anything-3.github.io" target="_blank" class="link-btn">
|
| 391 |
+
<i class="fas fa-globe" style="margin-right: 8px;"></i> Project Page
|
| 392 |
+
</a>
|
| 393 |
+
<a href="https://arxiv.org/abs/2406.09414" target="_blank" class="link-btn">
|
| 394 |
+
<i class="fas fa-file-pdf" style="margin-right: 8px;"></i> Paper
|
| 395 |
+
</a>
|
| 396 |
+
<a href="https://github.com/ByteDance-Seed/Depth-Anything-3" target="_blank" class="link-btn">
|
| 397 |
+
<i class="fab fa-github" style="margin-right: 8px;"></i> Code
|
| 398 |
+
</a>
|
| 399 |
+
</div>
|
| 400 |
+
</div>
|
| 401 |
+
</div>
|
| 402 |
+
|
| 403 |
+
<style>
|
| 404 |
+
/* Ensure tech-bg class is properly applied in dark mode */
|
| 405 |
+
@media (prefers-color-scheme: dark) {
|
| 406 |
+
.header-subtitle {
|
| 407 |
+
color: #cbd5e1;
|
| 408 |
+
}
|
| 409 |
+
/* Increase priority to ensure background color is properly applied */
|
| 410 |
+
.tech-bg {
|
| 411 |
+
background: linear-gradient(135deg, #0f172a, #1e293b) !important;
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
@media (prefers-color-scheme: light) {
|
| 416 |
+
.header-subtitle {
|
| 417 |
+
color: #475569;
|
| 418 |
+
}
|
| 419 |
+
/* Also add explicit background color for light mode */
|
| 420 |
+
.tech-bg {
|
| 421 |
+
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%) !important;
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
</style>
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def get_description_html():
|
| 429 |
+
"""
|
| 430 |
+
Generate the main description and getting started HTML.
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
str: HTML string for the description
|
| 434 |
+
"""
|
| 435 |
+
return """
|
| 436 |
+
<div class="description-container" style="padding: 25px; border-radius: 15px; margin: 0 0 20px 0;">
|
| 437 |
+
<h2 class="description-title" style="margin-top: 0; font-size: 1.6em; text-align: center;">
|
| 438 |
+
<i class="fas fa-bullseye fa-color-red" style="margin-right: 8px;"></i> What This Demo Does
|
| 439 |
+
</h2>
|
| 440 |
+
<div class="description-content" style="padding: 20px; border-radius: 10px; margin: 15px 0; text-align: center;">
|
| 441 |
+
<p class="description-main" style="line-height: 1.6; margin: 0; font-size: 1.45em;">
|
| 442 |
+
<strong>Upload images or videos</strong> → <strong>Get <span class="metric-text">Metric</span> <span class="pointcloud-text">Point Clouds</span>, <span class="cameras-text">Cameras</span> and <span class="gaussians-text">Novel Views</span></strong> → <strong>Explore in 3D</strong>
|
| 443 |
+
</p>
|
| 444 |
+
</div>
|
| 445 |
+
|
| 446 |
+
<div style="text-align: center; margin-top: 15px;">
|
| 447 |
+
<p class="description-tip" style="font-style: italic; margin: 0;">
|
| 448 |
+
<i class="fas fa-lightbulb fa-color-yellow" style="margin-right: 8px;"></i> <strong>Tip:</strong> Landscape-oriented images or videos are preferred for best 3D recovering.
|
| 449 |
+
</p>
|
| 450 |
+
</div>
|
| 451 |
+
</div>
|
| 452 |
+
|
| 453 |
+
<style>
|
| 454 |
+
@media (prefers-color-scheme: dark) {
|
| 455 |
+
.description-container {
|
| 456 |
+
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
|
| 457 |
+
border: 1px solid rgba(59, 130, 246, 0.2);
|
| 458 |
+
}
|
| 459 |
+
.description-title { color: #3b82f6; }
|
| 460 |
+
.description-content { background: rgba(0, 0, 0, 0.3); }
|
| 461 |
+
.description-main { color: #e0e0e0; }
|
| 462 |
+
.description-text { color: #cbd5e1; }
|
| 463 |
+
.description-tip { color: #cbd5e1; }
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
@media (prefers-color-scheme: light) {
|
| 467 |
+
.description-container {
|
| 468 |
+
background: linear-gradient(135deg, rgba(59, 130, 246, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%);
|
| 469 |
+
border: 1px solid rgba(59, 130, 246, 0.3);
|
| 470 |
+
}
|
| 471 |
+
.description-title { color: #3b82f6; }
|
| 472 |
+
.description-content { background: transparent; }
|
| 473 |
+
.description-main { color: #1e293b; }
|
| 474 |
+
.description-text { color: #475569; }
|
| 475 |
+
.description-tip { color: #475569; }
|
| 476 |
+
}
|
| 477 |
+
</style>
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_acknowledgements_html():
|
| 482 |
+
"""
|
| 483 |
+
Generate the acknowledgements section HTML.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
str: HTML string for the acknowledgements
|
| 487 |
+
"""
|
| 488 |
+
return """
|
| 489 |
+
<div style="background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
|
| 490 |
+
padding: 25px; border-radius: 15px; margin: 20px 0; border: 1px solid rgba(59, 130, 246, 0.2);">
|
| 491 |
+
<h3 style="color: #3b82f6; margin-top: 0; text-align: center; font-size: 1.4em;">
|
| 492 |
+
<i class="fas fa-trophy fa-color-yellow" style="margin-right: 8px;"></i> Research Credits & Acknowledgments
|
| 493 |
+
</h3>
|
| 494 |
+
|
| 495 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 15px 0;">
|
| 496 |
+
<!-- Original Research Section (Left) -->
|
| 497 |
+
<div style="text-align: center;">
|
| 498 |
+
<h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-flask fa-color-green" style="margin-right: 8px;"></i> Original Research</h4>
|
| 499 |
+
<p style="color: #e0e0e0; margin: 5px 0;">
|
| 500 |
+
<a href="https://depth-anything-3.github.io" target="_blank"
|
| 501 |
+
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
|
| 502 |
+
Depth Anything 3
|
| 503 |
+
</a>
|
| 504 |
+
</p>
|
| 505 |
+
</div>
|
| 506 |
+
|
| 507 |
+
<!-- Previous Versions Section (Right) -->
|
| 508 |
+
<div style="text-align: center;">
|
| 509 |
+
<h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-history fa-color-blue" style="margin-right: 8px;"></i> Previous Versions</h4>
|
| 510 |
+
<div style="display: flex; flex-direction: row; gap: 15px; justify-content: center; align-items: center;">
|
| 511 |
+
<p style="color: #e0e0e0; margin: 0;">
|
| 512 |
+
<a href="https://huggingface.co/spaces/LiheYoung/Depth-Anything" target="_blank"
|
| 513 |
+
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
|
| 514 |
+
Depth-Anything
|
| 515 |
+
</a>
|
| 516 |
+
</p>
|
| 517 |
+
<span style="color: #e0e0e0;">•</span>
|
| 518 |
+
<p style="color: #e0e0e0; margin: 0;">
|
| 519 |
+
<a href="https://huggingface.co/spaces/depth-anything/Depth-Anything-V2" target="_blank"
|
| 520 |
+
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
|
| 521 |
+
Depth-Anything-V2
|
| 522 |
+
</a>
|
| 523 |
+
</p>
|
| 524 |
+
</div>
|
| 525 |
+
</div>
|
| 526 |
+
</div>
|
| 527 |
+
|
| 528 |
+
<!-- HF Demo Adapted from - Centered at the bottom of the whole block -->
|
| 529 |
+
<div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid rgba(59, 130, 246, 0.3); text-align: center;">
|
| 530 |
+
<p style="color: #a0a0a0; font-size: 0.9em; margin: 0;">
|
| 531 |
+
<i class="fas fa-code-branch fa-color-gray" style="margin-right: 5px;"></i> HF demo adapted from <a href="https://huggingface.co/spaces/facebook/map-anything" target="_blank" style="color: inherit; text-decoration: none;">Map Anything</a>
|
| 532 |
+
</p>
|
| 533 |
+
</div>
|
| 534 |
+
</div>
|
| 535 |
+
"""
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_gradio_theme():
|
| 539 |
+
"""
|
| 540 |
+
Get the configured Gradio theme with adaptive tech colors.
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
gr.themes.Base: Configured Gradio theme
|
| 544 |
+
"""
|
| 545 |
+
import gradio as gr
|
| 546 |
+
|
| 547 |
+
return gr.themes.Base(
|
| 548 |
+
primary_hue=gr.themes.Color(
|
| 549 |
+
c50="#eff6ff",
|
| 550 |
+
c100="#dbeafe",
|
| 551 |
+
c200="#bfdbfe",
|
| 552 |
+
c300="#93c5fd",
|
| 553 |
+
c400="#60a5fa",
|
| 554 |
+
c500="#3b82f6",
|
| 555 |
+
c600="#2563eb",
|
| 556 |
+
c700="#1d4ed8",
|
| 557 |
+
c800="#1e40af",
|
| 558 |
+
c900="#1e3a8a",
|
| 559 |
+
c950="#172554",
|
| 560 |
+
),
|
| 561 |
+
secondary_hue=gr.themes.Color(
|
| 562 |
+
c50="#f5f3ff",
|
| 563 |
+
c100="#ede9fe",
|
| 564 |
+
c200="#ddd6fe",
|
| 565 |
+
c300="#c4b5fd",
|
| 566 |
+
c400="#a78bfa",
|
| 567 |
+
c500="#8b5cf6",
|
| 568 |
+
c600="#7c3aed",
|
| 569 |
+
c700="#6d28d9",
|
| 570 |
+
c800="#5b21b6",
|
| 571 |
+
c900="#4c1d95",
|
| 572 |
+
c950="#2e1065",
|
| 573 |
+
),
|
| 574 |
+
neutral_hue=gr.themes.Color(
|
| 575 |
+
c50="#f8fafc",
|
| 576 |
+
c100="#f1f5f9",
|
| 577 |
+
c200="#e2e8f0",
|
| 578 |
+
c300="#cbd5e1",
|
| 579 |
+
c400="#94a3b8",
|
| 580 |
+
c500="#64748b",
|
| 581 |
+
c600="#475569",
|
| 582 |
+
c700="#334155",
|
| 583 |
+
c800="#1e293b",
|
| 584 |
+
c900="#0f172a",
|
| 585 |
+
c950="#020617",
|
| 586 |
+
),
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# Measure tab instructions HTML
|
| 591 |
+
MEASURE_INSTRUCTIONS_HTML = """
|
| 592 |
+
### Click points on the image to compute distance.
|
| 593 |
+
> <i class="fas fa-triangle-exclamation fa-color-red" style="margin-right: 5px;"></i> Metric scale estimation is difficult on aerial/drone images.
|
| 594 |
+
"""
|
Depth-Anything-3/src/depth_anything_3/app/gradio_app.py
ADDED
|
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Refactored Gradio App for Depth Anything 3.
|
| 17 |
+
|
| 18 |
+
This is the main application file that orchestrates all components.
|
| 19 |
+
The original functionality has been split into modular components for better maintainability.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
from typing import Any, Dict, List
|
| 25 |
+
import gradio as gr
|
| 26 |
+
|
| 27 |
+
from depth_anything_3.app.css_and_html import GRADIO_CSS, get_gradio_theme
|
| 28 |
+
from depth_anything_3.app.modules.event_handlers import EventHandlers
|
| 29 |
+
from depth_anything_3.app.modules.ui_components import UIComponents
|
| 30 |
+
|
| 31 |
+
# Set environment variables
|
| 32 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DepthAnything3App:
|
| 36 |
+
"""
|
| 37 |
+
Main application class for Depth Anything 3 Gradio app.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, model_dir: str = None, workspace_dir: str = None, gallery_dir: str = None):
|
| 41 |
+
"""
|
| 42 |
+
Initialize the application.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
model_dir: Path to the model directory
|
| 46 |
+
workspace_dir: Path to the workspace directory
|
| 47 |
+
gallery_dir: Path to the gallery directory
|
| 48 |
+
"""
|
| 49 |
+
self.model_dir = model_dir
|
| 50 |
+
self.workspace_dir = workspace_dir
|
| 51 |
+
self.gallery_dir = gallery_dir
|
| 52 |
+
|
| 53 |
+
# Set environment variables for directories
|
| 54 |
+
if self.model_dir:
|
| 55 |
+
os.environ["DA3_MODEL_DIR"] = self.model_dir
|
| 56 |
+
if self.workspace_dir:
|
| 57 |
+
os.environ["DA3_WORKSPACE_DIR"] = self.workspace_dir
|
| 58 |
+
if self.gallery_dir:
|
| 59 |
+
os.environ["DA3_GALLERY_DIR"] = self.gallery_dir
|
| 60 |
+
|
| 61 |
+
self.event_handlers = EventHandlers()
|
| 62 |
+
self.ui_components = UIComponents()
|
| 63 |
+
|
| 64 |
+
def cache_examples(
|
| 65 |
+
self,
|
| 66 |
+
show_cam: bool = True,
|
| 67 |
+
filter_black_bg: bool = False,
|
| 68 |
+
filter_white_bg: bool = False,
|
| 69 |
+
save_percentage: float = 20.0,
|
| 70 |
+
num_max_points: int = 1000,
|
| 71 |
+
cache_gs_tag: str = "",
|
| 72 |
+
gs_trj_mode: str = "smooth",
|
| 73 |
+
gs_video_quality: str = "low",
|
| 74 |
+
) -> None:
|
| 75 |
+
"""
|
| 76 |
+
Pre-cache all example scenes at startup.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
show_cam: Whether to show camera in visualization
|
| 80 |
+
filter_black_bg: Whether to filter black background
|
| 81 |
+
filter_white_bg: Whether to filter white background
|
| 82 |
+
save_percentage: Filter percentage for point cloud
|
| 83 |
+
num_max_points: Maximum number of points
|
| 84 |
+
cache_gs_tag: Tag to match scene names for high-res+3DGS caching (e.g., "dl3dv")
|
| 85 |
+
gs_trj_mode: Trajectory mode for 3DGS
|
| 86 |
+
gs_video_quality: Video quality for 3DGS
|
| 87 |
+
"""
|
| 88 |
+
from depth_anything_3.app.modules.utils import get_scene_info
|
| 89 |
+
|
| 90 |
+
examples_dir = os.path.join(self.workspace_dir, "examples")
|
| 91 |
+
if not os.path.exists(examples_dir):
|
| 92 |
+
print(f"Examples directory not found: {examples_dir}")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
scenes = get_scene_info(examples_dir)
|
| 96 |
+
if not scenes:
|
| 97 |
+
print("No example scenes found to cache.")
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
print(f"\n{'='*60}")
|
| 101 |
+
print(f"Caching {len(scenes)} example scenes...")
|
| 102 |
+
print(f"{'='*60}\n")
|
| 103 |
+
|
| 104 |
+
for i, scene in enumerate(scenes, 1):
|
| 105 |
+
scene_name = scene["name"]
|
| 106 |
+
|
| 107 |
+
# Check if scene name matches the gs tag for high-res+3DGS caching
|
| 108 |
+
use_high_res_gs = cache_gs_tag and cache_gs_tag.lower() in scene_name.lower()
|
| 109 |
+
|
| 110 |
+
if use_high_res_gs:
|
| 111 |
+
print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (HIGH-RES + 3DGS)")
|
| 112 |
+
print(f" - Number of images: {scene['num_images']}")
|
| 113 |
+
print(f" - Matched tag: '{cache_gs_tag}' - using high_res + 3DGS")
|
| 114 |
+
else:
|
| 115 |
+
print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (LOW-RES)")
|
| 116 |
+
print(f" - Number of images: {scene['num_images']}")
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
# Load example scene
|
| 120 |
+
_, target_dir, _, _, _, _, _, _, _ = self.event_handlers.load_example_scene(
|
| 121 |
+
scene_name
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if target_dir and target_dir != "None":
|
| 125 |
+
# Run reconstruction with appropriate settings
|
| 126 |
+
print(" - Running reconstruction...")
|
| 127 |
+
result = self.event_handlers.gradio_demo(
|
| 128 |
+
target_dir=target_dir,
|
| 129 |
+
show_cam=show_cam,
|
| 130 |
+
filter_black_bg=filter_black_bg,
|
| 131 |
+
filter_white_bg=filter_white_bg,
|
| 132 |
+
process_res_method="high_res" if use_high_res_gs else "low_res",
|
| 133 |
+
save_percentage=save_percentage,
|
| 134 |
+
num_max_points=num_max_points,
|
| 135 |
+
infer_gs=use_high_res_gs,
|
| 136 |
+
ref_view_strategy="saddle_balanced",
|
| 137 |
+
gs_trj_mode=gs_trj_mode,
|
| 138 |
+
gs_video_quality=gs_video_quality,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Check if successful
|
| 142 |
+
if result[0] is not None: # reconstruction_output
|
| 143 |
+
print(f" ✓ Scene '{scene_name}' cached successfully")
|
| 144 |
+
else:
|
| 145 |
+
print(f" ✗ Scene '{scene_name}' caching failed: {result[1]}")
|
| 146 |
+
else:
|
| 147 |
+
print(f" ✗ Scene '{scene_name}' loading failed")
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f" ✗ Error caching scene '{scene_name}': {str(e)}")
|
| 151 |
+
|
| 152 |
+
print()
|
| 153 |
+
|
| 154 |
+
print("=" * 60)
|
| 155 |
+
print("Example scene caching completed!")
|
| 156 |
+
print("=" * 60 + "\n")
|
| 157 |
+
|
| 158 |
+
def create_app(self) -> gr.Blocks:
|
| 159 |
+
"""
|
| 160 |
+
Create and configure the Gradio application.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Configured Gradio Blocks interface
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# Initialize theme
|
| 167 |
+
def get_theme():
|
| 168 |
+
return get_gradio_theme()
|
| 169 |
+
|
| 170 |
+
with gr.Blocks(theme=get_theme(), css=GRADIO_CSS) as demo:
|
| 171 |
+
# State variables for the tabbed interface
|
| 172 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 173 |
+
processed_data_state = gr.State(value=None)
|
| 174 |
+
measure_points_state = gr.State(value=[])
|
| 175 |
+
selected_image_index_state = gr.State(value=0) # Track selected image index
|
| 176 |
+
# current_view_index = gr.State(value=0) # noqa: F841 Track current view index
|
| 177 |
+
|
| 178 |
+
# Header and description
|
| 179 |
+
self.ui_components.create_header_section()
|
| 180 |
+
self.ui_components.create_description_section()
|
| 181 |
+
|
| 182 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 183 |
+
|
| 184 |
+
# Main content area
|
| 185 |
+
with gr.Row():
|
| 186 |
+
with gr.Column(scale=2):
|
| 187 |
+
# Upload section
|
| 188 |
+
(
|
| 189 |
+
input_video,
|
| 190 |
+
s_time_interval,
|
| 191 |
+
input_images,
|
| 192 |
+
image_gallery,
|
| 193 |
+
) = self.ui_components.create_upload_section()
|
| 194 |
+
|
| 195 |
+
with gr.Column(scale=4):
|
| 196 |
+
with gr.Column():
|
| 197 |
+
# gr.Markdown("**Metric 3D Reconstruction (Point Cloud and Camera Poses)**")
|
| 198 |
+
# Reconstruction control section (buttons) - moved below tabs
|
| 199 |
+
|
| 200 |
+
log_output = gr.Markdown(
|
| 201 |
+
"Please upload a video or images, then click Reconstruct.",
|
| 202 |
+
elem_classes=["custom-log"],
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Tabbed interface
|
| 206 |
+
with gr.Tabs():
|
| 207 |
+
with gr.Tab("Point Cloud & Cameras"):
|
| 208 |
+
reconstruction_output = (
|
| 209 |
+
self.ui_components.create_3d_viewer_section()
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
with gr.Tab("Metric Depth"):
|
| 213 |
+
(
|
| 214 |
+
prev_measure_btn,
|
| 215 |
+
measure_view_selector,
|
| 216 |
+
next_measure_btn,
|
| 217 |
+
measure_image,
|
| 218 |
+
measure_depth_image,
|
| 219 |
+
measure_text,
|
| 220 |
+
) = self.ui_components.create_measure_section()
|
| 221 |
+
|
| 222 |
+
with gr.Tab("3DGS Rendered Novel Views"):
|
| 223 |
+
gs_video, gs_info = self.ui_components.create_nvs_video()
|
| 224 |
+
|
| 225 |
+
# Inference control section (before inference)
|
| 226 |
+
(process_res_method_dropdown, infer_gs, ref_view_strategy_dropdown) = (
|
| 227 |
+
self.ui_components.create_inference_control_section()
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Display control section - includes 3DGS options, buttons, and Visualization Options # noqa: E501
|
| 231 |
+
(
|
| 232 |
+
show_cam,
|
| 233 |
+
filter_black_bg,
|
| 234 |
+
filter_white_bg,
|
| 235 |
+
save_percentage,
|
| 236 |
+
num_max_points,
|
| 237 |
+
gs_trj_mode,
|
| 238 |
+
gs_video_quality,
|
| 239 |
+
submit_btn,
|
| 240 |
+
clear_btn,
|
| 241 |
+
) = self.ui_components.create_display_control_section()
|
| 242 |
+
|
| 243 |
+
# bind visibility of gs_trj_mode to infer_gs
|
| 244 |
+
infer_gs.change(
|
| 245 |
+
fn=lambda checked: (
|
| 246 |
+
gr.update(visible=checked),
|
| 247 |
+
gr.update(visible=checked),
|
| 248 |
+
gr.update(visible=checked),
|
| 249 |
+
gr.update(visible=(not checked)),
|
| 250 |
+
),
|
| 251 |
+
inputs=infer_gs,
|
| 252 |
+
outputs=[gs_trj_mode, gs_video_quality, gs_video, gs_info],
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Example scenes section
|
| 256 |
+
gr.Markdown("## Example Scenes")
|
| 257 |
+
|
| 258 |
+
scenes = self.ui_components.create_example_scenes_section()
|
| 259 |
+
scene_components = self.ui_components.create_example_scene_grid(scenes)
|
| 260 |
+
|
| 261 |
+
# Set up event handlers
|
| 262 |
+
self._setup_event_handlers(
|
| 263 |
+
demo,
|
| 264 |
+
is_example,
|
| 265 |
+
processed_data_state,
|
| 266 |
+
measure_points_state,
|
| 267 |
+
target_dir_output,
|
| 268 |
+
input_video,
|
| 269 |
+
input_images,
|
| 270 |
+
s_time_interval,
|
| 271 |
+
image_gallery,
|
| 272 |
+
reconstruction_output,
|
| 273 |
+
log_output,
|
| 274 |
+
show_cam,
|
| 275 |
+
filter_black_bg,
|
| 276 |
+
filter_white_bg,
|
| 277 |
+
process_res_method_dropdown,
|
| 278 |
+
save_percentage,
|
| 279 |
+
submit_btn,
|
| 280 |
+
clear_btn,
|
| 281 |
+
num_max_points,
|
| 282 |
+
infer_gs,
|
| 283 |
+
ref_view_strategy_dropdown,
|
| 284 |
+
selected_image_index_state,
|
| 285 |
+
measure_view_selector,
|
| 286 |
+
measure_image,
|
| 287 |
+
measure_depth_image,
|
| 288 |
+
measure_text,
|
| 289 |
+
prev_measure_btn,
|
| 290 |
+
next_measure_btn,
|
| 291 |
+
scenes,
|
| 292 |
+
scene_components,
|
| 293 |
+
gs_video,
|
| 294 |
+
gs_info,
|
| 295 |
+
gs_trj_mode,
|
| 296 |
+
gs_video_quality,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Acknowledgements
|
| 300 |
+
self.ui_components.create_acknowledgements_section()
|
| 301 |
+
|
| 302 |
+
return demo
|
| 303 |
+
|
| 304 |
+
def _setup_event_handlers(
|
| 305 |
+
self,
|
| 306 |
+
demo: gr.Blocks,
|
| 307 |
+
is_example: gr.Textbox,
|
| 308 |
+
processed_data_state: gr.State,
|
| 309 |
+
measure_points_state: gr.State,
|
| 310 |
+
target_dir_output: gr.Textbox,
|
| 311 |
+
input_video: gr.Video,
|
| 312 |
+
input_images: gr.File,
|
| 313 |
+
s_time_interval: gr.Slider,
|
| 314 |
+
image_gallery: gr.Gallery,
|
| 315 |
+
reconstruction_output: gr.Model3D,
|
| 316 |
+
log_output: gr.Markdown,
|
| 317 |
+
show_cam: gr.Checkbox,
|
| 318 |
+
filter_black_bg: gr.Checkbox,
|
| 319 |
+
filter_white_bg: gr.Checkbox,
|
| 320 |
+
process_res_method_dropdown: gr.Dropdown,
|
| 321 |
+
save_percentage: gr.Slider,
|
| 322 |
+
submit_btn: gr.Button,
|
| 323 |
+
clear_btn: gr.ClearButton,
|
| 324 |
+
num_max_points: gr.Slider,
|
| 325 |
+
infer_gs: gr.Checkbox,
|
| 326 |
+
ref_view_strategy_dropdown: gr.Dropdown,
|
| 327 |
+
selected_image_index_state: gr.State,
|
| 328 |
+
measure_view_selector: gr.Dropdown,
|
| 329 |
+
measure_image: gr.Image,
|
| 330 |
+
measure_depth_image: gr.Image,
|
| 331 |
+
measure_text: gr.Markdown,
|
| 332 |
+
prev_measure_btn: gr.Button,
|
| 333 |
+
next_measure_btn: gr.Button,
|
| 334 |
+
scenes: List[Dict[str, Any]],
|
| 335 |
+
scene_components: List[gr.Image],
|
| 336 |
+
gs_video: gr.Video,
|
| 337 |
+
gs_info: gr.Markdown,
|
| 338 |
+
gs_trj_mode: gr.Dropdown,
|
| 339 |
+
gs_video_quality: gr.Dropdown,
|
| 340 |
+
) -> None:
|
| 341 |
+
"""
|
| 342 |
+
Set up all event handlers for the application.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
demo: Gradio Blocks interface
|
| 346 |
+
All other arguments: Gradio components to connect
|
| 347 |
+
"""
|
| 348 |
+
# Configure clear button
|
| 349 |
+
clear_btn.add(
|
| 350 |
+
[
|
| 351 |
+
input_video,
|
| 352 |
+
input_images,
|
| 353 |
+
reconstruction_output,
|
| 354 |
+
log_output,
|
| 355 |
+
target_dir_output,
|
| 356 |
+
image_gallery,
|
| 357 |
+
gs_video,
|
| 358 |
+
]
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Main reconstruction button
|
| 362 |
+
submit_btn.click(
|
| 363 |
+
fn=self.event_handlers.clear_fields, inputs=[], outputs=[reconstruction_output]
|
| 364 |
+
).then(fn=self.event_handlers.update_log, inputs=[], outputs=[log_output]).then(
|
| 365 |
+
fn=self.event_handlers.gradio_demo,
|
| 366 |
+
inputs=[
|
| 367 |
+
target_dir_output,
|
| 368 |
+
show_cam,
|
| 369 |
+
filter_black_bg,
|
| 370 |
+
filter_white_bg,
|
| 371 |
+
process_res_method_dropdown,
|
| 372 |
+
save_percentage,
|
| 373 |
+
# pass num_max_points
|
| 374 |
+
num_max_points,
|
| 375 |
+
infer_gs,
|
| 376 |
+
ref_view_strategy_dropdown,
|
| 377 |
+
gs_trj_mode,
|
| 378 |
+
gs_video_quality,
|
| 379 |
+
],
|
| 380 |
+
outputs=[
|
| 381 |
+
reconstruction_output,
|
| 382 |
+
log_output,
|
| 383 |
+
processed_data_state,
|
| 384 |
+
measure_image,
|
| 385 |
+
measure_depth_image,
|
| 386 |
+
measure_text,
|
| 387 |
+
measure_view_selector,
|
| 388 |
+
gs_video,
|
| 389 |
+
gs_video, # gs_video visibility
|
| 390 |
+
gs_info, # gs_info visibility
|
| 391 |
+
],
|
| 392 |
+
).then(
|
| 393 |
+
fn=lambda: "False",
|
| 394 |
+
inputs=[],
|
| 395 |
+
outputs=[is_example], # set is_example to "False"
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Real-time visualization updates
|
| 399 |
+
self._setup_visualization_handlers(
|
| 400 |
+
show_cam,
|
| 401 |
+
filter_black_bg,
|
| 402 |
+
filter_white_bg,
|
| 403 |
+
process_res_method_dropdown,
|
| 404 |
+
target_dir_output,
|
| 405 |
+
is_example,
|
| 406 |
+
reconstruction_output,
|
| 407 |
+
log_output,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# File upload handlers
|
| 411 |
+
input_video.change(
|
| 412 |
+
fn=self.event_handlers.handle_uploads,
|
| 413 |
+
inputs=[input_video, input_images, s_time_interval],
|
| 414 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 415 |
+
)
|
| 416 |
+
input_images.change(
|
| 417 |
+
fn=self.event_handlers.handle_uploads,
|
| 418 |
+
inputs=[input_video, input_images, s_time_interval],
|
| 419 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Navigation handlers
|
| 423 |
+
self._setup_navigation_handlers(
|
| 424 |
+
prev_measure_btn,
|
| 425 |
+
next_measure_btn,
|
| 426 |
+
measure_view_selector,
|
| 427 |
+
measure_image,
|
| 428 |
+
measure_depth_image,
|
| 429 |
+
measure_points_state,
|
| 430 |
+
processed_data_state,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Measurement handler
|
| 434 |
+
measure_image.select(
|
| 435 |
+
fn=self.event_handlers.measure,
|
| 436 |
+
inputs=[processed_data_state, measure_points_state, measure_view_selector],
|
| 437 |
+
outputs=[measure_image, measure_depth_image, measure_points_state, measure_text],
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Example scene handlers
|
| 441 |
+
self._setup_example_scene_handlers(
|
| 442 |
+
scenes,
|
| 443 |
+
scene_components,
|
| 444 |
+
reconstruction_output,
|
| 445 |
+
target_dir_output,
|
| 446 |
+
image_gallery,
|
| 447 |
+
log_output,
|
| 448 |
+
is_example,
|
| 449 |
+
processed_data_state,
|
| 450 |
+
measure_view_selector,
|
| 451 |
+
measure_image,
|
| 452 |
+
measure_depth_image,
|
| 453 |
+
gs_video,
|
| 454 |
+
gs_info,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
def _setup_visualization_handlers(
|
| 458 |
+
self,
|
| 459 |
+
show_cam: gr.Checkbox,
|
| 460 |
+
filter_black_bg: gr.Checkbox,
|
| 461 |
+
filter_white_bg: gr.Checkbox,
|
| 462 |
+
process_res_method_dropdown: gr.Dropdown,
|
| 463 |
+
target_dir_output: gr.Textbox,
|
| 464 |
+
is_example: gr.Textbox,
|
| 465 |
+
reconstruction_output: gr.Model3D,
|
| 466 |
+
log_output: gr.Markdown,
|
| 467 |
+
) -> None:
|
| 468 |
+
"""Set up visualization update handlers."""
|
| 469 |
+
# Common inputs for visualization updates
|
| 470 |
+
viz_inputs = [
|
| 471 |
+
target_dir_output,
|
| 472 |
+
show_cam,
|
| 473 |
+
is_example,
|
| 474 |
+
filter_black_bg,
|
| 475 |
+
filter_white_bg,
|
| 476 |
+
process_res_method_dropdown,
|
| 477 |
+
]
|
| 478 |
+
|
| 479 |
+
# Set up change handlers for all visualization controls
|
| 480 |
+
for component in [show_cam, filter_black_bg, filter_white_bg]:
|
| 481 |
+
component.change(
|
| 482 |
+
fn=self.event_handlers.update_visualization,
|
| 483 |
+
inputs=viz_inputs,
|
| 484 |
+
outputs=[reconstruction_output, log_output],
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
def _setup_navigation_handlers(
|
| 488 |
+
self,
|
| 489 |
+
prev_measure_btn: gr.Button,
|
| 490 |
+
next_measure_btn: gr.Button,
|
| 491 |
+
measure_view_selector: gr.Dropdown,
|
| 492 |
+
measure_image: gr.Image,
|
| 493 |
+
measure_depth_image: gr.Image,
|
| 494 |
+
measure_points_state: gr.State,
|
| 495 |
+
processed_data_state: gr.State,
|
| 496 |
+
) -> None:
|
| 497 |
+
"""Set up navigation handlers for measure tab."""
|
| 498 |
+
# Measure tab navigation
|
| 499 |
+
prev_measure_btn.click(
|
| 500 |
+
fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
|
| 501 |
+
processed_data, current_selector, -1
|
| 502 |
+
),
|
| 503 |
+
inputs=[processed_data_state, measure_view_selector],
|
| 504 |
+
outputs=[
|
| 505 |
+
measure_view_selector,
|
| 506 |
+
measure_image,
|
| 507 |
+
measure_depth_image,
|
| 508 |
+
measure_points_state,
|
| 509 |
+
],
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
next_measure_btn.click(
|
| 513 |
+
fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
|
| 514 |
+
processed_data, current_selector, 1
|
| 515 |
+
),
|
| 516 |
+
inputs=[processed_data_state, measure_view_selector],
|
| 517 |
+
outputs=[
|
| 518 |
+
measure_view_selector,
|
| 519 |
+
measure_image,
|
| 520 |
+
measure_depth_image,
|
| 521 |
+
measure_points_state,
|
| 522 |
+
],
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
measure_view_selector.change(
|
| 526 |
+
fn=lambda processed_data, selector_value: (
|
| 527 |
+
self.event_handlers.update_measure_view(
|
| 528 |
+
processed_data, int(selector_value.split()[1]) - 1
|
| 529 |
+
)
|
| 530 |
+
if selector_value
|
| 531 |
+
else (None, None, [])
|
| 532 |
+
),
|
| 533 |
+
inputs=[processed_data_state, measure_view_selector],
|
| 534 |
+
outputs=[measure_image, measure_depth_image, measure_points_state],
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def _setup_example_scene_handlers(
|
| 538 |
+
self,
|
| 539 |
+
scenes: List[Dict[str, Any]],
|
| 540 |
+
scene_components: List[gr.Image],
|
| 541 |
+
reconstruction_output: gr.Model3D,
|
| 542 |
+
target_dir_output: gr.Textbox,
|
| 543 |
+
image_gallery: gr.Gallery,
|
| 544 |
+
log_output: gr.Markdown,
|
| 545 |
+
is_example: gr.Textbox,
|
| 546 |
+
processed_data_state: gr.State,
|
| 547 |
+
measure_view_selector: gr.Dropdown,
|
| 548 |
+
measure_image: gr.Image,
|
| 549 |
+
measure_depth_image: gr.Image,
|
| 550 |
+
gs_video: gr.Video,
|
| 551 |
+
gs_info: gr.Markdown,
|
| 552 |
+
) -> None:
|
| 553 |
+
"""Set up example scene handlers."""
|
| 554 |
+
|
| 555 |
+
def load_and_update_measure(name):
|
| 556 |
+
result = self.event_handlers.load_example_scene(name)
|
| 557 |
+
# result = (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
|
| 558 |
+
|
| 559 |
+
# Update measure view if processed_data is available
|
| 560 |
+
measure_img = None
|
| 561 |
+
measure_depth = None
|
| 562 |
+
if result[4] is not None: # processed_data exists
|
| 563 |
+
measure_img, measure_depth, _ = (
|
| 564 |
+
self.event_handlers.visualization_handler.update_measure_view(result[4], 0)
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
return result + ("True", measure_img, measure_depth)
|
| 568 |
+
|
| 569 |
+
for i, scene in enumerate(scenes):
|
| 570 |
+
if i < len(scene_components):
|
| 571 |
+
scene_components[i].select(
|
| 572 |
+
fn=lambda name=scene["name"]: load_and_update_measure(name),
|
| 573 |
+
outputs=[
|
| 574 |
+
reconstruction_output,
|
| 575 |
+
target_dir_output,
|
| 576 |
+
image_gallery,
|
| 577 |
+
log_output,
|
| 578 |
+
processed_data_state,
|
| 579 |
+
measure_view_selector,
|
| 580 |
+
gs_video,
|
| 581 |
+
gs_video, # gs_video_visibility
|
| 582 |
+
gs_info, # gs_info_visibility
|
| 583 |
+
is_example,
|
| 584 |
+
measure_image,
|
| 585 |
+
measure_depth_image,
|
| 586 |
+
],
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def launch(self, host: str = "127.0.0.1", port: int = 7860, **kwargs) -> None:
|
| 590 |
+
"""
|
| 591 |
+
Launch the application.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
host: Host address to bind to
|
| 595 |
+
port: Port number to bind to
|
| 596 |
+
**kwargs: Additional arguments for demo.launch()
|
| 597 |
+
"""
|
| 598 |
+
demo = self.create_app()
|
| 599 |
+
demo.queue(max_size=20).launch(
|
| 600 |
+
show_error=True, ssr_mode=False, server_name=host, server_port=port, **kwargs
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def main():
|
| 605 |
+
"""Main function to run the application."""
|
| 606 |
+
parser = argparse.ArgumentParser(
|
| 607 |
+
description="Depth Anything 3 Gradio Application",
|
| 608 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 609 |
+
epilog="""
|
| 610 |
+
Examples:
|
| 611 |
+
# Basic usage
|
| 612 |
+
python gradio_app.py --help
|
| 613 |
+
python gradio_app.py --host 0.0.0.0 --port 8080
|
| 614 |
+
python gradio_app.py --model-dir /path/to/model --workspace-dir /path/to/workspace
|
| 615 |
+
|
| 616 |
+
# Cache examples at startup (all low-res)
|
| 617 |
+
python gradio_app.py --cache-examples
|
| 618 |
+
|
| 619 |
+
# Cache with selective high-res+3DGS for scenes matching tag
|
| 620 |
+
python gradio_app.py --cache-examples --cache-gs-tag dl3dv
|
| 621 |
+
# This will use high-res + 3DGS for scenes containing "dl3dv" in their name,
|
| 622 |
+
# and low-res only for other scenes
|
| 623 |
+
""",
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
# Server configuration
|
| 627 |
+
parser.add_argument(
|
| 628 |
+
"--host", default="127.0.0.1", help="Host address to bind to (default: 127.0.0.1)"
|
| 629 |
+
)
|
| 630 |
+
parser.add_argument(
|
| 631 |
+
"--port", type=int, default=7860, help="Port number to bind to (default: 7860)"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# Directory configuration
|
| 635 |
+
parser.add_argument(
|
| 636 |
+
"--model-dir",
|
| 637 |
+
default="depth-anything/DA3NESTED-GIANT-LARGE",
|
| 638 |
+
help="Path to the model directory (default: depth-anything/DA3NESTED-GIANT-LARGE)",
|
| 639 |
+
)
|
| 640 |
+
parser.add_argument(
|
| 641 |
+
"--workspace-dir",
|
| 642 |
+
default="workspace/gradio", # noqa: E501
|
| 643 |
+
help="Path to the workspace directory (default: workspace/gradio)", # noqa: E501
|
| 644 |
+
)
|
| 645 |
+
parser.add_argument(
|
| 646 |
+
"--gallery-dir",
|
| 647 |
+
default="workspace/gallery",
|
| 648 |
+
help="Path to the gallery directory (default: workspace/gallery)", # noqa: E501
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Additional Gradio options
|
| 652 |
+
parser.add_argument("--share", action="store_true", help="Create a public link for the app")
|
| 653 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 654 |
+
|
| 655 |
+
# Example caching options
|
| 656 |
+
parser.add_argument(
|
| 657 |
+
"--cache-examples",
|
| 658 |
+
action="store_true",
|
| 659 |
+
help="Pre-cache all example scenes at startup for faster loading",
|
| 660 |
+
)
|
| 661 |
+
parser.add_argument(
|
| 662 |
+
"--cache-gs-tag",
|
| 663 |
+
type=str,
|
| 664 |
+
default="",
|
| 665 |
+
help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", # noqa: E501
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
args = parser.parse_args()
|
| 669 |
+
|
| 670 |
+
# Create directories if they don't exist
|
| 671 |
+
os.makedirs(args.workspace_dir, exist_ok=True)
|
| 672 |
+
os.makedirs(args.gallery_dir, exist_ok=True)
|
| 673 |
+
|
| 674 |
+
# Initialize and launch the application
|
| 675 |
+
app = DepthAnything3App(
|
| 676 |
+
model_dir=args.model_dir, workspace_dir=args.workspace_dir, gallery_dir=args.gallery_dir
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# Prepare launch arguments
|
| 680 |
+
launch_kwargs = {"share": args.share, "debug": args.debug}
|
| 681 |
+
|
| 682 |
+
print("Starting Depth Anything 3 Gradio App...")
|
| 683 |
+
print(f"Host: {args.host}")
|
| 684 |
+
print(f"Port: {args.port}")
|
| 685 |
+
print(f"Model Directory: {args.model_dir}")
|
| 686 |
+
print(f"Workspace Directory: {args.workspace_dir}")
|
| 687 |
+
print(f"Gallery Directory: {args.gallery_dir}")
|
| 688 |
+
print(f"Share: {args.share}")
|
| 689 |
+
print(f"Debug: {args.debug}")
|
| 690 |
+
print(f"Cache Examples: {args.cache_examples}")
|
| 691 |
+
if args.cache_examples:
|
| 692 |
+
if args.cache_gs_tag:
|
| 693 |
+
print(
|
| 694 |
+
f"Cache GS Tag: '{args.cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" # noqa: E501
|
| 695 |
+
) # noqa: E501
|
| 696 |
+
else:
|
| 697 |
+
print("Cache GS Tag: None (all scenes will use low-res only)")
|
| 698 |
+
|
| 699 |
+
# Pre-cache examples if requested
|
| 700 |
+
if args.cache_examples:
|
| 701 |
+
print("\n" + "=" * 60)
|
| 702 |
+
print("Pre-caching mode enabled")
|
| 703 |
+
if args.cache_gs_tag:
|
| 704 |
+
print(f"Scenes containing '{args.cache_gs_tag}' will use HIGH-RES + 3DGS")
|
| 705 |
+
print("Other scenes will use LOW-RES only")
|
| 706 |
+
else:
|
| 707 |
+
print("All scenes will use LOW-RES only")
|
| 708 |
+
print("=" * 60)
|
| 709 |
+
app.cache_examples(
|
| 710 |
+
show_cam=True,
|
| 711 |
+
filter_black_bg=False,
|
| 712 |
+
filter_white_bg=False,
|
| 713 |
+
save_percentage=5.0,
|
| 714 |
+
num_max_points=1000,
|
| 715 |
+
cache_gs_tag=args.cache_gs_tag,
|
| 716 |
+
gs_trj_mode="smooth",
|
| 717 |
+
gs_video_quality="low",
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
app.launch(host=args.host, port=args.port, **launch_kwargs)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
if __name__ == "__main__":
|
| 724 |
+
main()
|
Depth-Anything-3/src/depth_anything_3/app/modules/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Modules package for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This package contains all the modular components for the Gradio application.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from depth_anything_3.app.modules.event_handlers import EventHandlers
|
| 22 |
+
from depth_anything_3.app.modules.file_handlers import FileHandler
|
| 23 |
+
from depth_anything_3.app.modules.model_inference import ModelInference
|
| 24 |
+
from depth_anything_3.app.modules.ui_components import UIComponents
|
| 25 |
+
from depth_anything_3.app.modules.utils import (
|
| 26 |
+
create_depth_visualization,
|
| 27 |
+
get_logo_base64,
|
| 28 |
+
get_scene_info,
|
| 29 |
+
save_to_gallery_func,
|
| 30 |
+
)
|
| 31 |
+
from depth_anything_3.app.modules.visualization import VisualizationHandler
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"ModelInference",
|
| 35 |
+
"FileHandler",
|
| 36 |
+
"VisualizationHandler",
|
| 37 |
+
"EventHandlers",
|
| 38 |
+
"UIComponents",
|
| 39 |
+
"create_depth_visualization",
|
| 40 |
+
"save_to_gallery_func",
|
| 41 |
+
"get_scene_info",
|
| 42 |
+
"get_logo_base64",
|
| 43 |
+
]
|
Depth-Anything-3/src/depth_anything_3/app/modules/event_handlers.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Event handling module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles all event callbacks and user interactions.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import time
|
| 23 |
+
from glob import glob
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 25 |
+
import gradio as gr
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from depth_anything_3.app.modules.file_handlers import FileHandler
|
| 30 |
+
from depth_anything_3.app.modules.model_inference import ModelInference
|
| 31 |
+
from depth_anything_3.utils.memory import cleanup_cuda_memory
|
| 32 |
+
from depth_anything_3.app.modules.visualization import VisualizationHandler
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class EventHandlers:
|
| 36 |
+
"""
|
| 37 |
+
Handles all event callbacks and user interactions for the Gradio app.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
"""Initialize the event handlers."""
|
| 42 |
+
self.model_inference = ModelInference()
|
| 43 |
+
self.file_handler = FileHandler()
|
| 44 |
+
self.visualization_handler = VisualizationHandler()
|
| 45 |
+
|
| 46 |
+
def clear_fields(self) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Clears the 3D viewer, the stored target_dir, and empties the gallery.
|
| 49 |
+
"""
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def update_log(self) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Display a quick log message while waiting.
|
| 55 |
+
"""
|
| 56 |
+
return "Loading and Reconstructing..."
|
| 57 |
+
|
| 58 |
+
def save_current_visualization(
|
| 59 |
+
self,
|
| 60 |
+
target_dir: str,
|
| 61 |
+
save_percentage: float,
|
| 62 |
+
show_cam: bool,
|
| 63 |
+
filter_black_bg: bool,
|
| 64 |
+
filter_white_bg: bool,
|
| 65 |
+
processed_data: Optional[Dict],
|
| 66 |
+
scene_name: str = "",
|
| 67 |
+
) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Save current visualization results to gallery with specified save percentage.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
target_dir: Directory containing results
|
| 73 |
+
save_percentage: Percentage of points to save (0-100)
|
| 74 |
+
show_cam: Whether to show cameras
|
| 75 |
+
filter_black_bg: Whether to filter black background
|
| 76 |
+
filter_white_bg: Whether to filter white background
|
| 77 |
+
processed_data: Processed data from reconstruction
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Status message
|
| 81 |
+
"""
|
| 82 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 83 |
+
return "No reconstruction available. Please run 'Reconstruct' first."
|
| 84 |
+
|
| 85 |
+
if processed_data is None:
|
| 86 |
+
return "No processed data available. Please run 'Reconstruct' first."
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# Add debug information
|
| 90 |
+
print("[DEBUG] save_current_visualization called with:")
|
| 91 |
+
print(f" target_dir: {target_dir}")
|
| 92 |
+
print(f" save_percentage: {save_percentage}")
|
| 93 |
+
print(f" show_cam: {show_cam}")
|
| 94 |
+
print(f" filter_black_bg: {filter_black_bg}")
|
| 95 |
+
print(f" filter_white_bg: {filter_white_bg}")
|
| 96 |
+
print(f" processed_data: {processed_data is not None}")
|
| 97 |
+
|
| 98 |
+
# Import the gallery save function
|
| 99 |
+
# Create gallery name with user input or auto-generated
|
| 100 |
+
import datetime
|
| 101 |
+
|
| 102 |
+
from .utils import save_to_gallery_func
|
| 103 |
+
|
| 104 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 105 |
+
if scene_name and scene_name.strip():
|
| 106 |
+
gallery_name = f"{scene_name.strip()}_{timestamp}_pct{save_percentage:.0f}"
|
| 107 |
+
else:
|
| 108 |
+
gallery_name = f"save_{timestamp}_pct{save_percentage:.0f}"
|
| 109 |
+
|
| 110 |
+
print(f"[DEBUG] Saving to gallery with name: {gallery_name}")
|
| 111 |
+
|
| 112 |
+
# Save entire process folder to gallery
|
| 113 |
+
success, message = save_to_gallery_func(
|
| 114 |
+
target_dir=target_dir, processed_data=processed_data, gallery_name=gallery_name
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if success:
|
| 118 |
+
print(f"[DEBUG] Gallery save completed successfully: {message}")
|
| 119 |
+
return (
|
| 120 |
+
"Successfully saved to gallery!\n"
|
| 121 |
+
f"Gallery name: {gallery_name}\n"
|
| 122 |
+
f"Save percentage: {save_percentage}%\n"
|
| 123 |
+
f"Show cameras: {show_cam}\n"
|
| 124 |
+
f"Filter black bg: {filter_black_bg}\n"
|
| 125 |
+
f"Filter white bg: {filter_white_bg}\n\n"
|
| 126 |
+
f"{message}"
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
print(f"[DEBUG] Gallery save failed: {message}")
|
| 130 |
+
return f"Failed to save to gallery: {message}"
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
return f"Error saving visualization: {str(e)}"
|
| 134 |
+
|
| 135 |
+
def gradio_demo(
|
| 136 |
+
self,
|
| 137 |
+
target_dir: str,
|
| 138 |
+
show_cam: bool = True,
|
| 139 |
+
filter_black_bg: bool = False,
|
| 140 |
+
filter_white_bg: bool = False,
|
| 141 |
+
process_res_method: str = "upper_bound_resize",
|
| 142 |
+
save_percentage: float = 30.0,
|
| 143 |
+
num_max_points: int = 1_000_000,
|
| 144 |
+
infer_gs: bool = False,
|
| 145 |
+
ref_view_strategy: str = "saddle_balanced",
|
| 146 |
+
gs_trj_mode: str = "extend",
|
| 147 |
+
gs_video_quality: str = "high",
|
| 148 |
+
) -> Tuple[
|
| 149 |
+
Optional[str],
|
| 150 |
+
str,
|
| 151 |
+
Optional[Dict],
|
| 152 |
+
Optional[np.ndarray],
|
| 153 |
+
Optional[np.ndarray],
|
| 154 |
+
str,
|
| 155 |
+
gr.Dropdown,
|
| 156 |
+
Optional[str], # gs video path
|
| 157 |
+
gr.update, # gs video visibility update
|
| 158 |
+
gr.update, # gs info visibility update
|
| 159 |
+
]:
|
| 160 |
+
"""
|
| 161 |
+
Perform reconstruction using the already-created target_dir/images.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
target_dir: Directory containing images
|
| 165 |
+
show_cam: Whether to show camera
|
| 166 |
+
filter_black_bg: Whether to filter black background
|
| 167 |
+
filter_white_bg: Whether to filter white background
|
| 168 |
+
process_res_method: Method for resizing input images
|
| 169 |
+
save_percentage: Filter percentage for point cloud
|
| 170 |
+
num_max_points: Maximum number of points
|
| 171 |
+
infer_gs: Whether to infer 3D Gaussian Splatting
|
| 172 |
+
ref_view_strategy: Reference view selection strategy
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Tuple of reconstruction results
|
| 176 |
+
"""
|
| 177 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 178 |
+
return (
|
| 179 |
+
None,
|
| 180 |
+
"No valid target directory found. Please upload first.",
|
| 181 |
+
None,
|
| 182 |
+
None,
|
| 183 |
+
None,
|
| 184 |
+
"",
|
| 185 |
+
None,
|
| 186 |
+
None,
|
| 187 |
+
gr.update(visible=False), # gs_video
|
| 188 |
+
gr.update(visible=True), # gs_info
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
start_time = time.time()
|
| 192 |
+
cleanup_cuda_memory()
|
| 193 |
+
|
| 194 |
+
# Get image files for logging
|
| 195 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 196 |
+
all_files = (
|
| 197 |
+
sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
print("Running DepthAnything3 model...")
|
| 201 |
+
print(f"Reference view strategy: {ref_view_strategy}")
|
| 202 |
+
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
prediction, processed_data = self.model_inference.run_inference(
|
| 205 |
+
target_dir,
|
| 206 |
+
process_res_method=process_res_method,
|
| 207 |
+
show_camera=show_cam,
|
| 208 |
+
save_percentage=save_percentage,
|
| 209 |
+
num_max_points=int(num_max_points * 1000), # Convert K to actual count
|
| 210 |
+
infer_gs=infer_gs,
|
| 211 |
+
ref_view_strategy=ref_view_strategy,
|
| 212 |
+
gs_trj_mode=gs_trj_mode,
|
| 213 |
+
gs_video_quality=gs_video_quality,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# The GLB file is already generated by the API
|
| 217 |
+
glbfile = os.path.join(target_dir, "scene.glb")
|
| 218 |
+
|
| 219 |
+
# Handle 3DGS video based on infer_gs flag
|
| 220 |
+
gsvideo_path = None
|
| 221 |
+
gs_video_visible = False
|
| 222 |
+
gs_info_visible = True
|
| 223 |
+
|
| 224 |
+
if infer_gs:
|
| 225 |
+
try:
|
| 226 |
+
gsvideo_path = sorted(glob(os.path.join(target_dir, "gs_video", "*.mp4")))[-1]
|
| 227 |
+
gs_video_visible = True
|
| 228 |
+
gs_info_visible = False
|
| 229 |
+
except IndexError:
|
| 230 |
+
gsvideo_path = None
|
| 231 |
+
print("3DGS video not found, but infer_gs was enabled")
|
| 232 |
+
|
| 233 |
+
# Cleanup
|
| 234 |
+
cleanup_cuda_memory()
|
| 235 |
+
|
| 236 |
+
end_time = time.time()
|
| 237 |
+
print(f"Total time: {end_time - start_time:.2f} seconds")
|
| 238 |
+
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 239 |
+
|
| 240 |
+
# Populate visualization tabs with processed data
|
| 241 |
+
depth_vis, measure_img, measure_depth_vis, measure_pts = (
|
| 242 |
+
self.visualization_handler.populate_visualization_tabs(processed_data)
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Update view selectors based on available views
|
| 246 |
+
depth_selector, measure_selector = self.visualization_handler.update_view_selectors(
|
| 247 |
+
processed_data
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return (
|
| 251 |
+
glbfile,
|
| 252 |
+
log_msg,
|
| 253 |
+
processed_data,
|
| 254 |
+
measure_img, # measure_image
|
| 255 |
+
measure_depth_vis, # measure_depth_image
|
| 256 |
+
"", # measure_text (empty initially)
|
| 257 |
+
measure_selector, # measure_view_selector
|
| 258 |
+
gsvideo_path,
|
| 259 |
+
gr.update(visible=gs_video_visible), # gs_video visibility
|
| 260 |
+
gr.update(visible=gs_info_visible), # gs_info visibility
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def update_visualization(
|
| 264 |
+
self,
|
| 265 |
+
target_dir: str,
|
| 266 |
+
show_cam: bool,
|
| 267 |
+
is_example: str,
|
| 268 |
+
filter_black_bg: bool = False,
|
| 269 |
+
filter_white_bg: bool = False,
|
| 270 |
+
process_res_method: str = "upper_bound_resize",
|
| 271 |
+
) -> Tuple[gr.update, str]:
|
| 272 |
+
"""
|
| 273 |
+
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
| 274 |
+
and return it for the 3D viewer.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
target_dir: Directory containing results
|
| 278 |
+
show_cam: Whether to show camera
|
| 279 |
+
is_example: Whether this is an example scene
|
| 280 |
+
filter_black_bg: Whether to filter black background
|
| 281 |
+
filter_white_bg: Whether to filter white background
|
| 282 |
+
process_res_method: Method for resizing input images
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Tuple of (glb_file, log_message)
|
| 286 |
+
"""
|
| 287 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 288 |
+
return (
|
| 289 |
+
gr.update(),
|
| 290 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Check if GLB exists (could be cached example or reconstructed scene)
|
| 294 |
+
glbfile = os.path.join(target_dir, "scene.glb")
|
| 295 |
+
if os.path.exists(glbfile):
|
| 296 |
+
return (
|
| 297 |
+
glbfile,
|
| 298 |
+
(
|
| 299 |
+
"Visualization loaded from cache."
|
| 300 |
+
if is_example == "True"
|
| 301 |
+
else "Visualization updated."
|
| 302 |
+
),
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# If no GLB but it's an example that hasn't been reconstructed yet
|
| 306 |
+
if is_example == "True":
|
| 307 |
+
return (
|
| 308 |
+
gr.update(),
|
| 309 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# For non-examples, check predictions.npz
|
| 313 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 314 |
+
if not os.path.exists(predictions_path):
|
| 315 |
+
error_message = (
|
| 316 |
+
f"No reconstruction available at {predictions_path}. "
|
| 317 |
+
"Please run 'Reconstruct' first."
|
| 318 |
+
)
|
| 319 |
+
return gr.update(), error_message
|
| 320 |
+
|
| 321 |
+
loaded = np.load(predictions_path, allow_pickle=True)
|
| 322 |
+
predictions = {key: loaded[key] for key in loaded.keys()} # noqa: F841
|
| 323 |
+
|
| 324 |
+
return (
|
| 325 |
+
glbfile,
|
| 326 |
+
"Visualization updated.",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def handle_uploads(
|
| 330 |
+
self,
|
| 331 |
+
input_video: Optional[str],
|
| 332 |
+
input_images: Optional[List],
|
| 333 |
+
s_time_interval: float = 10.0,
|
| 334 |
+
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
|
| 335 |
+
"""
|
| 336 |
+
Handle file uploads and update gallery.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
input_video: Path to input video file
|
| 340 |
+
input_images: List of input image files
|
| 341 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
|
| 345 |
+
"""
|
| 346 |
+
return self.file_handler.update_gallery_on_upload(
|
| 347 |
+
input_video, input_images, s_time_interval
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def load_example_scene(self, scene_name: str, examples_dir: str = None) -> Tuple[
|
| 351 |
+
Optional[str],
|
| 352 |
+
Optional[str],
|
| 353 |
+
Optional[List],
|
| 354 |
+
str,
|
| 355 |
+
Optional[Dict],
|
| 356 |
+
gr.Dropdown,
|
| 357 |
+
Optional[str],
|
| 358 |
+
gr.update,
|
| 359 |
+
gr.update,
|
| 360 |
+
]:
|
| 361 |
+
"""
|
| 362 |
+
Load a scene from examples directory.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
scene_name: Name of the scene to load
|
| 366 |
+
examples_dir: Path to examples directory (if None, uses workspace_dir/examples)
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
|
| 370 |
+
"""
|
| 371 |
+
if examples_dir is None:
|
| 372 |
+
# Get workspace directory from environment variable
|
| 373 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 374 |
+
examples_dir = os.path.join(workspace_dir, "examples")
|
| 375 |
+
|
| 376 |
+
reconstruction_output, target_dir, image_paths, log_message = (
|
| 377 |
+
self.file_handler.load_example_scene(scene_name, examples_dir)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Try to load cached processed data if available
|
| 381 |
+
processed_data = None
|
| 382 |
+
measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1")
|
| 383 |
+
gs_video_path = None
|
| 384 |
+
gs_video_visible = False
|
| 385 |
+
gs_info_visible = True
|
| 386 |
+
|
| 387 |
+
if target_dir and target_dir != "None":
|
| 388 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 389 |
+
if os.path.exists(predictions_path):
|
| 390 |
+
try:
|
| 391 |
+
# Load predictions from cache
|
| 392 |
+
loaded = np.load(predictions_path, allow_pickle=True)
|
| 393 |
+
predictions = {key: loaded[key] for key in loaded.keys()}
|
| 394 |
+
|
| 395 |
+
# Reconstruct processed_data structure
|
| 396 |
+
num_images = len(predictions.get("images", []))
|
| 397 |
+
processed_data = {}
|
| 398 |
+
|
| 399 |
+
for i in range(num_images):
|
| 400 |
+
processed_data[i] = {
|
| 401 |
+
"image": predictions["images"][i] if "images" in predictions else None,
|
| 402 |
+
"depth": predictions["depths"][i] if "depths" in predictions else None,
|
| 403 |
+
"depth_image": os.path.join(
|
| 404 |
+
target_dir, "depth_vis", f"{i:04d}.jpg" # Fixed: use .jpg not .png
|
| 405 |
+
),
|
| 406 |
+
"intrinsics": (
|
| 407 |
+
predictions["intrinsics"][i]
|
| 408 |
+
if "intrinsics" in predictions
|
| 409 |
+
and i < len(predictions["intrinsics"])
|
| 410 |
+
else None
|
| 411 |
+
),
|
| 412 |
+
"mask": None,
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
# Update measure view selector
|
| 416 |
+
choices = [f"View {i + 1}" for i in range(num_images)]
|
| 417 |
+
measure_view_selector = gr.Dropdown(choices=choices, value=choices[0])
|
| 418 |
+
|
| 419 |
+
except Exception as e:
|
| 420 |
+
print(f"Error loading cached data: {e}")
|
| 421 |
+
|
| 422 |
+
# Check for cached 3DGS video
|
| 423 |
+
gs_video_dir = os.path.join(target_dir, "gs_video")
|
| 424 |
+
if os.path.exists(gs_video_dir):
|
| 425 |
+
try:
|
| 426 |
+
from glob import glob
|
| 427 |
+
|
| 428 |
+
gs_videos = sorted(glob(os.path.join(gs_video_dir, "*.mp4")))
|
| 429 |
+
if gs_videos:
|
| 430 |
+
gs_video_path = gs_videos[-1]
|
| 431 |
+
gs_video_visible = True
|
| 432 |
+
gs_info_visible = False
|
| 433 |
+
print(f"Loaded cached 3DGS video: {gs_video_path}")
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"Error loading cached 3DGS video: {e}")
|
| 436 |
+
|
| 437 |
+
return (
|
| 438 |
+
reconstruction_output,
|
| 439 |
+
target_dir,
|
| 440 |
+
image_paths,
|
| 441 |
+
log_message,
|
| 442 |
+
processed_data,
|
| 443 |
+
measure_view_selector,
|
| 444 |
+
gs_video_path,
|
| 445 |
+
gr.update(visible=gs_video_visible),
|
| 446 |
+
gr.update(visible=gs_info_visible),
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def navigate_depth_view(
|
| 450 |
+
self,
|
| 451 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 452 |
+
current_selector: str,
|
| 453 |
+
direction: int,
|
| 454 |
+
) -> Tuple[str, Optional[str]]:
|
| 455 |
+
"""
|
| 456 |
+
Navigate depth view.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
processed_data: Processed data dictionary
|
| 460 |
+
current_selector: Current selector value
|
| 461 |
+
direction: Direction to navigate
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
Tuple of (new_selector_value, depth_vis)
|
| 465 |
+
"""
|
| 466 |
+
return self.visualization_handler.navigate_depth_view(
|
| 467 |
+
processed_data, current_selector, direction
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
def update_depth_view(
|
| 471 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 472 |
+
) -> Optional[str]:
|
| 473 |
+
"""
|
| 474 |
+
Update depth view for a specific view index.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
processed_data: Processed data dictionary
|
| 478 |
+
view_index: Index of the view to update
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
Path to depth visualization image or None
|
| 482 |
+
"""
|
| 483 |
+
return self.visualization_handler.update_depth_view(processed_data, view_index)
|
| 484 |
+
|
| 485 |
+
def navigate_measure_view(
|
| 486 |
+
self,
|
| 487 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 488 |
+
current_selector: str,
|
| 489 |
+
direction: int,
|
| 490 |
+
) -> Tuple[str, Optional[np.ndarray], Optional[np.ndarray], List]:
|
| 491 |
+
"""
|
| 492 |
+
Navigate measure view.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
processed_data: Processed data dictionary
|
| 496 |
+
current_selector: Current selector value
|
| 497 |
+
direction: Direction to navigate
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
Tuple of (new_selector_value, measure_image, depth_right_half, measure_points)
|
| 501 |
+
"""
|
| 502 |
+
return self.visualization_handler.navigate_measure_view(
|
| 503 |
+
processed_data, current_selector, direction
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
def update_measure_view(
|
| 507 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 508 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
|
| 509 |
+
"""
|
| 510 |
+
Update measure view for a specific view index.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
processed_data: Processed data dictionary
|
| 514 |
+
view_index: Index of the view to update
|
| 515 |
+
|
| 516 |
+
Returns:
|
| 517 |
+
Tuple of (measure_image, depth_right_half, measure_points)
|
| 518 |
+
"""
|
| 519 |
+
return self.visualization_handler.update_measure_view(processed_data, view_index)
|
| 520 |
+
|
| 521 |
+
def measure(
|
| 522 |
+
self,
|
| 523 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 524 |
+
measure_points: List,
|
| 525 |
+
current_view_selector: str,
|
| 526 |
+
event: gr.SelectData,
|
| 527 |
+
) -> List:
|
| 528 |
+
"""
|
| 529 |
+
Handle measurement on images.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
processed_data: Processed data dictionary
|
| 533 |
+
measure_points: List of current measure points
|
| 534 |
+
current_view_selector: Current view selector value
|
| 535 |
+
event: Gradio select event
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
List of [image, depth_right_half, measure_points, text]
|
| 539 |
+
"""
|
| 540 |
+
return self.visualization_handler.measure(
|
| 541 |
+
processed_data, measure_points, current_view_selector, event
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
def select_first_frame(
|
| 545 |
+
self, image_gallery: List, selected_index: int = 0
|
| 546 |
+
) -> Tuple[List, str, str]:
|
| 547 |
+
"""
|
| 548 |
+
Select the first frame from the image gallery.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
image_gallery: List of images in the gallery
|
| 552 |
+
selected_index: Index of the selected image (default: 0)
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
Tuple of (updated_image_gallery, log_message, selected_frame_path)
|
| 556 |
+
"""
|
| 557 |
+
try:
|
| 558 |
+
if not image_gallery or len(image_gallery) == 0:
|
| 559 |
+
return image_gallery, "No images available to select as first frame.", ""
|
| 560 |
+
|
| 561 |
+
# Handle None or invalid selected_index
|
| 562 |
+
if (
|
| 563 |
+
selected_index is None
|
| 564 |
+
or selected_index < 0
|
| 565 |
+
or selected_index >= len(image_gallery)
|
| 566 |
+
):
|
| 567 |
+
selected_index = 0
|
| 568 |
+
print(f"Invalid selected_index: {selected_index}, using default: 0")
|
| 569 |
+
|
| 570 |
+
# Get the selected image based on index
|
| 571 |
+
selected_image = image_gallery[selected_index]
|
| 572 |
+
print(f"Selected image index: {selected_index}")
|
| 573 |
+
print(f"Total images: {len(image_gallery)}")
|
| 574 |
+
|
| 575 |
+
# Extract the file path from the selected image
|
| 576 |
+
selected_frame_path = ""
|
| 577 |
+
print(f"Selected image type: {type(selected_image)}")
|
| 578 |
+
print(f"Selected image: {selected_image}")
|
| 579 |
+
|
| 580 |
+
if isinstance(selected_image, tuple):
|
| 581 |
+
# Gradio Gallery returns tuple (path, None)
|
| 582 |
+
selected_frame_path = selected_image[0]
|
| 583 |
+
elif isinstance(selected_image, str):
|
| 584 |
+
selected_frame_path = selected_image
|
| 585 |
+
elif hasattr(selected_image, "name"):
|
| 586 |
+
selected_frame_path = selected_image.name
|
| 587 |
+
elif isinstance(selected_image, dict):
|
| 588 |
+
if "name" in selected_image:
|
| 589 |
+
selected_frame_path = selected_image["name"]
|
| 590 |
+
elif "path" in selected_image:
|
| 591 |
+
selected_frame_path = selected_image["path"]
|
| 592 |
+
elif "src" in selected_image:
|
| 593 |
+
selected_frame_path = selected_image["src"]
|
| 594 |
+
else:
|
| 595 |
+
# Try to convert to string
|
| 596 |
+
selected_frame_path = str(selected_image)
|
| 597 |
+
|
| 598 |
+
print(f"Extracted path: {selected_frame_path}")
|
| 599 |
+
|
| 600 |
+
# Extract filename from the path for matching
|
| 601 |
+
import os
|
| 602 |
+
|
| 603 |
+
selected_filename = os.path.basename(selected_frame_path)
|
| 604 |
+
print(f"Selected filename: {selected_filename}")
|
| 605 |
+
|
| 606 |
+
# Move the selected image to the front
|
| 607 |
+
updated_gallery = [selected_image] + [
|
| 608 |
+
img for img in image_gallery if img != selected_image
|
| 609 |
+
]
|
| 610 |
+
|
| 611 |
+
log_message = (
|
| 612 |
+
f"Selected frame: {selected_filename}. "
|
| 613 |
+
f"Moved to first position. Total frames: {len(updated_gallery)}"
|
| 614 |
+
)
|
| 615 |
+
return updated_gallery, log_message, selected_filename
|
| 616 |
+
|
| 617 |
+
except Exception as e:
|
| 618 |
+
print(f"Error selecting first frame: {e}")
|
| 619 |
+
return image_gallery, f"Error selecting first frame: {e}", ""
|
Depth-Anything-3/src/depth_anything_3/app/modules/file_handlers.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
File handling module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles file uploads, video processing, and file operations.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
import time
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from typing import List, Optional, Tuple
|
| 26 |
+
import cv2
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from pillow_heif import register_heif_opener
|
| 29 |
+
|
| 30 |
+
register_heif_opener()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FileHandler:
|
| 34 |
+
"""
|
| 35 |
+
Handles file uploads and processing for the Gradio app.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
"""Initialize the file handler."""
|
| 40 |
+
|
| 41 |
+
def handle_uploads(
|
| 42 |
+
self,
|
| 43 |
+
input_video: Optional[str],
|
| 44 |
+
input_images: Optional[List],
|
| 45 |
+
s_time_interval: float = 10.0,
|
| 46 |
+
) -> Tuple[str, List[str]]:
|
| 47 |
+
"""
|
| 48 |
+
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
|
| 49 |
+
images or extracted frames from video into it.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
input_video: Path to input video file
|
| 53 |
+
input_images: List of input image files
|
| 54 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple of (target_dir, image_paths)
|
| 58 |
+
"""
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
|
| 61 |
+
# Get workspace directory from environment variable or use default
|
| 62 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 63 |
+
if not os.path.exists(workspace_dir):
|
| 64 |
+
os.makedirs(workspace_dir)
|
| 65 |
+
|
| 66 |
+
# Create input_images subdirectory
|
| 67 |
+
input_images_dir = os.path.join(workspace_dir, "input_images")
|
| 68 |
+
if not os.path.exists(input_images_dir):
|
| 69 |
+
os.makedirs(input_images_dir)
|
| 70 |
+
|
| 71 |
+
# Create a unique folder name within input_images
|
| 72 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 73 |
+
target_dir = os.path.join(input_images_dir, f"session_{timestamp}")
|
| 74 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 75 |
+
|
| 76 |
+
# Clean up if somehow that folder already exists
|
| 77 |
+
if os.path.exists(target_dir):
|
| 78 |
+
shutil.rmtree(target_dir)
|
| 79 |
+
os.makedirs(target_dir)
|
| 80 |
+
os.makedirs(target_dir_images)
|
| 81 |
+
|
| 82 |
+
image_paths = []
|
| 83 |
+
|
| 84 |
+
# Handle images
|
| 85 |
+
if input_images is not None:
|
| 86 |
+
image_paths.extend(self._process_images(input_images, target_dir_images))
|
| 87 |
+
|
| 88 |
+
# Handle video
|
| 89 |
+
if input_video is not None:
|
| 90 |
+
image_paths.extend(
|
| 91 |
+
self._process_video(input_video, target_dir_images, s_time_interval)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Sort final images for gallery
|
| 95 |
+
image_paths = sorted(image_paths)
|
| 96 |
+
|
| 97 |
+
end_time = time.time()
|
| 98 |
+
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
|
| 99 |
+
return target_dir, image_paths
|
| 100 |
+
|
| 101 |
+
def _process_images(self, input_images: List, target_dir_images: str) -> List[str]:
|
| 102 |
+
"""
|
| 103 |
+
Process uploaded images.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
input_images: List of input image files
|
| 107 |
+
target_dir_images: Target directory for images
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
List of processed image paths
|
| 111 |
+
"""
|
| 112 |
+
image_paths = []
|
| 113 |
+
|
| 114 |
+
for file_data in input_images:
|
| 115 |
+
if isinstance(file_data, dict) and "name" in file_data:
|
| 116 |
+
file_path = file_data["name"]
|
| 117 |
+
else:
|
| 118 |
+
file_path = file_data
|
| 119 |
+
|
| 120 |
+
# Check if the file is a HEIC image
|
| 121 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
| 122 |
+
if file_ext in [".heic", ".heif"]:
|
| 123 |
+
# Convert HEIC to JPEG for better gallery compatibility
|
| 124 |
+
try:
|
| 125 |
+
with Image.open(file_path) as img:
|
| 126 |
+
# Convert to RGB if necessary (HEIC can have different color modes)
|
| 127 |
+
if img.mode not in ("RGB", "L"):
|
| 128 |
+
img = img.convert("RGB")
|
| 129 |
+
|
| 130 |
+
# Create JPEG filename
|
| 131 |
+
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 132 |
+
dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
|
| 133 |
+
|
| 134 |
+
# Save as JPEG with high quality
|
| 135 |
+
img.save(dst_path, "JPEG", quality=95)
|
| 136 |
+
image_paths.append(dst_path)
|
| 137 |
+
print(
|
| 138 |
+
f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> "
|
| 139 |
+
f"{os.path.basename(dst_path)}"
|
| 140 |
+
)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"Error converting HEIC file {file_path}: {e}")
|
| 143 |
+
# Fall back to copying as is
|
| 144 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 145 |
+
shutil.copy(file_path, dst_path)
|
| 146 |
+
image_paths.append(dst_path)
|
| 147 |
+
else:
|
| 148 |
+
# Regular image files - copy as is
|
| 149 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 150 |
+
shutil.copy(file_path, dst_path)
|
| 151 |
+
image_paths.append(dst_path)
|
| 152 |
+
|
| 153 |
+
return image_paths
|
| 154 |
+
|
| 155 |
+
def _process_video(
|
| 156 |
+
self, input_video: str, target_dir_images: str, s_time_interval: float
|
| 157 |
+
) -> List[str]:
|
| 158 |
+
"""
|
| 159 |
+
Process video file and extract frames.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
input_video: Path to input video file
|
| 163 |
+
target_dir_images: Target directory for extracted frames
|
| 164 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
List of extracted frame paths
|
| 168 |
+
"""
|
| 169 |
+
image_paths = []
|
| 170 |
+
|
| 171 |
+
if isinstance(input_video, dict) and "name" in input_video:
|
| 172 |
+
video_path = input_video["name"]
|
| 173 |
+
else:
|
| 174 |
+
video_path = input_video
|
| 175 |
+
|
| 176 |
+
vs = cv2.VideoCapture(video_path)
|
| 177 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
| 178 |
+
frame_interval = max(1, int(fps / s_time_interval)) # Convert FPS to frame interval
|
| 179 |
+
|
| 180 |
+
count = 0
|
| 181 |
+
video_frame_num = 0
|
| 182 |
+
while True:
|
| 183 |
+
gotit, frame = vs.read()
|
| 184 |
+
if not gotit:
|
| 185 |
+
break
|
| 186 |
+
count += 1
|
| 187 |
+
if count % frame_interval == 0:
|
| 188 |
+
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
|
| 189 |
+
cv2.imwrite(image_path, frame)
|
| 190 |
+
image_paths.append(image_path)
|
| 191 |
+
video_frame_num += 1
|
| 192 |
+
|
| 193 |
+
return image_paths
|
| 194 |
+
|
| 195 |
+
def update_gallery_on_upload(
|
| 196 |
+
self,
|
| 197 |
+
input_video: Optional[str],
|
| 198 |
+
input_images: Optional[List],
|
| 199 |
+
s_time_interval: float = 10.0,
|
| 200 |
+
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
|
| 201 |
+
"""
|
| 202 |
+
Handle file uploads and update gallery.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
input_video: Path to input video file
|
| 206 |
+
input_images: List of input image files
|
| 207 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
|
| 211 |
+
"""
|
| 212 |
+
if not input_video and not input_images:
|
| 213 |
+
return None, None, None, None
|
| 214 |
+
|
| 215 |
+
target_dir, image_paths = self.handle_uploads(input_video, input_images, s_time_interval)
|
| 216 |
+
return (
|
| 217 |
+
None,
|
| 218 |
+
target_dir,
|
| 219 |
+
image_paths,
|
| 220 |
+
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def load_example_scene(
|
| 224 |
+
self, scene_name: str, examples_dir: str = "examples"
|
| 225 |
+
) -> Tuple[Optional[str], Optional[str], Optional[List], str]:
|
| 226 |
+
"""
|
| 227 |
+
Load a scene from examples directory.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
scene_name: Name of the scene to load
|
| 231 |
+
examples_dir: Path to examples directory
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
|
| 235 |
+
"""
|
| 236 |
+
from depth_anything_3.app.modules.utils import get_scene_info
|
| 237 |
+
|
| 238 |
+
scenes = get_scene_info(examples_dir)
|
| 239 |
+
|
| 240 |
+
# Find the selected scene
|
| 241 |
+
selected_scene = None
|
| 242 |
+
for scene in scenes:
|
| 243 |
+
if scene["name"] == scene_name:
|
| 244 |
+
selected_scene = scene
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
if selected_scene is None:
|
| 248 |
+
return None, None, None, "Scene not found"
|
| 249 |
+
|
| 250 |
+
# Use fixed directory name for examples (not timestamp-based)
|
| 251 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 252 |
+
input_images_dir = os.path.join(workspace_dir, "input_images")
|
| 253 |
+
if not os.path.exists(input_images_dir):
|
| 254 |
+
os.makedirs(input_images_dir)
|
| 255 |
+
|
| 256 |
+
# Create a fixed folder name based on scene name
|
| 257 |
+
target_dir = os.path.join(input_images_dir, f"example_{scene_name}")
|
| 258 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 259 |
+
|
| 260 |
+
# Check if already cached (GLB file exists)
|
| 261 |
+
glb_path = os.path.join(target_dir, "scene.glb")
|
| 262 |
+
is_cached = os.path.exists(glb_path)
|
| 263 |
+
|
| 264 |
+
# Create directory if it doesn't exist
|
| 265 |
+
if not os.path.exists(target_dir):
|
| 266 |
+
os.makedirs(target_dir)
|
| 267 |
+
os.makedirs(target_dir_images)
|
| 268 |
+
|
| 269 |
+
# Copy images if directory is new or empty
|
| 270 |
+
if not os.path.exists(target_dir_images) or len(os.listdir(target_dir_images)) == 0:
|
| 271 |
+
os.makedirs(target_dir_images, exist_ok=True)
|
| 272 |
+
image_paths = []
|
| 273 |
+
for file_path in selected_scene["image_files"]:
|
| 274 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 275 |
+
shutil.copy(file_path, dst_path)
|
| 276 |
+
image_paths.append(dst_path)
|
| 277 |
+
else:
|
| 278 |
+
# Use existing images
|
| 279 |
+
image_paths = sorted(
|
| 280 |
+
[
|
| 281 |
+
os.path.join(target_dir_images, f)
|
| 282 |
+
for f in os.listdir(target_dir_images)
|
| 283 |
+
if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"))
|
| 284 |
+
]
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Return cached GLB if available
|
| 288 |
+
if is_cached:
|
| 289 |
+
return (
|
| 290 |
+
glb_path, # Return cached reconstruction
|
| 291 |
+
target_dir, # Set target directory
|
| 292 |
+
image_paths, # Set gallery
|
| 293 |
+
f"Loaded cached scene '{scene_name}' with {selected_scene['num_images']} images.",
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
return (
|
| 297 |
+
None, # No cached reconstruction
|
| 298 |
+
target_dir, # Set target directory
|
| 299 |
+
image_paths, # Set gallery
|
| 300 |
+
(
|
| 301 |
+
f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. "
|
| 302 |
+
"Click 'Reconstruct' to begin 3D processing."
|
| 303 |
+
),
|
| 304 |
+
)
|
Depth-Anything-3/src/depth_anything_3/app/modules/model_inference.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Model inference module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles all model-related operations including inference,
|
| 19 |
+
data processing, and result preparation.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import glob
|
| 23 |
+
import os
|
| 24 |
+
from typing import Any, Dict, Optional, Tuple
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from depth_anything_3.api import DepthAnything3
|
| 29 |
+
from depth_anything_3.utils.memory import cleanup_cuda_memory
|
| 30 |
+
from depth_anything_3.utils.export.glb import export_to_glb
|
| 31 |
+
from depth_anything_3.utils.export.gs import export_to_gs_video
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ModelInference:
|
| 35 |
+
"""
|
| 36 |
+
Handles model inference and data processing for Depth Anything 3.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
"""Initialize the model inference handler."""
|
| 41 |
+
self.model = None
|
| 42 |
+
|
| 43 |
+
def initialize_model(self, device: str = "cuda") -> None:
|
| 44 |
+
"""
|
| 45 |
+
Initialize the DepthAnything3 model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
device: Device to load the model on
|
| 49 |
+
"""
|
| 50 |
+
if self.model is None:
|
| 51 |
+
# Get model directory from environment variable or use default
|
| 52 |
+
model_dir = os.environ.get(
|
| 53 |
+
"DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL"
|
| 54 |
+
)
|
| 55 |
+
self.model = DepthAnything3.from_pretrained(model_dir)
|
| 56 |
+
self.model = self.model.to(device)
|
| 57 |
+
else:
|
| 58 |
+
self.model = self.model.to(device)
|
| 59 |
+
|
| 60 |
+
self.model.eval()
|
| 61 |
+
|
| 62 |
+
def run_inference(
|
| 63 |
+
self,
|
| 64 |
+
target_dir: str,
|
| 65 |
+
filter_black_bg: bool = False,
|
| 66 |
+
filter_white_bg: bool = False,
|
| 67 |
+
process_res_method: str = "upper_bound_resize",
|
| 68 |
+
show_camera: bool = True,
|
| 69 |
+
save_percentage: float = 30.0,
|
| 70 |
+
num_max_points: int = 1_000_000,
|
| 71 |
+
infer_gs: bool = False,
|
| 72 |
+
ref_view_strategy: str = "saddle_balanced",
|
| 73 |
+
gs_trj_mode: str = "extend",
|
| 74 |
+
gs_video_quality: str = "high",
|
| 75 |
+
) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
|
| 76 |
+
"""
|
| 77 |
+
Run DepthAnything3 model inference on images.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
target_dir: Directory containing images
|
| 81 |
+
filter_black_bg: Whether to filter black background
|
| 82 |
+
filter_white_bg: Whether to filter white background
|
| 83 |
+
process_res_method: Method for resizing input images
|
| 84 |
+
show_camera: Whether to show camera in 3D view
|
| 85 |
+
save_percentage: Percentage of points to save (0-100)
|
| 86 |
+
num_max_points: Maximum number of points in point cloud
|
| 87 |
+
infer_gs: Whether to infer 3D Gaussian Splatting
|
| 88 |
+
ref_view_strategy: Reference view selection strategy
|
| 89 |
+
gs_trj_mode: Trajectory mode for 3DGS
|
| 90 |
+
gs_video_quality: Video quality for 3DGS
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Tuple of (prediction, processed_data)
|
| 94 |
+
"""
|
| 95 |
+
print(f"Processing images from {target_dir}")
|
| 96 |
+
|
| 97 |
+
# Device check
|
| 98 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 99 |
+
device = torch.device(device)
|
| 100 |
+
|
| 101 |
+
# Initialize model if needed
|
| 102 |
+
self.initialize_model(device)
|
| 103 |
+
|
| 104 |
+
# Get image paths
|
| 105 |
+
print("Loading images...")
|
| 106 |
+
image_folder_path = os.path.join(target_dir, "images")
|
| 107 |
+
all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))
|
| 108 |
+
|
| 109 |
+
# Filter for image files
|
| 110 |
+
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
|
| 111 |
+
all_image_paths = [
|
| 112 |
+
path
|
| 113 |
+
for path in all_image_paths
|
| 114 |
+
if any(path.lower().endswith(ext) for ext in image_extensions)
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
print(f"Found {len(all_image_paths)} images")
|
| 118 |
+
print(f"All image paths: {all_image_paths}")
|
| 119 |
+
|
| 120 |
+
# Use sorted image order (reference view will be selected automatically)
|
| 121 |
+
image_paths = all_image_paths
|
| 122 |
+
print(f"Reference view selection strategy: {ref_view_strategy}")
|
| 123 |
+
|
| 124 |
+
if len(image_paths) == 0:
|
| 125 |
+
raise ValueError("No images found. Check your upload.")
|
| 126 |
+
|
| 127 |
+
# Map UI options to actual method names
|
| 128 |
+
method_mapping = {"high_res": "lower_bound_resize", "low_res": "upper_bound_resize"}
|
| 129 |
+
actual_method = method_mapping.get(process_res_method, "upper_bound_crop")
|
| 130 |
+
|
| 131 |
+
# Run model inference
|
| 132 |
+
print(f"Running inference with method: {actual_method}")
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
prediction = self.model.inference(
|
| 135 |
+
image_paths,
|
| 136 |
+
export_dir=None,
|
| 137 |
+
process_res_method=actual_method,
|
| 138 |
+
infer_gs=infer_gs,
|
| 139 |
+
ref_view_strategy=ref_view_strategy,
|
| 140 |
+
)
|
| 141 |
+
# num_max_points: int = 1_000_000,
|
| 142 |
+
export_to_glb(
|
| 143 |
+
prediction,
|
| 144 |
+
filter_black_bg=filter_black_bg,
|
| 145 |
+
filter_white_bg=filter_white_bg,
|
| 146 |
+
export_dir=target_dir,
|
| 147 |
+
show_cameras=show_camera,
|
| 148 |
+
conf_thresh_percentile=save_percentage,
|
| 149 |
+
num_max_points=int(num_max_points),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# export to gs video if needed
|
| 153 |
+
if infer_gs:
|
| 154 |
+
mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
|
| 155 |
+
print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}")
|
| 156 |
+
export_to_gs_video(
|
| 157 |
+
prediction,
|
| 158 |
+
export_dir=target_dir,
|
| 159 |
+
chunk_size=4,
|
| 160 |
+
trj_mode=mode_mapping.get(gs_trj_mode, "extend"),
|
| 161 |
+
enable_tqdm=True,
|
| 162 |
+
vis_depth="hcat",
|
| 163 |
+
video_quality=gs_video_quality,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Save predictions.npz for caching metric depth data
|
| 167 |
+
self._save_predictions_cache(target_dir, prediction)
|
| 168 |
+
|
| 169 |
+
# Process results
|
| 170 |
+
processed_data = self._process_results(target_dir, prediction, image_paths)
|
| 171 |
+
|
| 172 |
+
# Clean up using centralized memory utilities for consistency with backend
|
| 173 |
+
cleanup_cuda_memory()
|
| 174 |
+
|
| 175 |
+
return prediction, processed_data
|
| 176 |
+
|
| 177 |
+
def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None:
|
| 178 |
+
"""
|
| 179 |
+
Save predictions data to predictions.npz for caching.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
target_dir: Directory to save the cache
|
| 183 |
+
prediction: Model prediction object
|
| 184 |
+
"""
|
| 185 |
+
try:
|
| 186 |
+
output_file = os.path.join(target_dir, "predictions.npz")
|
| 187 |
+
|
| 188 |
+
# Build save dict with prediction data
|
| 189 |
+
save_dict = {}
|
| 190 |
+
|
| 191 |
+
# Save processed images if available
|
| 192 |
+
if prediction.processed_images is not None:
|
| 193 |
+
save_dict["images"] = prediction.processed_images
|
| 194 |
+
|
| 195 |
+
# Save depth data
|
| 196 |
+
if prediction.depth is not None:
|
| 197 |
+
save_dict["depths"] = np.round(prediction.depth, 6)
|
| 198 |
+
|
| 199 |
+
# Save confidence if available
|
| 200 |
+
if prediction.conf is not None:
|
| 201 |
+
save_dict["conf"] = np.round(prediction.conf, 2)
|
| 202 |
+
|
| 203 |
+
# Save camera parameters
|
| 204 |
+
if prediction.extrinsics is not None:
|
| 205 |
+
save_dict["extrinsics"] = prediction.extrinsics
|
| 206 |
+
if prediction.intrinsics is not None:
|
| 207 |
+
save_dict["intrinsics"] = prediction.intrinsics
|
| 208 |
+
|
| 209 |
+
# Save to file
|
| 210 |
+
np.savez_compressed(output_file, **save_dict)
|
| 211 |
+
print(f"Saved predictions cache to: {output_file}")
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"Warning: Failed to save predictions cache: {e}")
|
| 215 |
+
|
| 216 |
+
def _process_results(
|
| 217 |
+
self, target_dir: str, prediction: Any, image_paths: list
|
| 218 |
+
) -> Dict[int, Dict[str, Any]]:
|
| 219 |
+
"""
|
| 220 |
+
Process model results into structured data.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
target_dir: Directory containing results
|
| 224 |
+
prediction: Model prediction object
|
| 225 |
+
image_paths: List of input image paths
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Dictionary containing processed data for each view
|
| 229 |
+
"""
|
| 230 |
+
processed_data = {}
|
| 231 |
+
|
| 232 |
+
# Read generated depth visualization files
|
| 233 |
+
depth_vis_dir = os.path.join(target_dir, "depth_vis")
|
| 234 |
+
|
| 235 |
+
if os.path.exists(depth_vis_dir):
|
| 236 |
+
depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg")))
|
| 237 |
+
for i, depth_file in enumerate(depth_files):
|
| 238 |
+
# Use processed images directly from API
|
| 239 |
+
processed_image = None
|
| 240 |
+
if prediction.processed_images is not None and i < len(
|
| 241 |
+
prediction.processed_images
|
| 242 |
+
):
|
| 243 |
+
processed_image = prediction.processed_images[i]
|
| 244 |
+
|
| 245 |
+
processed_data[i] = {
|
| 246 |
+
"depth_image": depth_file,
|
| 247 |
+
"image": processed_image,
|
| 248 |
+
"original_image_path": image_paths[i] if i < len(image_paths) else None,
|
| 249 |
+
"depth": prediction.depth[i] if i < len(prediction.depth) else None,
|
| 250 |
+
"intrinsics": (
|
| 251 |
+
prediction.intrinsics[i]
|
| 252 |
+
if prediction.intrinsics is not None and i < len(prediction.intrinsics)
|
| 253 |
+
else None
|
| 254 |
+
),
|
| 255 |
+
"mask": None, # No mask information available
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
return processed_data
|
| 259 |
+
|
| 260 |
+
# cleanup() removed: call cleanup_cuda_memory() directly where needed.
|
Depth-Anything-3/src/depth_anything_3/app/modules/ui_components.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
UI components module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module contains UI component definitions and layout functions.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Dict, List, Tuple
|
| 23 |
+
import gradio as gr
|
| 24 |
+
|
| 25 |
+
from depth_anything_3.app.modules.utils import get_logo_base64, get_scene_info
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class UIComponents:
|
| 29 |
+
"""
|
| 30 |
+
Handles UI component creation and layout for the Gradio app.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
"""Initialize the UI components handler."""
|
| 35 |
+
|
| 36 |
+
def create_upload_section(self) -> Tuple[gr.Video, gr.Slider, gr.File, gr.Gallery]:
|
| 37 |
+
"""
|
| 38 |
+
Create the upload section with video, images, and gallery components.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A tuple of Gradio components: (input_video, s_time_interval, input_images, image_gallery).
|
| 42 |
+
"""
|
| 43 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 44 |
+
s_time_interval = gr.Slider(
|
| 45 |
+
minimum=0.1,
|
| 46 |
+
maximum=60,
|
| 47 |
+
value=10,
|
| 48 |
+
step=0.1,
|
| 49 |
+
label="Sampling FPS (Frames Per Second)",
|
| 50 |
+
interactive=True,
|
| 51 |
+
visible=True,
|
| 52 |
+
)
|
| 53 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
| 54 |
+
image_gallery = gr.Gallery(
|
| 55 |
+
label="Preview",
|
| 56 |
+
columns=4,
|
| 57 |
+
height="300px",
|
| 58 |
+
show_download_button=True,
|
| 59 |
+
object_fit="contain",
|
| 60 |
+
preview=True,
|
| 61 |
+
interactive=False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return input_video, s_time_interval, input_images, image_gallery
|
| 65 |
+
|
| 66 |
+
def create_3d_viewer_section(self) -> gr.Model3D:
|
| 67 |
+
"""
|
| 68 |
+
Create the 3D viewer component.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
3D model viewer component
|
| 72 |
+
"""
|
| 73 |
+
return gr.Model3D(
|
| 74 |
+
height=520,
|
| 75 |
+
zoom_speed=0.5,
|
| 76 |
+
pan_speed=0.5,
|
| 77 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 78 |
+
key="persistent_3d_viewer",
|
| 79 |
+
elem_id="reconstruction_3d_viewer",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def create_nvs_video(self) -> Tuple[gr.Video, gr.Markdown]:
|
| 83 |
+
"""
|
| 84 |
+
Create the 3DGS rendered video display component and info message.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Tuple of (video component, info message component)
|
| 88 |
+
"""
|
| 89 |
+
with gr.Column():
|
| 90 |
+
gs_info = gr.Markdown(
|
| 91 |
+
(
|
| 92 |
+
"‼️ **3D Gaussian Splatting rendering is currently DISABLED.** <br><br><br>"
|
| 93 |
+
"To render novel views from 3DGS, "
|
| 94 |
+
"enable **Infer 3D Gaussian Splatting** below. <br>"
|
| 95 |
+
"Next, in **Visualization Options**, "
|
| 96 |
+
"*optionally* configure the **rendering trajectory** (default: smooth) "
|
| 97 |
+
"and **video quality** (default: low), "
|
| 98 |
+
"then click **Reconstruct**."
|
| 99 |
+
),
|
| 100 |
+
visible=True,
|
| 101 |
+
height=520,
|
| 102 |
+
)
|
| 103 |
+
gs_video = gr.Video(
|
| 104 |
+
height=520,
|
| 105 |
+
label="3DGS Rendered NVS Video (depth shown for reference only)",
|
| 106 |
+
interactive=False,
|
| 107 |
+
visible=False,
|
| 108 |
+
)
|
| 109 |
+
return gs_video, gs_info
|
| 110 |
+
|
| 111 |
+
def create_depth_section(self) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image]:
|
| 112 |
+
"""
|
| 113 |
+
Create the depth visualization section.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
A tuple of (prev_depth_btn, depth_view_selector, next_depth_btn, depth_map)
|
| 117 |
+
"""
|
| 118 |
+
with gr.Row(elem_classes=["navigation-row"]):
|
| 119 |
+
prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
|
| 120 |
+
depth_view_selector = gr.Dropdown(
|
| 121 |
+
choices=["View 1"],
|
| 122 |
+
value="View 1",
|
| 123 |
+
label="Select View",
|
| 124 |
+
scale=2,
|
| 125 |
+
interactive=True,
|
| 126 |
+
allow_custom_value=True,
|
| 127 |
+
)
|
| 128 |
+
next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
|
| 129 |
+
depth_map = gr.Image(
|
| 130 |
+
type="numpy",
|
| 131 |
+
label="Colorized Depth Map",
|
| 132 |
+
format="png",
|
| 133 |
+
interactive=False,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return prev_depth_btn, depth_view_selector, next_depth_btn, depth_map
|
| 137 |
+
|
| 138 |
+
def create_measure_section(
|
| 139 |
+
self,
|
| 140 |
+
) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image, gr.Image, gr.Markdown]:
|
| 141 |
+
"""
|
| 142 |
+
Create the measurement section.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
A tuple of (prev_measure_btn, measure_view_selector, next_measure_btn, measure_image,
|
| 146 |
+
measure_depth_image, measure_text)
|
| 147 |
+
"""
|
| 148 |
+
from depth_anything_3.app.css_and_html import MEASURE_INSTRUCTIONS_HTML
|
| 149 |
+
|
| 150 |
+
gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
|
| 151 |
+
with gr.Row(elem_classes=["navigation-row"]):
|
| 152 |
+
prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1)
|
| 153 |
+
measure_view_selector = gr.Dropdown(
|
| 154 |
+
choices=["View 1"],
|
| 155 |
+
value="View 1",
|
| 156 |
+
label="Select View",
|
| 157 |
+
scale=2,
|
| 158 |
+
interactive=True,
|
| 159 |
+
allow_custom_value=True,
|
| 160 |
+
)
|
| 161 |
+
next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
|
| 162 |
+
with gr.Row():
|
| 163 |
+
measure_image = gr.Image(
|
| 164 |
+
type="numpy",
|
| 165 |
+
show_label=False,
|
| 166 |
+
format="webp",
|
| 167 |
+
interactive=False,
|
| 168 |
+
sources=[],
|
| 169 |
+
label="RGB Image",
|
| 170 |
+
scale=1,
|
| 171 |
+
height=275,
|
| 172 |
+
)
|
| 173 |
+
measure_depth_image = gr.Image(
|
| 174 |
+
type="numpy",
|
| 175 |
+
show_label=False,
|
| 176 |
+
format="webp",
|
| 177 |
+
interactive=False,
|
| 178 |
+
sources=[],
|
| 179 |
+
label="Depth Visualization (Right Half)",
|
| 180 |
+
scale=1,
|
| 181 |
+
height=275,
|
| 182 |
+
)
|
| 183 |
+
gr.Markdown(
|
| 184 |
+
"**Note:** Images have been adjusted to model processing size. "
|
| 185 |
+
"Click two points on the RGB image to measure distance."
|
| 186 |
+
)
|
| 187 |
+
measure_text = gr.Markdown("")
|
| 188 |
+
|
| 189 |
+
return (
|
| 190 |
+
prev_measure_btn,
|
| 191 |
+
measure_view_selector,
|
| 192 |
+
next_measure_btn,
|
| 193 |
+
measure_image,
|
| 194 |
+
measure_depth_image,
|
| 195 |
+
measure_text,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def create_inference_control_section(self) -> Tuple[gr.Dropdown, gr.Checkbox, gr.Dropdown]:
|
| 199 |
+
"""
|
| 200 |
+
Create the inference control section (before inference).
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tuple of (process_res_method_dropdown, infer_gs, ref_view_strategy)
|
| 204 |
+
"""
|
| 205 |
+
with gr.Row():
|
| 206 |
+
process_res_method_dropdown = gr.Dropdown(
|
| 207 |
+
choices=["high_res", "low_res"],
|
| 208 |
+
value="low_res",
|
| 209 |
+
label="Image Processing Method",
|
| 210 |
+
info="low_res for much more images",
|
| 211 |
+
scale=1,
|
| 212 |
+
)
|
| 213 |
+
# Modify line 220, add color class
|
| 214 |
+
infer_gs = gr.Checkbox(
|
| 215 |
+
label="Infer 3D Gaussian Splatting",
|
| 216 |
+
value=False,
|
| 217 |
+
info=(
|
| 218 |
+
'Enable novel view rendering from 3DGS (<i class="fas fa-triangle-exclamation '
|
| 219 |
+
'fa-color-red"></i> requires extra processing time)'
|
| 220 |
+
),
|
| 221 |
+
scale=1,
|
| 222 |
+
)
|
| 223 |
+
ref_view_strategy = gr.Dropdown(
|
| 224 |
+
choices=["saddle_balanced", "saddle_sim_range", "first", "middle"],
|
| 225 |
+
value="saddle_balanced",
|
| 226 |
+
label="Reference View Strategy",
|
| 227 |
+
info="Strategy for selecting reference view from multiple inputs",
|
| 228 |
+
scale=1,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return (process_res_method_dropdown, infer_gs, ref_view_strategy)
|
| 232 |
+
|
| 233 |
+
def create_display_control_section(
|
| 234 |
+
self,
|
| 235 |
+
) -> Tuple[
|
| 236 |
+
gr.Checkbox,
|
| 237 |
+
gr.Checkbox,
|
| 238 |
+
gr.Checkbox,
|
| 239 |
+
gr.Slider,
|
| 240 |
+
gr.Slider,
|
| 241 |
+
gr.Dropdown,
|
| 242 |
+
gr.Dropdown,
|
| 243 |
+
gr.Button,
|
| 244 |
+
gr.ClearButton,
|
| 245 |
+
]:
|
| 246 |
+
"""
|
| 247 |
+
Create the display control section (options for visualization).
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Tuple of display control components including buttons
|
| 251 |
+
"""
|
| 252 |
+
with gr.Column():
|
| 253 |
+
# 3DGS options at the top
|
| 254 |
+
with gr.Row():
|
| 255 |
+
gs_trj_mode = gr.Dropdown(
|
| 256 |
+
choices=["smooth", "extend"],
|
| 257 |
+
value="smooth",
|
| 258 |
+
label=("Rendering trajectory for 3DGS viewpoints (requires n_views ≥ 2)"),
|
| 259 |
+
info=("'smooth' for view interpolation; 'extend' for longer trajectory"),
|
| 260 |
+
visible=False, # initially hidden
|
| 261 |
+
)
|
| 262 |
+
gs_video_quality = gr.Dropdown(
|
| 263 |
+
choices=["low", "medium", "high"],
|
| 264 |
+
value="low",
|
| 265 |
+
label=("Video quality for 3DGS rendered outputs"),
|
| 266 |
+
info=("'low' for faster loading speed; 'high' for better visual quality"),
|
| 267 |
+
visible=False, # initially hidden
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Reconstruct and Clear buttons (before Visualization Options)
|
| 271 |
+
with gr.Row():
|
| 272 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 273 |
+
clear_btn = gr.ClearButton(scale=1)
|
| 274 |
+
|
| 275 |
+
gr.Markdown("### Visualization Options: (Click Reconstruct to update)")
|
| 276 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 277 |
+
filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
| 278 |
+
filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
|
| 279 |
+
save_percentage = gr.Slider(
|
| 280 |
+
minimum=0,
|
| 281 |
+
maximum=100,
|
| 282 |
+
value=10,
|
| 283 |
+
step=1,
|
| 284 |
+
label="Filter Percentage",
|
| 285 |
+
info="Confidence Threshold (%): Higher values filter more points.",
|
| 286 |
+
)
|
| 287 |
+
num_max_points = gr.Slider(
|
| 288 |
+
minimum=1000,
|
| 289 |
+
maximum=100000,
|
| 290 |
+
value=1000,
|
| 291 |
+
step=1000,
|
| 292 |
+
label="Max Points (K points)",
|
| 293 |
+
info="Maximum number of points to export to GLB (in thousands)",
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
return (
|
| 297 |
+
show_cam,
|
| 298 |
+
filter_black_bg,
|
| 299 |
+
filter_white_bg,
|
| 300 |
+
save_percentage,
|
| 301 |
+
num_max_points,
|
| 302 |
+
gs_trj_mode,
|
| 303 |
+
gs_video_quality,
|
| 304 |
+
submit_btn,
|
| 305 |
+
clear_btn,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def create_control_section(
|
| 309 |
+
self,
|
| 310 |
+
) -> Tuple[
|
| 311 |
+
gr.Button,
|
| 312 |
+
gr.ClearButton,
|
| 313 |
+
gr.Dropdown,
|
| 314 |
+
gr.Checkbox,
|
| 315 |
+
gr.Checkbox,
|
| 316 |
+
gr.Checkbox,
|
| 317 |
+
gr.Checkbox,
|
| 318 |
+
gr.Checkbox,
|
| 319 |
+
gr.Dropdown,
|
| 320 |
+
gr.Checkbox,
|
| 321 |
+
gr.Textbox,
|
| 322 |
+
]:
|
| 323 |
+
"""
|
| 324 |
+
Create the control section with buttons and options.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
Tuple of control components
|
| 328 |
+
"""
|
| 329 |
+
with gr.Row():
|
| 330 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 331 |
+
clear_btn = gr.ClearButton(
|
| 332 |
+
scale=1,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
with gr.Row():
|
| 336 |
+
frame_filter = gr.Dropdown(
|
| 337 |
+
choices=["All"], value="All", label="Show Points from Frame"
|
| 338 |
+
)
|
| 339 |
+
with gr.Column():
|
| 340 |
+
gr.Markdown("### Visualization Option: (Click Reconstruct to update)")
|
| 341 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 342 |
+
show_mesh = gr.Checkbox(label="Show Mesh", value=True)
|
| 343 |
+
filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
| 344 |
+
filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
|
| 345 |
+
gr.Markdown("### Reconstruction Options: (updated on next run)")
|
| 346 |
+
apply_mask_checkbox = gr.Checkbox(
|
| 347 |
+
label="Apply mask for predicted ambiguous depth classes & edges",
|
| 348 |
+
value=True,
|
| 349 |
+
)
|
| 350 |
+
process_res_method_dropdown = gr.Dropdown(
|
| 351 |
+
choices=[
|
| 352 |
+
"upper_bound_resize",
|
| 353 |
+
"upper_bound_crop",
|
| 354 |
+
"lower_bound_resize",
|
| 355 |
+
"lower_bound_crop",
|
| 356 |
+
],
|
| 357 |
+
value="upper_bound_resize",
|
| 358 |
+
label="Image Processing Method",
|
| 359 |
+
info="Method for resizing input images",
|
| 360 |
+
)
|
| 361 |
+
save_to_gallery_checkbox = gr.Checkbox(
|
| 362 |
+
label="Save to Gallery",
|
| 363 |
+
value=False,
|
| 364 |
+
info="Save current reconstruction results to gallery directory",
|
| 365 |
+
)
|
| 366 |
+
gallery_name_input = gr.Textbox(
|
| 367 |
+
label="Gallery Name",
|
| 368 |
+
placeholder="Enter a name for the gallery folder",
|
| 369 |
+
value="",
|
| 370 |
+
info="Leave empty for auto-generated name with timestamp",
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
return (
|
| 374 |
+
submit_btn,
|
| 375 |
+
clear_btn,
|
| 376 |
+
frame_filter,
|
| 377 |
+
show_cam,
|
| 378 |
+
show_mesh,
|
| 379 |
+
filter_black_bg,
|
| 380 |
+
filter_white_bg,
|
| 381 |
+
apply_mask_checkbox,
|
| 382 |
+
process_res_method_dropdown,
|
| 383 |
+
save_to_gallery_checkbox,
|
| 384 |
+
gallery_name_input,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
def create_example_scenes_section(self) -> List[Dict[str, Any]]:
|
| 388 |
+
"""
|
| 389 |
+
Create the example scenes section.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List of scene information dictionaries
|
| 393 |
+
"""
|
| 394 |
+
# Get workspace directory from environment variable
|
| 395 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 396 |
+
examples_dir = os.path.join(workspace_dir, "examples")
|
| 397 |
+
|
| 398 |
+
# Get scene information
|
| 399 |
+
scenes = get_scene_info(examples_dir)
|
| 400 |
+
|
| 401 |
+
return scenes
|
| 402 |
+
|
| 403 |
+
def create_example_scene_grid(self, scenes: List[Dict[str, Any]]) -> List[gr.Image]:
|
| 404 |
+
"""
|
| 405 |
+
Create the example scene grid.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
scenes: List of scene information dictionaries
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
List of scene image components
|
| 412 |
+
"""
|
| 413 |
+
scene_components = []
|
| 414 |
+
|
| 415 |
+
if scenes:
|
| 416 |
+
for i in range(0, len(scenes), 4): # Process 4 scenes per row
|
| 417 |
+
with gr.Row():
|
| 418 |
+
for j in range(4):
|
| 419 |
+
scene_idx = i + j
|
| 420 |
+
if scene_idx < len(scenes):
|
| 421 |
+
scene = scenes[scene_idx]
|
| 422 |
+
with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
|
| 423 |
+
# Clickable thumbnail
|
| 424 |
+
scene_img = gr.Image(
|
| 425 |
+
value=scene["thumbnail"],
|
| 426 |
+
height=150,
|
| 427 |
+
interactive=False,
|
| 428 |
+
show_label=False,
|
| 429 |
+
elem_id=f"scene_thumb_{scene['name']}",
|
| 430 |
+
sources=[],
|
| 431 |
+
)
|
| 432 |
+
scene_components.append(scene_img)
|
| 433 |
+
|
| 434 |
+
# Scene name and image count as text below thumbnail
|
| 435 |
+
gr.Markdown(
|
| 436 |
+
f"**{scene['name']}** \n {scene['num_images']} images",
|
| 437 |
+
elem_classes=["scene-info"],
|
| 438 |
+
)
|
| 439 |
+
else:
|
| 440 |
+
# Empty column to maintain grid structure
|
| 441 |
+
with gr.Column(scale=1):
|
| 442 |
+
pass
|
| 443 |
+
|
| 444 |
+
return scene_components
|
| 445 |
+
|
| 446 |
+
def create_header_section(self) -> gr.HTML:
|
| 447 |
+
"""
|
| 448 |
+
Create the header section with logo and title.
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
Header HTML component
|
| 452 |
+
"""
|
| 453 |
+
from depth_anything_3.app.css_and_html import get_header_html
|
| 454 |
+
|
| 455 |
+
return gr.HTML(get_header_html(get_logo_base64()))
|
| 456 |
+
|
| 457 |
+
def create_description_section(self) -> gr.HTML:
|
| 458 |
+
"""
|
| 459 |
+
Create the description section.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Description HTML component
|
| 463 |
+
"""
|
| 464 |
+
from depth_anything_3.app.css_and_html import get_description_html
|
| 465 |
+
|
| 466 |
+
return gr.HTML(get_description_html())
|
| 467 |
+
|
| 468 |
+
def create_acknowledgements_section(self) -> gr.HTML:
|
| 469 |
+
"""
|
| 470 |
+
Create the acknowledgements section.
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
Acknowledgements HTML component
|
| 474 |
+
"""
|
| 475 |
+
from depth_anything_3.app.css_and_html import get_acknowledgements_html
|
| 476 |
+
|
| 477 |
+
return gr.HTML(get_acknowledgements_html())
|
Depth-Anything-3/src/depth_anything_3/app/modules/utils.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Utility functions for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module contains helper functions for data processing, visualization,
|
| 19 |
+
and file operations.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]:
|
| 31 |
+
"""
|
| 32 |
+
Create a colored depth visualization.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
depth: Depth array
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Colored depth visualization or None
|
| 39 |
+
"""
|
| 40 |
+
if depth is None:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
# Normalize depth to 0-1 range
|
| 44 |
+
depth_min = depth[depth > 0].min() if (depth > 0).any() else 0
|
| 45 |
+
depth_max = depth.max()
|
| 46 |
+
|
| 47 |
+
if depth_max <= depth_min:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# Normalize depth
|
| 51 |
+
depth_norm = (depth - depth_min) / (depth_max - depth_min)
|
| 52 |
+
depth_norm = np.clip(depth_norm, 0, 1)
|
| 53 |
+
|
| 54 |
+
# Apply colormap (using matplotlib's viridis colormap)
|
| 55 |
+
import matplotlib.cm as cm
|
| 56 |
+
|
| 57 |
+
# Convert to colored image
|
| 58 |
+
depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel
|
| 59 |
+
depth_colored = (depth_colored * 255).astype(np.uint8)
|
| 60 |
+
|
| 61 |
+
return depth_colored
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def save_to_gallery_func(
|
| 65 |
+
target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None
|
| 66 |
+
) -> Tuple[bool, str]:
|
| 67 |
+
"""
|
| 68 |
+
Save the current reconstruction results to the gallery directory.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
target_dir: Source directory containing reconstruction results
|
| 72 |
+
processed_data: Processed data dictionary
|
| 73 |
+
gallery_name: Name for the gallery folder
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Tuple of (success, message)
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
# Get gallery directory from environment variable or use default
|
| 80 |
+
gallery_dir = os.environ.get(
|
| 81 |
+
"DA3_GALLERY_DIR",
|
| 82 |
+
"workspace/gallery",
|
| 83 |
+
)
|
| 84 |
+
if not os.path.exists(gallery_dir):
|
| 85 |
+
os.makedirs(gallery_dir)
|
| 86 |
+
|
| 87 |
+
# Use provided name or create a unique name
|
| 88 |
+
if gallery_name is None or gallery_name.strip() == "":
|
| 89 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 90 |
+
gallery_name = f"reconstruction_{timestamp}"
|
| 91 |
+
|
| 92 |
+
gallery_path = os.path.join(gallery_dir, gallery_name)
|
| 93 |
+
|
| 94 |
+
# Check if directory already exists
|
| 95 |
+
if os.path.exists(gallery_path):
|
| 96 |
+
return False, f"Save failed: folder '{gallery_name}' already exists"
|
| 97 |
+
|
| 98 |
+
# Create the gallery directory
|
| 99 |
+
os.makedirs(gallery_path, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
# Copy GLB file
|
| 102 |
+
glb_source = os.path.join(target_dir, "scene.glb")
|
| 103 |
+
glb_dest = os.path.join(gallery_path, "scene.glb")
|
| 104 |
+
if os.path.exists(glb_source):
|
| 105 |
+
shutil.copy2(glb_source, glb_dest)
|
| 106 |
+
|
| 107 |
+
# Copy depth visualization images
|
| 108 |
+
depth_vis_dir = os.path.join(target_dir, "depth_vis")
|
| 109 |
+
if os.path.exists(depth_vis_dir):
|
| 110 |
+
gallery_depth_vis = os.path.join(gallery_path, "depth_vis")
|
| 111 |
+
shutil.copytree(depth_vis_dir, gallery_depth_vis)
|
| 112 |
+
|
| 113 |
+
# Copy original images
|
| 114 |
+
images_source = os.path.join(target_dir, "images")
|
| 115 |
+
if os.path.exists(images_source):
|
| 116 |
+
gallery_images = os.path.join(gallery_path, "images")
|
| 117 |
+
shutil.copytree(images_source, gallery_images)
|
| 118 |
+
|
| 119 |
+
scene_preview_source = os.path.join(target_dir, "scene.jpg")
|
| 120 |
+
scene_preview_dest = os.path.join(gallery_path, "scene.jpg")
|
| 121 |
+
shutil.copy2(scene_preview_source, scene_preview_dest)
|
| 122 |
+
|
| 123 |
+
# Save metadata
|
| 124 |
+
metadata = {
|
| 125 |
+
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
| 126 |
+
"num_images": len(processed_data) if processed_data else 0,
|
| 127 |
+
"gallery_name": gallery_name,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
with open(os.path.join(gallery_path, "metadata.json"), "w") as f:
|
| 131 |
+
json.dump(metadata, f, indent=2)
|
| 132 |
+
|
| 133 |
+
print(f"Saved reconstruction to gallery: {gallery_path}")
|
| 134 |
+
return True, f"Save successful: saved to {gallery_path}"
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"Error saving to gallery: {e}")
|
| 138 |
+
return False, f"Save failed: {str(e)}"
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]:
|
| 142 |
+
"""
|
| 143 |
+
Get information about scenes in the examples directory.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
examples_dir: Path to examples directory
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
List of scene information dictionaries
|
| 150 |
+
"""
|
| 151 |
+
import glob
|
| 152 |
+
|
| 153 |
+
scenes = []
|
| 154 |
+
if not os.path.exists(examples_dir):
|
| 155 |
+
return scenes
|
| 156 |
+
|
| 157 |
+
for scene_folder in sorted(os.listdir(examples_dir)):
|
| 158 |
+
scene_path = os.path.join(examples_dir, scene_folder)
|
| 159 |
+
if os.path.isdir(scene_path):
|
| 160 |
+
# Find all image files in the scene folder
|
| 161 |
+
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
|
| 162 |
+
image_files = []
|
| 163 |
+
for ext in image_extensions:
|
| 164 |
+
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
|
| 165 |
+
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
|
| 166 |
+
|
| 167 |
+
if image_files:
|
| 168 |
+
# Sort images and get the first one for thumbnail
|
| 169 |
+
image_files = sorted(image_files)
|
| 170 |
+
first_image = image_files[0]
|
| 171 |
+
num_images = len(image_files)
|
| 172 |
+
|
| 173 |
+
scenes.append(
|
| 174 |
+
{
|
| 175 |
+
"name": scene_folder,
|
| 176 |
+
"path": scene_path,
|
| 177 |
+
"thumbnail": first_image,
|
| 178 |
+
"num_images": num_images,
|
| 179 |
+
"image_files": image_files,
|
| 180 |
+
}
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return scenes
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# NOTE: cleanup was moved to a single canonical helper in
|
| 187 |
+
# `depth_anything_3.utils.memory.cleanup_cuda_memory`.
|
| 188 |
+
# Callers should import and call that directly instead of using this module.
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_logo_base64() -> Optional[str]:
|
| 192 |
+
"""
|
| 193 |
+
Convert WAI logo to base64 for embedding in HTML.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Base64 encoded logo string or None
|
| 197 |
+
"""
|
| 198 |
+
import base64
|
| 199 |
+
|
| 200 |
+
logo_path = "examples/WAI-Logo/wai_logo.png"
|
| 201 |
+
try:
|
| 202 |
+
with open(logo_path, "rb") as img_file:
|
| 203 |
+
img_data = img_file.read()
|
| 204 |
+
base64_str = base64.b64encode(img_data).decode()
|
| 205 |
+
return f"data:image/png;base64,{base64_str}"
|
| 206 |
+
except FileNotFoundError:
|
| 207 |
+
return None
|
Depth-Anything-3/src/depth_anything_3/app/modules/visualization.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Visualization module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles visualization updates, navigation, and measurement functionality.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 23 |
+
import cv2
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class VisualizationHandler:
|
| 29 |
+
"""
|
| 30 |
+
Handles visualization updates and navigation for the Gradio app.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
"""Initialize the visualization handler."""
|
| 35 |
+
|
| 36 |
+
def update_view_selectors(
|
| 37 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
|
| 38 |
+
) -> Tuple[gr.Dropdown, gr.Dropdown]:
|
| 39 |
+
"""
|
| 40 |
+
Update view selector dropdowns based on available views.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
processed_data: Processed data dictionary
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple of (depth_view_selector, measure_view_selector)
|
| 47 |
+
"""
|
| 48 |
+
if processed_data is None or len(processed_data) == 0:
|
| 49 |
+
choices = ["View 1"]
|
| 50 |
+
else:
|
| 51 |
+
num_views = len(processed_data)
|
| 52 |
+
choices = [f"View {i + 1}" for i in range(num_views)]
|
| 53 |
+
|
| 54 |
+
return (
|
| 55 |
+
gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
|
| 56 |
+
gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def get_view_data_by_index(
|
| 60 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 61 |
+
) -> Optional[Dict[str, Any]]:
|
| 62 |
+
"""
|
| 63 |
+
Get view data by index, handling bounds.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
processed_data: Processed data dictionary
|
| 67 |
+
view_index: Index of the view to get
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
View data dictionary or None
|
| 71 |
+
"""
|
| 72 |
+
if processed_data is None or len(processed_data) == 0:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
view_keys = list(processed_data.keys())
|
| 76 |
+
if view_index < 0 or view_index >= len(view_keys):
|
| 77 |
+
view_index = 0
|
| 78 |
+
|
| 79 |
+
return processed_data[view_keys[view_index]]
|
| 80 |
+
|
| 81 |
+
def update_depth_view(
|
| 82 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 83 |
+
) -> Optional[str]:
|
| 84 |
+
"""
|
| 85 |
+
Update depth view for a specific view index.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
processed_data: Processed data dictionary
|
| 89 |
+
view_index: Index of the view to update
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Path to depth visualization image or None
|
| 93 |
+
"""
|
| 94 |
+
view_data = self.get_view_data_by_index(processed_data, view_index)
|
| 95 |
+
if view_data is None or view_data.get("depth_image") is None:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
# Return the depth visualization image directly
|
| 99 |
+
return view_data["depth_image"]
|
| 100 |
+
|
| 101 |
+
def navigate_depth_view(
|
| 102 |
+
self,
|
| 103 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 104 |
+
current_selector_value: str,
|
| 105 |
+
direction: int,
|
| 106 |
+
) -> Tuple[str, Optional[str]]:
|
| 107 |
+
"""
|
| 108 |
+
Navigate depth view (direction: -1 for previous, +1 for next).
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
processed_data: Processed data dictionary
|
| 112 |
+
current_selector_value: Current selector value
|
| 113 |
+
direction: Direction to navigate (-1 for previous, +1 for next)
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Tuple of (new_selector_value, depth_vis)
|
| 117 |
+
"""
|
| 118 |
+
if processed_data is None or len(processed_data) == 0:
|
| 119 |
+
return "View 1", None
|
| 120 |
+
|
| 121 |
+
# Parse current view number
|
| 122 |
+
try:
|
| 123 |
+
current_view = int(current_selector_value.split()[1]) - 1
|
| 124 |
+
except: # noqa
|
| 125 |
+
current_view = 0
|
| 126 |
+
|
| 127 |
+
num_views = len(processed_data)
|
| 128 |
+
new_view = (current_view + direction) % num_views
|
| 129 |
+
|
| 130 |
+
new_selector_value = f"View {new_view + 1}"
|
| 131 |
+
depth_vis = self.update_depth_view(processed_data, new_view)
|
| 132 |
+
|
| 133 |
+
return new_selector_value, depth_vis
|
| 134 |
+
|
| 135 |
+
def update_measure_view(
|
| 136 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 137 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
|
| 138 |
+
"""
|
| 139 |
+
Update measure view for a specific view index.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
processed_data: Processed data dictionary
|
| 143 |
+
view_index: Index of the view to update
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tuple of (measure_image, depth_right_half, measure_points)
|
| 147 |
+
"""
|
| 148 |
+
view_data = self.get_view_data_by_index(processed_data, view_index)
|
| 149 |
+
if view_data is None:
|
| 150 |
+
return None, None, [] # image, depth_right_half, measure_points
|
| 151 |
+
|
| 152 |
+
# Get the processed (resized) image
|
| 153 |
+
if "image" in view_data and view_data["image"] is not None:
|
| 154 |
+
image = view_data["image"].copy()
|
| 155 |
+
else:
|
| 156 |
+
return None, None, []
|
| 157 |
+
|
| 158 |
+
# Ensure image is in uint8 format
|
| 159 |
+
if image.dtype != np.uint8:
|
| 160 |
+
if image.max() <= 1.0:
|
| 161 |
+
image = (image * 255).astype(np.uint8)
|
| 162 |
+
else:
|
| 163 |
+
image = image.astype(np.uint8)
|
| 164 |
+
|
| 165 |
+
# Extract right half of the depth visualization (pure depth part)
|
| 166 |
+
depth_image_path = view_data.get("depth_image", None)
|
| 167 |
+
depth_right_half = None
|
| 168 |
+
|
| 169 |
+
if depth_image_path and os.path.exists(depth_image_path):
|
| 170 |
+
try:
|
| 171 |
+
# Load the combined depth visualization image
|
| 172 |
+
depth_combined = cv2.imread(depth_image_path)
|
| 173 |
+
depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB)
|
| 174 |
+
if depth_combined is not None:
|
| 175 |
+
height, width = depth_combined.shape[:2]
|
| 176 |
+
# Extract right half (depth visualization part)
|
| 177 |
+
depth_right_half = depth_combined[:, width // 2 :]
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error extracting depth right half: {e}")
|
| 180 |
+
|
| 181 |
+
return image, depth_right_half, []
|
| 182 |
+
|
| 183 |
+
def navigate_measure_view(
|
| 184 |
+
self,
|
| 185 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 186 |
+
current_selector_value: str,
|
| 187 |
+
direction: int,
|
| 188 |
+
) -> Tuple[str, Optional[np.ndarray], Optional[str], List]:
|
| 189 |
+
"""
|
| 190 |
+
Navigate measure view (direction: -1 for previous, +1 for next).
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
processed_data: Processed data dictionary
|
| 194 |
+
current_selector_value: Current selector value
|
| 195 |
+
direction: Direction to navigate (-1 for previous, +1 for next)
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tuple of (new_selector_value, measure_image, depth_image_path, measure_points)
|
| 199 |
+
"""
|
| 200 |
+
if processed_data is None or len(processed_data) == 0:
|
| 201 |
+
return "View 1", None, None, []
|
| 202 |
+
|
| 203 |
+
# Parse current view number
|
| 204 |
+
try:
|
| 205 |
+
current_view = int(current_selector_value.split()[1]) - 1
|
| 206 |
+
except: # noqa
|
| 207 |
+
current_view = 0
|
| 208 |
+
|
| 209 |
+
num_views = len(processed_data)
|
| 210 |
+
new_view = (current_view + direction) % num_views
|
| 211 |
+
|
| 212 |
+
new_selector_value = f"View {new_view + 1}"
|
| 213 |
+
measure_image, depth_right_half, measure_points = self.update_measure_view(
|
| 214 |
+
processed_data, new_view
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return new_selector_value, measure_image, depth_right_half, measure_points
|
| 218 |
+
|
| 219 |
+
def populate_visualization_tabs(
|
| 220 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
|
| 221 |
+
) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]:
|
| 222 |
+
"""
|
| 223 |
+
Populate the depth and measure tabs with processed data.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
processed_data: Processed data dictionary
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Tuple of (depth_vis, measure_img, depth_image_path, measure_points)
|
| 230 |
+
"""
|
| 231 |
+
if processed_data is None or len(processed_data) == 0:
|
| 232 |
+
return None, None, None, []
|
| 233 |
+
|
| 234 |
+
# Use update function to get depth visualization
|
| 235 |
+
depth_vis = self.update_depth_view(processed_data, 0)
|
| 236 |
+
measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0)
|
| 237 |
+
|
| 238 |
+
return depth_vis, measure_img, depth_right_half, []
|
| 239 |
+
|
| 240 |
+
def reset_measure(
|
| 241 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
|
| 242 |
+
) -> Tuple[Optional[np.ndarray], List, str]:
|
| 243 |
+
"""
|
| 244 |
+
Reset measure points.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
processed_data: Processed data dictionary
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Tuple of (image, measure_points, text)
|
| 251 |
+
"""
|
| 252 |
+
if processed_data is None or len(processed_data) == 0:
|
| 253 |
+
return None, [], ""
|
| 254 |
+
|
| 255 |
+
# Return the first view image
|
| 256 |
+
first_view = list(processed_data.values())[0]
|
| 257 |
+
return first_view["image"], [], ""
|
| 258 |
+
|
| 259 |
+
def measure(
|
| 260 |
+
self,
|
| 261 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 262 |
+
measure_points: List,
|
| 263 |
+
current_view_selector: str,
|
| 264 |
+
event: gr.SelectData,
|
| 265 |
+
) -> List:
|
| 266 |
+
"""
|
| 267 |
+
Handle measurement on images.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
processed_data: Processed data dictionary
|
| 271 |
+
measure_points: List of current measure points
|
| 272 |
+
current_view_selector: Current view selector value
|
| 273 |
+
event: Gradio select event
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
List of [image, depth_right_half, measure_points, text]
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
print(f"Measure function called with selector: {current_view_selector}")
|
| 280 |
+
|
| 281 |
+
if processed_data is None or len(processed_data) == 0:
|
| 282 |
+
return [None, [], "No data available"]
|
| 283 |
+
|
| 284 |
+
# Use the currently selected view instead of always using the first view
|
| 285 |
+
try:
|
| 286 |
+
current_view_index = int(current_view_selector.split()[1]) - 1
|
| 287 |
+
except: # noqa
|
| 288 |
+
current_view_index = 0
|
| 289 |
+
|
| 290 |
+
print(f"Using view index: {current_view_index}")
|
| 291 |
+
|
| 292 |
+
# Get view data safely
|
| 293 |
+
if current_view_index < 0 or current_view_index >= len(processed_data):
|
| 294 |
+
current_view_index = 0
|
| 295 |
+
|
| 296 |
+
view_keys = list(processed_data.keys())
|
| 297 |
+
current_view = processed_data[view_keys[current_view_index]]
|
| 298 |
+
|
| 299 |
+
if current_view is None:
|
| 300 |
+
return [None, [], "No view data available"]
|
| 301 |
+
|
| 302 |
+
point2d = event.index[0], event.index[1]
|
| 303 |
+
print(f"Clicked point: {point2d}")
|
| 304 |
+
|
| 305 |
+
measure_points.append(point2d)
|
| 306 |
+
|
| 307 |
+
# Get image and depth visualization
|
| 308 |
+
image, depth_right_half, _ = self.update_measure_view(
|
| 309 |
+
processed_data, current_view_index
|
| 310 |
+
)
|
| 311 |
+
if image is None:
|
| 312 |
+
return [None, [], "No image available"]
|
| 313 |
+
|
| 314 |
+
image = image.copy()
|
| 315 |
+
|
| 316 |
+
# Ensure image is in uint8 format for proper cv2 operations
|
| 317 |
+
try:
|
| 318 |
+
if image.dtype != np.uint8:
|
| 319 |
+
if image.max() <= 1.0:
|
| 320 |
+
# Image is in [0, 1] range, convert to [0, 255]
|
| 321 |
+
image = (image * 255).astype(np.uint8)
|
| 322 |
+
else:
|
| 323 |
+
# Image is already in [0, 255] range
|
| 324 |
+
image = image.astype(np.uint8)
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"Image conversion error: {e}")
|
| 327 |
+
return [None, [], f"Image conversion error: {e}"]
|
| 328 |
+
|
| 329 |
+
# Draw circles for points
|
| 330 |
+
try:
|
| 331 |
+
for p in measure_points:
|
| 332 |
+
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
|
| 333 |
+
image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"Drawing error: {e}")
|
| 336 |
+
return [None, [], f"Drawing error: {e}"]
|
| 337 |
+
|
| 338 |
+
# Get depth information from processed_data
|
| 339 |
+
depth_text = ""
|
| 340 |
+
try:
|
| 341 |
+
for i, p in enumerate(measure_points):
|
| 342 |
+
if (
|
| 343 |
+
current_view["depth"] is not None
|
| 344 |
+
and 0 <= p[1] < current_view["depth"].shape[0]
|
| 345 |
+
and 0 <= p[0] < current_view["depth"].shape[1]
|
| 346 |
+
):
|
| 347 |
+
d = current_view["depth"][p[1], p[0]]
|
| 348 |
+
depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n"
|
| 349 |
+
else:
|
| 350 |
+
depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f"Depth text error: {e}")
|
| 353 |
+
depth_text = f"Error computing depth: {e}\n"
|
| 354 |
+
|
| 355 |
+
if len(measure_points) == 2:
|
| 356 |
+
try:
|
| 357 |
+
point1, point2 = measure_points
|
| 358 |
+
# Draw line
|
| 359 |
+
if (
|
| 360 |
+
0 <= point1[0] < image.shape[1]
|
| 361 |
+
and 0 <= point1[1] < image.shape[0]
|
| 362 |
+
and 0 <= point2[0] < image.shape[1]
|
| 363 |
+
and 0 <= point2[1] < image.shape[0]
|
| 364 |
+
):
|
| 365 |
+
image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
|
| 366 |
+
|
| 367 |
+
# Compute 3D distance using depth information and camera intrinsics
|
| 368 |
+
distance_text = "- **Distance: Unable to calculate 3D distance**"
|
| 369 |
+
if (
|
| 370 |
+
current_view["depth"] is not None
|
| 371 |
+
and 0 <= point1[1] < current_view["depth"].shape[0]
|
| 372 |
+
and 0 <= point1[0] < current_view["depth"].shape[1]
|
| 373 |
+
and 0 <= point2[1] < current_view["depth"].shape[0]
|
| 374 |
+
and 0 <= point2[0] < current_view["depth"].shape[1]
|
| 375 |
+
):
|
| 376 |
+
try:
|
| 377 |
+
# Get depth values at the two points
|
| 378 |
+
d1 = current_view["depth"][point1[1], point1[0]]
|
| 379 |
+
d2 = current_view["depth"][point2[1], point2[0]]
|
| 380 |
+
|
| 381 |
+
# Convert 2D pixel coordinates to 3D world coordinates
|
| 382 |
+
if current_view["intrinsics"] is not None:
|
| 383 |
+
# Get camera intrinsics
|
| 384 |
+
K = current_view["intrinsics"] # 3x3 intrinsic matrix
|
| 385 |
+
fx, fy = K[0, 0], K[1, 1] # focal lengths
|
| 386 |
+
cx, cy = K[0, 2], K[1, 2] # principal point
|
| 387 |
+
|
| 388 |
+
# Convert pixel coordinates to normalized camera coordinates
|
| 389 |
+
# Point 1: (u1, v1) -> (x1, y1, z1)
|
| 390 |
+
u1, v1 = point1[0], point1[1]
|
| 391 |
+
x1 = (u1 - cx) * d1 / fx
|
| 392 |
+
y1 = (v1 - cy) * d1 / fy
|
| 393 |
+
z1 = d1
|
| 394 |
+
|
| 395 |
+
# Point 2: (u2, v2) -> (x2, y2, z2)
|
| 396 |
+
u2, v2 = point2[0], point2[1]
|
| 397 |
+
x2 = (u2 - cx) * d2 / fx
|
| 398 |
+
y2 = (v2 - cy) * d2 / fy
|
| 399 |
+
z2 = d2
|
| 400 |
+
|
| 401 |
+
# Calculate 3D Euclidean distance
|
| 402 |
+
p1_3d = np.array([x1, y1, z1])
|
| 403 |
+
p2_3d = np.array([x2, y2, z2])
|
| 404 |
+
distance_3d = np.linalg.norm(p1_3d - p2_3d)
|
| 405 |
+
|
| 406 |
+
distance_text = f"- **Distance: {distance_3d:.2f}m**"
|
| 407 |
+
else:
|
| 408 |
+
# Fallback to simplified calculation if no intrinsics
|
| 409 |
+
pixel_distance = np.sqrt(
|
| 410 |
+
(point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
|
| 411 |
+
)
|
| 412 |
+
avg_depth = (d1 + d2) / 2
|
| 413 |
+
scale_factor = avg_depth / 1000 # Rough scaling factor
|
| 414 |
+
estimated_3d_distance = pixel_distance * scale_factor
|
| 415 |
+
distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501
|
| 416 |
+
|
| 417 |
+
except Exception as e:
|
| 418 |
+
print(f"Distance computation error: {e}")
|
| 419 |
+
distance_text = f"- **Distance computation error: {e}**"
|
| 420 |
+
|
| 421 |
+
measure_points = []
|
| 422 |
+
text = depth_text + distance_text
|
| 423 |
+
print(f"Measurement complete: {text}")
|
| 424 |
+
return [image, depth_right_half, measure_points, text]
|
| 425 |
+
except Exception as e:
|
| 426 |
+
print(f"Final measurement error: {e}")
|
| 427 |
+
return [None, [], f"Measurement error: {e}"]
|
| 428 |
+
else:
|
| 429 |
+
print(f"Single point measurement: {depth_text}")
|
| 430 |
+
return [image, depth_right_half, measure_points, depth_text]
|
| 431 |
+
|
| 432 |
+
except Exception as e:
|
| 433 |
+
print(f"Overall measure function error: {e}")
|
| 434 |
+
return [None, [], f"Measure function error: {e}"]
|
Depth-Anything-3/src/depth_anything_3/bench/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Depth Anything 3 Benchmark Evaluation Module.
|
| 17 |
+
|
| 18 |
+
This module provides tools for evaluating DepthAnything3 model on various benchmark datasets.
|
| 19 |
+
Currently supported datasets:
|
| 20 |
+
- DTU (3D Reconstruction)
|
| 21 |
+
- DTU-64 (Pose Evaluation Only)
|
| 22 |
+
- ETH3D (3D Reconstruction)
|
| 23 |
+
- 7Scenes (3D Reconstruction)
|
| 24 |
+
- ScanNet++ (3D Reconstruction)
|
| 25 |
+
- HiRoom (3D Reconstruction)
|
| 26 |
+
|
| 27 |
+
Supported evaluation modes:
|
| 28 |
+
- pose: Camera pose estimation evaluation
|
| 29 |
+
- recon_unposed: 3D reconstruction with predicted poses
|
| 30 |
+
- recon_posed: 3D reconstruction with ground truth poses
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
from depth_anything_3.bench.registries import MV_REGISTRY, MONO_REGISTRY
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def __getattr__(name):
|
| 37 |
+
"""Lazy import to avoid circular import when running as __main__."""
|
| 38 |
+
if name == "Evaluator":
|
| 39 |
+
from depth_anything_3.bench.evaluator import Evaluator
|
| 40 |
+
return Evaluator
|
| 41 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
__all__ = ["Evaluator", "MV_REGISTRY", "MONO_REGISTRY"]
|
| 45 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/configs/eval_bench.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DepthAnything3 Benchmark Evaluation Configuration
|
| 2 |
+
#
|
| 3 |
+
# This config can be loaded and overridden via command line.
|
| 4 |
+
# Example: python -m depth_anything_3.bench.evaluator --model /path/to/model --work_dir /path/to/workspace
|
| 5 |
+
#
|
| 6 |
+
# See depth_anything_3.cfg for config utility functions.
|
| 7 |
+
|
| 8 |
+
# ==============================================================================
|
| 9 |
+
# Model Configuration
|
| 10 |
+
# ==============================================================================
|
| 11 |
+
model:
|
| 12 |
+
# Path to model checkpoint or HuggingFace model ID
|
| 13 |
+
path: depth-anything/DA3-GIANT
|
| 14 |
+
|
| 15 |
+
# ==============================================================================
|
| 16 |
+
# Workspace Configuration
|
| 17 |
+
# ==============================================================================
|
| 18 |
+
workspace:
|
| 19 |
+
# Working directory for outputs (model results, metrics, etc.)
|
| 20 |
+
work_dir: ./workspace/evaluation
|
| 21 |
+
|
| 22 |
+
# ==============================================================================
|
| 23 |
+
# Evaluation Configuration
|
| 24 |
+
# ==============================================================================
|
| 25 |
+
eval:
|
| 26 |
+
# Datasets to evaluate
|
| 27 |
+
# Options: dtu, dtu64, eth3d, 7scenes (sevenscenes), scannetpp, hiroom
|
| 28 |
+
datasets:
|
| 29 |
+
- eth3d
|
| 30 |
+
- 7scenes
|
| 31 |
+
- scannetpp
|
| 32 |
+
- hiroom
|
| 33 |
+
- dtu
|
| 34 |
+
- dtu64
|
| 35 |
+
|
| 36 |
+
# Evaluation modes
|
| 37 |
+
# Options: pose, recon_unposed, recon_posed, view_syn
|
| 38 |
+
modes:
|
| 39 |
+
- pose
|
| 40 |
+
- recon_unposed
|
| 41 |
+
- recon_posed
|
| 42 |
+
|
| 43 |
+
# Reference view selection strategy for inference
|
| 44 |
+
# Options: first, saddle_balanced, auto, mid
|
| 45 |
+
ref_view_strategy: "first"
|
| 46 |
+
|
| 47 |
+
# Specific scenes to evaluate (null = all scenes)
|
| 48 |
+
# Example: [courtyard, relief] for eth3d
|
| 49 |
+
scenes: null
|
| 50 |
+
|
| 51 |
+
# Maximum number of frames per scene (for sampling)
|
| 52 |
+
# If a scene has more frames, randomly sample to this limit.
|
| 53 |
+
# Set to -1 to disable sampling.
|
| 54 |
+
max_frames: 100
|
| 55 |
+
|
| 56 |
+
# Only run evaluation (skip inference)
|
| 57 |
+
eval_only: false
|
| 58 |
+
|
| 59 |
+
# Only print saved metrics (skip inference and evaluation)
|
| 60 |
+
print_only: false
|
| 61 |
+
|
| 62 |
+
# ==============================================================================
|
| 63 |
+
# Inference Configuration
|
| 64 |
+
# ==============================================================================
|
| 65 |
+
inference:
|
| 66 |
+
# Number of parallel workers for TSDF fusion
|
| 67 |
+
num_fusion_workers: 4
|
| 68 |
+
|
| 69 |
+
# Enable debug mode with verbose output
|
| 70 |
+
debug: false
|
| 71 |
+
|
| 72 |
+
# ==============================================================================
|
| 73 |
+
# Preset Configurations
|
| 74 |
+
# ==============================================================================
|
| 75 |
+
# These can be activated via command line: --preset full_eval
|
| 76 |
+
|
| 77 |
+
presets:
|
| 78 |
+
# Full evaluation on all 6 datasets
|
| 79 |
+
full_eval:
|
| 80 |
+
datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu, dtu64]
|
| 81 |
+
modes: [pose, recon_unposed, recon_posed]
|
| 82 |
+
|
| 83 |
+
# Pose-only evaluation
|
| 84 |
+
pose_only:
|
| 85 |
+
datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu64]
|
| 86 |
+
modes: [pose]
|
| 87 |
+
|
| 88 |
+
# Reconstruction-only evaluation (5 datasets, excluding dtu64)
|
| 89 |
+
recon_only:
|
| 90 |
+
datasets: [eth3d, 7scenes, scannetpp, hiroom, dtu]
|
| 91 |
+
modes: [recon_unposed, recon_posed]
|
| 92 |
+
|
| 93 |
+
# Quick test (single scene per dataset)
|
| 94 |
+
quick_test:
|
| 95 |
+
datasets: [eth3d]
|
| 96 |
+
modes: [pose, recon_unposed]
|
| 97 |
+
scenes: [courtyard]
|
| 98 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/dataset.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Base dataset class for benchmark evaluation.
|
| 17 |
+
|
| 18 |
+
All dataset implementations should inherit from this class and implement
|
| 19 |
+
the required abstract methods.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import time
|
| 24 |
+
from abc import abstractmethod
|
| 25 |
+
from typing import Dict as TDict
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from addict import Dict
|
| 30 |
+
|
| 31 |
+
from depth_anything_3.bench.utils import compute_pose
|
| 32 |
+
from depth_anything_3.utils.geometry import as_homogeneous
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _wait_for_file_ready(path: str, timeout: float = 3.0, interval: float = 0.2) -> None:
|
| 36 |
+
"""Wait until file size stabilizes for 2 consecutive checks."""
|
| 37 |
+
last_size = -1
|
| 38 |
+
stable_count = 0
|
| 39 |
+
start = time.time()
|
| 40 |
+
while time.time() - start < timeout:
|
| 41 |
+
time.sleep(interval)
|
| 42 |
+
size = os.path.getsize(path)
|
| 43 |
+
if size == last_size and size > 0:
|
| 44 |
+
stable_count += 1
|
| 45 |
+
if stable_count >= 2: # Need 2 consecutive stable checks
|
| 46 |
+
return
|
| 47 |
+
else:
|
| 48 |
+
stable_count = 0
|
| 49 |
+
last_size = size
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Dataset:
|
| 53 |
+
"""
|
| 54 |
+
Base class for all benchmark datasets.
|
| 55 |
+
|
| 56 |
+
Subclasses must implement:
|
| 57 |
+
- SCENES: List of scene identifiers
|
| 58 |
+
- data_root: Path to dataset root
|
| 59 |
+
- get_data(scene): Return scene data (images, intrinsics, extrinsics, etc.)
|
| 60 |
+
- eval3d(scene, fuse_path): Evaluate 3D reconstruction
|
| 61 |
+
- fuse3d(scene, result_path, fuse_path, mode): Fuse depth maps into point cloud
|
| 62 |
+
|
| 63 |
+
Optional overrides:
|
| 64 |
+
- eval_pose(scene, result_path): Evaluate pose estimation (default provided)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
# Subclasses should define these
|
| 68 |
+
SCENES: list = []
|
| 69 |
+
data_root: str = ""
|
| 70 |
+
|
| 71 |
+
def __init__(self):
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
def eval_pose(self, scene: str, result_path: str) -> TDict[str, float]:
|
| 75 |
+
"""
|
| 76 |
+
Evaluate camera pose estimation accuracy.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
scene: Scene identifier
|
| 80 |
+
result_path: Path to .npz file containing predicted extrinsics
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Dict with pose metrics (auc30, auc15, auc05, auc03)
|
| 84 |
+
"""
|
| 85 |
+
_wait_for_file_ready(result_path)
|
| 86 |
+
pred = np.load(result_path)
|
| 87 |
+
gt = self.get_data(scene)
|
| 88 |
+
return compute_pose(
|
| 89 |
+
torch.from_numpy(as_homogeneous(pred["extrinsics"])),
|
| 90 |
+
torch.from_numpy(as_homogeneous(gt["extrinsics"])),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
@abstractmethod
|
| 94 |
+
def get_data(self, scene: str) -> Dict:
|
| 95 |
+
"""
|
| 96 |
+
Get scene data including images, camera parameters, and auxiliary info.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
scene: Scene identifier
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Dict with:
|
| 103 |
+
- image_files: List[str] - paths to images
|
| 104 |
+
- extrinsics: np.ndarray [N, 4, 4] - camera extrinsics (world-to-camera)
|
| 105 |
+
- intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
|
| 106 |
+
- aux: Dict - auxiliary data (masks, GT paths, etc.)
|
| 107 |
+
"""
|
| 108 |
+
raise NotImplementedError
|
| 109 |
+
|
| 110 |
+
@abstractmethod
|
| 111 |
+
def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
|
| 112 |
+
"""
|
| 113 |
+
Evaluate 3D reconstruction quality against ground truth.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
scene: Scene identifier
|
| 117 |
+
fuse_path: Path to fused point cloud (.ply)
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Dict with reconstruction metrics (e.g., acc, comp, overall)
|
| 121 |
+
"""
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
|
| 124 |
+
@abstractmethod
|
| 125 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 126 |
+
"""
|
| 127 |
+
Fuse per-view depth maps into a single point cloud.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
scene: Scene identifier
|
| 131 |
+
result_path: Path to .npz file with predicted depths and poses
|
| 132 |
+
fuse_path: Output path for fused point cloud (.ply)
|
| 133 |
+
mode: Fusion mode ("recon_unposed" or "recon_posed")
|
| 134 |
+
"""
|
| 135 |
+
raise NotImplementedError
|
| 136 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Benchmark dataset implementations.
|
| 17 |
+
|
| 18 |
+
Datasets are auto-registered via decorators when imported.
|
| 19 |
+
Add new dataset files here and they will be automatically discovered.
|
| 20 |
+
"""
|
| 21 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
DTU Benchmark dataset implementation.
|
| 17 |
+
|
| 18 |
+
DTU is a multi-view stereo benchmark for 3D reconstruction evaluation.
|
| 19 |
+
Reference: https://roboimagedata.compute.dtu.dk/
|
| 20 |
+
|
| 21 |
+
Note: DepthAnything3 was never trained on any images from DTU.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import glob
|
| 25 |
+
import os
|
| 26 |
+
from typing import Dict as TDict, List
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import open3d as o3d
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
from addict import Dict
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from plyfile import PlyData
|
| 35 |
+
from scipy.io import loadmat
|
| 36 |
+
from sklearn import neighbors as skln
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
from depth_anything_3.bench.dataset import Dataset
|
| 40 |
+
from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
|
| 41 |
+
from depth_anything_3.utils.constants import (
|
| 42 |
+
DTU_DIST_THRESH,
|
| 43 |
+
DTU_EVAL_DATA_ROOT,
|
| 44 |
+
DTU_MAX_POINTS,
|
| 45 |
+
DTU_NUM_CONSIST,
|
| 46 |
+
DTU_SCENES,
|
| 47 |
+
)
|
| 48 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@MV_REGISTRY.register(name="dtu")
|
| 52 |
+
@MONO_REGISTRY.register(name="dtu")
|
| 53 |
+
class DTU(Dataset):
|
| 54 |
+
"""
|
| 55 |
+
DTU Benchmark dataset wrapper for DepthAnything3 evaluation.
|
| 56 |
+
|
| 57 |
+
Supports:
|
| 58 |
+
- Camera pose estimation evaluation (AUC metrics)
|
| 59 |
+
- 3D reconstruction evaluation (accuracy, completeness, overall)
|
| 60 |
+
- Point cloud fusion from depth maps
|
| 61 |
+
|
| 62 |
+
The dataset uses MVSNet evaluation protocol:
|
| 63 |
+
https://drive.google.com/file/d/1rX0EXlUL4prRxrRu2DgLJv2j7-tpUD4D/view
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
data_root = DTU_EVAL_DATA_ROOT
|
| 67 |
+
SCENES = DTU_SCENES
|
| 68 |
+
|
| 69 |
+
# Evaluation/triangulation hyperparameters from constants
|
| 70 |
+
dist_thresh = DTU_DIST_THRESH
|
| 71 |
+
num_consist = DTU_NUM_CONSIST
|
| 72 |
+
|
| 73 |
+
# ------------------------------
|
| 74 |
+
# Public API
|
| 75 |
+
# ------------------------------
|
| 76 |
+
|
| 77 |
+
def read_cam_file(self, filename: str) -> tuple:
|
| 78 |
+
"""
|
| 79 |
+
Read DTU camera file containing extrinsics and intrinsics.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
filename: Path to camera text file
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Tuple of (intrinsics [3,3], extrinsics [4,4])
|
| 86 |
+
"""
|
| 87 |
+
with open(filename) as f:
|
| 88 |
+
lines = [line.rstrip() for line in f.readlines()]
|
| 89 |
+
extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4))
|
| 90 |
+
intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3))
|
| 91 |
+
return intrinsics, extrinsics
|
| 92 |
+
|
| 93 |
+
def get_data(self, scene: str) -> Dict:
|
| 94 |
+
"""
|
| 95 |
+
Collect per-view image paths, intrinsics/extrinsics, and GT masks.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
scene: Scene identifier (e.g., "scan1")
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Dict with:
|
| 102 |
+
- image_files: List[str] - paths to images
|
| 103 |
+
- extrinsics: np.ndarray [N, 4, 4]
|
| 104 |
+
- intrinsics: np.ndarray [N, 3, 3]
|
| 105 |
+
- aux.mask_files: List[str] - paths to depth masks
|
| 106 |
+
"""
|
| 107 |
+
rgb_folder = os.path.join(self.data_root, "Rectified", scene)
|
| 108 |
+
camera_folder = os.path.join(self.data_root, "Cameras")
|
| 109 |
+
|
| 110 |
+
files = sorted(glob.glob(os.path.join(rgb_folder, "*.png")))
|
| 111 |
+
# Reorder: place index 33 first (reference view convention)
|
| 112 |
+
files = [files[33]] + files[:33] + files[34:]
|
| 113 |
+
|
| 114 |
+
out = Dict(
|
| 115 |
+
{
|
| 116 |
+
"image_files": files,
|
| 117 |
+
"extrinsics": [],
|
| 118 |
+
"intrinsics": [],
|
| 119 |
+
"aux": Dict({"mask_files": []}),
|
| 120 |
+
}
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
for rgb_file in files:
|
| 124 |
+
basename = os.path.basename(rgb_file)
|
| 125 |
+
file_idx = basename.split("_")[1]
|
| 126 |
+
cam_idx = depth_idx = int(file_idx) - 1
|
| 127 |
+
|
| 128 |
+
mask_file = self._depth_mask_path(scene, depth_idx)
|
| 129 |
+
proj_mat_filename = os.path.join(camera_folder, f"{cam_idx:0>8}_cam.txt")
|
| 130 |
+
|
| 131 |
+
ixt, ext = self.read_cam_file(proj_mat_filename)
|
| 132 |
+
out.extrinsics.append(ext)
|
| 133 |
+
out.intrinsics.append(ixt)
|
| 134 |
+
out.aux.mask_files.append(mask_file)
|
| 135 |
+
|
| 136 |
+
out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
|
| 137 |
+
out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
def get_3dgtpath(self, scene: str) -> str:
|
| 141 |
+
"""Get path to ground truth point cloud for a scene."""
|
| 142 |
+
scene_id = int(scene[4:])
|
| 143 |
+
return os.path.join(self.data_root, f"Points/stl/stl{scene_id:03}_total.ply")
|
| 144 |
+
|
| 145 |
+
def eval3d(self, scene: str, fuse_path: str, use_gpu: bool = False) -> TDict[str, float]:
|
| 146 |
+
"""
|
| 147 |
+
Evaluate fused point cloud against DTU GT with ObsMask/Plane.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
scene: Scene identifier
|
| 151 |
+
fuse_path: Path to fused point cloud
|
| 152 |
+
use_gpu: If True, use GPU-accelerated distance computation (faster but may have minor numerical differences)
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Dict with metrics: {"comp": float, "acc": float, "overall": float}
|
| 156 |
+
"""
|
| 157 |
+
scene_id = int(scene[4:])
|
| 158 |
+
gt_ply = os.path.join(self.data_root, f"Points/stl/stl{scene_id:03}_total.ply")
|
| 159 |
+
mask_file = os.path.join(
|
| 160 |
+
self.data_root, f"SampleSet/mvs_data/ObsMask/ObsMask{scene_id}_10.mat"
|
| 161 |
+
)
|
| 162 |
+
plane_file = os.path.join(
|
| 163 |
+
self.data_root, f"SampleSet/mvs_data/ObsMask/Plane{scene_id}.mat"
|
| 164 |
+
)
|
| 165 |
+
result = self._evaluate_reconstruction(
|
| 166 |
+
scene, fuse_path, gt_ply, mask_file, plane_file, use_gpu=use_gpu
|
| 167 |
+
)
|
| 168 |
+
return {"comp": result[0], "acc": result[1], "overall": result[2]}
|
| 169 |
+
|
| 170 |
+
def load_masks(self, mask_files: List[str]) -> np.ndarray:
|
| 171 |
+
"""
|
| 172 |
+
Load DTU depth validity masks.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
mask_files: List of paths to mask images
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Boolean array [N, H, W] indicating valid depth regions
|
| 179 |
+
"""
|
| 180 |
+
masks = []
|
| 181 |
+
for mask_file in mask_files:
|
| 182 |
+
mask = Image.open(mask_file)
|
| 183 |
+
mask = np.array(mask, dtype=np.float32)
|
| 184 |
+
masks.append(mask > 10)
|
| 185 |
+
return np.asarray(masks)
|
| 186 |
+
|
| 187 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 188 |
+
"""
|
| 189 |
+
Fuse per-view depths into a point cloud and save to PLY.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
scene: Scene identifier (e.g., "scan114")
|
| 193 |
+
result_path: Path to npz file containing predicted depths/poses
|
| 194 |
+
fuse_path: Output path for fused point cloud (.ply)
|
| 195 |
+
mode: "recon_unposed" or "recon_posed"
|
| 196 |
+
"""
|
| 197 |
+
gt_data = self.get_data(scene)
|
| 198 |
+
pred_data = Dict({k: v for k, v in np.load(result_path).items()})
|
| 199 |
+
masks = self.load_masks(gt_data.aux.mask_files)
|
| 200 |
+
|
| 201 |
+
if mode == "recon_unposed":
|
| 202 |
+
depths, intrinsics, extrinsics = self._prep_unposed(pred_data, gt_data, masks)
|
| 203 |
+
elif mode == "recon_posed":
|
| 204 |
+
depths, intrinsics, extrinsics = self._prep_posed(pred_data, gt_data, masks)
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 207 |
+
|
| 208 |
+
proj_mat = self._build_proj_mats(intrinsics, extrinsics)
|
| 209 |
+
|
| 210 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 211 |
+
dtype = torch.float32
|
| 212 |
+
depths_t = torch.from_numpy(depths).to(device=device, dtype=dtype).unsqueeze(1)
|
| 213 |
+
proj_t = torch.from_numpy(proj_mat).to(device=device, dtype=dtype)
|
| 214 |
+
height, width = depths_t.shape[-2:]
|
| 215 |
+
|
| 216 |
+
points: List[np.ndarray] = []
|
| 217 |
+
for idx in range(len(gt_data.image_files)):
|
| 218 |
+
if mode == "recon_unposed":
|
| 219 |
+
# Simple unfiltered back-projection per frame
|
| 220 |
+
cur_p_pcd = self._generate_points_from_depth(
|
| 221 |
+
depths_t[idx : idx + 1], proj_t[idx : idx + 1]
|
| 222 |
+
)
|
| 223 |
+
mask = (depths_t[idx : idx + 1] > 0.001).squeeze()
|
| 224 |
+
cur_p_pcd = cur_p_pcd[:, :, mask]
|
| 225 |
+
no_filter_pc = cur_p_pcd.squeeze(0).permute(1, 0).cpu().numpy()
|
| 226 |
+
points.append(no_filter_pc)
|
| 227 |
+
else: # recon_posed
|
| 228 |
+
final_pc = self._fuse_consistent_points(depths_t, proj_t, idx, height, width)
|
| 229 |
+
points.append(final_pc)
|
| 230 |
+
|
| 231 |
+
# Concatenate and optionally downsample to hard cap
|
| 232 |
+
points_np = np.concatenate(points, axis=0)
|
| 233 |
+
points_np = self._cap_points(points_np, max_points=DTU_MAX_POINTS)
|
| 234 |
+
|
| 235 |
+
os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
|
| 236 |
+
pcd = o3d.geometry.PointCloud()
|
| 237 |
+
pcd.points = o3d.utility.Vector3dVector(points_np)
|
| 238 |
+
o3d.io.write_point_cloud(fuse_path, pcd)
|
| 239 |
+
|
| 240 |
+
# ------------------------------
|
| 241 |
+
# Geometry helpers
|
| 242 |
+
# ------------------------------
|
| 243 |
+
|
| 244 |
+
def _generate_points_from_depth(
|
| 245 |
+
self, depth: torch.Tensor, proj: torch.Tensor
|
| 246 |
+
) -> torch.Tensor:
|
| 247 |
+
"""
|
| 248 |
+
Back-project depth map into 3D world coordinates.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
depth: Depth tensor [B, 1, H, W]
|
| 252 |
+
proj: Projection matrix [B, 4, 4] = [[K@R, K@t], [0,0,0,1]]
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Point cloud tensor [B, 3, H, W]
|
| 256 |
+
"""
|
| 257 |
+
batch, height, width = depth.shape[0], depth.shape[2], depth.shape[3]
|
| 258 |
+
inv_proj = torch.inverse(proj)
|
| 259 |
+
rot = inv_proj[:, :3, :3]
|
| 260 |
+
trans = inv_proj[:, :3, 3:4]
|
| 261 |
+
|
| 262 |
+
y, x = torch.meshgrid(
|
| 263 |
+
[
|
| 264 |
+
torch.arange(0, height, dtype=torch.float32, device=depth.device),
|
| 265 |
+
torch.arange(0, width, dtype=torch.float32, device=depth.device),
|
| 266 |
+
],
|
| 267 |
+
indexing="ij",
|
| 268 |
+
)
|
| 269 |
+
y, x = y.contiguous(), x.contiguous()
|
| 270 |
+
y, x = y.view(height * width), x.view(height * width)
|
| 271 |
+
xyz = torch.stack((x, y, torch.ones_like(x)))
|
| 272 |
+
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)
|
| 273 |
+
rot_xyz = torch.matmul(rot, xyz)
|
| 274 |
+
rot_depth_xyz = rot_xyz * depth.view(batch, 1, -1)
|
| 275 |
+
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1)
|
| 276 |
+
return proj_xyz.view(batch, 3, height, width)
|
| 277 |
+
|
| 278 |
+
def _homo_warping(
|
| 279 |
+
self,
|
| 280 |
+
src_fea: torch.Tensor,
|
| 281 |
+
src_proj: torch.Tensor,
|
| 282 |
+
ref_proj: torch.Tensor,
|
| 283 |
+
depth_values: torch.Tensor,
|
| 284 |
+
) -> torch.Tensor:
|
| 285 |
+
"""
|
| 286 |
+
Homography warping for multi-view consistency checking.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
src_fea: Source features [B, C, H, W]
|
| 290 |
+
src_proj: Source projection [B, 4, 4]
|
| 291 |
+
ref_proj: Reference projection [B, 4, 4]
|
| 292 |
+
depth_values: Depth values [B, Ndepth] or [B, Ndepth, H, W]
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Warped features [B, C, H, W]
|
| 296 |
+
"""
|
| 297 |
+
batch, channels = src_fea.shape[0], src_fea.shape[1]
|
| 298 |
+
height, width = src_fea.shape[2], src_fea.shape[3]
|
| 299 |
+
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
|
| 302 |
+
rot = proj[:, :3, :3]
|
| 303 |
+
trans = proj[:, :3, 3:4]
|
| 304 |
+
|
| 305 |
+
y, x = torch.meshgrid(
|
| 306 |
+
[
|
| 307 |
+
torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
|
| 308 |
+
torch.arange(0, width, dtype=torch.float32, device=src_fea.device),
|
| 309 |
+
],
|
| 310 |
+
indexing="ij",
|
| 311 |
+
)
|
| 312 |
+
y, x = y.contiguous(), x.contiguous()
|
| 313 |
+
y, x = y.view(height * width), x.view(height * width)
|
| 314 |
+
xyz = torch.stack((x, y, torch.ones_like(x)))
|
| 315 |
+
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)
|
| 316 |
+
rot_xyz = torch.matmul(rot, xyz)
|
| 317 |
+
|
| 318 |
+
rot_depth_xyz = rot_xyz.unsqueeze(2) * depth_values.view(-1, 1, 1, height * width)
|
| 319 |
+
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)
|
| 320 |
+
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]
|
| 321 |
+
proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
|
| 322 |
+
proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
|
| 323 |
+
grid = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)
|
| 324 |
+
|
| 325 |
+
warped_src_fea = F.grid_sample(
|
| 326 |
+
src_fea,
|
| 327 |
+
grid.view(batch, height, width, 2),
|
| 328 |
+
mode="bilinear",
|
| 329 |
+
padding_mode="zeros",
|
| 330 |
+
align_corners=True,
|
| 331 |
+
)
|
| 332 |
+
return warped_src_fea.view(batch, channels, height, width)
|
| 333 |
+
|
| 334 |
+
def _filter_depth(
|
| 335 |
+
self,
|
| 336 |
+
ref_depth: torch.Tensor,
|
| 337 |
+
src_depths: torch.Tensor,
|
| 338 |
+
ref_proj: torch.Tensor,
|
| 339 |
+
src_projs: torch.Tensor,
|
| 340 |
+
) -> tuple:
|
| 341 |
+
"""
|
| 342 |
+
Compute geometric consistency between reference and source depths.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
ref_depth: Reference depth [1, 1, H, W]
|
| 346 |
+
src_depths: Source depths [B, 1, H, W]
|
| 347 |
+
ref_proj: Reference projection [1, 4, 4]
|
| 348 |
+
src_projs: Source projections [B, 4, 4]
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Tuple of (ref_pc, aligned_pcs, dist)
|
| 352 |
+
"""
|
| 353 |
+
ref_pc = self._generate_points_from_depth(ref_depth, ref_proj)
|
| 354 |
+
src_pcs = self._generate_points_from_depth(src_depths, src_projs)
|
| 355 |
+
aligned_pcs = self._homo_warping(src_pcs, src_projs, ref_proj, ref_depth)
|
| 356 |
+
x_2 = (ref_pc[:, 0] - aligned_pcs[:, 0]) ** 2
|
| 357 |
+
y_2 = (ref_pc[:, 1] - aligned_pcs[:, 1]) ** 2
|
| 358 |
+
z_2 = (ref_pc[:, 2] - aligned_pcs[:, 2]) ** 2
|
| 359 |
+
dist = torch.sqrt(x_2 + y_2 + z_2).unsqueeze(1)
|
| 360 |
+
return ref_pc, aligned_pcs, dist
|
| 361 |
+
|
| 362 |
+
def _extract_points(
|
| 363 |
+
self, pc: torch.Tensor, mask: torch.Tensor, rgb: np.ndarray = None
|
| 364 |
+
) -> np.ndarray:
|
| 365 |
+
"""Extract masked points from a dense grid."""
|
| 366 |
+
pc = pc.cpu().numpy()
|
| 367 |
+
mask = mask.cpu().numpy().reshape(-1)
|
| 368 |
+
pc = pc.reshape(-1, 3)
|
| 369 |
+
points = pc[np.where(mask)]
|
| 370 |
+
if rgb is not None:
|
| 371 |
+
rgb = rgb.reshape(-1, 3)
|
| 372 |
+
colors = rgb[np.where(mask)]
|
| 373 |
+
return np.concatenate([points, colors], axis=1)
|
| 374 |
+
return points
|
| 375 |
+
|
| 376 |
+
# ------------------------------
|
| 377 |
+
# 3D Reconstruction Evaluation
|
| 378 |
+
# ------------------------------
|
| 379 |
+
|
| 380 |
+
def _evaluate_reconstruction(
|
| 381 |
+
self,
|
| 382 |
+
scanid: str,
|
| 383 |
+
pred_ply: str,
|
| 384 |
+
gt_ply: str,
|
| 385 |
+
mask_file: str,
|
| 386 |
+
plane_file: str,
|
| 387 |
+
down_dense: float = 0.2,
|
| 388 |
+
patch: int = 60,
|
| 389 |
+
max_dist: int = 20,
|
| 390 |
+
use_gpu: bool = False,
|
| 391 |
+
) -> tuple:
|
| 392 |
+
"""
|
| 393 |
+
Compute accuracy, completeness, and overall metrics for one scan.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
scanid: Scan identifier
|
| 397 |
+
pred_ply: Predicted point cloud path or array
|
| 398 |
+
gt_ply: Ground truth point cloud path or array
|
| 399 |
+
mask_file: ObsMask file path
|
| 400 |
+
plane_file: Plane file path
|
| 401 |
+
down_dense: Downsample density (min distance between points)
|
| 402 |
+
patch: Patch size for boundary
|
| 403 |
+
max_dist: Outlier threshold in mm
|
| 404 |
+
use_gpu: If True, use GPU-accelerated distance computation
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
Tuple of (mean_d2s, mean_s2d, overall)
|
| 408 |
+
"""
|
| 409 |
+
thresh = down_dense
|
| 410 |
+
|
| 411 |
+
# Load and downsample predicted point cloud
|
| 412 |
+
data_pcd = self._read_ply(pred_ply) if isinstance(pred_ply, str) else pred_ply
|
| 413 |
+
# Use fixed seed for reproducibility
|
| 414 |
+
shuffle_rng = np.random.default_rng(seed=42)
|
| 415 |
+
shuffle_rng.shuffle(data_pcd, axis=0)
|
| 416 |
+
|
| 417 |
+
# Downsample point cloud
|
| 418 |
+
nn_engine = skln.NearestNeighbors(
|
| 419 |
+
n_neighbors=1, radius=thresh, algorithm="kd_tree", n_jobs=-1
|
| 420 |
+
)
|
| 421 |
+
nn_engine.fit(data_pcd)
|
| 422 |
+
rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False)
|
| 423 |
+
mask = np.ones(data_pcd.shape[0], dtype=np.bool_)
|
| 424 |
+
for curr, idxs in enumerate(rnn_idxs):
|
| 425 |
+
if mask[curr]:
|
| 426 |
+
mask[idxs] = 0
|
| 427 |
+
mask[curr] = 1
|
| 428 |
+
data_down = data_pcd[mask]
|
| 429 |
+
|
| 430 |
+
# Restrict to observed volume (ObsMask)
|
| 431 |
+
obs_mask_file = loadmat(mask_file)
|
| 432 |
+
ObsMask, BB, Res = (obs_mask_file[attr] for attr in ["ObsMask", "BB", "Res"])
|
| 433 |
+
BB = BB.astype(np.float32)
|
| 434 |
+
|
| 435 |
+
inbound = ((data_down >= BB[:1] - patch) & (data_down < BB[1:] + patch * 2)).sum(
|
| 436 |
+
axis=-1
|
| 437 |
+
) == 3
|
| 438 |
+
data_in = data_down[inbound]
|
| 439 |
+
|
| 440 |
+
data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32)
|
| 441 |
+
grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(
|
| 442 |
+
axis=-1
|
| 443 |
+
) == 3
|
| 444 |
+
data_grid_in = data_grid[grid_inbound]
|
| 445 |
+
in_obs = ObsMask[data_grid_in[:, 0], data_grid_in[:, 1], data_grid_in[:, 2]].astype(
|
| 446 |
+
np.bool_
|
| 447 |
+
)
|
| 448 |
+
data_in_obs = data_in[grid_inbound][in_obs]
|
| 449 |
+
|
| 450 |
+
# Compute accuracy (pred -> GT) and completeness (GT -> pred)
|
| 451 |
+
stl = self._read_ply(gt_ply) if isinstance(gt_ply, str) else gt_ply
|
| 452 |
+
|
| 453 |
+
if use_gpu and torch.cuda.is_available():
|
| 454 |
+
# GPU-accelerated distance computation
|
| 455 |
+
mean_d2s = self._knn_dist_gpu(data_in_obs, stl, max_dist)
|
| 456 |
+
else:
|
| 457 |
+
# CPU version (original, for exact reproduction)
|
| 458 |
+
nn_engine.fit(stl)
|
| 459 |
+
dist_d2s, _ = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True)
|
| 460 |
+
mean_d2s = dist_d2s[dist_d2s < max_dist].mean()
|
| 461 |
+
|
| 462 |
+
ground_plane = loadmat(plane_file)["P"]
|
| 463 |
+
stl_hom = np.concatenate([stl, np.ones_like(stl[:, :1])], -1)
|
| 464 |
+
above = (ground_plane.reshape((1, 4)) * stl_hom).sum(-1) > 0
|
| 465 |
+
stl_above = stl[above]
|
| 466 |
+
|
| 467 |
+
if use_gpu and torch.cuda.is_available():
|
| 468 |
+
# GPU-accelerated distance computation
|
| 469 |
+
mean_s2d = self._knn_dist_gpu(stl_above, data_in, max_dist)
|
| 470 |
+
else:
|
| 471 |
+
# CPU version (original, for exact reproduction)
|
| 472 |
+
nn_engine.fit(data_in)
|
| 473 |
+
dist_s2d, _ = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True)
|
| 474 |
+
mean_s2d = dist_s2d[dist_s2d < max_dist].mean()
|
| 475 |
+
|
| 476 |
+
overall = (mean_d2s + mean_s2d) / 2
|
| 477 |
+
return mean_d2s, mean_s2d, overall
|
| 478 |
+
|
| 479 |
+
def _knn_dist_gpu(
|
| 480 |
+
self,
|
| 481 |
+
query: np.ndarray,
|
| 482 |
+
target: np.ndarray,
|
| 483 |
+
max_dist: float,
|
| 484 |
+
batch_size: int = 8192,
|
| 485 |
+
target_batch_size: int = 50000,
|
| 486 |
+
) -> float:
|
| 487 |
+
"""
|
| 488 |
+
GPU-accelerated nearest neighbor distance computation.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
query: Query points [N, 3]
|
| 492 |
+
target: Target points [M, 3]
|
| 493 |
+
max_dist: Outlier threshold
|
| 494 |
+
batch_size: Batch size for query to avoid OOM (tuned for 16GB GPU)
|
| 495 |
+
target_batch_size: Batch size for target to avoid OOM
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
Mean distance (excluding outliers)
|
| 499 |
+
"""
|
| 500 |
+
device = torch.device("cuda")
|
| 501 |
+
|
| 502 |
+
all_min_dists = []
|
| 503 |
+
n_query_batches = (len(query) + batch_size - 1) // batch_size
|
| 504 |
+
n_target_batches = (len(target) + target_batch_size - 1) // target_batch_size
|
| 505 |
+
|
| 506 |
+
# Pre-load target batches to GPU to avoid repeated transfers
|
| 507 |
+
# Memory: ~50000 pts * 3 coords * 4 bytes * n_batches
|
| 508 |
+
target_batches = []
|
| 509 |
+
for j in range(0, len(target), target_batch_size):
|
| 510 |
+
target_batch = target[j : j + target_batch_size]
|
| 511 |
+
target_t = torch.from_numpy(target_batch).float().to(device)
|
| 512 |
+
target_batches.append(target_t)
|
| 513 |
+
|
| 514 |
+
with tqdm(total=n_query_batches, desc=" GPU KNN", leave=False, ncols=100) as pbar:
|
| 515 |
+
for i in range(0, len(query), batch_size):
|
| 516 |
+
batch = query[i : i + batch_size]
|
| 517 |
+
query_t = torch.from_numpy(batch).float().to(device)
|
| 518 |
+
|
| 519 |
+
# Compute distances to all target batches
|
| 520 |
+
# Memory peak: query_batch × target_batch_size × 4 bytes
|
| 521 |
+
# = 8192 × 50000 × 4 = ~1.6 GB per cdist call
|
| 522 |
+
batch_min_dists = []
|
| 523 |
+
for target_t in target_batches:
|
| 524 |
+
dists = torch.cdist(query_t, target_t)
|
| 525 |
+
batch_min_dists.append(dists.min(dim=1).values)
|
| 526 |
+
del dists # Free immediately
|
| 527 |
+
|
| 528 |
+
# Get minimum distance across all target batches
|
| 529 |
+
min_dists = torch.stack(batch_min_dists, dim=1).min(dim=1).values
|
| 530 |
+
all_min_dists.append(min_dists.cpu().numpy())
|
| 531 |
+
|
| 532 |
+
del query_t, min_dists, batch_min_dists
|
| 533 |
+
pbar.update(1)
|
| 534 |
+
|
| 535 |
+
# Clean up target batches
|
| 536 |
+
for target_t in target_batches:
|
| 537 |
+
del target_t
|
| 538 |
+
torch.cuda.empty_cache()
|
| 539 |
+
|
| 540 |
+
all_min_dists = np.concatenate(all_min_dists)
|
| 541 |
+
return all_min_dists[all_min_dists < max_dist].mean()
|
| 542 |
+
|
| 543 |
+
def _read_ply(self, file: str) -> np.ndarray:
|
| 544 |
+
"""Read point cloud from PLY file."""
|
| 545 |
+
data = PlyData.read(file)
|
| 546 |
+
vertex = data["vertex"]
|
| 547 |
+
return np.stack([vertex["x"], vertex["y"], vertex["z"]], axis=-1)
|
| 548 |
+
|
| 549 |
+
# ------------------------------
|
| 550 |
+
# Private helpers
|
| 551 |
+
# ------------------------------
|
| 552 |
+
|
| 553 |
+
def _depth_mask_path(self, scene: str, depth_idx: int) -> str:
|
| 554 |
+
"""Get path to depth mask for a scene and frame."""
|
| 555 |
+
return os.path.join(
|
| 556 |
+
self.data_root, "depth_raw", "Depths", scene, f"depth_visual_{depth_idx:04d}.png"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def _prep_unposed(
|
| 560 |
+
self, pred_data: Dict, gt_data: Dict, masks: np.ndarray
|
| 561 |
+
) -> tuple:
|
| 562 |
+
"""
|
| 563 |
+
Prepare depths/intrinsics/extrinsics for recon_unposed mode.
|
| 564 |
+
|
| 565 |
+
Applies Umeyama scale, rescales intrinsics if depth resolution differs,
|
| 566 |
+
and zeroes invalid-mask depths (nearest interpolation as in paper).
|
| 567 |
+
"""
|
| 568 |
+
_, _, scale, extrinsics = align_poses_umeyama(
|
| 569 |
+
gt_data.extrinsics.copy(),
|
| 570 |
+
pred_data.extrinsics.copy(),
|
| 571 |
+
ransac=True,
|
| 572 |
+
return_aligned=True,
|
| 573 |
+
random_state=42,
|
| 574 |
+
)
|
| 575 |
+
depths = pred_data.depth * scale
|
| 576 |
+
intrinsics = pred_data.intrinsics.copy()
|
| 577 |
+
|
| 578 |
+
if depths.shape[-2:] != masks.shape[-2:]:
|
| 579 |
+
# When resizing depths to mask size, adjust intrinsics accordingly
|
| 580 |
+
sx = masks.shape[-1] / depths.shape[-1]
|
| 581 |
+
sy = masks.shape[-2] / depths.shape[-2]
|
| 582 |
+
intrinsics[:, 0:1] *= sx
|
| 583 |
+
intrinsics[:, 1:2] *= sy
|
| 584 |
+
depths = F.interpolate(
|
| 585 |
+
torch.from_numpy(depths)[None].float(),
|
| 586 |
+
size=(masks.shape[-2], masks.shape[-1]),
|
| 587 |
+
mode="nearest",
|
| 588 |
+
)[0].numpy()
|
| 589 |
+
depths[masks == False] = 0.0 # noqa: E712
|
| 590 |
+
|
| 591 |
+
return depths, intrinsics, extrinsics
|
| 592 |
+
|
| 593 |
+
def _prep_posed(
|
| 594 |
+
self, pred_data: Dict, gt_data: Dict, masks: np.ndarray
|
| 595 |
+
) -> tuple:
|
| 596 |
+
"""
|
| 597 |
+
Prepare depths/intrinsics/extrinsics for recon_posed mode.
|
| 598 |
+
|
| 599 |
+
Uses GT intrinsics/extrinsics but aligns scale via Umeyama.
|
| 600 |
+
Same mask order as other datasets: mask BEFORE scale.
|
| 601 |
+
"""
|
| 602 |
+
_, _, scale, _ = align_poses_umeyama(
|
| 603 |
+
gt_data.extrinsics.copy(),
|
| 604 |
+
pred_data.extrinsics.copy(),
|
| 605 |
+
ransac=True,
|
| 606 |
+
return_aligned=True,
|
| 607 |
+
random_state=42,
|
| 608 |
+
)
|
| 609 |
+
depths = pred_data.depth.copy()
|
| 610 |
+
intrinsics = gt_data.intrinsics.copy()
|
| 611 |
+
extrinsics = gt_data.extrinsics.copy()
|
| 612 |
+
|
| 613 |
+
if depths.shape[-2:] != masks.shape[-2:]:
|
| 614 |
+
depths = F.interpolate(
|
| 615 |
+
torch.from_numpy(depths)[None].float(),
|
| 616 |
+
size=(masks.shape[-2], masks.shape[-1]),
|
| 617 |
+
mode="nearest",
|
| 618 |
+
)[0].numpy()
|
| 619 |
+
|
| 620 |
+
# Mask BEFORE scale (same as other datasets)
|
| 621 |
+
depths[masks == False] = 0.0 # noqa: E712
|
| 622 |
+
depths = depths * scale
|
| 623 |
+
|
| 624 |
+
return depths, intrinsics, extrinsics
|
| 625 |
+
|
| 626 |
+
def _build_proj_mats(
|
| 627 |
+
self, intrinsics: np.ndarray, extrinsics: np.ndarray
|
| 628 |
+
) -> np.ndarray:
|
| 629 |
+
"""Compute per-view 4x4 projection matrices from K and [R|t]."""
|
| 630 |
+
proj_mat_list = []
|
| 631 |
+
for i in range(len(intrinsics)):
|
| 632 |
+
proj_mat = np.eye(4, dtype=np.float32)
|
| 633 |
+
proj_mat[:3, :4] = np.dot(intrinsics[i], extrinsics[i][:3])
|
| 634 |
+
proj_mat_list.append(proj_mat)
|
| 635 |
+
return np.stack(proj_mat_list, axis=0)
|
| 636 |
+
|
| 637 |
+
def _fuse_consistent_points(
|
| 638 |
+
self,
|
| 639 |
+
depths_t: torch.Tensor,
|
| 640 |
+
proj_t: torch.Tensor,
|
| 641 |
+
idx: int,
|
| 642 |
+
H: int,
|
| 643 |
+
W: int,
|
| 644 |
+
) -> np.ndarray:
|
| 645 |
+
"""Fuse points consistent across multiple source views for a reference index."""
|
| 646 |
+
device, dtype = depths_t.device, depths_t.dtype
|
| 647 |
+
pc_buff = torch.zeros((3, H, W), device=device, dtype=dtype)
|
| 648 |
+
val_cnt = torch.zeros((1, H, W), device=device, dtype=dtype)
|
| 649 |
+
|
| 650 |
+
j = 0
|
| 651 |
+
batch_size = 20
|
| 652 |
+
tot_frame = depths_t.shape[0]
|
| 653 |
+
while True:
|
| 654 |
+
ref_pc, pcs, dist = self._filter_depth(
|
| 655 |
+
ref_depth=depths_t[idx : idx + 1],
|
| 656 |
+
src_depths=depths_t[j : min(j + batch_size, tot_frame)],
|
| 657 |
+
ref_proj=proj_t[idx : idx + 1],
|
| 658 |
+
src_projs=proj_t[j : min(j + batch_size, tot_frame)],
|
| 659 |
+
)
|
| 660 |
+
masks = (dist < self.dist_thresh).float()
|
| 661 |
+
masked_pc = pcs * masks
|
| 662 |
+
pc_buff += masked_pc.sum(dim=0, keepdim=False)
|
| 663 |
+
val_cnt += masks.sum(dim=0, keepdim=False)
|
| 664 |
+
j += batch_size
|
| 665 |
+
if j >= tot_frame:
|
| 666 |
+
break
|
| 667 |
+
|
| 668 |
+
final_mask = (val_cnt >= self.num_consist).squeeze(0)
|
| 669 |
+
avg_points = torch.div(pc_buff, val_cnt).permute(1, 2, 0)
|
| 670 |
+
final_pc = self._extract_points(avg_points, final_mask)
|
| 671 |
+
return final_pc
|
| 672 |
+
|
| 673 |
+
def _cap_points(self, points: np.ndarray, max_points: int) -> np.ndarray:
|
| 674 |
+
"""Downsample points if exceeding max count."""
|
| 675 |
+
if len(points) <= max_points:
|
| 676 |
+
return points
|
| 677 |
+
# Use fixed seed for reproducibility
|
| 678 |
+
rng = np.random.default_rng(seed=42)
|
| 679 |
+
random_idx = rng.choice(len(points), max_points, replace=False)
|
| 680 |
+
return points[random_idx]
|
| 681 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/dtu64.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
DTU-64 Dataset implementation for POSE EVALUATION ONLY.
|
| 17 |
+
|
| 18 |
+
This is a subset of DTU with 64 images per scene, specifically designed for
|
| 19 |
+
camera pose estimation evaluation. It does NOT support 3D reconstruction.
|
| 20 |
+
|
| 21 |
+
Note: GT depth loading is not implemented as it's not needed for pose evaluation.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import glob
|
| 25 |
+
import os
|
| 26 |
+
from typing import Dict as TDict
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
from addict import Dict
|
| 30 |
+
|
| 31 |
+
from depth_anything_3.bench.dataset import Dataset
|
| 32 |
+
from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
|
| 33 |
+
from depth_anything_3.utils.constants import (
|
| 34 |
+
DTU64_CAMERA_ROOT,
|
| 35 |
+
DTU64_EVAL_DATA_ROOT,
|
| 36 |
+
DTU64_SCENES,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@MV_REGISTRY.register(name="dtu64")
|
| 41 |
+
@MONO_REGISTRY.register(name="dtu64")
|
| 42 |
+
class DTU64(Dataset):
|
| 43 |
+
"""
|
| 44 |
+
DTU-64 Dataset wrapper for DepthAnything3 POSE EVALUATION ONLY.
|
| 45 |
+
|
| 46 |
+
This dataset is a subset of DTU with 64 images per scene.
|
| 47 |
+
It is specifically designed for camera pose estimation evaluation
|
| 48 |
+
and does NOT support 3D reconstruction evaluation.
|
| 49 |
+
|
| 50 |
+
Dataset structure:
|
| 51 |
+
DTU/scans/
|
| 52 |
+
├── {scene}/
|
| 53 |
+
│ └── image/ # RGB images (64 per scene)
|
| 54 |
+
└── Cameras/
|
| 55 |
+
└── {idx}_cam.txt # Camera parameters
|
| 56 |
+
|
| 57 |
+
Supported modes:
|
| 58 |
+
- pose: Camera pose estimation evaluation
|
| 59 |
+
|
| 60 |
+
NOT supported:
|
| 61 |
+
- recon_unposed: 3D reconstruction (no GT depth available)
|
| 62 |
+
- recon_posed: 3D reconstruction (no GT depth available)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
data_root = DTU64_EVAL_DATA_ROOT
|
| 66 |
+
camera_root = DTU64_CAMERA_ROOT
|
| 67 |
+
SCENES = DTU64_SCENES
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self._scene_cache = {}
|
| 72 |
+
|
| 73 |
+
# ------------------------------
|
| 74 |
+
# Camera file parsing
|
| 75 |
+
# ------------------------------
|
| 76 |
+
|
| 77 |
+
def read_cam_file(self, filename: str) -> tuple:
|
| 78 |
+
"""
|
| 79 |
+
Read DTU camera file containing extrinsics and intrinsics.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
filename: Path to camera text file
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Tuple of (intrinsics [3,3], extrinsics [4,4])
|
| 86 |
+
"""
|
| 87 |
+
with open(filename) as f:
|
| 88 |
+
lines = [line.rstrip() for line in f.readlines()]
|
| 89 |
+
# extrinsics: line [1,5), 4x4 matrix
|
| 90 |
+
extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4))
|
| 91 |
+
# intrinsics: line [7-10), 3x3 matrix
|
| 92 |
+
intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3))
|
| 93 |
+
return intrinsics, extrinsics
|
| 94 |
+
|
| 95 |
+
# ------------------------------
|
| 96 |
+
# Public API
|
| 97 |
+
# ------------------------------
|
| 98 |
+
|
| 99 |
+
def get_data(self, scene: str) -> Dict:
|
| 100 |
+
"""
|
| 101 |
+
Collect per-view image paths, intrinsics/extrinsics for a scene.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
scene: Scene identifier (e.g., "scan105")
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Dict with:
|
| 108 |
+
- image_files: List[str] - paths to images (64 per scene)
|
| 109 |
+
- extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
|
| 110 |
+
- intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
|
| 111 |
+
- aux: Dict (empty for this dataset)
|
| 112 |
+
"""
|
| 113 |
+
if scene in self._scene_cache:
|
| 114 |
+
return self._scene_cache[scene]
|
| 115 |
+
|
| 116 |
+
rgb_folder = os.path.join(self.data_root, scene, "image")
|
| 117 |
+
|
| 118 |
+
# Get all PNG files sorted
|
| 119 |
+
files = sorted(glob.glob(os.path.join(rgb_folder, "*.png")))
|
| 120 |
+
|
| 121 |
+
# Reorder: place index 33 first (reference view convention)
|
| 122 |
+
if len(files) > 33:
|
| 123 |
+
files = [files[33]] + files[:33] + files[34:]
|
| 124 |
+
|
| 125 |
+
out = Dict({
|
| 126 |
+
"image_files": [],
|
| 127 |
+
"extrinsics": [],
|
| 128 |
+
"intrinsics": [],
|
| 129 |
+
"aux": Dict({}),
|
| 130 |
+
})
|
| 131 |
+
|
| 132 |
+
for rgb_file in files:
|
| 133 |
+
basename = os.path.basename(rgb_file)
|
| 134 |
+
# File naming: "00000033.png" -> cam_idx = 33
|
| 135 |
+
file_idx = basename.split(".")[0]
|
| 136 |
+
cam_idx = int(file_idx)
|
| 137 |
+
|
| 138 |
+
# Camera file path
|
| 139 |
+
cam_file = os.path.join(self.camera_root, f"{cam_idx:0>8}_cam.txt")
|
| 140 |
+
|
| 141 |
+
if not os.path.exists(cam_file):
|
| 142 |
+
print(f"[DTU-64] Warning: Camera file not found: {cam_file}")
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
intrinsics, extrinsics = self.read_cam_file(cam_file)
|
| 146 |
+
|
| 147 |
+
out.image_files.append(rgb_file)
|
| 148 |
+
out.extrinsics.append(extrinsics)
|
| 149 |
+
out.intrinsics.append(intrinsics)
|
| 150 |
+
|
| 151 |
+
out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
|
| 152 |
+
out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
|
| 153 |
+
|
| 154 |
+
print(f"[DTU-64] {scene}: {len(out.image_files)} images (pose evaluation only)")
|
| 155 |
+
|
| 156 |
+
self._scene_cache[scene] = out
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
|
| 160 |
+
"""
|
| 161 |
+
NOT SUPPORTED for DTU-64.
|
| 162 |
+
|
| 163 |
+
DTU-64 is only for pose evaluation, not 3D reconstruction.
|
| 164 |
+
"""
|
| 165 |
+
raise NotImplementedError(
|
| 166 |
+
"DTU-64 dataset is for POSE EVALUATION ONLY. "
|
| 167 |
+
"3D reconstruction evaluation is not supported. "
|
| 168 |
+
"Use the standard 'dtu' dataset for 3D reconstruction evaluation."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 172 |
+
"""
|
| 173 |
+
NOT SUPPORTED for DTU-64.
|
| 174 |
+
|
| 175 |
+
DTU-64 is only for pose evaluation, not 3D reconstruction.
|
| 176 |
+
"""
|
| 177 |
+
raise NotImplementedError(
|
| 178 |
+
"DTU-64 dataset is for POSE EVALUATION ONLY. "
|
| 179 |
+
"3D reconstruction (fuse3d) is not supported. "
|
| 180 |
+
"Use the standard 'dtu' dataset for 3D reconstruction."
|
| 181 |
+
)
|
| 182 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/eth3d.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
ETH3D Benchmark dataset implementation.
|
| 17 |
+
|
| 18 |
+
ETH3D is a multi-view stereo benchmark with high-resolution images and
|
| 19 |
+
accurate ground truth geometry from laser scanning.
|
| 20 |
+
Reference: https://www.eth3d.net/
|
| 21 |
+
|
| 22 |
+
Evaluation metrics:
|
| 23 |
+
- 3D reconstruction: Accuracy, Completeness, F-score
|
| 24 |
+
- Camera pose estimation: AUC metrics
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import glob
|
| 28 |
+
import os
|
| 29 |
+
from typing import Dict as TDict, List, Optional
|
| 30 |
+
|
| 31 |
+
import cv2
|
| 32 |
+
import numpy as np
|
| 33 |
+
import open3d as o3d
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
from addict import Dict
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
|
| 40 |
+
from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
|
| 41 |
+
from depth_anything_3.bench.utils import (
|
| 42 |
+
create_tsdf_volume,
|
| 43 |
+
evaluate_3d_reconstruction,
|
| 44 |
+
fuse_depth_to_tsdf,
|
| 45 |
+
quat2rotmat,
|
| 46 |
+
sample_points_from_mesh,
|
| 47 |
+
)
|
| 48 |
+
from depth_anything_3.utils.constants import (
|
| 49 |
+
ETH3D_DOWN_SAMPLE,
|
| 50 |
+
ETH3D_EVAL_DATA_ROOT,
|
| 51 |
+
ETH3D_EVAL_THRESHOLD,
|
| 52 |
+
ETH3D_FILTER_KEYS,
|
| 53 |
+
ETH3D_MAX_DEPTH,
|
| 54 |
+
ETH3D_SAMPLING_NUMBER,
|
| 55 |
+
ETH3D_SCENES,
|
| 56 |
+
ETH3D_SDF_TRUNC,
|
| 57 |
+
ETH3D_VOXEL_LENGTH,
|
| 58 |
+
)
|
| 59 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@MV_REGISTRY.register(name="eth3d")
|
| 63 |
+
@MONO_REGISTRY.register(name="eth3d")
|
| 64 |
+
class ETH3D(Dataset):
|
| 65 |
+
"""
|
| 66 |
+
ETH3D Benchmark dataset wrapper for DepthAnything3 evaluation.
|
| 67 |
+
|
| 68 |
+
Supports:
|
| 69 |
+
- Camera pose estimation evaluation (AUC metrics)
|
| 70 |
+
- 3D reconstruction evaluation (Accuracy, Completeness, F-score)
|
| 71 |
+
- TSDF-based point cloud fusion
|
| 72 |
+
|
| 73 |
+
Dataset structure:
|
| 74 |
+
eth3d/multiview/
|
| 75 |
+
├── scene_name/
|
| 76 |
+
│ ├── images/ # RGB images
|
| 77 |
+
│ ├── dslr_calibration_jpg/
|
| 78 |
+
│ │ ├── cameras.txt # Camera intrinsics
|
| 79 |
+
│ │ └── images.txt # Camera poses
|
| 80 |
+
│ ├── combined_mesh.ply # Ground truth mesh
|
| 81 |
+
│ └── ground_truth_depth/ # GT depth maps (optional)
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
data_root = ETH3D_EVAL_DATA_ROOT
|
| 85 |
+
SCENES = ETH3D_SCENES
|
| 86 |
+
|
| 87 |
+
# Evaluation hyperparameters from constants
|
| 88 |
+
max_depth = ETH3D_MAX_DEPTH
|
| 89 |
+
sampling_number = ETH3D_SAMPLING_NUMBER
|
| 90 |
+
voxel_length = ETH3D_VOXEL_LENGTH
|
| 91 |
+
sdf_trunc = ETH3D_SDF_TRUNC
|
| 92 |
+
eval_threshold = ETH3D_EVAL_THRESHOLD
|
| 93 |
+
down_sample = ETH3D_DOWN_SAMPLE
|
| 94 |
+
|
| 95 |
+
def __init__(self):
|
| 96 |
+
super().__init__()
|
| 97 |
+
# Pre-load scene data for efficiency
|
| 98 |
+
self._scene_cache = {}
|
| 99 |
+
|
| 100 |
+
# ------------------------------
|
| 101 |
+
# Camera file parsing
|
| 102 |
+
# ------------------------------
|
| 103 |
+
|
| 104 |
+
def _parse_cameras_txt(self, filepath: str) -> dict:
|
| 105 |
+
"""
|
| 106 |
+
Parse COLMAP-style cameras.txt file.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Dict mapping camera_id to intrinsic parameters
|
| 110 |
+
"""
|
| 111 |
+
camera_dict = {}
|
| 112 |
+
with open(filepath) as f:
|
| 113 |
+
lines = f.readlines()
|
| 114 |
+
for line in lines[3:]: # Skip header
|
| 115 |
+
line = line.strip()
|
| 116 |
+
if not line or line.startswith("#"):
|
| 117 |
+
continue
|
| 118 |
+
parts = line.split()
|
| 119 |
+
if len(parts) < 8:
|
| 120 |
+
continue
|
| 121 |
+
cam_id = parts[0]
|
| 122 |
+
# Format: ID, MODEL, WIDTH, HEIGHT, fx, fy, cx, cy, [distortion params...]
|
| 123 |
+
camera_dict[cam_id] = {
|
| 124 |
+
"width": float(parts[2]),
|
| 125 |
+
"height": float(parts[3]),
|
| 126 |
+
"fx": float(parts[4]),
|
| 127 |
+
"fy": float(parts[5]),
|
| 128 |
+
"cx": float(parts[6]),
|
| 129 |
+
"cy": float(parts[7]),
|
| 130 |
+
}
|
| 131 |
+
return camera_dict
|
| 132 |
+
|
| 133 |
+
def _parse_images_txt(self, filepath: str) -> dict:
|
| 134 |
+
"""
|
| 135 |
+
Parse COLMAP-style images.txt file.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Dict mapping image path to pose parameters
|
| 139 |
+
"""
|
| 140 |
+
pose_dict = {}
|
| 141 |
+
with open(filepath) as f:
|
| 142 |
+
lines = f.readlines()
|
| 143 |
+
for idx, line in enumerate(lines[4:]): # Skip header
|
| 144 |
+
line = line.strip()
|
| 145 |
+
if not line or line.startswith("#"):
|
| 146 |
+
continue
|
| 147 |
+
# Every other line contains pose info
|
| 148 |
+
if idx % 2 == 0:
|
| 149 |
+
parts = line.split()
|
| 150 |
+
if len(parts) < 10:
|
| 151 |
+
continue
|
| 152 |
+
# Format: IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME
|
| 153 |
+
image_id = parts[0]
|
| 154 |
+
qw, qx, qy, qz = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
|
| 155 |
+
tx, ty, tz = float(parts[5]), float(parts[6]), float(parts[7])
|
| 156 |
+
camera_id = parts[8]
|
| 157 |
+
name = parts[9]
|
| 158 |
+
pose_dict[name] = {
|
| 159 |
+
"image_id": image_id,
|
| 160 |
+
"quat": [qw, qx, qy, qz],
|
| 161 |
+
"trans": [tx, ty, tz],
|
| 162 |
+
"camera_id": camera_id,
|
| 163 |
+
}
|
| 164 |
+
return pose_dict
|
| 165 |
+
|
| 166 |
+
def _should_filter_image(self, scene: str, image_name: str) -> bool:
|
| 167 |
+
"""Check if image should be filtered out based on known problematic views."""
|
| 168 |
+
filter_keys = ETH3D_FILTER_KEYS.get(scene, [])
|
| 169 |
+
for key in filter_keys:
|
| 170 |
+
if image_name.endswith(key):
|
| 171 |
+
return True
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
# ------------------------------
|
| 175 |
+
# Public API
|
| 176 |
+
# ------------------------------
|
| 177 |
+
|
| 178 |
+
def get_data(self, scene: str) -> Dict:
|
| 179 |
+
"""
|
| 180 |
+
Collect per-view image paths, intrinsics/extrinsics for a scene.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
scene: Scene identifier (e.g., "courtyard")
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Dict with:
|
| 187 |
+
- image_files: List[str] - paths to images
|
| 188 |
+
- extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
|
| 189 |
+
- intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
|
| 190 |
+
- aux: Dict with gt_mesh_path
|
| 191 |
+
"""
|
| 192 |
+
# Check cache
|
| 193 |
+
if scene in self._scene_cache:
|
| 194 |
+
return self._scene_cache[scene]
|
| 195 |
+
|
| 196 |
+
scene_dir = os.path.join(self.data_root, scene)
|
| 197 |
+
|
| 198 |
+
# Parse camera files
|
| 199 |
+
cameras_file = os.path.join(scene_dir, "dslr_calibration_jpg", "cameras.txt")
|
| 200 |
+
images_file = os.path.join(scene_dir, "dslr_calibration_jpg", "images.txt")
|
| 201 |
+
camera_dict = self._parse_cameras_txt(cameras_file)
|
| 202 |
+
pose_dict = self._parse_images_txt(images_file)
|
| 203 |
+
|
| 204 |
+
# Ground truth mesh path
|
| 205 |
+
gt_mesh_path = os.path.join(scene_dir, "combined_mesh.ply")
|
| 206 |
+
|
| 207 |
+
out = Dict({
|
| 208 |
+
"image_files": [],
|
| 209 |
+
"extrinsics": [],
|
| 210 |
+
"intrinsics": [],
|
| 211 |
+
"aux": Dict({
|
| 212 |
+
"gt_mesh_path": gt_mesh_path,
|
| 213 |
+
"heights": [],
|
| 214 |
+
"widths": [],
|
| 215 |
+
}),
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
# Process each image (preserve original order from images.txt)
|
| 219 |
+
filtered_count = 0
|
| 220 |
+
for image_name, pose_info in pose_dict.items():
|
| 221 |
+
# Filter problematic views
|
| 222 |
+
if self._should_filter_image(scene, image_name):
|
| 223 |
+
filtered_count += 1
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
image_path = os.path.join(scene_dir, "images", image_name)
|
| 227 |
+
if not os.path.exists(image_path):
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
cam_info = camera_dict.get(pose_info["camera_id"])
|
| 231 |
+
if cam_info is None:
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
# Build intrinsics matrix
|
| 235 |
+
ixt = np.array([
|
| 236 |
+
[cam_info["fx"], 0, cam_info["cx"]],
|
| 237 |
+
[0, cam_info["fy"], cam_info["cy"]],
|
| 238 |
+
[0, 0, 1],
|
| 239 |
+
], dtype=np.float32)
|
| 240 |
+
|
| 241 |
+
# Build extrinsics matrix (world-to-camera)
|
| 242 |
+
# COLMAP format: world point -> camera point
|
| 243 |
+
rot = quat2rotmat(pose_info["quat"])
|
| 244 |
+
ext = np.eye(4, dtype=np.float32)
|
| 245 |
+
ext[:3, :3] = rot
|
| 246 |
+
ext[:3, 3] = pose_info["trans"]
|
| 247 |
+
|
| 248 |
+
out.image_files.append(image_path)
|
| 249 |
+
out.extrinsics.append(ext)
|
| 250 |
+
out.intrinsics.append(ixt)
|
| 251 |
+
out.aux.heights.append(cam_info["height"])
|
| 252 |
+
out.aux.widths.append(cam_info["width"])
|
| 253 |
+
|
| 254 |
+
out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
|
| 255 |
+
out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
|
| 256 |
+
|
| 257 |
+
# Print scene info
|
| 258 |
+
total_images = len(pose_dict)
|
| 259 |
+
used_images = len(out.image_files)
|
| 260 |
+
print(f"[ETH3D] {scene}: {used_images}/{total_images} images "
|
| 261 |
+
f"(filtered {filtered_count}, missing {total_images - used_images - filtered_count})")
|
| 262 |
+
|
| 263 |
+
if used_images < 3:
|
| 264 |
+
print(f"[ETH3D] ⚠️ WARNING: {scene} has only {used_images} images - evaluation may fail!")
|
| 265 |
+
|
| 266 |
+
# Cache result
|
| 267 |
+
self._scene_cache[scene] = out
|
| 268 |
+
return out
|
| 269 |
+
|
| 270 |
+
def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
|
| 271 |
+
"""
|
| 272 |
+
Evaluate fused point cloud against ETH3D ground truth mesh.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
scene: Scene identifier
|
| 276 |
+
fuse_path: Path to fused point cloud (.ply)
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Dict with metrics: acc, comp, overall, precision, recall, fscore
|
| 280 |
+
"""
|
| 281 |
+
gt_data = self.get_data(scene)
|
| 282 |
+
gt_mesh_path = gt_data.aux.gt_mesh_path
|
| 283 |
+
|
| 284 |
+
# Load and sample ground truth mesh
|
| 285 |
+
gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
|
| 286 |
+
gt_pcd = sample_points_from_mesh(gt_mesh, self.sampling_number)
|
| 287 |
+
|
| 288 |
+
# Load predicted point cloud
|
| 289 |
+
pred_pcd = o3d.io.read_point_cloud(fuse_path)
|
| 290 |
+
|
| 291 |
+
# Evaluate using shared utility function
|
| 292 |
+
metrics = evaluate_3d_reconstruction(
|
| 293 |
+
pred_pcd,
|
| 294 |
+
gt_pcd,
|
| 295 |
+
threshold=self.eval_threshold,
|
| 296 |
+
down_sample=self.down_sample,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
return metrics
|
| 300 |
+
|
| 301 |
+
def _load_gt_meta(self, result_path: str) -> Dict:
|
| 302 |
+
"""
|
| 303 |
+
Load saved GT meta (extrinsics, intrinsics, image_files) for fusion.
|
| 304 |
+
|
| 305 |
+
This is needed when frames are sampled, so fuse3d uses the correct
|
| 306 |
+
(sampled) GT instead of full dataset GT.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
result_path: Path to npz file (used to derive gt_meta.npz path)
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Dict with GT data, or None if gt_meta.npz doesn't exist
|
| 313 |
+
"""
|
| 314 |
+
# gt_meta.npz is in the same exports/ directory as results.npz
|
| 315 |
+
export_dir = os.path.dirname(result_path) # exports/mini_npz/
|
| 316 |
+
gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
|
| 317 |
+
|
| 318 |
+
if os.path.exists(gt_meta_path):
|
| 319 |
+
data = np.load(gt_meta_path, allow_pickle=True)
|
| 320 |
+
return Dict({
|
| 321 |
+
"extrinsics": data["extrinsics"],
|
| 322 |
+
"intrinsics": data["intrinsics"],
|
| 323 |
+
"image_files": data["image_files"] if "image_files" in data else None,
|
| 324 |
+
})
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 328 |
+
"""
|
| 329 |
+
Fuse per-view depths into a point cloud using TSDF fusion.
|
| 330 |
+
|
| 331 |
+
Pipeline:
|
| 332 |
+
1. Load original images (keep original size)
|
| 333 |
+
2. Resize depth to original image size (nearest interpolation)
|
| 334 |
+
3. Adjust intrinsics to original image size
|
| 335 |
+
4. Apply scale alignment and mask invalid depths
|
| 336 |
+
5. TSDF fusion
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
scene: Scene identifier
|
| 340 |
+
result_path: Path to npz file with predicted depths/poses
|
| 341 |
+
fuse_path: Output path for fused point cloud (.ply)
|
| 342 |
+
mode: "recon_unposed" or "recon_posed"
|
| 343 |
+
"""
|
| 344 |
+
# Try to load saved GT meta (handles frame sampling)
|
| 345 |
+
gt_meta = self._load_gt_meta(result_path)
|
| 346 |
+
if gt_meta is not None:
|
| 347 |
+
gt_data = gt_meta
|
| 348 |
+
else:
|
| 349 |
+
gt_data = self.get_data(scene)
|
| 350 |
+
_wait_for_file_ready(result_path)
|
| 351 |
+
pred_data = Dict({k: v for k, v in np.load(result_path).items()})
|
| 352 |
+
|
| 353 |
+
# Load original images (keep original size)
|
| 354 |
+
images = []
|
| 355 |
+
orig_sizes = [] # (H, W) for each image
|
| 356 |
+
for img_path in gt_data.image_files:
|
| 357 |
+
img = cv2.imread(img_path)
|
| 358 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 359 |
+
images.append(img)
|
| 360 |
+
orig_sizes.append((img.shape[0], img.shape[1]))
|
| 361 |
+
|
| 362 |
+
# Prepare depths, intrinsics, extrinsics with resize to original size
|
| 363 |
+
if mode == "recon_unposed":
|
| 364 |
+
depths, intrinsics, extrinsics = self._prep_unposed(
|
| 365 |
+
pred_data, gt_data, orig_sizes, scene=scene
|
| 366 |
+
)
|
| 367 |
+
elif mode == "recon_posed":
|
| 368 |
+
depths, intrinsics, extrinsics = self._prep_posed(
|
| 369 |
+
pred_data, gt_data, orig_sizes, scene=scene
|
| 370 |
+
)
|
| 371 |
+
else:
|
| 372 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 373 |
+
|
| 374 |
+
images = np.stack(images, axis=0)
|
| 375 |
+
|
| 376 |
+
# Create TSDF volume and fuse
|
| 377 |
+
volume = create_tsdf_volume(
|
| 378 |
+
voxel_length=self.voxel_length,
|
| 379 |
+
sdf_trunc=self.sdf_trunc,
|
| 380 |
+
)
|
| 381 |
+
mesh = fuse_depth_to_tsdf(
|
| 382 |
+
volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Sample points from mesh
|
| 386 |
+
pcd = sample_points_from_mesh(mesh, self.sampling_number)
|
| 387 |
+
|
| 388 |
+
# Save point cloud
|
| 389 |
+
os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
|
| 390 |
+
o3d.io.write_point_cloud(fuse_path, pcd)
|
| 391 |
+
|
| 392 |
+
# ------------------------------
|
| 393 |
+
# Private helpers
|
| 394 |
+
# ------------------------------
|
| 395 |
+
|
| 396 |
+
def _prep_unposed(
|
| 397 |
+
self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str = None
|
| 398 |
+
) -> tuple:
|
| 399 |
+
"""
|
| 400 |
+
Prepare depths/intrinsics/extrinsics for recon_unposed mode.
|
| 401 |
+
|
| 402 |
+
Pipeline:
|
| 403 |
+
1. Umeyama scale alignment
|
| 404 |
+
2. Load GT mask for each frame
|
| 405 |
+
3. Resize depth to original image size (nearest)
|
| 406 |
+
4. Apply GT mask BEFORE scale
|
| 407 |
+
5. Apply scale
|
| 408 |
+
6. Adjust intrinsics to original image size
|
| 409 |
+
"""
|
| 410 |
+
# Scale alignment with fixed random_state for reproducibility
|
| 411 |
+
_, _, scale, extrinsics = align_poses_umeyama(
|
| 412 |
+
gt_data.extrinsics.copy(),
|
| 413 |
+
pred_data.extrinsics.copy(),
|
| 414 |
+
return_aligned=True,
|
| 415 |
+
ransac=True,
|
| 416 |
+
random_state=42,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Get model output size
|
| 420 |
+
model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
|
| 421 |
+
|
| 422 |
+
# Process each frame
|
| 423 |
+
depths_out = []
|
| 424 |
+
intrinsics_out = []
|
| 425 |
+
for i in range(len(pred_data.depth)):
|
| 426 |
+
orig_h, orig_w = orig_sizes[i]
|
| 427 |
+
image_name = os.path.basename(gt_data.image_files[i])
|
| 428 |
+
|
| 429 |
+
# Resize depth to original image size (nearest interpolation)
|
| 430 |
+
depth = cv2.resize(
|
| 431 |
+
pred_data.depth[i],
|
| 432 |
+
(orig_w, orig_h),
|
| 433 |
+
interpolation=cv2.INTER_NEAREST,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# Load GT mask (apply BEFORE scale)
|
| 437 |
+
gt_zero_mask = None
|
| 438 |
+
if scene is not None:
|
| 439 |
+
gt_zero_mask = self._load_gt_mask(scene, image_name, (orig_h, orig_w))
|
| 440 |
+
|
| 441 |
+
# Mask invalid depths BEFORE scale
|
| 442 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 443 |
+
|
| 444 |
+
# Apply scale AFTER mask
|
| 445 |
+
depth = depth * scale
|
| 446 |
+
|
| 447 |
+
# Adjust intrinsics to original image size
|
| 448 |
+
h_ratio = orig_h / model_h
|
| 449 |
+
w_ratio = orig_w / model_w
|
| 450 |
+
ixt = pred_data.intrinsics[i].copy()
|
| 451 |
+
ixt[0, :] *= w_ratio # fx, 0, cx
|
| 452 |
+
ixt[1, :] *= h_ratio # 0, fy, cy
|
| 453 |
+
|
| 454 |
+
depths_out.append(depth)
|
| 455 |
+
intrinsics_out.append(ixt)
|
| 456 |
+
|
| 457 |
+
return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
|
| 458 |
+
|
| 459 |
+
def _prep_posed(
|
| 460 |
+
self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str = None
|
| 461 |
+
) -> tuple:
|
| 462 |
+
"""
|
| 463 |
+
Prepare depths/intrinsics/extrinsics for recon_posed mode.
|
| 464 |
+
|
| 465 |
+
Uses GT intrinsics/extrinsics but aligns depth scale via Umeyama.
|
| 466 |
+
Depth is resized to original image size.
|
| 467 |
+
"""
|
| 468 |
+
# Scale alignment with fixed random_state for reproducibility
|
| 469 |
+
_, _, scale, _ = align_poses_umeyama(
|
| 470 |
+
gt_data.extrinsics.copy(),
|
| 471 |
+
pred_data.extrinsics.copy(),
|
| 472 |
+
return_aligned=True,
|
| 473 |
+
ransac=True,
|
| 474 |
+
random_state=42,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Process each frame
|
| 478 |
+
depths_out = []
|
| 479 |
+
for i in range(len(pred_data.depth)):
|
| 480 |
+
orig_h, orig_w = orig_sizes[i]
|
| 481 |
+
image_name = os.path.basename(gt_data.image_files[i])
|
| 482 |
+
|
| 483 |
+
# Resize depth to original image size (nearest interpolation)
|
| 484 |
+
depth = cv2.resize(
|
| 485 |
+
pred_data.depth[i],
|
| 486 |
+
(orig_w, orig_h),
|
| 487 |
+
interpolation=cv2.INTER_NEAREST,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Load GT mask (apply BEFORE scale)
|
| 491 |
+
gt_zero_mask = None
|
| 492 |
+
if scene is not None:
|
| 493 |
+
gt_zero_mask = self._load_gt_mask(scene, image_name, (orig_h, orig_w))
|
| 494 |
+
|
| 495 |
+
# Mask invalid depths BEFORE scale
|
| 496 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 497 |
+
|
| 498 |
+
# Apply scale AFTER mask
|
| 499 |
+
depth = depth * scale
|
| 500 |
+
|
| 501 |
+
depths_out.append(depth)
|
| 502 |
+
|
| 503 |
+
# Use GT intrinsics and extrinsics (already at original image size)
|
| 504 |
+
return np.stack(depths_out), gt_data.intrinsics.copy(), gt_data.extrinsics.copy()
|
| 505 |
+
|
| 506 |
+
def _load_gt_mask(self, scene: str, image_name: str, shape: tuple) -> np.ndarray:
|
| 507 |
+
"""
|
| 508 |
+
Load GT mask for masking invalid regions.
|
| 509 |
+
|
| 510 |
+
GT mask marks occluded or invalid regions that should be excluded
|
| 511 |
+
from depth fusion and evaluation.
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
scene: Scene identifier
|
| 515 |
+
image_name: Image filename (e.g., "DSC_0307.JPG")
|
| 516 |
+
shape: (height, width) of the image
|
| 517 |
+
|
| 518 |
+
Returns:
|
| 519 |
+
Boolean mask where True = valid region to keep
|
| 520 |
+
"""
|
| 521 |
+
h, w = shape
|
| 522 |
+
|
| 523 |
+
# GT mask file path
|
| 524 |
+
gt_mask_path = os.path.join(
|
| 525 |
+
self.data_root, scene, "masks_for_images", "dslr_images",
|
| 526 |
+
image_name.replace(".JPG", ".png")
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# GT depth file path (used to determine valid depth regions)
|
| 530 |
+
gt_depth_path = os.path.join(
|
| 531 |
+
self.data_root, scene, "ground_truth_depth", "dslr_images", image_name
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Load GT depth
|
| 535 |
+
if os.path.exists(gt_depth_path):
|
| 536 |
+
gt_depth = np.fromfile(gt_depth_path, dtype=np.float32).reshape(h, w)
|
| 537 |
+
else:
|
| 538 |
+
gt_depth = np.ones((h, w), dtype=np.float32)
|
| 539 |
+
|
| 540 |
+
# Load GT mask
|
| 541 |
+
if os.path.exists(gt_mask_path):
|
| 542 |
+
gt_mask = cv2.imread(gt_mask_path, cv2.IMREAD_GRAYSCALE)
|
| 543 |
+
gt_mask = np.asarray(gt_mask)
|
| 544 |
+
else:
|
| 545 |
+
gt_mask = np.zeros((h, w), dtype=np.uint8)
|
| 546 |
+
|
| 547 |
+
# Compute zero_mask
|
| 548 |
+
# gt_mask == 1 means occluded/invalid region
|
| 549 |
+
invalid_mask_from_gt = gt_mask == 1
|
| 550 |
+
gt_depth_copy = gt_depth.copy()
|
| 551 |
+
gt_depth_copy[gt_mask == 1] = 0
|
| 552 |
+
|
| 553 |
+
invalid_mask_from_gt_depth = np.logical_or(gt_depth_copy == 0, gt_depth_copy == np.inf)
|
| 554 |
+
|
| 555 |
+
# zero_mask: valid region that should be kept
|
| 556 |
+
zero_mask = np.logical_and(
|
| 557 |
+
np.logical_not(invalid_mask_from_gt),
|
| 558 |
+
np.logical_not(invalid_mask_from_gt_depth)
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
return zero_mask
|
| 562 |
+
|
| 563 |
+
def _mask_invalid_depth(
|
| 564 |
+
self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
|
| 565 |
+
) -> np.ndarray:
|
| 566 |
+
"""
|
| 567 |
+
Mask invalid depth values by setting them to 0.
|
| 568 |
+
|
| 569 |
+
Logic:
|
| 570 |
+
1. Apply GT mask (if provided) - marks occluded/invalid regions
|
| 571 |
+
2. Mask pred invalid values (nan, inf)
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
depth: Depth map to mask
|
| 575 |
+
gt_zero_mask: Optional GT mask (True = valid region)
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
Masked depth map with invalid regions set to 0
|
| 579 |
+
"""
|
| 580 |
+
depth = depth.copy()
|
| 581 |
+
|
| 582 |
+
# Apply GT mask first (before scale)
|
| 583 |
+
if gt_zero_mask is not None:
|
| 584 |
+
# Also mask out invalid pred depth
|
| 585 |
+
pred_invalid = np.isnan(depth) | np.isinf(depth)
|
| 586 |
+
combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
|
| 587 |
+
depth = depth * combined_mask.astype(np.float32)
|
| 588 |
+
else:
|
| 589 |
+
# Fallback: only mask pred invalid values
|
| 590 |
+
invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
|
| 591 |
+
depth[invalid_mask] = 0.0
|
| 592 |
+
|
| 593 |
+
return depth
|
| 594 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/hiroom.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
HiRoom Dataset implementation.
|
| 17 |
+
|
| 18 |
+
HiRoom is an indoor RGB-D dataset containing ground truth camera poses,
|
| 19 |
+
depth maps, and fused point clouds.
|
| 20 |
+
|
| 21 |
+
Evaluation metrics:
|
| 22 |
+
- 3D reconstruction: Accuracy, Completeness, F-score
|
| 23 |
+
- Camera pose estimation: AUC metrics
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
from typing import Dict as TDict, List
|
| 28 |
+
|
| 29 |
+
import cv2
|
| 30 |
+
import numpy as np
|
| 31 |
+
import open3d as o3d
|
| 32 |
+
from addict import Dict
|
| 33 |
+
|
| 34 |
+
from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
|
| 35 |
+
from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
|
| 36 |
+
from depth_anything_3.bench.utils import (
|
| 37 |
+
create_tsdf_volume,
|
| 38 |
+
evaluate_3d_reconstruction,
|
| 39 |
+
fuse_depth_to_tsdf,
|
| 40 |
+
sample_points_from_mesh,
|
| 41 |
+
)
|
| 42 |
+
from depth_anything_3.utils.constants import (
|
| 43 |
+
HIROOM_DOWN_SAMPLE,
|
| 44 |
+
HIROOM_EVAL_DATA_ROOT,
|
| 45 |
+
HIROOM_EVAL_THRESHOLD,
|
| 46 |
+
HIROOM_GT_ROOT_PATH,
|
| 47 |
+
HIROOM_MAX_DEPTH,
|
| 48 |
+
HIROOM_SAMPLING_NUMBER,
|
| 49 |
+
HIROOM_SCENE_LIST_PATH,
|
| 50 |
+
HIROOM_SDF_TRUNC,
|
| 51 |
+
HIROOM_VOXEL_LENGTH,
|
| 52 |
+
)
|
| 53 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _load_scene_list() -> List[str]:
|
| 57 |
+
"""Load scene list from file."""
|
| 58 |
+
if os.path.exists(HIROOM_SCENE_LIST_PATH):
|
| 59 |
+
with open(HIROOM_SCENE_LIST_PATH, "r") as f:
|
| 60 |
+
return f.read().splitlines()
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@MV_REGISTRY.register(name="hiroom")
|
| 65 |
+
@MONO_REGISTRY.register(name="hiroom")
|
| 66 |
+
class HiRoomDataset(Dataset):
|
| 67 |
+
"""
|
| 68 |
+
HiRoom Dataset wrapper for DepthAnything3 evaluation.
|
| 69 |
+
|
| 70 |
+
Supports:
|
| 71 |
+
- Camera pose estimation evaluation (AUC metrics)
|
| 72 |
+
- 3D reconstruction evaluation (Accuracy, Completeness, F-score)
|
| 73 |
+
- TSDF-based point cloud fusion
|
| 74 |
+
|
| 75 |
+
Dataset structure:
|
| 76 |
+
HiRoom/
|
| 77 |
+
├── {scene_path}/
|
| 78 |
+
│ ├── image/ # RGB images
|
| 79 |
+
│ ├── depth/ # GT depth maps
|
| 80 |
+
│ ├── pose/ # Camera poses (.npy)
|
| 81 |
+
│ ├── cam_K.npy # Camera intrinsics
|
| 82 |
+
│ └── aliasing_mask/ # Aliasing masks
|
| 83 |
+
|
| 84 |
+
fused_pcd/
|
| 85 |
+
└── {scene_name}.ply # Ground truth fused point cloud
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
data_root = HIROOM_EVAL_DATA_ROOT
|
| 89 |
+
gt_root_path = HIROOM_GT_ROOT_PATH
|
| 90 |
+
SCENES = _load_scene_list()
|
| 91 |
+
|
| 92 |
+
# Evaluation hyperparameters from constants
|
| 93 |
+
max_depth = HIROOM_MAX_DEPTH
|
| 94 |
+
sampling_number = HIROOM_SAMPLING_NUMBER
|
| 95 |
+
voxel_length = HIROOM_VOXEL_LENGTH
|
| 96 |
+
sdf_trunc = HIROOM_SDF_TRUNC
|
| 97 |
+
eval_threshold = HIROOM_EVAL_THRESHOLD
|
| 98 |
+
down_sample = HIROOM_DOWN_SAMPLE
|
| 99 |
+
|
| 100 |
+
def __init__(self):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self._scene_cache = {}
|
| 103 |
+
|
| 104 |
+
# ------------------------------
|
| 105 |
+
# Public API
|
| 106 |
+
# ------------------------------
|
| 107 |
+
|
| 108 |
+
def get_data(self, scene: str) -> Dict:
|
| 109 |
+
"""
|
| 110 |
+
Collect per-view image paths, intrinsics/extrinsics for a scene.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
scene: Scene path (e.g., "xxx/yyy/zzz")
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Dict with:
|
| 117 |
+
- image_files: List[str] - paths to images
|
| 118 |
+
- extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
|
| 119 |
+
- intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
|
| 120 |
+
- aux: Dict with gt_pcd_path, gt_depth_files, aliasing_mask_files
|
| 121 |
+
"""
|
| 122 |
+
if scene in self._scene_cache:
|
| 123 |
+
return self._scene_cache[scene]
|
| 124 |
+
|
| 125 |
+
scene_dir = os.path.join(self.data_root, scene)
|
| 126 |
+
image_dir = os.path.join(scene_dir, "image")
|
| 127 |
+
|
| 128 |
+
# Get scene name for GT point cloud
|
| 129 |
+
scene_name = "-".join(scene.split("/")[-3:])
|
| 130 |
+
gt_pcd_path = os.path.join(self.gt_root_path, f"{scene_name}.ply")
|
| 131 |
+
|
| 132 |
+
# Load shared camera intrinsics
|
| 133 |
+
intrin_path = os.path.join(scene_dir, "cam_K.npy")
|
| 134 |
+
ixt_shared = np.load(intrin_path).astype(np.float32)
|
| 135 |
+
|
| 136 |
+
# Get all image names sorted
|
| 137 |
+
image_names = sorted(os.listdir(image_dir))
|
| 138 |
+
|
| 139 |
+
out = Dict({
|
| 140 |
+
"image_files": [],
|
| 141 |
+
"extrinsics": [],
|
| 142 |
+
"intrinsics": [],
|
| 143 |
+
"aux": Dict({
|
| 144 |
+
"gt_pcd_path": gt_pcd_path,
|
| 145 |
+
"gt_depth_files": [],
|
| 146 |
+
"aliasing_mask_files": [],
|
| 147 |
+
}),
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
for img_name in image_names:
|
| 151 |
+
img_path = os.path.join(image_dir, img_name)
|
| 152 |
+
frame_name = img_name.split(".")[0]
|
| 153 |
+
|
| 154 |
+
# Depth and pose paths
|
| 155 |
+
depth_path = os.path.join(scene_dir, "depth", f"{frame_name}.png")
|
| 156 |
+
pose_path = os.path.join(scene_dir, "pose", f"{frame_name}.npy")
|
| 157 |
+
aliasing_mask_path = os.path.join(scene_dir, "aliasing_mask", f"{frame_name}.png")
|
| 158 |
+
|
| 159 |
+
if not os.path.exists(pose_path):
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# Load extrinsics (world-to-camera)
|
| 163 |
+
ext = np.load(pose_path).astype(np.float32)
|
| 164 |
+
|
| 165 |
+
out.image_files.append(img_path)
|
| 166 |
+
out.extrinsics.append(ext)
|
| 167 |
+
out.intrinsics.append(ixt_shared.copy())
|
| 168 |
+
out.aux.gt_depth_files.append(depth_path)
|
| 169 |
+
out.aux.aliasing_mask_files.append(aliasing_mask_path)
|
| 170 |
+
|
| 171 |
+
out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
|
| 172 |
+
out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
|
| 173 |
+
|
| 174 |
+
print(f"[HiRoom] {scene}: {len(out.image_files)} images")
|
| 175 |
+
|
| 176 |
+
self._scene_cache[scene] = out
|
| 177 |
+
return out
|
| 178 |
+
|
| 179 |
+
def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
|
| 180 |
+
"""
|
| 181 |
+
Evaluate fused point cloud against HiRoom ground truth point cloud.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
scene: Scene identifier
|
| 185 |
+
fuse_path: Path to fused point cloud (.ply)
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Dict with metrics: acc, comp, overall, precision, recall, fscore
|
| 189 |
+
"""
|
| 190 |
+
gt_data = self.get_data(scene)
|
| 191 |
+
gt_pcd_path = gt_data.aux.gt_pcd_path
|
| 192 |
+
|
| 193 |
+
# Load ground truth point cloud
|
| 194 |
+
gt_pcd = o3d.io.read_point_cloud(gt_pcd_path)
|
| 195 |
+
|
| 196 |
+
# Load predicted point cloud
|
| 197 |
+
pred_pcd = o3d.io.read_point_cloud(fuse_path)
|
| 198 |
+
|
| 199 |
+
# Evaluate using shared utility function
|
| 200 |
+
metrics = evaluate_3d_reconstruction(
|
| 201 |
+
pred_pcd,
|
| 202 |
+
gt_pcd,
|
| 203 |
+
threshold=self.eval_threshold,
|
| 204 |
+
down_sample=self.down_sample,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return metrics
|
| 208 |
+
|
| 209 |
+
def _load_gt_meta(self, result_path: str) -> Dict:
|
| 210 |
+
"""Load saved GT meta for fusion."""
|
| 211 |
+
export_dir = os.path.dirname(result_path)
|
| 212 |
+
gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
|
| 213 |
+
|
| 214 |
+
if os.path.exists(gt_meta_path):
|
| 215 |
+
data = np.load(gt_meta_path, allow_pickle=True)
|
| 216 |
+
image_files = list(data["image_files"])
|
| 217 |
+
return Dict({
|
| 218 |
+
"extrinsics": data["extrinsics"],
|
| 219 |
+
"intrinsics": data["intrinsics"],
|
| 220 |
+
"image_files": image_files,
|
| 221 |
+
})
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 225 |
+
"""
|
| 226 |
+
Fuse per-view depths into a point cloud using TSDF fusion.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
scene: Scene identifier
|
| 230 |
+
result_path: Path to npz file with predicted depths/poses
|
| 231 |
+
fuse_path: Output path for fused point cloud (.ply)
|
| 232 |
+
mode: "recon_unposed" or "recon_posed"
|
| 233 |
+
"""
|
| 234 |
+
# Get full GT data
|
| 235 |
+
full_gt_data = self.get_data(scene)
|
| 236 |
+
|
| 237 |
+
# Try to load saved GT meta (handles frame sampling)
|
| 238 |
+
gt_meta = self._load_gt_meta(result_path)
|
| 239 |
+
if gt_meta is not None:
|
| 240 |
+
gt_data = gt_meta
|
| 241 |
+
image_indices = [
|
| 242 |
+
full_gt_data.image_files.index(f)
|
| 243 |
+
for f in gt_data.image_files
|
| 244 |
+
if f in full_gt_data.image_files
|
| 245 |
+
]
|
| 246 |
+
else:
|
| 247 |
+
gt_data = full_gt_data
|
| 248 |
+
image_indices = list(range(len(full_gt_data.image_files)))
|
| 249 |
+
|
| 250 |
+
_wait_for_file_ready(result_path)
|
| 251 |
+
pred_data = Dict({k: v for k, v in np.load(result_path).items()})
|
| 252 |
+
|
| 253 |
+
# Load images
|
| 254 |
+
images = []
|
| 255 |
+
orig_sizes = []
|
| 256 |
+
for img_idx in image_indices:
|
| 257 |
+
img_path = full_gt_data.image_files[img_idx]
|
| 258 |
+
img = cv2.imread(img_path)
|
| 259 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 260 |
+
images.append(img)
|
| 261 |
+
orig_sizes.append((img.shape[0], img.shape[1]))
|
| 262 |
+
|
| 263 |
+
images = np.stack(images, axis=0)
|
| 264 |
+
|
| 265 |
+
# Prepare depths, intrinsics, extrinsics
|
| 266 |
+
if mode == "recon_unposed":
|
| 267 |
+
depths, intrinsics, extrinsics = self._prep_unposed(
|
| 268 |
+
pred_data, gt_data, full_gt_data, image_indices, orig_sizes, scene=scene
|
| 269 |
+
)
|
| 270 |
+
elif mode == "recon_posed":
|
| 271 |
+
depths, intrinsics, extrinsics = self._prep_posed(
|
| 272 |
+
pred_data, gt_data, full_gt_data, image_indices, orig_sizes, scene=scene
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 276 |
+
|
| 277 |
+
# Create TSDF volume and fuse
|
| 278 |
+
volume = create_tsdf_volume(
|
| 279 |
+
voxel_length=self.voxel_length,
|
| 280 |
+
sdf_trunc=self.sdf_trunc,
|
| 281 |
+
)
|
| 282 |
+
mesh = fuse_depth_to_tsdf(
|
| 283 |
+
volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Sample points from mesh
|
| 287 |
+
pcd = sample_points_from_mesh(mesh, self.sampling_number)
|
| 288 |
+
|
| 289 |
+
# Save point cloud
|
| 290 |
+
os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
|
| 291 |
+
o3d.io.write_point_cloud(fuse_path, pcd)
|
| 292 |
+
|
| 293 |
+
# ------------------------------
|
| 294 |
+
# Private helpers
|
| 295 |
+
# ------------------------------
|
| 296 |
+
|
| 297 |
+
def _prep_unposed(
|
| 298 |
+
self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
|
| 299 |
+
image_indices: list, orig_sizes: list, scene: str = None
|
| 300 |
+
) -> tuple:
|
| 301 |
+
"""Prepare depths/intrinsics/extrinsics for recon_unposed mode."""
|
| 302 |
+
# Scale alignment with fixed random_state for reproducibility
|
| 303 |
+
_, _, scale, extrinsics = align_poses_umeyama(
|
| 304 |
+
gt_data.extrinsics.copy(),
|
| 305 |
+
pred_data.extrinsics.copy(),
|
| 306 |
+
return_aligned=True,
|
| 307 |
+
ransac=True,
|
| 308 |
+
random_state=42,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
|
| 312 |
+
|
| 313 |
+
depths_out = []
|
| 314 |
+
intrinsics_out = []
|
| 315 |
+
for i in range(len(pred_data.depth)):
|
| 316 |
+
orig_h, orig_w = orig_sizes[i]
|
| 317 |
+
img_idx = image_indices[i]
|
| 318 |
+
|
| 319 |
+
# Resize depth to original image size
|
| 320 |
+
depth = cv2.resize(
|
| 321 |
+
pred_data.depth[i],
|
| 322 |
+
(orig_w, orig_h),
|
| 323 |
+
interpolation=cv2.INTER_NEAREST,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Load GT mask
|
| 327 |
+
gt_zero_mask = self._load_gt_mask(
|
| 328 |
+
full_gt_data.aux.gt_depth_files[img_idx],
|
| 329 |
+
full_gt_data.aux.aliasing_mask_files[img_idx],
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Mask invalid depths BEFORE scale
|
| 333 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 334 |
+
|
| 335 |
+
# Apply scale AFTER mask
|
| 336 |
+
depth = depth * scale
|
| 337 |
+
|
| 338 |
+
# Adjust intrinsics to original image size
|
| 339 |
+
h_ratio = orig_h / model_h
|
| 340 |
+
w_ratio = orig_w / model_w
|
| 341 |
+
ixt = pred_data.intrinsics[i].copy()
|
| 342 |
+
ixt[0, :] *= w_ratio
|
| 343 |
+
ixt[1, :] *= h_ratio
|
| 344 |
+
|
| 345 |
+
depths_out.append(depth)
|
| 346 |
+
intrinsics_out.append(ixt)
|
| 347 |
+
|
| 348 |
+
return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
|
| 349 |
+
|
| 350 |
+
def _prep_posed(
|
| 351 |
+
self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
|
| 352 |
+
image_indices: list, orig_sizes: list, scene: str = None
|
| 353 |
+
) -> tuple:
|
| 354 |
+
"""Prepare depths/intrinsics/extrinsics for recon_posed mode."""
|
| 355 |
+
# Scale alignment
|
| 356 |
+
_, _, scale, _ = align_poses_umeyama(
|
| 357 |
+
gt_data.extrinsics.copy(),
|
| 358 |
+
pred_data.extrinsics.copy(),
|
| 359 |
+
return_aligned=True,
|
| 360 |
+
ransac=True,
|
| 361 |
+
random_state=42,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
depths_out = []
|
| 365 |
+
for i in range(len(pred_data.depth)):
|
| 366 |
+
orig_h, orig_w = orig_sizes[i]
|
| 367 |
+
img_idx = image_indices[i]
|
| 368 |
+
|
| 369 |
+
# Resize depth to original image size
|
| 370 |
+
depth = cv2.resize(
|
| 371 |
+
pred_data.depth[i],
|
| 372 |
+
(orig_w, orig_h),
|
| 373 |
+
interpolation=cv2.INTER_NEAREST,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Load GT mask
|
| 377 |
+
gt_zero_mask = self._load_gt_mask(
|
| 378 |
+
full_gt_data.aux.gt_depth_files[img_idx],
|
| 379 |
+
full_gt_data.aux.aliasing_mask_files[img_idx],
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Mask invalid depths BEFORE scale
|
| 383 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 384 |
+
|
| 385 |
+
# Apply scale AFTER mask
|
| 386 |
+
depth = depth * scale
|
| 387 |
+
|
| 388 |
+
depths_out.append(depth)
|
| 389 |
+
|
| 390 |
+
# Use GT intrinsics and extrinsics
|
| 391 |
+
gt_intrinsics = np.stack([full_gt_data.intrinsics[idx] for idx in image_indices])
|
| 392 |
+
gt_extrinsics = np.stack([full_gt_data.extrinsics[idx] for idx in image_indices])
|
| 393 |
+
|
| 394 |
+
return np.stack(depths_out), gt_intrinsics, gt_extrinsics
|
| 395 |
+
|
| 396 |
+
def _load_gt_mask(self, gt_depth_path: str, aliasing_mask_path: str) -> np.ndarray:
|
| 397 |
+
"""
|
| 398 |
+
Load GT depth and aliasing mask to create valid mask.
|
| 399 |
+
|
| 400 |
+
For HiRoom:
|
| 401 |
+
- GT depth is stored as 16-bit PNG, scaled to 100m range
|
| 402 |
+
- Aliasing mask marks regions to exclude
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Boolean mask where True = valid region to keep
|
| 406 |
+
"""
|
| 407 |
+
# Load GT depth
|
| 408 |
+
if os.path.exists(gt_depth_path):
|
| 409 |
+
gt_depth = cv2.imread(gt_depth_path, -1) / 65535.0 * 100.0
|
| 410 |
+
else:
|
| 411 |
+
return None
|
| 412 |
+
|
| 413 |
+
# Load aliasing mask
|
| 414 |
+
aliasing_mask = None
|
| 415 |
+
if os.path.exists(aliasing_mask_path):
|
| 416 |
+
aliasing_mask = cv2.imread(aliasing_mask_path, -1) > 0
|
| 417 |
+
|
| 418 |
+
# Valid mask: depth > 0 and not in aliasing region
|
| 419 |
+
valid_mask = gt_depth > 0
|
| 420 |
+
if aliasing_mask is not None:
|
| 421 |
+
valid_mask = np.logical_and(valid_mask, np.logical_not(aliasing_mask))
|
| 422 |
+
|
| 423 |
+
return valid_mask
|
| 424 |
+
|
| 425 |
+
def _mask_invalid_depth(
|
| 426 |
+
self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
|
| 427 |
+
) -> np.ndarray:
|
| 428 |
+
"""Mask invalid depth values by setting them to 0."""
|
| 429 |
+
depth = depth.copy()
|
| 430 |
+
|
| 431 |
+
if gt_zero_mask is not None:
|
| 432 |
+
pred_invalid = np.isnan(depth) | np.isinf(depth)
|
| 433 |
+
combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
|
| 434 |
+
depth = depth * combined_mask.astype(np.float32)
|
| 435 |
+
else:
|
| 436 |
+
invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
|
| 437 |
+
depth[invalid_mask] = 0.0
|
| 438 |
+
|
| 439 |
+
return depth
|
| 440 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/scannetpp.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
ScanNet++ Benchmark dataset implementation.
|
| 17 |
+
|
| 18 |
+
ScanNet++ is a high-quality indoor RGB-D dataset with iPhone and DSLR images,
|
| 19 |
+
ground truth camera poses from COLMAP, and high-resolution 3D meshes.
|
| 20 |
+
Reference: https://kaldir.vc.in.tum.de/scannetpp/
|
| 21 |
+
|
| 22 |
+
Evaluation metrics:
|
| 23 |
+
- 3D reconstruction: Accuracy, Completeness, F-score
|
| 24 |
+
- Camera pose estimation: AUC metrics
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import os
|
| 28 |
+
from typing import Dict as TDict
|
| 29 |
+
|
| 30 |
+
import cv2
|
| 31 |
+
import imageio
|
| 32 |
+
import numpy as np
|
| 33 |
+
import open3d as o3d
|
| 34 |
+
from addict import Dict
|
| 35 |
+
|
| 36 |
+
from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
|
| 37 |
+
from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
|
| 38 |
+
from depth_anything_3.bench.utils import (
|
| 39 |
+
create_tsdf_volume,
|
| 40 |
+
fuse_depth_to_tsdf,
|
| 41 |
+
nn_correspondance,
|
| 42 |
+
sample_points_from_mesh,
|
| 43 |
+
)
|
| 44 |
+
from depth_anything_3.utils.constants import (
|
| 45 |
+
SCANNETPP_DOWN_SAMPLE,
|
| 46 |
+
SCANNETPP_EVAL_DATA_ROOT,
|
| 47 |
+
SCANNETPP_EVAL_THRESHOLD,
|
| 48 |
+
SCANNETPP_INPUT_H,
|
| 49 |
+
SCANNETPP_INPUT_W,
|
| 50 |
+
SCANNETPP_MAX_DEPTH,
|
| 51 |
+
SCANNETPP_SAMPLING_NUMBER,
|
| 52 |
+
SCANNETPP_SCENES,
|
| 53 |
+
SCANNETPP_SDF_TRUNC,
|
| 54 |
+
SCANNETPP_VOXEL_LENGTH,
|
| 55 |
+
)
|
| 56 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 57 |
+
from depth_anything_3.utils.read_write_model import read_model
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@MV_REGISTRY.register(name="scannetpp")
|
| 61 |
+
@MONO_REGISTRY.register(name="scannetpp")
|
| 62 |
+
class ScanNetPP(Dataset):
|
| 63 |
+
"""
|
| 64 |
+
ScanNet++ Benchmark dataset wrapper for DepthAnything3 evaluation.
|
| 65 |
+
|
| 66 |
+
Supports:
|
| 67 |
+
- Camera pose estimation evaluation (AUC metrics)
|
| 68 |
+
- 3D reconstruction evaluation (Accuracy, Completeness, F-score)
|
| 69 |
+
- TSDF-based point cloud fusion
|
| 70 |
+
|
| 71 |
+
Dataset structure:
|
| 72 |
+
scannetpp/data/
|
| 73 |
+
├── {scene_id}/
|
| 74 |
+
│ ├── merge_dslr_iphone/
|
| 75 |
+
│ │ ├── colmap/sparse_render_rgb/ # COLMAP reconstruction
|
| 76 |
+
│ │ ├── images/ # RGB images
|
| 77 |
+
│ │ └── render_depth/ # GT depth maps
|
| 78 |
+
│ └── scans/
|
| 79 |
+
│ └── mesh_aligned_0.05.ply # Ground truth mesh
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
data_root = SCANNETPP_EVAL_DATA_ROOT
|
| 83 |
+
SCENES = SCANNETPP_SCENES
|
| 84 |
+
|
| 85 |
+
# Input resolution after undistortion and resize
|
| 86 |
+
input_h = SCANNETPP_INPUT_H
|
| 87 |
+
input_w = SCANNETPP_INPUT_W
|
| 88 |
+
|
| 89 |
+
# Evaluation hyperparameters from constants
|
| 90 |
+
max_depth = SCANNETPP_MAX_DEPTH
|
| 91 |
+
sampling_number = SCANNETPP_SAMPLING_NUMBER
|
| 92 |
+
voxel_length = SCANNETPP_VOXEL_LENGTH
|
| 93 |
+
sdf_trunc = SCANNETPP_SDF_TRUNC
|
| 94 |
+
eval_threshold = SCANNETPP_EVAL_THRESHOLD
|
| 95 |
+
down_sample = SCANNETPP_DOWN_SAMPLE
|
| 96 |
+
|
| 97 |
+
def __init__(self):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self._scene_cache = {}
|
| 100 |
+
|
| 101 |
+
# ------------------------------
|
| 102 |
+
# Public API
|
| 103 |
+
# ------------------------------
|
| 104 |
+
|
| 105 |
+
def get_data(self, scene: str) -> Dict:
|
| 106 |
+
"""
|
| 107 |
+
Collect per-view image paths, intrinsics/extrinsics for a scene.
|
| 108 |
+
|
| 109 |
+
Only uses iPhone images (not DSLR).
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
scene: Scene identifier (e.g., "09c1414f1b")
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dict with:
|
| 116 |
+
- image_files: List[str] - paths to images
|
| 117 |
+
- extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
|
| 118 |
+
- intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
|
| 119 |
+
- aux: Dict with gt_mesh_path, dist, roi, cam_hw, etc.
|
| 120 |
+
"""
|
| 121 |
+
if scene in self._scene_cache:
|
| 122 |
+
return self._scene_cache[scene]
|
| 123 |
+
|
| 124 |
+
input_path = os.path.join(self.data_root, scene, "merge_dslr_iphone")
|
| 125 |
+
colmap_path = os.path.join(input_path, "colmap/sparse_render_rgb")
|
| 126 |
+
image_path = os.path.join(input_path, "images")
|
| 127 |
+
depth_path_dir = os.path.join(input_path, "render_depth")
|
| 128 |
+
|
| 129 |
+
# Read COLMAP model
|
| 130 |
+
cams, images, points3d = read_model(colmap_path)
|
| 131 |
+
|
| 132 |
+
# Map image names to IDs
|
| 133 |
+
name2id = {image.name: k for k, image in images.items()}
|
| 134 |
+
names = sorted([image.name for k, image in images.items()])
|
| 135 |
+
# Only use iPhone images
|
| 136 |
+
names = [name for name in names if "iphone" in name]
|
| 137 |
+
|
| 138 |
+
gt_mesh_path = os.path.join(
|
| 139 |
+
input_path.replace("merge_dslr_iphone", "scans"), "mesh_aligned_0.05.ply"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
out = Dict({
|
| 143 |
+
"image_files": [],
|
| 144 |
+
"extrinsics": [],
|
| 145 |
+
"intrinsics": [],
|
| 146 |
+
"aux": Dict({
|
| 147 |
+
"gt_mesh_path": gt_mesh_path,
|
| 148 |
+
"dist_list": [],
|
| 149 |
+
"roi_list": [],
|
| 150 |
+
"cam_hw_list": [],
|
| 151 |
+
"ixt_raw_list": [],
|
| 152 |
+
"gt_depth_files": [],
|
| 153 |
+
}),
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
for name in names:
|
| 157 |
+
image = images[name2id[name]]
|
| 158 |
+
img_path = os.path.join(image_path, name)
|
| 159 |
+
|
| 160 |
+
if not os.path.exists(img_path):
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
# Build extrinsics (world-to-camera)
|
| 164 |
+
ext = np.eye(4, dtype=np.float32)
|
| 165 |
+
ext[:3, :3] = image.qvec2rotmat()
|
| 166 |
+
ext[:3, 3] = image.tvec
|
| 167 |
+
|
| 168 |
+
# Get camera parameters
|
| 169 |
+
cam_id = image.camera_id
|
| 170 |
+
camera = cams[cam_id]
|
| 171 |
+
cam_height, cam_width = camera.height, camera.width
|
| 172 |
+
|
| 173 |
+
# Build intrinsics
|
| 174 |
+
ixt = np.eye(3, dtype=np.float32)
|
| 175 |
+
ixt[0, 0], ixt[1, 1], ixt[0, 2], ixt[1, 2] = camera.params[:4]
|
| 176 |
+
ixt[:2, 2] -= 0.5 # COLMAP convention adjustment
|
| 177 |
+
ixt_raw = ixt.copy()
|
| 178 |
+
|
| 179 |
+
# Handle distortion (OPENCV model)
|
| 180 |
+
dist = np.zeros(5, dtype=np.float32)
|
| 181 |
+
roi = (0, 0, cam_width, cam_height)
|
| 182 |
+
if camera.model == "OPENCV":
|
| 183 |
+
dist[:4] = camera.params[4:]
|
| 184 |
+
ixt, roi = cv2.getOptimalNewCameraMatrix(
|
| 185 |
+
ixt, dist, (cam_width, cam_height), 1, (cam_width, cam_height)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Depth file path
|
| 189 |
+
frame_name = os.path.basename(name)[:-4] # Remove .jpg
|
| 190 |
+
depth_file = os.path.join(depth_path_dir, f"{frame_name}.png")
|
| 191 |
+
|
| 192 |
+
out.image_files.append(img_path)
|
| 193 |
+
out.extrinsics.append(ext)
|
| 194 |
+
out.intrinsics.append(ixt)
|
| 195 |
+
out.aux.dist_list.append(dist)
|
| 196 |
+
out.aux.roi_list.append(roi)
|
| 197 |
+
out.aux.cam_hw_list.append((cam_height, cam_width))
|
| 198 |
+
out.aux.ixt_raw_list.append(ixt_raw)
|
| 199 |
+
out.aux.gt_depth_files.append(depth_file)
|
| 200 |
+
|
| 201 |
+
out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
|
| 202 |
+
out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
|
| 203 |
+
|
| 204 |
+
print(f"[ScanNet++] {scene}: {len(out.image_files)} images")
|
| 205 |
+
|
| 206 |
+
self._scene_cache[scene] = out
|
| 207 |
+
return out
|
| 208 |
+
|
| 209 |
+
def load_image(self, img_path: str, idx: int, aux: Dict) -> np.ndarray:
|
| 210 |
+
"""
|
| 211 |
+
Load and preprocess image with undistortion and cropping.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
img_path: Path to image file
|
| 215 |
+
idx: Index of the image in the dataset
|
| 216 |
+
aux: Auxiliary data from get_data
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Preprocessed RGB image
|
| 220 |
+
"""
|
| 221 |
+
image = imageio.imread(img_path).astype(np.uint8)
|
| 222 |
+
ixt_raw = aux.ixt_raw_list[idx]
|
| 223 |
+
ixt = aux.intrinsics[idx] if hasattr(aux, 'intrinsics') else None
|
| 224 |
+
dist = aux.dist_list[idx]
|
| 225 |
+
roi = aux.roi_list[idx]
|
| 226 |
+
|
| 227 |
+
# Undistort using raw intrinsics
|
| 228 |
+
# Use the stored intrinsics from get_data for newCameraMatrix
|
| 229 |
+
stored_ixt = self._scene_cache.get(aux.scene, {}).get('intrinsics', [None])[idx] if hasattr(aux, 'scene') else None
|
| 230 |
+
if stored_ixt is None:
|
| 231 |
+
# Recompute optimal camera matrix for undistortion
|
| 232 |
+
cam_h, cam_w = aux.cam_hw_list[idx]
|
| 233 |
+
ixt_for_undistort = ixt_raw.copy()
|
| 234 |
+
ixt_for_undistort, _ = cv2.getOptimalNewCameraMatrix(
|
| 235 |
+
ixt_raw, dist, (cam_w, cam_h), 1, (cam_w, cam_h)
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
ixt_for_undistort = stored_ixt
|
| 239 |
+
|
| 240 |
+
image = cv2.undistort(image, ixt_raw, dist, newCameraMatrix=ixt_for_undistort)
|
| 241 |
+
|
| 242 |
+
# Crop to ROI
|
| 243 |
+
x, y, w, h = roi
|
| 244 |
+
image = image[y:y+h, x:x+w]
|
| 245 |
+
|
| 246 |
+
# Resize to target resolution
|
| 247 |
+
image = cv2.resize(image, (self.input_w, self.input_h), interpolation=cv2.INTER_AREA)
|
| 248 |
+
|
| 249 |
+
return image
|
| 250 |
+
|
| 251 |
+
def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
|
| 252 |
+
"""
|
| 253 |
+
Evaluate fused point cloud against ScanNet++ ground truth mesh.
|
| 254 |
+
|
| 255 |
+
Uses AABB cropping to only evaluate points within GT bounding box.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
scene: Scene identifier
|
| 259 |
+
fuse_path: Path to fused point cloud (.ply)
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
Dict with metrics: acc, comp, overall, precision, recall, fscore
|
| 263 |
+
"""
|
| 264 |
+
gt_data = self.get_data(scene)
|
| 265 |
+
gt_mesh_path = gt_data.aux.gt_mesh_path
|
| 266 |
+
|
| 267 |
+
# Load ground truth mesh and sample points
|
| 268 |
+
gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
|
| 269 |
+
gt_pcd = sample_points_from_mesh(gt_mesh, self.sampling_number)
|
| 270 |
+
|
| 271 |
+
# Load predicted point cloud
|
| 272 |
+
pred_pcd = o3d.io.read_point_cloud(fuse_path)
|
| 273 |
+
|
| 274 |
+
# Crop prediction to GT bounding box (with 0.1m margin)
|
| 275 |
+
aabb = gt_pcd.get_axis_aligned_bounding_box()
|
| 276 |
+
points = np.asarray(pred_pcd.points)
|
| 277 |
+
inside_mask = (
|
| 278 |
+
(points[:, 0] >= aabb.min_bound[0] - 0.1) &
|
| 279 |
+
(points[:, 0] <= aabb.max_bound[0] + 0.1) &
|
| 280 |
+
(points[:, 1] >= aabb.min_bound[1] - 0.1) &
|
| 281 |
+
(points[:, 1] <= aabb.max_bound[1] + 0.1) &
|
| 282 |
+
(points[:, 2] >= aabb.min_bound[2] - 0.1) &
|
| 283 |
+
(points[:, 2] <= aabb.max_bound[2] + 0.1)
|
| 284 |
+
)
|
| 285 |
+
pred_pcd = pred_pcd.select_by_index(inside_mask.nonzero()[0])
|
| 286 |
+
|
| 287 |
+
# Downsample
|
| 288 |
+
if self.down_sample > 0:
|
| 289 |
+
pred_pcd = pred_pcd.voxel_down_sample(self.down_sample)
|
| 290 |
+
gt_pcd = gt_pcd.voxel_down_sample(self.down_sample)
|
| 291 |
+
|
| 292 |
+
verts_pred = np.asarray(pred_pcd.points)
|
| 293 |
+
verts_gt = np.asarray(gt_pcd.points)
|
| 294 |
+
|
| 295 |
+
if len(verts_pred) == 0 or len(verts_gt) == 0:
|
| 296 |
+
return {
|
| 297 |
+
"acc": float("inf"),
|
| 298 |
+
"comp": float("inf"),
|
| 299 |
+
"overall": float("inf"),
|
| 300 |
+
"precision": 0.0,
|
| 301 |
+
"recall": 0.0,
|
| 302 |
+
"fscore": 0.0,
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
# Compute distances
|
| 306 |
+
dist_pred_to_gt = nn_correspondance(verts_gt, verts_pred)
|
| 307 |
+
dist_gt_to_pred = nn_correspondance(verts_pred, verts_gt)
|
| 308 |
+
|
| 309 |
+
# Compute metrics
|
| 310 |
+
accuracy = float(np.mean(dist_pred_to_gt))
|
| 311 |
+
completeness = float(np.mean(dist_gt_to_pred))
|
| 312 |
+
overall = (accuracy + completeness) / 2
|
| 313 |
+
|
| 314 |
+
precision = float(np.mean((dist_pred_to_gt < self.eval_threshold).astype(float)))
|
| 315 |
+
recall = float(np.mean((dist_gt_to_pred < self.eval_threshold).astype(float)))
|
| 316 |
+
|
| 317 |
+
if precision + recall > 0:
|
| 318 |
+
fscore = 2 * precision * recall / (precision + recall)
|
| 319 |
+
else:
|
| 320 |
+
fscore = 0.0
|
| 321 |
+
|
| 322 |
+
return {
|
| 323 |
+
"acc": accuracy,
|
| 324 |
+
"comp": completeness,
|
| 325 |
+
"overall": overall,
|
| 326 |
+
"precision": precision,
|
| 327 |
+
"recall": recall,
|
| 328 |
+
"fscore": fscore,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
def _load_gt_meta(self, result_path: str) -> Dict:
|
| 332 |
+
"""Load saved GT meta for fusion."""
|
| 333 |
+
export_dir = os.path.dirname(result_path)
|
| 334 |
+
gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
|
| 335 |
+
|
| 336 |
+
if os.path.exists(gt_meta_path):
|
| 337 |
+
data = np.load(gt_meta_path, allow_pickle=True)
|
| 338 |
+
image_files = list(data["image_files"])
|
| 339 |
+
|
| 340 |
+
# Reconstruct aux data from image files
|
| 341 |
+
return Dict({
|
| 342 |
+
"extrinsics": data["extrinsics"],
|
| 343 |
+
"intrinsics": data["intrinsics"],
|
| 344 |
+
"image_files": image_files,
|
| 345 |
+
})
|
| 346 |
+
return None
|
| 347 |
+
|
| 348 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 349 |
+
"""
|
| 350 |
+
Fuse per-view depths into a point cloud using TSDF fusion.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
scene: Scene identifier
|
| 354 |
+
result_path: Path to npz file with predicted depths/poses
|
| 355 |
+
fuse_path: Output path for fused point cloud (.ply)
|
| 356 |
+
mode: "recon_unposed" or "recon_posed"
|
| 357 |
+
"""
|
| 358 |
+
# Get GT data
|
| 359 |
+
full_gt_data = self.get_data(scene)
|
| 360 |
+
|
| 361 |
+
# Try to load saved GT meta (handles frame sampling)
|
| 362 |
+
gt_meta = self._load_gt_meta(result_path)
|
| 363 |
+
if gt_meta is not None:
|
| 364 |
+
gt_data = gt_meta
|
| 365 |
+
# Need to rebuild aux from full GT data based on image indices
|
| 366 |
+
image_indices = [
|
| 367 |
+
full_gt_data.image_files.index(f)
|
| 368 |
+
for f in gt_data.image_files
|
| 369 |
+
if f in full_gt_data.image_files
|
| 370 |
+
]
|
| 371 |
+
else:
|
| 372 |
+
gt_data = full_gt_data
|
| 373 |
+
image_indices = list(range(len(full_gt_data.image_files)))
|
| 374 |
+
|
| 375 |
+
_wait_for_file_ready(result_path)
|
| 376 |
+
pred_data = Dict({k: v for k, v in np.load(result_path).items()})
|
| 377 |
+
|
| 378 |
+
# Load and preprocess images
|
| 379 |
+
images = []
|
| 380 |
+
for idx, img_idx in enumerate(image_indices):
|
| 381 |
+
img_path = full_gt_data.image_files[img_idx]
|
| 382 |
+
image = imageio.imread(img_path).astype(np.uint8)
|
| 383 |
+
|
| 384 |
+
# Undistort and crop
|
| 385 |
+
ixt_raw = full_gt_data.aux.ixt_raw_list[img_idx]
|
| 386 |
+
ixt = full_gt_data.intrinsics[img_idx]
|
| 387 |
+
dist = full_gt_data.aux.dist_list[img_idx]
|
| 388 |
+
roi = full_gt_data.aux.roi_list[img_idx]
|
| 389 |
+
|
| 390 |
+
image = cv2.undistort(image, ixt_raw, dist, newCameraMatrix=ixt)
|
| 391 |
+
x, y, w, h = roi
|
| 392 |
+
image = image[y:y+h, x:x+w]
|
| 393 |
+
image = cv2.resize(image, (self.input_w, self.input_h), interpolation=cv2.INTER_AREA)
|
| 394 |
+
|
| 395 |
+
images.append(image)
|
| 396 |
+
|
| 397 |
+
images = np.stack(images, axis=0)
|
| 398 |
+
|
| 399 |
+
# Prepare depths, intrinsics, extrinsics
|
| 400 |
+
if mode == "recon_unposed":
|
| 401 |
+
depths, intrinsics, extrinsics = self._prep_unposed(
|
| 402 |
+
pred_data, gt_data, full_gt_data, image_indices, scene=scene
|
| 403 |
+
)
|
| 404 |
+
elif mode == "recon_posed":
|
| 405 |
+
depths, intrinsics, extrinsics = self._prep_posed(
|
| 406 |
+
pred_data, gt_data, full_gt_data, image_indices, scene=scene
|
| 407 |
+
)
|
| 408 |
+
else:
|
| 409 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 410 |
+
|
| 411 |
+
# Create TSDF volume and fuse
|
| 412 |
+
volume = create_tsdf_volume(
|
| 413 |
+
voxel_length=self.voxel_length,
|
| 414 |
+
sdf_trunc=self.sdf_trunc,
|
| 415 |
+
)
|
| 416 |
+
mesh = fuse_depth_to_tsdf(
|
| 417 |
+
volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Sample points from mesh
|
| 421 |
+
pcd = sample_points_from_mesh(mesh, self.sampling_number)
|
| 422 |
+
|
| 423 |
+
# Save point cloud
|
| 424 |
+
os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
|
| 425 |
+
o3d.io.write_point_cloud(fuse_path, pcd)
|
| 426 |
+
|
| 427 |
+
# ------------------------------
|
| 428 |
+
# Private helpers
|
| 429 |
+
# ------------------------------
|
| 430 |
+
|
| 431 |
+
def _prep_unposed(
|
| 432 |
+
self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
|
| 433 |
+
image_indices: list, scene: str = None
|
| 434 |
+
) -> tuple:
|
| 435 |
+
"""Prepare depths/intrinsics/extrinsics for recon_unposed mode."""
|
| 436 |
+
# Scale alignment with fixed random_state for reproducibility
|
| 437 |
+
_, _, scale, extrinsics = align_poses_umeyama(
|
| 438 |
+
gt_data.extrinsics.copy(),
|
| 439 |
+
pred_data.extrinsics.copy(),
|
| 440 |
+
return_aligned=True,
|
| 441 |
+
ransac=True,
|
| 442 |
+
random_state=42,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
|
| 446 |
+
|
| 447 |
+
depths_out = []
|
| 448 |
+
intrinsics_out = []
|
| 449 |
+
for i in range(len(pred_data.depth)):
|
| 450 |
+
img_idx = image_indices[i]
|
| 451 |
+
|
| 452 |
+
# Get original image size (after undistort+crop, before resize to input_h/w)
|
| 453 |
+
orig_h, orig_w = full_gt_data.aux.cam_hw_list[img_idx]
|
| 454 |
+
|
| 455 |
+
# Step 1: nearest resize to original image size
|
| 456 |
+
depth = cv2.resize(
|
| 457 |
+
pred_data.depth[i],
|
| 458 |
+
(orig_w, orig_h),
|
| 459 |
+
interpolation=cv2.INTER_NEAREST,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Step 2: linear resize to target resolution
|
| 463 |
+
depth = cv2.resize(
|
| 464 |
+
depth,
|
| 465 |
+
(self.input_w, self.input_h),
|
| 466 |
+
interpolation=cv2.INTER_LINEAR,
|
| 467 |
+
).astype(np.float32)
|
| 468 |
+
|
| 469 |
+
# Load GT depth for masking
|
| 470 |
+
gt_zero_mask = self._load_gt_mask(full_gt_data.aux.gt_depth_files[img_idx])
|
| 471 |
+
|
| 472 |
+
# Mask invalid depths BEFORE scale
|
| 473 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 474 |
+
|
| 475 |
+
# Apply scale AFTER mask
|
| 476 |
+
depth = depth * scale
|
| 477 |
+
|
| 478 |
+
# Adjust intrinsics to target resolution
|
| 479 |
+
h_ratio = self.input_h / model_h
|
| 480 |
+
w_ratio = self.input_w / model_w
|
| 481 |
+
ixt = pred_data.intrinsics[i].copy()
|
| 482 |
+
ixt[0, :] *= w_ratio
|
| 483 |
+
ixt[1, :] *= h_ratio
|
| 484 |
+
|
| 485 |
+
depths_out.append(depth)
|
| 486 |
+
intrinsics_out.append(ixt)
|
| 487 |
+
|
| 488 |
+
return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
|
| 489 |
+
|
| 490 |
+
def _prep_posed(
|
| 491 |
+
self, pred_data: Dict, gt_data: Dict, full_gt_data: Dict,
|
| 492 |
+
image_indices: list, scene: str = None
|
| 493 |
+
) -> tuple:
|
| 494 |
+
"""Prepare depths/intrinsics/extrinsics for recon_posed mode."""
|
| 495 |
+
# Scale alignment
|
| 496 |
+
_, _, scale, _ = align_poses_umeyama(
|
| 497 |
+
gt_data.extrinsics.copy(),
|
| 498 |
+
pred_data.extrinsics.copy(),
|
| 499 |
+
return_aligned=True,
|
| 500 |
+
ransac=True,
|
| 501 |
+
random_state=42,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
depths_out = []
|
| 505 |
+
intrinsics_out = []
|
| 506 |
+
extrinsics_out = []
|
| 507 |
+
|
| 508 |
+
for i in range(len(pred_data.depth)):
|
| 509 |
+
img_idx = image_indices[i]
|
| 510 |
+
|
| 511 |
+
# Get original image size (after undistort+crop, before resize to input_h/w)
|
| 512 |
+
orig_h, orig_w = full_gt_data.aux.cam_hw_list[img_idx]
|
| 513 |
+
|
| 514 |
+
# Step 1: nearest resize to original image size
|
| 515 |
+
depth = cv2.resize(
|
| 516 |
+
pred_data.depth[i],
|
| 517 |
+
(orig_w, orig_h),
|
| 518 |
+
interpolation=cv2.INTER_NEAREST,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Step 2: linear resize to target resolution
|
| 522 |
+
depth = cv2.resize(
|
| 523 |
+
depth,
|
| 524 |
+
(self.input_w, self.input_h),
|
| 525 |
+
interpolation=cv2.INTER_LINEAR,
|
| 526 |
+
).astype(np.float32)
|
| 527 |
+
|
| 528 |
+
# Load GT depth for masking
|
| 529 |
+
gt_zero_mask = self._load_gt_mask(full_gt_data.aux.gt_depth_files[img_idx])
|
| 530 |
+
|
| 531 |
+
# Mask invalid depths BEFORE scale
|
| 532 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 533 |
+
|
| 534 |
+
# Apply scale AFTER mask
|
| 535 |
+
depth = depth * scale
|
| 536 |
+
|
| 537 |
+
depths_out.append(depth)
|
| 538 |
+
|
| 539 |
+
# Get GT intrinsics and scale to target resolution
|
| 540 |
+
ixt = full_gt_data.intrinsics[img_idx].copy()
|
| 541 |
+
cam_h, cam_w = full_gt_data.aux.cam_hw_list[img_idx]
|
| 542 |
+
ixt[:2, 2] += 0.5 # Undo COLMAP convention
|
| 543 |
+
ixt[0, :] *= self.input_w / cam_w
|
| 544 |
+
ixt[1, :] *= self.input_h / cam_h
|
| 545 |
+
intrinsics_out.append(ixt)
|
| 546 |
+
|
| 547 |
+
extrinsics_out.append(full_gt_data.extrinsics[img_idx])
|
| 548 |
+
|
| 549 |
+
return np.stack(depths_out), np.stack(intrinsics_out), np.stack(extrinsics_out)
|
| 550 |
+
|
| 551 |
+
def _load_gt_mask(self, gt_depth_path: str) -> np.ndarray:
|
| 552 |
+
"""
|
| 553 |
+
Load GT depth and create valid mask.
|
| 554 |
+
|
| 555 |
+
For ScanNet++, GT depth is stored as 16-bit PNG in millimeters.
|
| 556 |
+
|
| 557 |
+
Returns:
|
| 558 |
+
Boolean mask where True = valid region to keep
|
| 559 |
+
"""
|
| 560 |
+
if not os.path.exists(gt_depth_path):
|
| 561 |
+
return None
|
| 562 |
+
|
| 563 |
+
gt_depth = imageio.imread(gt_depth_path) / 1000.0 # mm to meters
|
| 564 |
+
|
| 565 |
+
# Resize to target resolution
|
| 566 |
+
gt_depth = cv2.resize(
|
| 567 |
+
gt_depth,
|
| 568 |
+
(self.input_w, self.input_h),
|
| 569 |
+
interpolation=cv2.INTER_LINEAR,
|
| 570 |
+
).astype(np.float32)
|
| 571 |
+
|
| 572 |
+
# Valid mask: depth > 0 and not inf
|
| 573 |
+
valid_mask = np.logical_and(gt_depth > 0, gt_depth != np.inf)
|
| 574 |
+
return valid_mask
|
| 575 |
+
|
| 576 |
+
def _mask_invalid_depth(
|
| 577 |
+
self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
|
| 578 |
+
) -> np.ndarray:
|
| 579 |
+
"""Mask invalid depth values by setting them to 0."""
|
| 580 |
+
depth = depth.copy()
|
| 581 |
+
|
| 582 |
+
if gt_zero_mask is not None:
|
| 583 |
+
pred_invalid = np.isnan(depth) | np.isinf(depth)
|
| 584 |
+
combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
|
| 585 |
+
depth = depth * combined_mask.astype(np.float32)
|
| 586 |
+
else:
|
| 587 |
+
invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
|
| 588 |
+
depth[invalid_mask] = 0.0
|
| 589 |
+
|
| 590 |
+
return depth
|
| 591 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/datasets/sevenscenes.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
7Scenes Benchmark dataset implementation.
|
| 17 |
+
|
| 18 |
+
7Scenes is an indoor RGB-D dataset with ground truth camera poses and 3D meshes.
|
| 19 |
+
Reference: https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/
|
| 20 |
+
|
| 21 |
+
Evaluation metrics:
|
| 22 |
+
- 3D reconstruction: Accuracy, Completeness, F-score
|
| 23 |
+
- Camera pose estimation: AUC metrics
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
from typing import Dict as TDict
|
| 28 |
+
|
| 29 |
+
import cv2
|
| 30 |
+
import numpy as np
|
| 31 |
+
import open3d as o3d
|
| 32 |
+
from addict import Dict
|
| 33 |
+
|
| 34 |
+
from depth_anything_3.bench.dataset import Dataset, _wait_for_file_ready
|
| 35 |
+
from depth_anything_3.bench.registries import MONO_REGISTRY, MV_REGISTRY
|
| 36 |
+
from depth_anything_3.bench.utils import (
|
| 37 |
+
create_tsdf_volume,
|
| 38 |
+
evaluate_3d_reconstruction,
|
| 39 |
+
fuse_depth_to_tsdf,
|
| 40 |
+
sample_points_from_mesh,
|
| 41 |
+
)
|
| 42 |
+
from depth_anything_3.utils.constants import (
|
| 43 |
+
SEVENSCENES_CX,
|
| 44 |
+
SEVENSCENES_CY,
|
| 45 |
+
SEVENSCENES_DOWN_SAMPLE,
|
| 46 |
+
SEVENSCENES_EVAL_DATA_ROOT,
|
| 47 |
+
SEVENSCENES_EVAL_THRESHOLD,
|
| 48 |
+
SEVENSCENES_FX,
|
| 49 |
+
SEVENSCENES_FY,
|
| 50 |
+
SEVENSCENES_MAX_DEPTH,
|
| 51 |
+
SEVENSCENES_SAMPLING_NUMBER,
|
| 52 |
+
SEVENSCENES_SCENES,
|
| 53 |
+
SEVENSCENES_SDF_TRUNC,
|
| 54 |
+
SEVENSCENES_VOXEL_LENGTH,
|
| 55 |
+
)
|
| 56 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@MV_REGISTRY.register(name="7scenes")
|
| 60 |
+
@MONO_REGISTRY.register(name="7scenes")
|
| 61 |
+
class SevenScenes(Dataset):
|
| 62 |
+
"""
|
| 63 |
+
7Scenes Benchmark dataset wrapper for DepthAnything3 evaluation.
|
| 64 |
+
|
| 65 |
+
Supports:
|
| 66 |
+
- Camera pose estimation evaluation (AUC metrics)
|
| 67 |
+
- 3D reconstruction evaluation (Accuracy, Completeness, F-score)
|
| 68 |
+
- TSDF-based point cloud fusion
|
| 69 |
+
|
| 70 |
+
Dataset structure:
|
| 71 |
+
7scenes/
|
| 72 |
+
├── 7Scenes/
|
| 73 |
+
│ ├── {scene}/
|
| 74 |
+
│ │ └── seq-01/ (or seq-02 for stairs)
|
| 75 |
+
│ │ ├── frame-XXXXXX.color.png
|
| 76 |
+
│ │ ├── frame-XXXXXX.depth.png
|
| 77 |
+
│ │ └── frame-XXXXXX.pose.txt
|
| 78 |
+
│ └── meshes/
|
| 79 |
+
│ └── {scene}.ply # Ground truth mesh
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
data_root = SEVENSCENES_EVAL_DATA_ROOT
|
| 83 |
+
SCENES = SEVENSCENES_SCENES
|
| 84 |
+
|
| 85 |
+
# Evaluation hyperparameters from constants
|
| 86 |
+
max_depth = SEVENSCENES_MAX_DEPTH
|
| 87 |
+
sampling_number = SEVENSCENES_SAMPLING_NUMBER
|
| 88 |
+
voxel_length = SEVENSCENES_VOXEL_LENGTH
|
| 89 |
+
sdf_trunc = SEVENSCENES_SDF_TRUNC
|
| 90 |
+
eval_threshold = SEVENSCENES_EVAL_THRESHOLD
|
| 91 |
+
down_sample = SEVENSCENES_DOWN_SAMPLE
|
| 92 |
+
|
| 93 |
+
# Fixed camera intrinsics for all 7Scenes images
|
| 94 |
+
fx = SEVENSCENES_FX
|
| 95 |
+
fy = SEVENSCENES_FY
|
| 96 |
+
cx = SEVENSCENES_CX
|
| 97 |
+
cy = SEVENSCENES_CY
|
| 98 |
+
|
| 99 |
+
def __init__(self):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self._scene_cache = {}
|
| 102 |
+
|
| 103 |
+
# ------------------------------
|
| 104 |
+
# Public API
|
| 105 |
+
# ------------------------------
|
| 106 |
+
|
| 107 |
+
def get_data(self, scene: str) -> Dict:
|
| 108 |
+
"""
|
| 109 |
+
Collect per-view image paths, intrinsics/extrinsics for a scene.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
scene: Scene identifier (e.g., "chess")
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dict with:
|
| 116 |
+
- image_files: List[str] - paths to images
|
| 117 |
+
- extrinsics: np.ndarray [N, 4, 4] - world-to-camera transforms
|
| 118 |
+
- intrinsics: np.ndarray [N, 3, 3] - camera intrinsics
|
| 119 |
+
- aux: Dict with gt_mesh_path, gt_depth_files
|
| 120 |
+
"""
|
| 121 |
+
if scene in self._scene_cache:
|
| 122 |
+
return self._scene_cache[scene]
|
| 123 |
+
|
| 124 |
+
# Different sequence for stairs scene
|
| 125 |
+
if scene == "stairs":
|
| 126 |
+
data_folder = os.path.join(self.data_root, "7Scenes", scene, "seq-02")
|
| 127 |
+
n_imgs = 500
|
| 128 |
+
else:
|
| 129 |
+
data_folder = os.path.join(self.data_root, "7Scenes", scene, "seq-01")
|
| 130 |
+
n_imgs = 1000
|
| 131 |
+
|
| 132 |
+
gt_mesh_path = os.path.join(self.data_root, "7Scenes", "meshes", f"{scene}.ply")
|
| 133 |
+
|
| 134 |
+
# Fixed intrinsics for all images
|
| 135 |
+
ixt = np.array([
|
| 136 |
+
[self.fx, 0, self.cx],
|
| 137 |
+
[0, self.fy, self.cy],
|
| 138 |
+
[0, 0, 1],
|
| 139 |
+
], dtype=np.float32)
|
| 140 |
+
|
| 141 |
+
out = Dict({
|
| 142 |
+
"image_files": [],
|
| 143 |
+
"extrinsics": [],
|
| 144 |
+
"intrinsics": [],
|
| 145 |
+
"aux": Dict({
|
| 146 |
+
"gt_mesh_path": gt_mesh_path,
|
| 147 |
+
"gt_depth_files": [],
|
| 148 |
+
}),
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
for i in range(0, n_imgs, 1):
|
| 152 |
+
img_path = os.path.join(data_folder, f"frame-{i:06d}.color.png")
|
| 153 |
+
pose_path = os.path.join(data_folder, f"frame-{i:06d}.pose.txt")
|
| 154 |
+
depth_path = os.path.join(data_folder, f"frame-{i:06d}.depth.png")
|
| 155 |
+
|
| 156 |
+
if not os.path.exists(img_path) or not os.path.exists(pose_path):
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
# Load camera-to-world pose and convert to world-to-camera (extrinsic)
|
| 160 |
+
c2w = np.loadtxt(pose_path)
|
| 161 |
+
ext = np.linalg.inv(c2w).astype(np.float32)
|
| 162 |
+
|
| 163 |
+
out.image_files.append(img_path)
|
| 164 |
+
out.extrinsics.append(ext)
|
| 165 |
+
out.intrinsics.append(ixt.copy())
|
| 166 |
+
out.aux.gt_depth_files.append(depth_path)
|
| 167 |
+
|
| 168 |
+
out.extrinsics = np.asarray(out.extrinsics, dtype=np.float32)
|
| 169 |
+
out.intrinsics = np.asarray(out.intrinsics, dtype=np.float32)
|
| 170 |
+
|
| 171 |
+
print(f"[7Scenes] {scene}: {len(out.image_files)} images")
|
| 172 |
+
|
| 173 |
+
self._scene_cache[scene] = out
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
def eval3d(self, scene: str, fuse_path: str) -> TDict[str, float]:
|
| 177 |
+
"""
|
| 178 |
+
Evaluate fused point cloud against 7Scenes ground truth mesh.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
scene: Scene identifier
|
| 182 |
+
fuse_path: Path to fused point cloud (.ply)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Dict with metrics: acc, comp, overall, precision, recall, fscore
|
| 186 |
+
"""
|
| 187 |
+
gt_data = self.get_data(scene)
|
| 188 |
+
gt_mesh_path = gt_data.aux.gt_mesh_path
|
| 189 |
+
|
| 190 |
+
# Load and sample ground truth mesh
|
| 191 |
+
gt_mesh = o3d.io.read_triangle_mesh(gt_mesh_path)
|
| 192 |
+
gt_pcd = sample_points_from_mesh(gt_mesh, self.sampling_number)
|
| 193 |
+
|
| 194 |
+
# Load predicted point cloud
|
| 195 |
+
pred_pcd = o3d.io.read_point_cloud(fuse_path)
|
| 196 |
+
|
| 197 |
+
# Evaluate using shared utility function
|
| 198 |
+
metrics = evaluate_3d_reconstruction(
|
| 199 |
+
pred_pcd,
|
| 200 |
+
gt_pcd,
|
| 201 |
+
threshold=self.eval_threshold,
|
| 202 |
+
down_sample=self.down_sample,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return metrics
|
| 206 |
+
|
| 207 |
+
def _load_gt_meta(self, result_path: str) -> Dict:
|
| 208 |
+
"""
|
| 209 |
+
Load saved GT meta (extrinsics, intrinsics, image_files) for fusion.
|
| 210 |
+
|
| 211 |
+
This is needed when frames are sampled, so fuse3d uses the correct
|
| 212 |
+
(sampled) GT instead of full dataset GT.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
result_path: Path to npz file (used to derive gt_meta.npz path)
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Dict with GT data, or None if gt_meta.npz doesn't exist
|
| 219 |
+
"""
|
| 220 |
+
export_dir = os.path.dirname(result_path) # exports/mini_npz/
|
| 221 |
+
gt_meta_path = os.path.join(os.path.dirname(export_dir), "gt_meta.npz")
|
| 222 |
+
|
| 223 |
+
if os.path.exists(gt_meta_path):
|
| 224 |
+
data = np.load(gt_meta_path, allow_pickle=True)
|
| 225 |
+
# Build aux with gt_depth_files derived from image_files
|
| 226 |
+
image_files = list(data["image_files"])
|
| 227 |
+
gt_depth_files = [
|
| 228 |
+
img_path.replace("color", "depth").replace(".color.", ".depth.")
|
| 229 |
+
for img_path in image_files
|
| 230 |
+
]
|
| 231 |
+
return Dict({
|
| 232 |
+
"extrinsics": data["extrinsics"],
|
| 233 |
+
"intrinsics": data["intrinsics"],
|
| 234 |
+
"image_files": image_files,
|
| 235 |
+
"aux": Dict({"gt_depth_files": gt_depth_files}),
|
| 236 |
+
})
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
def fuse3d(self, scene: str, result_path: str, fuse_path: str, mode: str) -> None:
|
| 240 |
+
"""
|
| 241 |
+
Fuse per-view depths into a point cloud using TSDF fusion.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
scene: Scene identifier
|
| 245 |
+
result_path: Path to npz file with predicted depths/poses
|
| 246 |
+
fuse_path: Output path for fused point cloud (.ply)
|
| 247 |
+
mode: "recon_unposed" or "recon_posed"
|
| 248 |
+
"""
|
| 249 |
+
# Try to load saved GT meta (handles frame sampling)
|
| 250 |
+
gt_meta = self._load_gt_meta(result_path)
|
| 251 |
+
if gt_meta is not None:
|
| 252 |
+
gt_data = gt_meta
|
| 253 |
+
else:
|
| 254 |
+
gt_data = self.get_data(scene)
|
| 255 |
+
_wait_for_file_ready(result_path)
|
| 256 |
+
pred_data = Dict({k: v for k, v in np.load(result_path).items()})
|
| 257 |
+
|
| 258 |
+
# Load original images (keep original size)
|
| 259 |
+
images = []
|
| 260 |
+
orig_sizes = []
|
| 261 |
+
for img_path in gt_data.image_files:
|
| 262 |
+
img = cv2.imread(img_path)
|
| 263 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 264 |
+
images.append(img)
|
| 265 |
+
orig_sizes.append((img.shape[0], img.shape[1]))
|
| 266 |
+
|
| 267 |
+
# Prepare depths, intrinsics, extrinsics
|
| 268 |
+
if mode == "recon_unposed":
|
| 269 |
+
depths, intrinsics, extrinsics = self._prep_unposed(
|
| 270 |
+
pred_data, gt_data, orig_sizes, scene=scene
|
| 271 |
+
)
|
| 272 |
+
elif mode == "recon_posed":
|
| 273 |
+
depths, intrinsics, extrinsics = self._prep_posed(
|
| 274 |
+
pred_data, gt_data, orig_sizes, scene=scene
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 278 |
+
|
| 279 |
+
images = np.stack(images, axis=0)
|
| 280 |
+
|
| 281 |
+
# Create TSDF volume and fuse
|
| 282 |
+
volume = create_tsdf_volume(
|
| 283 |
+
voxel_length=self.voxel_length,
|
| 284 |
+
sdf_trunc=self.sdf_trunc,
|
| 285 |
+
)
|
| 286 |
+
mesh = fuse_depth_to_tsdf(
|
| 287 |
+
volume, depths, images, intrinsics, extrinsics, max_depth=self.max_depth
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Sample points from mesh
|
| 291 |
+
pcd = sample_points_from_mesh(mesh, self.sampling_number)
|
| 292 |
+
|
| 293 |
+
# Save point cloud
|
| 294 |
+
os.makedirs(os.path.dirname(fuse_path), exist_ok=True)
|
| 295 |
+
o3d.io.write_point_cloud(fuse_path, pcd)
|
| 296 |
+
|
| 297 |
+
# ------------------------------
|
| 298 |
+
# Private helpers
|
| 299 |
+
# ------------------------------
|
| 300 |
+
|
| 301 |
+
def _prep_unposed(
|
| 302 |
+
self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str
|
| 303 |
+
) -> tuple:
|
| 304 |
+
"""
|
| 305 |
+
Prepare depths/intrinsics/extrinsics for recon_unposed mode.
|
| 306 |
+
|
| 307 |
+
Similar to ETH3D but uses GT depth for masking instead of separate mask files.
|
| 308 |
+
"""
|
| 309 |
+
# Scale alignment with fixed random_state for reproducibility
|
| 310 |
+
_, _, scale, extrinsics = align_poses_umeyama(
|
| 311 |
+
gt_data.extrinsics.copy(),
|
| 312 |
+
pred_data.extrinsics.copy(),
|
| 313 |
+
return_aligned=True,
|
| 314 |
+
ransac=True,
|
| 315 |
+
random_state=42,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
|
| 319 |
+
|
| 320 |
+
depths_out = []
|
| 321 |
+
intrinsics_out = []
|
| 322 |
+
for i in range(len(pred_data.depth)):
|
| 323 |
+
orig_h, orig_w = orig_sizes[i]
|
| 324 |
+
|
| 325 |
+
# Resize depth to original image size (nearest interpolation)
|
| 326 |
+
depth = cv2.resize(
|
| 327 |
+
pred_data.depth[i],
|
| 328 |
+
(orig_w, orig_h),
|
| 329 |
+
interpolation=cv2.INTER_NEAREST,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Load GT depth for masking
|
| 333 |
+
gt_zero_mask = self._load_gt_mask(gt_data.aux.gt_depth_files[i])
|
| 334 |
+
|
| 335 |
+
# Mask invalid depths BEFORE scale
|
| 336 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 337 |
+
|
| 338 |
+
# Apply scale AFTER mask
|
| 339 |
+
depth = depth * scale
|
| 340 |
+
|
| 341 |
+
# Adjust intrinsics to original image size
|
| 342 |
+
h_ratio = orig_h / model_h
|
| 343 |
+
w_ratio = orig_w / model_w
|
| 344 |
+
ixt = pred_data.intrinsics[i].copy()
|
| 345 |
+
ixt[0, :] *= w_ratio
|
| 346 |
+
ixt[1, :] *= h_ratio
|
| 347 |
+
|
| 348 |
+
depths_out.append(depth)
|
| 349 |
+
intrinsics_out.append(ixt)
|
| 350 |
+
|
| 351 |
+
return np.stack(depths_out), np.stack(intrinsics_out), extrinsics
|
| 352 |
+
|
| 353 |
+
def _prep_posed(
|
| 354 |
+
self, pred_data: Dict, gt_data: Dict, orig_sizes: list, scene: str
|
| 355 |
+
) -> tuple:
|
| 356 |
+
"""
|
| 357 |
+
Prepare depths/intrinsics/extrinsics for recon_posed mode.
|
| 358 |
+
Uses GT intrinsics/extrinsics but aligns depth scale via Umeyama.
|
| 359 |
+
"""
|
| 360 |
+
# Scale alignment with fixed random_state
|
| 361 |
+
_, _, scale, _ = align_poses_umeyama(
|
| 362 |
+
gt_data.extrinsics.copy(),
|
| 363 |
+
pred_data.extrinsics.copy(),
|
| 364 |
+
return_aligned=True,
|
| 365 |
+
ransac=True,
|
| 366 |
+
random_state=42,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
model_h, model_w = pred_data.depth.shape[1], pred_data.depth.shape[2]
|
| 370 |
+
|
| 371 |
+
depths_out = []
|
| 372 |
+
for i in range(len(pred_data.depth)):
|
| 373 |
+
orig_h, orig_w = orig_sizes[i]
|
| 374 |
+
|
| 375 |
+
# Resize depth to original image size
|
| 376 |
+
depth = cv2.resize(
|
| 377 |
+
pred_data.depth[i],
|
| 378 |
+
(orig_w, orig_h),
|
| 379 |
+
interpolation=cv2.INTER_NEAREST,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Load GT depth for masking
|
| 383 |
+
gt_zero_mask = self._load_gt_mask(gt_data.aux.gt_depth_files[i])
|
| 384 |
+
|
| 385 |
+
# Mask invalid depths BEFORE scale
|
| 386 |
+
depth = self._mask_invalid_depth(depth, gt_zero_mask)
|
| 387 |
+
|
| 388 |
+
# Apply scale AFTER mask
|
| 389 |
+
depth = depth * scale
|
| 390 |
+
|
| 391 |
+
depths_out.append(depth)
|
| 392 |
+
|
| 393 |
+
# Use GT intrinsics and extrinsics
|
| 394 |
+
return np.stack(depths_out), gt_data.intrinsics.copy(), gt_data.extrinsics.copy()
|
| 395 |
+
|
| 396 |
+
def _load_gt_mask(self, gt_depth_path: str) -> np.ndarray:
|
| 397 |
+
"""
|
| 398 |
+
Load GT depth and create valid mask.
|
| 399 |
+
|
| 400 |
+
For 7Scenes, GT depth is stored as 16-bit PNG in millimeters.
|
| 401 |
+
Value 65535 indicates invalid depth.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Boolean mask where True = valid region to keep
|
| 405 |
+
"""
|
| 406 |
+
if not os.path.exists(gt_depth_path):
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
gt_depth = cv2.imread(gt_depth_path, -1)
|
| 410 |
+
if gt_depth is None:
|
| 411 |
+
return None
|
| 412 |
+
|
| 413 |
+
# 65535 is invalid depth marker in 7Scenes
|
| 414 |
+
gt_depth[gt_depth == 65535] = 0
|
| 415 |
+
# Convert to meters
|
| 416 |
+
gt_depth = gt_depth / 1000.0
|
| 417 |
+
|
| 418 |
+
# Valid mask: depth > 0
|
| 419 |
+
valid_mask = gt_depth > 0
|
| 420 |
+
return valid_mask
|
| 421 |
+
|
| 422 |
+
def _mask_invalid_depth(
|
| 423 |
+
self, depth: np.ndarray, gt_zero_mask: np.ndarray = None
|
| 424 |
+
) -> np.ndarray:
|
| 425 |
+
"""
|
| 426 |
+
Mask invalid depth values by setting them to 0.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
depth: Depth map to mask
|
| 430 |
+
gt_zero_mask: Optional GT mask (True = valid region)
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
Masked depth map with invalid regions set to 0
|
| 434 |
+
"""
|
| 435 |
+
depth = depth.copy()
|
| 436 |
+
|
| 437 |
+
if gt_zero_mask is not None:
|
| 438 |
+
# Also mask out invalid pred depth
|
| 439 |
+
pred_invalid = np.isnan(depth) | np.isinf(depth)
|
| 440 |
+
combined_mask = np.logical_and(gt_zero_mask, np.logical_not(pred_invalid))
|
| 441 |
+
depth = depth * combined_mask.astype(np.float32)
|
| 442 |
+
else:
|
| 443 |
+
# Fallback: only mask pred invalid values
|
| 444 |
+
invalid_mask = np.isnan(depth) | np.isinf(depth) | (depth <= 0)
|
| 445 |
+
depth[invalid_mask] = 0.0
|
| 446 |
+
|
| 447 |
+
return depth
|
| 448 |
+
|
| 449 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/evaluator.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Main Evaluator class for DepthAnything3 benchmark evaluation.
|
| 17 |
+
|
| 18 |
+
Supports multiple datasets and evaluation modes:
|
| 19 |
+
- pose: Camera pose estimation (AUC metrics)
|
| 20 |
+
- recon_unposed: 3D reconstruction with predicted poses
|
| 21 |
+
- recon_posed: 3D reconstruction with GT poses
|
| 22 |
+
- view_syn: Novel view synthesis (TODO)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import random
|
| 28 |
+
from typing import Dict as TDict, Iterable, List
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
from addict import Dict
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
from depth_anything_3.bench.print_metrics import MetricsPrinter
|
| 36 |
+
from depth_anything_3.utils.parallel_utils import parallel_execution
|
| 37 |
+
from depth_anything_3.bench.registries import MV_REGISTRY
|
| 38 |
+
from depth_anything_3.utils.constants import EVAL_REF_VIEW_STRATEGY
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Evaluator:
|
| 42 |
+
"""
|
| 43 |
+
Main evaluation orchestrator for DepthAnything3 benchmarks.
|
| 44 |
+
|
| 45 |
+
Usage:
|
| 46 |
+
evaluator = Evaluator(
|
| 47 |
+
work_dir="./eval_workspace",
|
| 48 |
+
datas=["dtu"],
|
| 49 |
+
modes=["pose", "recon_unposed", "recon_posed"],
|
| 50 |
+
)
|
| 51 |
+
api = DepthAnything3.from_pretrained("...")
|
| 52 |
+
evaluator.infer(api)
|
| 53 |
+
metrics = evaluator.eval()
|
| 54 |
+
evaluator.print_metrics()
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
VALID_MODES = {"pose", "recon_unposed", "recon_posed", "view_syn"}
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
work_dir: str = "./eval_workspace",
|
| 62 |
+
datas: List[str] = ("dtu",),
|
| 63 |
+
modes: List[str] = ("recon_unposed",),
|
| 64 |
+
ref_view_strategy: str = EVAL_REF_VIEW_STRATEGY,
|
| 65 |
+
scenes: List[str] = None,
|
| 66 |
+
debug: bool = False,
|
| 67 |
+
num_fusion_workers: int = 4,
|
| 68 |
+
max_frames: int = 100,
|
| 69 |
+
gpu_id: int = 0,
|
| 70 |
+
total_gpus: int = 1,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Initialize the evaluator.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
work_dir: Base directory for model outputs and metric files
|
| 77 |
+
datas: List of dataset names (must be registered in MV_REGISTRY)
|
| 78 |
+
modes: List of evaluation modes to run
|
| 79 |
+
ref_view_strategy: Reference view selection strategy for inference
|
| 80 |
+
("first", "saddle_balanced", etc.)
|
| 81 |
+
scenes: Specific scenes to evaluate (None = all scenes)
|
| 82 |
+
debug: Enable verbose debug output
|
| 83 |
+
num_fusion_workers: Number of parallel workers for TSDF fusion (default: 4)
|
| 84 |
+
max_frames: Maximum number of frames per scene (default: 100).
|
| 85 |
+
If a scene has more frames, randomly sample to this limit.
|
| 86 |
+
Set to -1 to disable sampling.
|
| 87 |
+
gpu_id: GPU index for multi-GPU (0-indexed)
|
| 88 |
+
total_gpus: Total number of GPUs for task distribution
|
| 89 |
+
"""
|
| 90 |
+
self.work_dir = work_dir
|
| 91 |
+
self.datas = list(datas)
|
| 92 |
+
self.modes = set(modes)
|
| 93 |
+
self.ref_view_strategy = ref_view_strategy
|
| 94 |
+
self.scenes_filter = scenes
|
| 95 |
+
self.debug = debug
|
| 96 |
+
self.num_fusion_workers = num_fusion_workers
|
| 97 |
+
self.max_frames = max_frames
|
| 98 |
+
self.gpu_id = gpu_id
|
| 99 |
+
self.total_gpus = total_gpus
|
| 100 |
+
|
| 101 |
+
# Validate modes
|
| 102 |
+
unknown = self.modes - self.VALID_MODES
|
| 103 |
+
if unknown:
|
| 104 |
+
raise ValueError(f"Unknown modes: {unknown}. Valid: {sorted(self.VALID_MODES)}")
|
| 105 |
+
|
| 106 |
+
os.makedirs(self.work_dir, exist_ok=True)
|
| 107 |
+
|
| 108 |
+
# Initialize datasets
|
| 109 |
+
self.datasets = Dict()
|
| 110 |
+
for data in self.datas:
|
| 111 |
+
if not MV_REGISTRY.has(data):
|
| 112 |
+
available = list(MV_REGISTRY.all().keys())
|
| 113 |
+
raise ValueError(f"Dataset '{data}' not found. Available: {available}")
|
| 114 |
+
self.datasets[data] = MV_REGISTRY.get(data)()
|
| 115 |
+
|
| 116 |
+
# Initialize metrics printer
|
| 117 |
+
self._printer = MetricsPrinter()
|
| 118 |
+
|
| 119 |
+
# -------------------- Public APIs -------------------- #
|
| 120 |
+
|
| 121 |
+
def all(self, api) -> TDict[str, dict]:
|
| 122 |
+
"""
|
| 123 |
+
Run complete evaluation pipeline: inference + evaluation.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
api: DepthAnything3 API instance
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Combined metrics dictionary
|
| 130 |
+
"""
|
| 131 |
+
self.infer(api)
|
| 132 |
+
return self.eval()
|
| 133 |
+
|
| 134 |
+
def _get_scenes(self, dataset) -> List[str]:
|
| 135 |
+
"""Get list of scenes to evaluate, optionally filtered."""
|
| 136 |
+
all_scenes = dataset.SCENES
|
| 137 |
+
if self.scenes_filter:
|
| 138 |
+
scenes = [s for s in all_scenes if s in self.scenes_filter]
|
| 139 |
+
if self.debug:
|
| 140 |
+
print(f"[DEBUG] Filtered scenes: {scenes} (from {len(all_scenes)} total)")
|
| 141 |
+
return scenes
|
| 142 |
+
return all_scenes
|
| 143 |
+
|
| 144 |
+
def infer(self, api, model_path: str = None) -> None:
|
| 145 |
+
"""
|
| 146 |
+
Run inference according to requested modes.
|
| 147 |
+
|
| 148 |
+
- Unposed export if 'pose' or 'recon_unposed' is in modes
|
| 149 |
+
- Posed export if 'recon_posed' or 'view_syn' is in modes
|
| 150 |
+
|
| 151 |
+
Multi-GPU: Use --gpu_id and --total_gpus to distribute tasks.
|
| 152 |
+
Example: Launch 4 processes with gpu_id=0,1,2,3 and total_gpus=4
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
api: DepthAnything3 API instance
|
| 156 |
+
model_path: Model path (unused, kept for API compatibility)
|
| 157 |
+
"""
|
| 158 |
+
need_unposed = {"pose", "recon_unposed"} & self.modes
|
| 159 |
+
need_posed = {"recon_posed", "view_syn"} & self.modes
|
| 160 |
+
export_format = "mini_npz-glb" if self.debug else "mini_npz"
|
| 161 |
+
|
| 162 |
+
# Collect all tasks
|
| 163 |
+
all_tasks = []
|
| 164 |
+
for data in self.datas:
|
| 165 |
+
dataset = self.datasets[data]
|
| 166 |
+
for scene in self._get_scenes(dataset):
|
| 167 |
+
all_tasks.append((data, scene))
|
| 168 |
+
|
| 169 |
+
# Distribute tasks across GPUs
|
| 170 |
+
if self.total_gpus > 1:
|
| 171 |
+
tasks = [t for i, t in enumerate(all_tasks) if i % self.total_gpus == self.gpu_id]
|
| 172 |
+
print(f"[INFO] GPU {self.gpu_id}/{self.total_gpus}: {len(tasks)}/{len(all_tasks)} tasks")
|
| 173 |
+
else:
|
| 174 |
+
tasks = all_tasks
|
| 175 |
+
print(f"[INFO] Total inference tasks: {len(tasks)}")
|
| 176 |
+
|
| 177 |
+
for data, scene in tqdm(tasks, desc=f"Inference (GPU {self.gpu_id})"):
|
| 178 |
+
dataset = self.datasets[data]
|
| 179 |
+
scene_data = dataset.get_data(scene)
|
| 180 |
+
scene_data = self._sample_frames(scene_data, scene)
|
| 181 |
+
|
| 182 |
+
if need_unposed:
|
| 183 |
+
export_dir = self._export_dir(data, scene, posed=False)
|
| 184 |
+
api.inference(
|
| 185 |
+
scene_data.image_files,
|
| 186 |
+
export_dir=export_dir,
|
| 187 |
+
export_format=export_format,
|
| 188 |
+
ref_view_strategy=self.ref_view_strategy,
|
| 189 |
+
)
|
| 190 |
+
self._save_gt_meta(export_dir, scene_data)
|
| 191 |
+
|
| 192 |
+
if need_posed:
|
| 193 |
+
export_dir = self._export_dir(data, scene, posed=True)
|
| 194 |
+
api.inference(
|
| 195 |
+
scene_data.image_files,
|
| 196 |
+
scene_data.extrinsics,
|
| 197 |
+
scene_data.intrinsics,
|
| 198 |
+
export_dir=export_dir,
|
| 199 |
+
export_format=export_format,
|
| 200 |
+
ref_view_strategy=self.ref_view_strategy,
|
| 201 |
+
)
|
| 202 |
+
self._save_gt_meta(export_dir, scene_data)
|
| 203 |
+
|
| 204 |
+
def eval(self) -> TDict[str, dict]:
|
| 205 |
+
"""
|
| 206 |
+
Evaluate for all configured modes and write JSON files.
|
| 207 |
+
|
| 208 |
+
Evaluation order by mode (all datasets per mode):
|
| 209 |
+
1. pose - all datasets
|
| 210 |
+
2. recon_unposed - all datasets
|
| 211 |
+
3. recon_posed - all datasets
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Summary mapping: {"<data>_<mode>": metrics_dict}
|
| 215 |
+
"""
|
| 216 |
+
summary: TDict[str, dict] = {}
|
| 217 |
+
|
| 218 |
+
# Evaluate by mode (all datasets per mode)
|
| 219 |
+
if "pose" in self.modes:
|
| 220 |
+
print(f"\n{'='*60}")
|
| 221 |
+
print(f"📊 Evaluating POSE for all datasets...")
|
| 222 |
+
print(f"{'='*60}")
|
| 223 |
+
for data, result in self._eval_pose():
|
| 224 |
+
summary[f"{data}_pose"] = result
|
| 225 |
+
|
| 226 |
+
if "recon_unposed" in self.modes:
|
| 227 |
+
print(f"\n{'='*60}")
|
| 228 |
+
print(f"📊 Evaluating RECON_UNPOSED for all datasets...")
|
| 229 |
+
print(f"{'='*60}")
|
| 230 |
+
for data, result in self._eval_reconstruction("recon_unposed"):
|
| 231 |
+
summary[f"{data}_recon_unposed"] = result
|
| 232 |
+
|
| 233 |
+
if "recon_posed" in self.modes:
|
| 234 |
+
print(f"\n{'='*60}")
|
| 235 |
+
print(f"📊 Evaluating RECON_POSED for all datasets...")
|
| 236 |
+
print(f"{'='*60}")
|
| 237 |
+
for data, result in self._eval_reconstruction("recon_posed"):
|
| 238 |
+
summary[f"{data}_recon_posed"] = result
|
| 239 |
+
|
| 240 |
+
if "view_syn" in self.modes:
|
| 241 |
+
# TODO: Add view synthesis metrics here when available
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
return summary
|
| 245 |
+
|
| 246 |
+
def print_metrics(self, metrics: TDict[str, dict] = None) -> None:
|
| 247 |
+
"""
|
| 248 |
+
Print evaluation metrics in a beautiful tabular format.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
metrics: Metrics dictionary. If None, loads from saved JSON files.
|
| 252 |
+
"""
|
| 253 |
+
if metrics is None:
|
| 254 |
+
metrics = self._load_metrics()
|
| 255 |
+
|
| 256 |
+
self._printer.print_results(metrics)
|
| 257 |
+
|
| 258 |
+
# -------------------- Evaluation Methods -------------------- #
|
| 259 |
+
|
| 260 |
+
def _eval_pose(self) -> Iterable[tuple]:
|
| 261 |
+
"""Compute pose-estimation metrics for each dataset and scene."""
|
| 262 |
+
os.makedirs(self._metric_dir, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
for data in tqdm(self.datas, desc="Datasets (pose eval)"):
|
| 265 |
+
dataset = self.datasets[data]
|
| 266 |
+
dataset_results = Dict()
|
| 267 |
+
scenes = self._get_scenes(dataset)
|
| 268 |
+
|
| 269 |
+
for scene in tqdm(scenes, desc=f"{data} scenes", leave=False):
|
| 270 |
+
export_dir = self._export_dir(data, scene, posed=False)
|
| 271 |
+
result_path = os.path.join(export_dir, "exports", "mini_npz", "results.npz")
|
| 272 |
+
|
| 273 |
+
# Check if result file exists and is valid
|
| 274 |
+
if not os.path.exists(result_path):
|
| 275 |
+
print(f"\n[ERROR] Result file not found: {result_path}")
|
| 276 |
+
print(f"[ERROR] CWD: {os.getcwd()}")
|
| 277 |
+
print(f"[ERROR] Please run inference first (remove --eval_only)")
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
# Use saved GT meta (handles frame sampling correctly)
|
| 282 |
+
gt_meta = self._load_gt_meta(export_dir)
|
| 283 |
+
if gt_meta is not None:
|
| 284 |
+
result = self._compute_pose_with_gt(result_path, gt_meta)
|
| 285 |
+
else:
|
| 286 |
+
# Fallback to dataset GT (no sampling was done)
|
| 287 |
+
result = dataset.eval_pose(scene, result_path)
|
| 288 |
+
dataset_results[scene] = self._to_float_dict(result)
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"\n[ERROR] Failed to evaluate pose for {data}/{scene}: {e}")
|
| 291 |
+
print(f"[ERROR] File path: {os.path.abspath(result_path)}")
|
| 292 |
+
if self.debug:
|
| 293 |
+
import traceback
|
| 294 |
+
traceback.print_exc()
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
if not dataset_results:
|
| 298 |
+
print(f"[WARNING] No valid results for {data}")
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
dataset_results["mean"] = self._mean_of_dicts(dataset_results.values())
|
| 302 |
+
out_path = os.path.join(self._metric_dir, f"{data}_pose.json")
|
| 303 |
+
self._dump_json(out_path, dataset_results)
|
| 304 |
+
yield data, dataset_results
|
| 305 |
+
|
| 306 |
+
def _eval_reconstruction(self, mode: str) -> Iterable[tuple]:
|
| 307 |
+
"""
|
| 308 |
+
Compute reconstruction metrics for each dataset and scene.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
mode: "recon_unposed" or "recon_posed"
|
| 312 |
+
"""
|
| 313 |
+
assert mode in {"recon_unposed", "recon_posed"}
|
| 314 |
+
os.makedirs(self._metric_dir, exist_ok=True)
|
| 315 |
+
|
| 316 |
+
posed_flag = mode == "recon_posed"
|
| 317 |
+
|
| 318 |
+
# Filter out datasets that don't support reconstruction (e.g., dtu64)
|
| 319 |
+
recon_datas = [d for d in self.datas if d != "dtu64"]
|
| 320 |
+
|
| 321 |
+
for data in tqdm(recon_datas, desc=f"Datasets ({mode} eval)"):
|
| 322 |
+
dataset = self.datasets[data]
|
| 323 |
+
dataset_results = Dict()
|
| 324 |
+
scenes = self._get_scenes(dataset)
|
| 325 |
+
|
| 326 |
+
# Prepare paths for all scenes
|
| 327 |
+
scene_list = []
|
| 328 |
+
result_paths = []
|
| 329 |
+
fuse_paths = []
|
| 330 |
+
for scene in scenes:
|
| 331 |
+
export_dir = self._export_dir(data, scene, posed=posed_flag)
|
| 332 |
+
result_path = os.path.join(export_dir, "exports", "mini_npz", "results.npz")
|
| 333 |
+
fuse_path = os.path.join(export_dir, "exports", "fuse", "pcd.ply")
|
| 334 |
+
scene_list.append(scene)
|
| 335 |
+
result_paths.append(result_path)
|
| 336 |
+
fuse_paths.append(fuse_path)
|
| 337 |
+
|
| 338 |
+
# Parallel fusion (default 4 workers)
|
| 339 |
+
# DTU uses CUDA operations in fusion, which doesn't work well with ThreadPool
|
| 340 |
+
use_sequential = (data == "dtu")
|
| 341 |
+
parallel_execution(
|
| 342 |
+
scene_list,
|
| 343 |
+
result_paths,
|
| 344 |
+
fuse_paths,
|
| 345 |
+
action=lambda s, rp, fp: dataset.fuse3d(s, rp, fp, mode),
|
| 346 |
+
num_processes=self.num_fusion_workers,
|
| 347 |
+
print_progress=True,
|
| 348 |
+
desc=f"{data} fusion",
|
| 349 |
+
sequential=use_sequential,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Sequential evaluation (fast, no need to parallelize)
|
| 353 |
+
for scene, fuse_path in zip(scene_list, fuse_paths):
|
| 354 |
+
# DTU supports CPU-based evaluation
|
| 355 |
+
if data == "dtu" and hasattr(dataset, "eval3d"):
|
| 356 |
+
result = dataset.eval3d(scene, fuse_path)
|
| 357 |
+
else:
|
| 358 |
+
result = dataset.eval3d(scene, fuse_path)
|
| 359 |
+
dataset_results[scene] = self._to_float_dict(result)
|
| 360 |
+
print(f" {mode} | {data} | {scene}: {result}")
|
| 361 |
+
|
| 362 |
+
dataset_results["mean"] = self._mean_of_dicts(dataset_results.values())
|
| 363 |
+
out_path = os.path.join(self._metric_dir, f"{data}_{mode}.json")
|
| 364 |
+
self._dump_json(out_path, dataset_results)
|
| 365 |
+
yield data, dataset_results
|
| 366 |
+
|
| 367 |
+
# -------------------- Helpers -------------------- #
|
| 368 |
+
|
| 369 |
+
def _save_gt_meta(self, export_dir: str, scene_data: Dict) -> None:
|
| 370 |
+
"""
|
| 371 |
+
Save GT extrinsics/intrinsics/image_files for evaluation.
|
| 372 |
+
|
| 373 |
+
This is needed when frames are sampled, so eval_pose and fuse3d can use
|
| 374 |
+
the correct (sampled) GT instead of full dataset GT.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
export_dir: Export directory for the scene
|
| 378 |
+
scene_data: Sampled scene data
|
| 379 |
+
"""
|
| 380 |
+
meta_path = os.path.join(export_dir, "exports", "gt_meta.npz")
|
| 381 |
+
os.makedirs(os.path.dirname(meta_path), exist_ok=True)
|
| 382 |
+
np.savez_compressed(
|
| 383 |
+
meta_path,
|
| 384 |
+
extrinsics=scene_data.extrinsics,
|
| 385 |
+
intrinsics=scene_data.intrinsics,
|
| 386 |
+
image_files=np.array(scene_data.image_files, dtype=object),
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def _load_gt_meta(self, export_dir: str) -> Dict:
|
| 390 |
+
"""
|
| 391 |
+
Load saved GT extrinsics/intrinsics for evaluation.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Dict with extrinsics and intrinsics, or None if not found
|
| 395 |
+
"""
|
| 396 |
+
meta_path = os.path.join(export_dir, "exports", "gt_meta.npz")
|
| 397 |
+
if os.path.exists(meta_path):
|
| 398 |
+
data = np.load(meta_path)
|
| 399 |
+
return Dict({
|
| 400 |
+
"extrinsics": data["extrinsics"],
|
| 401 |
+
"intrinsics": data["intrinsics"],
|
| 402 |
+
})
|
| 403 |
+
return None
|
| 404 |
+
|
| 405 |
+
def _compute_pose_with_gt(self, result_path: str, gt_meta: Dict) -> TDict[str, float]:
|
| 406 |
+
"""
|
| 407 |
+
Compute pose metrics using saved GT meta (handles frame sampling).
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
result_path: Path to npz with predicted extrinsics
|
| 411 |
+
gt_meta: Dict with GT extrinsics from saved meta
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Dict with pose metrics
|
| 415 |
+
"""
|
| 416 |
+
from depth_anything_3.bench.dataset import _wait_for_file_ready
|
| 417 |
+
from depth_anything_3.bench.utils import compute_pose
|
| 418 |
+
from depth_anything_3.utils.geometry import as_homogeneous
|
| 419 |
+
|
| 420 |
+
_wait_for_file_ready(result_path)
|
| 421 |
+
pred = np.load(result_path)
|
| 422 |
+
return compute_pose(
|
| 423 |
+
torch.from_numpy(as_homogeneous(pred["extrinsics"])),
|
| 424 |
+
torch.from_numpy(as_homogeneous(gt_meta["extrinsics"])),
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
def _sample_frames(self, scene_data: Dict, scene: str) -> Dict:
|
| 428 |
+
"""
|
| 429 |
+
Sample frames if scene has more than max_frames.
|
| 430 |
+
|
| 431 |
+
Uses fixed random seed (42) for reproducibility.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
scene_data: Scene data dict with image_files, extrinsics, intrinsics, aux
|
| 435 |
+
scene: Scene name (for logging)
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
Sampled scene_data if num_frames > max_frames, otherwise original
|
| 439 |
+
"""
|
| 440 |
+
if self.max_frames <= 0:
|
| 441 |
+
return scene_data
|
| 442 |
+
|
| 443 |
+
num_frames = len(scene_data.image_files)
|
| 444 |
+
if num_frames <= self.max_frames:
|
| 445 |
+
return scene_data
|
| 446 |
+
|
| 447 |
+
# Sample with fixed seed for reproducibility
|
| 448 |
+
random.seed(42)
|
| 449 |
+
indices = list(range(num_frames))
|
| 450 |
+
random.shuffle(indices)
|
| 451 |
+
sampled_indices = sorted(indices[:self.max_frames])
|
| 452 |
+
|
| 453 |
+
print(f" [Sampling] {scene}: {num_frames} -> {self.max_frames} frames")
|
| 454 |
+
|
| 455 |
+
# Create new scene_data with sampled frames
|
| 456 |
+
sampled = Dict()
|
| 457 |
+
sampled.image_files = [scene_data.image_files[i] for i in sampled_indices]
|
| 458 |
+
sampled.extrinsics = scene_data.extrinsics[sampled_indices]
|
| 459 |
+
sampled.intrinsics = scene_data.intrinsics[sampled_indices]
|
| 460 |
+
|
| 461 |
+
# Copy aux data, sampling lists if needed
|
| 462 |
+
sampled.aux = Dict()
|
| 463 |
+
for key, val in scene_data.aux.items():
|
| 464 |
+
if isinstance(val, list) and len(val) == num_frames:
|
| 465 |
+
sampled.aux[key] = [val[i] for i in sampled_indices]
|
| 466 |
+
elif isinstance(val, np.ndarray) and len(val) == num_frames:
|
| 467 |
+
sampled.aux[key] = val[sampled_indices]
|
| 468 |
+
else:
|
| 469 |
+
sampled.aux[key] = val
|
| 470 |
+
|
| 471 |
+
return sampled
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
def _metric_dir(self) -> str:
|
| 475 |
+
"""Directory for storing metric JSON files."""
|
| 476 |
+
return os.path.join(self.work_dir, "metric_results")
|
| 477 |
+
|
| 478 |
+
def _export_dir(self, data: str, scene: str, posed: bool) -> str:
|
| 479 |
+
"""
|
| 480 |
+
Get export directory path.
|
| 481 |
+
|
| 482 |
+
Structure: .../model_results/{data}/{scene}/{posed|unposed}
|
| 483 |
+
"""
|
| 484 |
+
suffix = "posed" if posed else "unposed"
|
| 485 |
+
export_dir = os.path.join(self.work_dir, "model_results", data, scene, suffix)
|
| 486 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 487 |
+
return export_dir
|
| 488 |
+
|
| 489 |
+
@staticmethod
|
| 490 |
+
def _to_float_dict(d: TDict[str, float]) -> dict:
|
| 491 |
+
"""Convert numpy scalars to plain Python floats for JSON safety."""
|
| 492 |
+
return {k: float(v) for k, v in d.items()}
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def _mean_of_dicts(dicts: Iterable[dict]) -> dict:
|
| 496 |
+
"""Compute elementwise mean across a list of homogeneous metric dicts."""
|
| 497 |
+
dicts = list(dicts)
|
| 498 |
+
if not dicts:
|
| 499 |
+
return {}
|
| 500 |
+
keys = dicts[0].keys()
|
| 501 |
+
return {k: float(np.mean([d[k] for d in dicts]).item()) for k in keys}
|
| 502 |
+
|
| 503 |
+
@staticmethod
|
| 504 |
+
def _dump_json(path: str, obj: dict, indent: int = 4) -> None:
|
| 505 |
+
"""Write JSON with UTF-8 and pretty indentation."""
|
| 506 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 507 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 508 |
+
json.dump(obj, f, indent=indent, ensure_ascii=False)
|
| 509 |
+
|
| 510 |
+
def _load_metrics(self) -> TDict[str, dict]:
|
| 511 |
+
"""Load evaluation metrics from JSON files."""
|
| 512 |
+
metrics = {}
|
| 513 |
+
metric_dir = self._metric_dir
|
| 514 |
+
|
| 515 |
+
if not os.path.exists(metric_dir):
|
| 516 |
+
return metrics
|
| 517 |
+
|
| 518 |
+
for filename in os.listdir(metric_dir):
|
| 519 |
+
if filename.endswith(".json"):
|
| 520 |
+
filepath = os.path.join(metric_dir, filename)
|
| 521 |
+
try:
|
| 522 |
+
with open(filepath, encoding="utf-8") as f:
|
| 523 |
+
data = json.load(f)
|
| 524 |
+
key = filename[:-5] # Remove .json extension
|
| 525 |
+
metrics[key] = data
|
| 526 |
+
except Exception as e:
|
| 527 |
+
print(f"Warning: Failed to read metrics file: {filename} - {e}")
|
| 528 |
+
|
| 529 |
+
return metrics
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
# -------------------- CLI Entry Point -------------------- #
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
if __name__ == "__main__":
|
| 536 |
+
import sys
|
| 537 |
+
from omegaconf import OmegaConf
|
| 538 |
+
from depth_anything_3.cfg import load_config
|
| 539 |
+
|
| 540 |
+
# Get default config path (relative to this file)
|
| 541 |
+
_default_config = os.path.join(
|
| 542 |
+
os.path.dirname(__file__), "configs", "eval_bench.yaml"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# Check for help flag first (we need to handle this before OmegaConf)
|
| 546 |
+
if "--help" in sys.argv or "-h" in sys.argv:
|
| 547 |
+
pass # Will handle after config loading
|
| 548 |
+
|
| 549 |
+
# Set up argv for OmegaConf processing
|
| 550 |
+
argv = sys.argv[1:]
|
| 551 |
+
|
| 552 |
+
# Check if user provides custom config
|
| 553 |
+
config_path = _default_config
|
| 554 |
+
if "--config" in argv:
|
| 555 |
+
config_idx = argv.index("--config")
|
| 556 |
+
if config_idx + 1 < len(argv):
|
| 557 |
+
config_path = argv[config_idx + 1]
|
| 558 |
+
# Remove --config and its value
|
| 559 |
+
argv = argv[:config_idx] + argv[config_idx + 2:]
|
| 560 |
+
|
| 561 |
+
# Print help if requested
|
| 562 |
+
if "--help" in sys.argv or "-h" in sys.argv:
|
| 563 |
+
print("""
|
| 564 |
+
DepthAnything3 Benchmark Evaluation
|
| 565 |
+
|
| 566 |
+
Usage:
|
| 567 |
+
python -m depth_anything_3.bench.evaluator [OPTIONS] [KEY=VALUE ...]
|
| 568 |
+
|
| 569 |
+
Configuration:
|
| 570 |
+
--config PATH Config YAML file (default: bench/configs/eval_bench.yaml)
|
| 571 |
+
|
| 572 |
+
Config Overrides (using dotlist notation):
|
| 573 |
+
model.path=VALUE Model path or HuggingFace ID
|
| 574 |
+
workspace.work_dir=VALUE Working directory for outputs
|
| 575 |
+
eval.datasets=[dataset1,dataset2] Datasets to evaluate (eth3d,7scenes,scannetpp,hiroom,dtu,dtu64)
|
| 576 |
+
eval.modes=[mode1,mode2] Evaluation modes (pose,recon_unposed,recon_posed)
|
| 577 |
+
eval.scenes=[scene1,scene2] Specific scenes to evaluate (null=all)
|
| 578 |
+
eval.max_frames=VALUE Max frames per scene (-1=no limit, default: 100)
|
| 579 |
+
eval.ref_view_strategy=VALUE Reference view strategy (default: first)
|
| 580 |
+
eval.eval_only=VALUE Only run evaluation (skip inference) (true/false)
|
| 581 |
+
eval.print_only=VALUE Only print saved metrics (true/false)
|
| 582 |
+
inference.num_fusion_workers=VALUE Number of parallel workers (default: 4)
|
| 583 |
+
inference.debug=VALUE Enable debug mode (true/false)
|
| 584 |
+
|
| 585 |
+
Special Flags:
|
| 586 |
+
--help, -h Show this help message
|
| 587 |
+
|
| 588 |
+
Multi-GPU:
|
| 589 |
+
Use CUDA_VISIBLE_DEVICES to specify GPUs (auto-detected and distributed)
|
| 590 |
+
|
| 591 |
+
Examples:
|
| 592 |
+
# Use default config
|
| 593 |
+
python -m depth_anything_3.bench.evaluator
|
| 594 |
+
|
| 595 |
+
# Override model path
|
| 596 |
+
python -m depth_anything_3.bench.evaluator model.path=depth-anything/DA3-LARGE
|
| 597 |
+
|
| 598 |
+
# Evaluate specific datasets and modes
|
| 599 |
+
python -m depth_anything_3.bench.evaluator \\
|
| 600 |
+
eval.datasets=[eth3d,hiroom] \\
|
| 601 |
+
eval.modes=[pose]
|
| 602 |
+
|
| 603 |
+
# Use custom config with overrides
|
| 604 |
+
python -m depth_anything_3.bench.evaluator \\
|
| 605 |
+
--config my_config.yaml \\
|
| 606 |
+
model.path=/path/to/model \\
|
| 607 |
+
eval.max_frames=50
|
| 608 |
+
|
| 609 |
+
# Multi-GPU inference (auto-distributed)
|
| 610 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m depth_anything_3.bench.evaluator
|
| 611 |
+
|
| 612 |
+
# Debug specific scenes
|
| 613 |
+
python -m depth_anything_3.bench.evaluator \\
|
| 614 |
+
eval.datasets=[eth3d] \\
|
| 615 |
+
eval.scenes=[courtyard] \\
|
| 616 |
+
inference.debug=true
|
| 617 |
+
|
| 618 |
+
# Only evaluate (skip inference)
|
| 619 |
+
python -m depth_anything_3.bench.evaluator eval.eval_only=true
|
| 620 |
+
|
| 621 |
+
# Only print saved metrics
|
| 622 |
+
python -m depth_anything_3.bench.evaluator eval.print_only=true
|
| 623 |
+
|
| 624 |
+
""")
|
| 625 |
+
sys.exit(0)
|
| 626 |
+
|
| 627 |
+
# Load config with CLI overrides using OmegaConf dotlist
|
| 628 |
+
# Example: python evaluator.py model.path=/path/to/model eval.datasets=[eth3d,dtu]
|
| 629 |
+
config = load_config(config_path, argv=argv)
|
| 630 |
+
|
| 631 |
+
# Extract config values
|
| 632 |
+
work_dir = config.workspace.work_dir
|
| 633 |
+
model_path = config.model.path
|
| 634 |
+
datasets = config.eval.datasets
|
| 635 |
+
modes = config.eval.modes
|
| 636 |
+
ref_view_strategy = config.eval.ref_view_strategy
|
| 637 |
+
scenes = config.eval.scenes
|
| 638 |
+
max_frames = config.eval.max_frames
|
| 639 |
+
eval_only = config.eval.eval_only
|
| 640 |
+
print_only = config.eval.print_only
|
| 641 |
+
debug = config.inference.debug
|
| 642 |
+
num_fusion_workers = config.inference.num_fusion_workers
|
| 643 |
+
|
| 644 |
+
# GPU settings: parse from CLI dotlist args (gpu_id=X total_gpus=Y)
|
| 645 |
+
# These are passed by the main process when spawning workers
|
| 646 |
+
gpu_id = 0
|
| 647 |
+
total_gpus = 1
|
| 648 |
+
for arg in argv:
|
| 649 |
+
if arg.startswith("gpu_id="):
|
| 650 |
+
gpu_id = int(arg.split("=")[1])
|
| 651 |
+
elif arg.startswith("total_gpus="):
|
| 652 |
+
total_gpus = int(arg.split("=")[1])
|
| 653 |
+
|
| 654 |
+
# Override dataset scenes if specified
|
| 655 |
+
if scenes:
|
| 656 |
+
print(f"[INFO] Running on specific scenes: {scenes}")
|
| 657 |
+
|
| 658 |
+
evaluator = Evaluator(
|
| 659 |
+
work_dir=work_dir,
|
| 660 |
+
datas=datasets,
|
| 661 |
+
modes=modes,
|
| 662 |
+
ref_view_strategy=ref_view_strategy,
|
| 663 |
+
scenes=scenes,
|
| 664 |
+
debug=debug,
|
| 665 |
+
num_fusion_workers=num_fusion_workers,
|
| 666 |
+
max_frames=max_frames,
|
| 667 |
+
gpu_id=gpu_id,
|
| 668 |
+
total_gpus=total_gpus,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
if print_only:
|
| 672 |
+
evaluator.print_metrics()
|
| 673 |
+
elif eval_only:
|
| 674 |
+
metrics = evaluator.eval()
|
| 675 |
+
evaluator.print_metrics(metrics)
|
| 676 |
+
else:
|
| 677 |
+
# Parse CUDA_VISIBLE_DEVICES to get GPU list
|
| 678 |
+
# If not set, use all available GPUs
|
| 679 |
+
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
| 680 |
+
if cuda_devices is not None and cuda_devices.strip():
|
| 681 |
+
gpu_list = [g.strip() for g in cuda_devices.split(",") if g.strip()]
|
| 682 |
+
else:
|
| 683 |
+
# CUDA_VISIBLE_DEVICES not set, use all available GPUs
|
| 684 |
+
num_available = torch.cuda.device_count()
|
| 685 |
+
gpu_list = [str(i) for i in range(num_available)] if num_available > 0 else ["0"]
|
| 686 |
+
|
| 687 |
+
# Auto multi-GPU: if multiple GPUs and not a worker process
|
| 688 |
+
is_worker = os.environ.get("_DA3_WORKER") == "1"
|
| 689 |
+
|
| 690 |
+
if len(gpu_list) > 1 and not is_worker:
|
| 691 |
+
# Launch worker processes
|
| 692 |
+
import subprocess
|
| 693 |
+
|
| 694 |
+
num_gpus = len(gpu_list)
|
| 695 |
+
print(f"[INFO] Detected {num_gpus} GPUs: {gpu_list}")
|
| 696 |
+
print(f"[INFO] Launching {num_gpus} workers...")
|
| 697 |
+
|
| 698 |
+
# Build base command
|
| 699 |
+
base_cmd = [sys.executable, "-m", "depth_anything_3.bench.evaluator"]
|
| 700 |
+
# Pass config via dotlist instead of CLI args
|
| 701 |
+
if config_path != _default_config:
|
| 702 |
+
base_cmd += ["--config", config_path]
|
| 703 |
+
base_cmd += [f"model.path={model_path}"]
|
| 704 |
+
base_cmd += [f"workspace.work_dir={work_dir}"]
|
| 705 |
+
base_cmd += [f"eval.datasets=[{','.join(datasets)}]"]
|
| 706 |
+
base_cmd += [f"eval.modes=[{','.join(modes)}]"]
|
| 707 |
+
if scenes:
|
| 708 |
+
base_cmd += [f"eval.scenes=[{','.join(scenes)}]"]
|
| 709 |
+
base_cmd += [f"eval.max_frames={max_frames}"]
|
| 710 |
+
base_cmd += [f"eval.ref_view_strategy={ref_view_strategy}"]
|
| 711 |
+
base_cmd += [f"inference.debug={str(debug).lower()}"]
|
| 712 |
+
base_cmd += [f"inference.num_fusion_workers={num_fusion_workers}"]
|
| 713 |
+
|
| 714 |
+
# Launch workers
|
| 715 |
+
processes = []
|
| 716 |
+
for idx, gpu_id in enumerate(gpu_list):
|
| 717 |
+
env = os.environ.copy()
|
| 718 |
+
env["CUDA_VISIBLE_DEVICES"] = gpu_id
|
| 719 |
+
env["_DA3_WORKER"] = "1" # Mark as worker process
|
| 720 |
+
|
| 721 |
+
cmd = base_cmd.copy()
|
| 722 |
+
# GPU-specific worker config
|
| 723 |
+
cmd += [f"gpu_id={idx}", f"total_gpus={num_gpus}"]
|
| 724 |
+
|
| 725 |
+
print(f"[INFO] Starting worker {idx} on GPU {gpu_id}")
|
| 726 |
+
p = subprocess.Popen(cmd, env=env)
|
| 727 |
+
processes.append(p)
|
| 728 |
+
|
| 729 |
+
# Wait for all workers
|
| 730 |
+
for p in processes:
|
| 731 |
+
p.wait()
|
| 732 |
+
|
| 733 |
+
print(f"[INFO] All {num_gpus} workers completed")
|
| 734 |
+
|
| 735 |
+
# Run evaluation after all inference is done
|
| 736 |
+
metrics = evaluator.eval()
|
| 737 |
+
evaluator.print_metrics(metrics)
|
| 738 |
+
else:
|
| 739 |
+
# Single GPU or worker process
|
| 740 |
+
from depth_anything_3.api import DepthAnything3
|
| 741 |
+
|
| 742 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 743 |
+
api = DepthAnything3.from_pretrained(model_path)
|
| 744 |
+
api = api.to(device)
|
| 745 |
+
|
| 746 |
+
evaluator.infer(api, model_path=model_path)
|
| 747 |
+
|
| 748 |
+
# Only run eval if single GPU mode (workers don't eval)
|
| 749 |
+
if not is_worker:
|
| 750 |
+
metrics = evaluator.eval()
|
| 751 |
+
evaluator.print_metrics(metrics)
|
| 752 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/print_metrics.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Beautiful metrics printing utilities for benchmark evaluation.
|
| 17 |
+
|
| 18 |
+
Provides colorized, well-formatted tabular output for evaluation results.
|
| 19 |
+
Supports highlighting best/worst values and grouping by dataset/mode.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from typing import Dict as TDict, List, Optional
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ANSI color codes for terminal output
|
| 30 |
+
class Colors:
|
| 31 |
+
"""ANSI escape codes for terminal colors."""
|
| 32 |
+
|
| 33 |
+
RESET = "\033[0m"
|
| 34 |
+
BOLD = "\033[1m"
|
| 35 |
+
RED = "\033[31m"
|
| 36 |
+
GREEN = "\033[32m"
|
| 37 |
+
YELLOW = "\033[33m"
|
| 38 |
+
BLUE = "\033[34m"
|
| 39 |
+
MAGENTA = "\033[35m"
|
| 40 |
+
CYAN = "\033[36m"
|
| 41 |
+
WHITE = "\033[37m"
|
| 42 |
+
|
| 43 |
+
# Bold variants
|
| 44 |
+
BOLD_RED = "\033[1;31m"
|
| 45 |
+
BOLD_GREEN = "\033[1;32m"
|
| 46 |
+
BOLD_YELLOW = "\033[1;33m"
|
| 47 |
+
BOLD_BLUE = "\033[1;34m"
|
| 48 |
+
BOLD_MAGENTA = "\033[1;35m"
|
| 49 |
+
BOLD_CYAN = "\033[1;36m"
|
| 50 |
+
|
| 51 |
+
# Background
|
| 52 |
+
BG_DARK = "\033[48;5;236m"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def strip_ansi(text: str) -> str:
|
| 56 |
+
"""Remove ANSI escape sequences from string for length calculation."""
|
| 57 |
+
ansi_escape = re.compile(r"\x1b\[[0-9;]*m")
|
| 58 |
+
return ansi_escape.sub("", text)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def colorize_value(
|
| 62 |
+
value: str,
|
| 63 |
+
is_best: bool = False,
|
| 64 |
+
is_worst: bool = False,
|
| 65 |
+
lower_is_better: bool = False,
|
| 66 |
+
) -> str:
|
| 67 |
+
"""
|
| 68 |
+
Apply color to a metric value based on whether it's best/worst.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
value: String representation of the value
|
| 72 |
+
is_best: Whether this is the best value in its column
|
| 73 |
+
is_worst: Whether this is the worst value in its column
|
| 74 |
+
lower_is_better: If True, lower values are better (e.g., error metrics)
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Colorized string
|
| 78 |
+
"""
|
| 79 |
+
if lower_is_better:
|
| 80 |
+
# For metrics like error/distance, lower is better
|
| 81 |
+
if is_best:
|
| 82 |
+
return f"{Colors.BOLD_GREEN}{value}{Colors.RESET}"
|
| 83 |
+
elif is_worst:
|
| 84 |
+
return f"{Colors.BOLD_RED}{value}{Colors.RESET}"
|
| 85 |
+
else:
|
| 86 |
+
# For metrics like accuracy/AUC, higher is better
|
| 87 |
+
if is_best:
|
| 88 |
+
return f"{Colors.BOLD_GREEN}{value}{Colors.RESET}"
|
| 89 |
+
elif is_worst:
|
| 90 |
+
return f"{Colors.BOLD_RED}{value}{Colors.RESET}"
|
| 91 |
+
return value
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class MetricsPrinter:
|
| 95 |
+
"""
|
| 96 |
+
Beautiful tabular metrics printer with color support.
|
| 97 |
+
|
| 98 |
+
Features:
|
| 99 |
+
- Colorized best/worst values
|
| 100 |
+
- Grouped by dataset and evaluation mode
|
| 101 |
+
- Automatic column width calculation
|
| 102 |
+
- Support for multiple input directories comparison
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
# Metrics where lower values are better
|
| 106 |
+
LOWER_IS_BETTER = {"comp", "acc", "overall", "error", "loss", "rmse", "mae"}
|
| 107 |
+
|
| 108 |
+
def __init__(self, use_color: bool = True):
|
| 109 |
+
"""
|
| 110 |
+
Initialize the printer.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
use_color: Whether to use ANSI colors in output
|
| 114 |
+
"""
|
| 115 |
+
self.use_color = use_color
|
| 116 |
+
|
| 117 |
+
def print_results(self, metrics: TDict[str, dict], summary_only: bool = True) -> None:
|
| 118 |
+
"""
|
| 119 |
+
Print evaluation metrics in a beautiful tabular format.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
metrics: Dictionary mapping "dataset_mode" to metric results
|
| 123 |
+
summary_only: If True, only print summary table. If False, print per-dataset details too.
|
| 124 |
+
"""
|
| 125 |
+
if not metrics:
|
| 126 |
+
print(f"\n{Colors.BOLD_RED}❌ No evaluation metrics found{Colors.RESET}")
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
if not summary_only:
|
| 130 |
+
self._print_header()
|
| 131 |
+
grouped = self._group_by_dataset(metrics)
|
| 132 |
+
|
| 133 |
+
for dataset, modes_data in grouped.items():
|
| 134 |
+
self._print_dataset_section(dataset, modes_data)
|
| 135 |
+
|
| 136 |
+
# Print summary table with average metrics across datasets
|
| 137 |
+
self._print_summary(metrics)
|
| 138 |
+
|
| 139 |
+
self._print_footer()
|
| 140 |
+
|
| 141 |
+
def print_comparison(
|
| 142 |
+
self,
|
| 143 |
+
metrics_list: List[TDict[str, dict]],
|
| 144 |
+
labels: List[str],
|
| 145 |
+
) -> None:
|
| 146 |
+
"""
|
| 147 |
+
Print comparison table for multiple evaluation runs.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
metrics_list: List of metrics dictionaries
|
| 151 |
+
labels: Labels for each metrics dictionary
|
| 152 |
+
"""
|
| 153 |
+
if not metrics_list or not all(metrics_list):
|
| 154 |
+
print(f"\n{Colors.BOLD_RED}❌ No metrics to compare{Colors.RESET}")
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# Collect all datasets and modes
|
| 158 |
+
all_keys = set()
|
| 159 |
+
for metrics in metrics_list:
|
| 160 |
+
all_keys.update(metrics.keys())
|
| 161 |
+
|
| 162 |
+
self._print_header("COMPARISON")
|
| 163 |
+
|
| 164 |
+
for key in sorted(all_keys):
|
| 165 |
+
parts = key.rsplit("_", 1)
|
| 166 |
+
if len(parts) == 2:
|
| 167 |
+
dataset, mode = parts[0], parts[1]
|
| 168 |
+
else:
|
| 169 |
+
dataset, mode = key, "unknown"
|
| 170 |
+
|
| 171 |
+
print(f"\n{Colors.BOLD_CYAN}📊 {dataset.upper()} - {mode.upper()}{Colors.RESET}")
|
| 172 |
+
print("-" * 100)
|
| 173 |
+
|
| 174 |
+
# Collect metrics from all runs
|
| 175 |
+
all_metric_names = set()
|
| 176 |
+
for metrics in metrics_list:
|
| 177 |
+
if key in metrics and "mean" in metrics[key]:
|
| 178 |
+
all_metric_names.update(metrics[key]["mean"].keys())
|
| 179 |
+
|
| 180 |
+
if not all_metric_names:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# Build comparison table
|
| 184 |
+
metric_width = max(15, max(len(m) for m in all_metric_names) + 2)
|
| 185 |
+
label_width = max(15, max(len(l) for l in labels) + 2)
|
| 186 |
+
|
| 187 |
+
# Header
|
| 188 |
+
header = f"{'Metric':<{metric_width}}"
|
| 189 |
+
for label in labels:
|
| 190 |
+
header += f"{label:<{label_width}}"
|
| 191 |
+
print(header)
|
| 192 |
+
print("-" * len(strip_ansi(header)))
|
| 193 |
+
|
| 194 |
+
# Collect values for highlighting
|
| 195 |
+
for metric_name in sorted(all_metric_names):
|
| 196 |
+
values = []
|
| 197 |
+
for metrics in metrics_list:
|
| 198 |
+
if key in metrics and "mean" in metrics[key]:
|
| 199 |
+
val = metrics[key]["mean"].get(metric_name)
|
| 200 |
+
values.append(val if val is not None else float("nan"))
|
| 201 |
+
else:
|
| 202 |
+
values.append(float("nan"))
|
| 203 |
+
|
| 204 |
+
# Find best/worst
|
| 205 |
+
valid_values = [v for v in values if not (v != v)] # Filter NaN
|
| 206 |
+
if valid_values:
|
| 207 |
+
lower_better = any(
|
| 208 |
+
lb in metric_name.lower() for lb in self.LOWER_IS_BETTER
|
| 209 |
+
)
|
| 210 |
+
best_val = min(valid_values) if lower_better else max(valid_values)
|
| 211 |
+
worst_val = max(valid_values) if lower_better else min(valid_values)
|
| 212 |
+
else:
|
| 213 |
+
best_val = worst_val = None
|
| 214 |
+
|
| 215 |
+
# Print row
|
| 216 |
+
row = f"{metric_name:<{metric_width}}"
|
| 217 |
+
for val in values:
|
| 218 |
+
if val != val: # NaN check
|
| 219 |
+
val_str = "N/A"
|
| 220 |
+
else:
|
| 221 |
+
val_str = f"{val:.4f}"
|
| 222 |
+
if self.use_color and len(valid_values) > 1:
|
| 223 |
+
lower_better = any(
|
| 224 |
+
lb in metric_name.lower() for lb in self.LOWER_IS_BETTER
|
| 225 |
+
)
|
| 226 |
+
is_best = abs(val - best_val) < 1e-8 if best_val else False
|
| 227 |
+
is_worst = abs(val - worst_val) < 1e-8 if worst_val else False
|
| 228 |
+
val_str_padded = f"{val_str:<{label_width}}"
|
| 229 |
+
val_str = colorize_value(
|
| 230 |
+
val_str_padded, is_best, is_worst, lower_better
|
| 231 |
+
)
|
| 232 |
+
row += val_str
|
| 233 |
+
continue
|
| 234 |
+
row += f"{val_str:<{label_width}}"
|
| 235 |
+
print(row)
|
| 236 |
+
|
| 237 |
+
self._print_footer()
|
| 238 |
+
|
| 239 |
+
def _print_header(self, title: str = "EVALUATION RESULTS") -> None:
|
| 240 |
+
"""Print report header."""
|
| 241 |
+
width = 100
|
| 242 |
+
print()
|
| 243 |
+
print("=" * width)
|
| 244 |
+
print(f"{Colors.BOLD_CYAN}📊 DEPTH ANYTHING 3 {title}{Colors.RESET}")
|
| 245 |
+
print("=" * width)
|
| 246 |
+
|
| 247 |
+
def _print_footer(self) -> None:
|
| 248 |
+
"""Print report footer."""
|
| 249 |
+
width = 100
|
| 250 |
+
print()
|
| 251 |
+
print("=" * width)
|
| 252 |
+
print(f"{Colors.BOLD_GREEN}✅ Evaluation Complete{Colors.RESET}")
|
| 253 |
+
print("=" * width)
|
| 254 |
+
print()
|
| 255 |
+
|
| 256 |
+
def _group_by_dataset(self, metrics: TDict[str, dict]) -> TDict[str, dict]:
|
| 257 |
+
"""Group metrics by dataset."""
|
| 258 |
+
grouped = {}
|
| 259 |
+
for key, data in metrics.items():
|
| 260 |
+
if not isinstance(data, dict) or "mean" not in data:
|
| 261 |
+
continue
|
| 262 |
+
# Parse key format: "dataset_mode" (e.g., "dtu_recon_unposed")
|
| 263 |
+
parts = key.split("_", 1)
|
| 264 |
+
if len(parts) == 2:
|
| 265 |
+
dataset, mode = parts
|
| 266 |
+
if dataset not in grouped:
|
| 267 |
+
grouped[dataset] = {}
|
| 268 |
+
grouped[dataset][mode] = data
|
| 269 |
+
return grouped
|
| 270 |
+
|
| 271 |
+
def _print_dataset_section(self, dataset: str, modes_data: TDict[str, dict]) -> None:
|
| 272 |
+
"""Print metrics section for a single dataset."""
|
| 273 |
+
print(f"\n{Colors.BOLD_MAGENTA}🔍 {dataset.upper()}{Colors.RESET}")
|
| 274 |
+
print("-" * 100)
|
| 275 |
+
|
| 276 |
+
# Collect all unique metrics across all modes
|
| 277 |
+
all_metrics = set()
|
| 278 |
+
for mode_data in modes_data.values():
|
| 279 |
+
all_metrics.update(mode_data["mean"].keys())
|
| 280 |
+
all_metrics = sorted(list(all_metrics))
|
| 281 |
+
|
| 282 |
+
if not all_metrics:
|
| 283 |
+
print(" No metrics available")
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
# Calculate column widths
|
| 287 |
+
metric_width = max(18, max(len(m) for m in all_metrics) + 2)
|
| 288 |
+
mode_width = 18
|
| 289 |
+
modes = list(modes_data.keys())
|
| 290 |
+
|
| 291 |
+
# Print header
|
| 292 |
+
header = f"{'Metric':<{metric_width}}"
|
| 293 |
+
for mode in modes:
|
| 294 |
+
header += f"{mode.upper():<{mode_width}}"
|
| 295 |
+
print(f"{Colors.BOLD}{header}{Colors.RESET}")
|
| 296 |
+
print("-" * len(header))
|
| 297 |
+
|
| 298 |
+
# Print each metric row
|
| 299 |
+
for metric in all_metrics:
|
| 300 |
+
row = f"{metric:<{metric_width}}"
|
| 301 |
+
|
| 302 |
+
# Collect values for this metric across modes
|
| 303 |
+
values = []
|
| 304 |
+
for mode in modes:
|
| 305 |
+
if metric in modes_data[mode]["mean"]:
|
| 306 |
+
values.append(modes_data[mode]["mean"][metric])
|
| 307 |
+
else:
|
| 308 |
+
values.append(None)
|
| 309 |
+
|
| 310 |
+
# Find best/worst values
|
| 311 |
+
valid_values = [v for v in values if v is not None]
|
| 312 |
+
if valid_values:
|
| 313 |
+
lower_better = any(lb in metric.lower() for lb in self.LOWER_IS_BETTER)
|
| 314 |
+
best_val = min(valid_values) if lower_better else max(valid_values)
|
| 315 |
+
worst_val = max(valid_values) if lower_better else min(valid_values)
|
| 316 |
+
else:
|
| 317 |
+
best_val = worst_val = None
|
| 318 |
+
|
| 319 |
+
# Format each value
|
| 320 |
+
for val in values:
|
| 321 |
+
if val is None:
|
| 322 |
+
row += f"{'N/A':<{mode_width}}"
|
| 323 |
+
else:
|
| 324 |
+
val_str = f"{val:.4f}"
|
| 325 |
+
if self.use_color and len(valid_values) > 1:
|
| 326 |
+
is_best = abs(val - best_val) < 1e-8 if best_val else False
|
| 327 |
+
is_worst = abs(val - worst_val) < 1e-8 if worst_val else False
|
| 328 |
+
lower_better = any(
|
| 329 |
+
lb in metric.lower() for lb in self.LOWER_IS_BETTER
|
| 330 |
+
)
|
| 331 |
+
# Pad before colorizing to maintain alignment
|
| 332 |
+
val_str_padded = f"{val_str:<{mode_width}}"
|
| 333 |
+
row += colorize_value(
|
| 334 |
+
val_str_padded, is_best, is_worst, lower_better
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
row += f"{val_str:<{mode_width}}"
|
| 338 |
+
print(row)
|
| 339 |
+
|
| 340 |
+
# Show scene counts
|
| 341 |
+
scene_info = []
|
| 342 |
+
for mode, mode_data in modes_data.items():
|
| 343 |
+
scene_count = len([k for k in mode_data.keys() if k != "mean"])
|
| 344 |
+
scene_info.append(f"{mode}: {scene_count} scenes")
|
| 345 |
+
print(f"\n{Colors.CYAN}📈 {' | '.join(scene_info)}{Colors.RESET}")
|
| 346 |
+
|
| 347 |
+
def _print_summary(self, metrics: TDict[str, dict]) -> None:
|
| 348 |
+
"""
|
| 349 |
+
Print summary table with key metrics across all datasets.
|
| 350 |
+
|
| 351 |
+
Format: One row per metric, datasets as columns.
|
| 352 |
+
Order: HiRoom, ETH3D, DTU, 7Scenes, ScanNet++, (DTU-64 for pose only)
|
| 353 |
+
"""
|
| 354 |
+
print(f"\n{Colors.BOLD_CYAN}{'=' * 120}{Colors.RESET}")
|
| 355 |
+
print(f"{Colors.BOLD_CYAN}📊 SUMMARY{Colors.RESET}")
|
| 356 |
+
print(f"{Colors.BOLD_CYAN}{'=' * 120}{Colors.RESET}")
|
| 357 |
+
|
| 358 |
+
# Dataset display order and names
|
| 359 |
+
DATASET_ORDER = ["hiroom", "eth3d", "dtu", "7scenes", "scannetpp", "dtu64"]
|
| 360 |
+
DATASET_DISPLAY = {
|
| 361 |
+
"hiroom": "HiRoom",
|
| 362 |
+
"eth3d": "ETH3D",
|
| 363 |
+
"dtu": "DTU",
|
| 364 |
+
"7scenes": "7Scenes",
|
| 365 |
+
"scannetpp": "ScanNet++",
|
| 366 |
+
"dtu64": "DTU-64",
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
# Collect all metrics into a structured dict
|
| 370 |
+
# metric_data[dataset][mode] = {"Auc_3": x, "Auc_30": x, "fscore": x, "overall": x}
|
| 371 |
+
metric_data = {}
|
| 372 |
+
for key, data in metrics.items():
|
| 373 |
+
if not isinstance(data, dict) or "mean" not in data:
|
| 374 |
+
continue
|
| 375 |
+
parts = key.split("_", 1)
|
| 376 |
+
if len(parts) != 2:
|
| 377 |
+
continue
|
| 378 |
+
dataset, mode = parts
|
| 379 |
+
dataset_lower = dataset.lower()
|
| 380 |
+
if dataset_lower not in metric_data:
|
| 381 |
+
metric_data[dataset_lower] = {}
|
| 382 |
+
metric_data[dataset_lower][mode] = data["mean"]
|
| 383 |
+
|
| 384 |
+
col_width = 12
|
| 385 |
+
|
| 386 |
+
def fmt_val(val):
|
| 387 |
+
"""Format value or return N/A."""
|
| 388 |
+
if val is None:
|
| 389 |
+
return "N/A"
|
| 390 |
+
return f"{val:.4f}"
|
| 391 |
+
|
| 392 |
+
def get_metric(dataset, mode, metric_name):
|
| 393 |
+
"""Get metric value or None."""
|
| 394 |
+
if dataset not in metric_data:
|
| 395 |
+
return None
|
| 396 |
+
if mode not in metric_data[dataset]:
|
| 397 |
+
return None
|
| 398 |
+
return metric_data[dataset][mode].get(metric_name)
|
| 399 |
+
|
| 400 |
+
# ============ POSE METRICS ============
|
| 401 |
+
print(f"\n{Colors.BOLD_MAGENTA}🎯 POSE ESTIMATION{Colors.RESET}")
|
| 402 |
+
|
| 403 |
+
# Pose: show all datasets except DTU (keep DTU-64 only)
|
| 404 |
+
# Order: HiRoom, ETH3D, DTU-64, 7Scenes, ScanNet++
|
| 405 |
+
pose_datasets = ["hiroom", "eth3d", "dtu64", "7scenes", "scannetpp"]
|
| 406 |
+
|
| 407 |
+
# Header: Avg first, then datasets
|
| 408 |
+
header = f"{'Metric':<15}{'Avg':<{col_width}}"
|
| 409 |
+
for ds in pose_datasets:
|
| 410 |
+
header += f"{DATASET_DISPLAY[ds]:<{col_width}}"
|
| 411 |
+
print("-" * len(strip_ansi(header)))
|
| 412 |
+
print(f"{Colors.BOLD}{header}{Colors.RESET}")
|
| 413 |
+
print("-" * len(strip_ansi(header)))
|
| 414 |
+
|
| 415 |
+
# Helper to get metric with fallback names
|
| 416 |
+
def get_pose_metric(dataset, metric_name):
|
| 417 |
+
"""Get pose metric with fallback for different naming conventions."""
|
| 418 |
+
# Try different naming conventions
|
| 419 |
+
names = {
|
| 420 |
+
"Auc3": ["Auc_3", "auc03", "auc_3", "AUC_3", "Auc3", "auc3"],
|
| 421 |
+
"Auc30": ["Auc_30", "auc30", "auc_30", "AUC_30", "Auc30"],
|
| 422 |
+
}
|
| 423 |
+
for name in names.get(metric_name, [metric_name]):
|
| 424 |
+
val = get_metric(dataset, "pose", name)
|
| 425 |
+
if val is not None:
|
| 426 |
+
return val
|
| 427 |
+
return None
|
| 428 |
+
|
| 429 |
+
# Auc3 row
|
| 430 |
+
values = []
|
| 431 |
+
for ds in pose_datasets:
|
| 432 |
+
val = get_pose_metric(ds, "Auc3")
|
| 433 |
+
if val is not None:
|
| 434 |
+
values.append(val)
|
| 435 |
+
avg = sum(values) / len(values) if values else None
|
| 436 |
+
row = f"{'Auc3':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
|
| 437 |
+
for ds in pose_datasets:
|
| 438 |
+
val = get_pose_metric(ds, "Auc3")
|
| 439 |
+
row += f"{fmt_val(val):<{col_width}}"
|
| 440 |
+
print(row)
|
| 441 |
+
|
| 442 |
+
# Auc30 row
|
| 443 |
+
values = []
|
| 444 |
+
for ds in pose_datasets:
|
| 445 |
+
val = get_pose_metric(ds, "Auc30")
|
| 446 |
+
if val is not None:
|
| 447 |
+
values.append(val)
|
| 448 |
+
avg = sum(values) / len(values) if values else None
|
| 449 |
+
row = f"{'Auc30':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
|
| 450 |
+
for ds in pose_datasets:
|
| 451 |
+
val = get_pose_metric(ds, "Auc30")
|
| 452 |
+
row += f"{fmt_val(val):<{col_width}}"
|
| 453 |
+
print(row)
|
| 454 |
+
|
| 455 |
+
# ============ RECON_UNPOSED METRICS ============
|
| 456 |
+
print(f"\n{Colors.BOLD_MAGENTA}🏗️ RECON_UNPOSED (Pred Pose){Colors.RESET}")
|
| 457 |
+
|
| 458 |
+
# For recon, exclude dtu64 from columns
|
| 459 |
+
recon_datasets = ["hiroom", "eth3d", "dtu", "7scenes", "scannetpp"]
|
| 460 |
+
avg_datasets = ["hiroom", "eth3d", "7scenes", "scannetpp"] # Exclude DTU from avg
|
| 461 |
+
|
| 462 |
+
# Header: Avg first, then datasets
|
| 463 |
+
header = f"{'Metric':<15}{'Avg*':<{col_width}}"
|
| 464 |
+
for ds in recon_datasets:
|
| 465 |
+
header += f"{DATASET_DISPLAY[ds]:<{col_width}}"
|
| 466 |
+
print("-" * len(strip_ansi(header)))
|
| 467 |
+
print(f"{Colors.BOLD}{header}{Colors.RESET}")
|
| 468 |
+
print("-" * len(strip_ansi(header)))
|
| 469 |
+
|
| 470 |
+
# F-score row (only metric for avg)
|
| 471 |
+
values = []
|
| 472 |
+
for ds in recon_datasets:
|
| 473 |
+
val = get_metric(ds, "recon_unposed", "fscore")
|
| 474 |
+
if val is not None and ds in avg_datasets:
|
| 475 |
+
values.append(val)
|
| 476 |
+
avg = sum(values) / len(values) if values else None
|
| 477 |
+
row = f"{'F-score':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
|
| 478 |
+
for ds in recon_datasets:
|
| 479 |
+
val = get_metric(ds, "recon_unposed", "fscore")
|
| 480 |
+
row += f"{fmt_val(val):<{col_width}}"
|
| 481 |
+
print(row)
|
| 482 |
+
|
| 483 |
+
# Overall row (avg over 4 datasets excluding DTU)
|
| 484 |
+
values = []
|
| 485 |
+
for ds in recon_datasets:
|
| 486 |
+
val = get_metric(ds, "recon_unposed", "overall")
|
| 487 |
+
if val is not None and ds in avg_datasets:
|
| 488 |
+
values.append(val)
|
| 489 |
+
avg = sum(values) / len(values) if values else None
|
| 490 |
+
row = f"{'Overall':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
|
| 491 |
+
for ds in recon_datasets:
|
| 492 |
+
val = get_metric(ds, "recon_unposed", "overall")
|
| 493 |
+
row += f"{fmt_val(val):<{col_width}}"
|
| 494 |
+
print(row)
|
| 495 |
+
|
| 496 |
+
# ============ RECON_POSED METRICS ============
|
| 497 |
+
print(f"\n{Colors.BOLD_MAGENTA}🏗️ RECON_POSED (GT Pose){Colors.RESET}")
|
| 498 |
+
|
| 499 |
+
# Header: Avg first, then datasets
|
| 500 |
+
header = f"{'Metric':<15}{'Avg*':<{col_width}}"
|
| 501 |
+
for ds in recon_datasets:
|
| 502 |
+
header += f"{DATASET_DISPLAY[ds]:<{col_width}}"
|
| 503 |
+
print("-" * len(strip_ansi(header)))
|
| 504 |
+
print(f"{Colors.BOLD}{header}{Colors.RESET}")
|
| 505 |
+
print("-" * len(strip_ansi(header)))
|
| 506 |
+
|
| 507 |
+
# F-score row (only metric for avg)
|
| 508 |
+
values = []
|
| 509 |
+
for ds in recon_datasets:
|
| 510 |
+
val = get_metric(ds, "recon_posed", "fscore")
|
| 511 |
+
if val is not None and ds in avg_datasets:
|
| 512 |
+
values.append(val)
|
| 513 |
+
avg = sum(values) / len(values) if values else None
|
| 514 |
+
row = f"{'F-score':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
|
| 515 |
+
for ds in recon_datasets:
|
| 516 |
+
val = get_metric(ds, "recon_posed", "fscore")
|
| 517 |
+
row += f"{fmt_val(val):<{col_width}}"
|
| 518 |
+
print(row)
|
| 519 |
+
|
| 520 |
+
# Overall row (avg over 4 datasets excluding DTU)
|
| 521 |
+
values = []
|
| 522 |
+
for ds in recon_datasets:
|
| 523 |
+
val = get_metric(ds, "recon_posed", "overall")
|
| 524 |
+
if val is not None and ds in avg_datasets:
|
| 525 |
+
values.append(val)
|
| 526 |
+
avg = sum(values) / len(values) if values else None
|
| 527 |
+
row = f"{'Overall':<15}{Colors.BOLD_GREEN}{fmt_val(avg):<{col_width}}{Colors.RESET}"
|
| 528 |
+
for ds in recon_datasets:
|
| 529 |
+
val = get_metric(ds, "recon_posed", "overall")
|
| 530 |
+
row += f"{fmt_val(val):<{col_width}}"
|
| 531 |
+
print(row)
|
| 532 |
+
|
| 533 |
+
print(f"\n{Colors.CYAN}* Avg F-score / Overall = average over HiRoom, ETH3D, 7Scenes, ScanNet++ (4 datasets){Colors.RESET}")
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def load_metrics_from_dir(metric_dir: str) -> TDict[str, dict]:
|
| 537 |
+
"""
|
| 538 |
+
Load all metrics JSON files from a directory.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
metric_dir: Path to directory containing metric JSON files
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
Dictionary mapping filename (without .json) to metric data
|
| 545 |
+
"""
|
| 546 |
+
metrics = {}
|
| 547 |
+
if not os.path.exists(metric_dir):
|
| 548 |
+
return metrics
|
| 549 |
+
|
| 550 |
+
for filename in os.listdir(metric_dir):
|
| 551 |
+
if filename.endswith(".json"):
|
| 552 |
+
filepath = os.path.join(metric_dir, filename)
|
| 553 |
+
try:
|
| 554 |
+
with open(filepath, encoding="utf-8") as f:
|
| 555 |
+
content = f.read()
|
| 556 |
+
# Handle trailing commas in JSON
|
| 557 |
+
content = re.sub(r",\s*([\]\}])", r"\1", content)
|
| 558 |
+
data = json.loads(content)
|
| 559 |
+
key = filename[:-5]
|
| 560 |
+
metrics[key] = data
|
| 561 |
+
except Exception as e:
|
| 562 |
+
print(f"Warning: Failed to load {filename}: {e}")
|
| 563 |
+
|
| 564 |
+
return metrics
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def main():
|
| 568 |
+
"""Command-line interface for metrics printing."""
|
| 569 |
+
parser = argparse.ArgumentParser(
|
| 570 |
+
description="Print DepthAnything3 benchmark evaluation metrics."
|
| 571 |
+
)
|
| 572 |
+
parser.add_argument(
|
| 573 |
+
"--input_dir",
|
| 574 |
+
type=str,
|
| 575 |
+
default="./eval_workspace/metric_results",
|
| 576 |
+
help="Directory containing metric JSON files (comma-separated for comparison)",
|
| 577 |
+
)
|
| 578 |
+
parser.add_argument(
|
| 579 |
+
"--no_color",
|
| 580 |
+
action="store_true",
|
| 581 |
+
help="Disable colored output",
|
| 582 |
+
)
|
| 583 |
+
parser.add_argument(
|
| 584 |
+
"--key",
|
| 585 |
+
type=str,
|
| 586 |
+
default=None,
|
| 587 |
+
help="Specific metric key to highlight",
|
| 588 |
+
)
|
| 589 |
+
args = parser.parse_args()
|
| 590 |
+
|
| 591 |
+
# Support multiple directories for comparison
|
| 592 |
+
input_dirs = [d.strip() for d in args.input_dir.split(",") if d.strip()]
|
| 593 |
+
|
| 594 |
+
printer = MetricsPrinter(use_color=not args.no_color)
|
| 595 |
+
|
| 596 |
+
if len(input_dirs) == 1:
|
| 597 |
+
# Single directory - simple print
|
| 598 |
+
metrics = load_metrics_from_dir(input_dirs[0])
|
| 599 |
+
printer.print_results(metrics)
|
| 600 |
+
else:
|
| 601 |
+
# Multiple directories - comparison mode
|
| 602 |
+
metrics_list = []
|
| 603 |
+
labels = []
|
| 604 |
+
for d in input_dirs:
|
| 605 |
+
metrics = load_metrics_from_dir(d)
|
| 606 |
+
if metrics:
|
| 607 |
+
metrics_list.append(metrics)
|
| 608 |
+
labels.append(os.path.basename(d.rstrip("/")))
|
| 609 |
+
|
| 610 |
+
if metrics_list:
|
| 611 |
+
printer.print_comparison(metrics_list, labels)
|
| 612 |
+
else:
|
| 613 |
+
print("No metrics found in specified directories")
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
if __name__ == "__main__":
|
| 617 |
+
main()
|
| 618 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/registries.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Auto-loading registry system for benchmark datasets.
|
| 17 |
+
|
| 18 |
+
This module provides registry classes that automatically discover and import
|
| 19 |
+
dataset implementations from the datasets subpackage on first access.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import importlib
|
| 23 |
+
import pkgutil
|
| 24 |
+
import threading
|
| 25 |
+
|
| 26 |
+
from depth_anything_3.utils.registry import Registry
|
| 27 |
+
|
| 28 |
+
__all__ = ["METRIC_REGISTRY", "MONO_REGISTRY", "MV_REGISTRY", "NVS_REGISTRY"]
|
| 29 |
+
|
| 30 |
+
# ---- Lazy import: Only scan and import all datasets submodules on first registry access ----
|
| 31 |
+
_loaded = False
|
| 32 |
+
_lock = threading.Lock()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _import_all_datasets_once():
|
| 36 |
+
"""
|
| 37 |
+
Scan and import all .py submodules under depth_anything_3.bench.datasets
|
| 38 |
+
(skip files/packages starting with underscore), to trigger @REGISTRY.register(...) in each module.
|
| 39 |
+
"""
|
| 40 |
+
global _loaded
|
| 41 |
+
if _loaded:
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
with _lock:
|
| 45 |
+
if _loaded:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
pkg_name = "depth_anything_3.bench.datasets"
|
| 49 |
+
pkg = importlib.import_module(pkg_name)
|
| 50 |
+
pkg_paths = list(getattr(pkg, "__path__", []))
|
| 51 |
+
|
| 52 |
+
for finder, name, ispkg in pkgutil.walk_packages(pkg_paths, prefix=pkg_name + "."):
|
| 53 |
+
base = name.rsplit(".", 1)[-1]
|
| 54 |
+
if base.startswith("_"):
|
| 55 |
+
continue
|
| 56 |
+
try:
|
| 57 |
+
importlib.import_module(name)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"[datasets auto-import] Failed to import {name}: {e}")
|
| 60 |
+
|
| 61 |
+
_loaded = True
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class AutoRegistry(Registry):
|
| 65 |
+
"""Registry that ensures all datasets are auto-discovered and imported on first use."""
|
| 66 |
+
|
| 67 |
+
def get(self, name):
|
| 68 |
+
_import_all_datasets_once()
|
| 69 |
+
return super().get(name)
|
| 70 |
+
|
| 71 |
+
def all(self):
|
| 72 |
+
_import_all_datasets_once()
|
| 73 |
+
return super().all()
|
| 74 |
+
|
| 75 |
+
def has(self, name):
|
| 76 |
+
_import_all_datasets_once()
|
| 77 |
+
return name in self._map
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Four auto-lazy registry instances for different evaluation types
|
| 81 |
+
METRIC_REGISTRY = AutoRegistry() # For metric depth evaluation
|
| 82 |
+
MONO_REGISTRY = AutoRegistry() # For monocular depth evaluation
|
| 83 |
+
MV_REGISTRY = AutoRegistry() # For multi-view evaluation
|
| 84 |
+
NVS_REGISTRY = AutoRegistry() # For novel view synthesis evaluation
|
| 85 |
+
|
Depth-Anything-3/src/depth_anything_3/bench/utils.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Utility functions for benchmark evaluation.
|
| 17 |
+
|
| 18 |
+
Contains:
|
| 19 |
+
- Pose evaluation metrics (AUC) and helper functions
|
| 20 |
+
- 3D reconstruction evaluation metrics (Acc/Comp/F-score)
|
| 21 |
+
- Geometry utilities (quaternion conversion, etc.)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from typing import Dict as TDict, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import open3d as o3d
|
| 28 |
+
import torch
|
| 29 |
+
from addict import Dict
|
| 30 |
+
from scipy.spatial import KDTree
|
| 31 |
+
|
| 32 |
+
from depth_anything_3.utils.geometry import mat_to_quat
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# =============================================================================
|
| 36 |
+
# Geometry Utilities
|
| 37 |
+
# =============================================================================
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def quat2rotmat(qvec: list) -> np.ndarray:
|
| 41 |
+
"""
|
| 42 |
+
Convert quaternion (WXYZ order) to rotation matrix.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
qvec: Quaternion as [w, x, y, z]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
3x3 rotation matrix
|
| 49 |
+
"""
|
| 50 |
+
rotmat = np.array(
|
| 51 |
+
[
|
| 52 |
+
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
|
| 53 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
| 54 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
|
| 55 |
+
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
| 56 |
+
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
|
| 57 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
|
| 58 |
+
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
| 59 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
| 60 |
+
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
rotmat = rotmat.reshape(3, 3)
|
| 64 |
+
return rotmat
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# =============================================================================
|
| 68 |
+
# 3D Reconstruction Evaluation
|
| 69 |
+
# =============================================================================
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def nn_correspondance(verts1: np.ndarray, verts2: np.ndarray) -> np.ndarray:
|
| 73 |
+
"""
|
| 74 |
+
Compute nearest neighbor distances from verts2 to verts1 using KDTree.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
verts1: Reference point cloud [N, 3]
|
| 78 |
+
verts2: Query point cloud [M, 3]
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Distance array [M,] - distance from each point in verts2 to nearest in verts1
|
| 82 |
+
"""
|
| 83 |
+
if len(verts1) == 0 or len(verts2) == 0:
|
| 84 |
+
return np.array([])
|
| 85 |
+
|
| 86 |
+
kdtree = KDTree(verts1)
|
| 87 |
+
distances, _ = kdtree.query(verts2)
|
| 88 |
+
return distances.reshape(-1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def evaluate_3d_reconstruction(
|
| 92 |
+
pcd_pred: Union[o3d.geometry.PointCloud, np.ndarray],
|
| 93 |
+
pcd_trgt: Union[o3d.geometry.PointCloud, np.ndarray],
|
| 94 |
+
threshold: float = 0.05,
|
| 95 |
+
down_sample: Optional[float] = None,
|
| 96 |
+
) -> TDict[str, float]:
|
| 97 |
+
"""
|
| 98 |
+
Evaluate 3D reconstruction quality using standard metrics.
|
| 99 |
+
|
| 100 |
+
This function computes:
|
| 101 |
+
- Accuracy: Mean distance from predicted points to GT surface
|
| 102 |
+
- Completeness: Mean distance from GT points to predicted surface
|
| 103 |
+
- Overall: Average of accuracy and completeness
|
| 104 |
+
- Precision: Fraction of predicted points within threshold of GT
|
| 105 |
+
- Recall: Fraction of GT points within threshold of prediction
|
| 106 |
+
- F-score: Harmonic mean of precision and recall
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
pcd_pred: Predicted point cloud (Open3D or numpy array)
|
| 110 |
+
pcd_trgt: Ground truth point cloud (Open3D or numpy array)
|
| 111 |
+
threshold: Distance threshold for precision/recall (meters)
|
| 112 |
+
down_sample: Voxel size for downsampling (None to skip)
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dict with metrics: acc, comp, overall, precision, recall, fscore
|
| 116 |
+
"""
|
| 117 |
+
# Convert to Open3D if needed
|
| 118 |
+
if isinstance(pcd_pred, np.ndarray):
|
| 119 |
+
pcd_pred_o3d = o3d.geometry.PointCloud()
|
| 120 |
+
pcd_pred_o3d.points = o3d.utility.Vector3dVector(pcd_pred)
|
| 121 |
+
pcd_pred = pcd_pred_o3d
|
| 122 |
+
if isinstance(pcd_trgt, np.ndarray):
|
| 123 |
+
pcd_trgt_o3d = o3d.geometry.PointCloud()
|
| 124 |
+
pcd_trgt_o3d.points = o3d.utility.Vector3dVector(pcd_trgt)
|
| 125 |
+
pcd_trgt = pcd_trgt_o3d
|
| 126 |
+
|
| 127 |
+
# Downsample if requested
|
| 128 |
+
if down_sample is not None and down_sample > 0:
|
| 129 |
+
pcd_pred = pcd_pred.voxel_down_sample(down_sample)
|
| 130 |
+
pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)
|
| 131 |
+
|
| 132 |
+
verts_pred = np.asarray(pcd_pred.points)
|
| 133 |
+
verts_trgt = np.asarray(pcd_trgt.points)
|
| 134 |
+
|
| 135 |
+
# Handle empty point clouds
|
| 136 |
+
if len(verts_pred) == 0 or len(verts_trgt) == 0:
|
| 137 |
+
return {
|
| 138 |
+
"acc": float("inf"),
|
| 139 |
+
"comp": float("inf"),
|
| 140 |
+
"overall": float("inf"),
|
| 141 |
+
"precision": 0.0,
|
| 142 |
+
"recall": 0.0,
|
| 143 |
+
"fscore": 0.0,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Compute distances
|
| 147 |
+
dist_pred_to_gt = nn_correspondance(verts_trgt, verts_pred) # Accuracy
|
| 148 |
+
dist_gt_to_pred = nn_correspondance(verts_pred, verts_trgt) # Completeness
|
| 149 |
+
|
| 150 |
+
# Compute metrics
|
| 151 |
+
accuracy = float(np.mean(dist_pred_to_gt))
|
| 152 |
+
completeness = float(np.mean(dist_gt_to_pred))
|
| 153 |
+
overall = (accuracy + completeness) / 2
|
| 154 |
+
|
| 155 |
+
precision = float(np.mean((dist_pred_to_gt < threshold).astype(float)))
|
| 156 |
+
recall = float(np.mean((dist_gt_to_pred < threshold).astype(float)))
|
| 157 |
+
|
| 158 |
+
if precision + recall > 0:
|
| 159 |
+
fscore = 2 * precision * recall / (precision + recall)
|
| 160 |
+
else:
|
| 161 |
+
fscore = 0.0
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"acc": accuracy,
|
| 165 |
+
"comp": completeness,
|
| 166 |
+
"overall": overall,
|
| 167 |
+
"precision": precision,
|
| 168 |
+
"recall": recall,
|
| 169 |
+
"fscore": fscore,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_tsdf_volume(
|
| 174 |
+
voxel_length: float = 4.0 / 512.0,
|
| 175 |
+
sdf_trunc: float = 0.04,
|
| 176 |
+
color_type: str = "RGB8",
|
| 177 |
+
) -> o3d.pipelines.integration.ScalableTSDFVolume:
|
| 178 |
+
"""
|
| 179 |
+
Create a scalable TSDF volume for depth fusion.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
voxel_length: Size of each voxel
|
| 183 |
+
sdf_trunc: Truncation distance for SDF
|
| 184 |
+
color_type: Color integration type ("RGB8" or "Gray32")
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Initialized ScalableTSDFVolume
|
| 188 |
+
"""
|
| 189 |
+
if color_type == "RGB8":
|
| 190 |
+
color_enum = o3d.pipelines.integration.TSDFVolumeColorType.RGB8
|
| 191 |
+
else:
|
| 192 |
+
color_enum = o3d.pipelines.integration.TSDFVolumeColorType.Gray32
|
| 193 |
+
|
| 194 |
+
volume = o3d.pipelines.integration.ScalableTSDFVolume(
|
| 195 |
+
voxel_length=voxel_length,
|
| 196 |
+
sdf_trunc=sdf_trunc,
|
| 197 |
+
color_type=color_enum,
|
| 198 |
+
)
|
| 199 |
+
return volume
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def fuse_depth_to_tsdf(
|
| 203 |
+
volume: o3d.pipelines.integration.ScalableTSDFVolume,
|
| 204 |
+
depths: np.ndarray,
|
| 205 |
+
images: np.ndarray,
|
| 206 |
+
intrinsics: np.ndarray,
|
| 207 |
+
extrinsics: np.ndarray,
|
| 208 |
+
max_depth: float = 10.0,
|
| 209 |
+
) -> o3d.geometry.TriangleMesh:
|
| 210 |
+
"""
|
| 211 |
+
Fuse multiple depth maps into TSDF volume and extract mesh.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
volume: TSDF volume to integrate into
|
| 215 |
+
depths: Depth maps [N, H, W]
|
| 216 |
+
images: RGB images [N, H, W, 3]
|
| 217 |
+
intrinsics: Camera intrinsics [N, 3, 3]
|
| 218 |
+
extrinsics: Camera extrinsics (world-to-camera) [N, 4, 4]
|
| 219 |
+
max_depth: Maximum depth for truncation
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Extracted triangle mesh
|
| 223 |
+
"""
|
| 224 |
+
for i in range(len(depths)):
|
| 225 |
+
depth = depths[i]
|
| 226 |
+
image = images[i]
|
| 227 |
+
ixt = intrinsics[i]
|
| 228 |
+
ext = extrinsics[i]
|
| 229 |
+
|
| 230 |
+
h, w = depth.shape[:2]
|
| 231 |
+
|
| 232 |
+
# Create RGBD image
|
| 233 |
+
depth_o3d = o3d.geometry.Image(depth.astype(np.float32))
|
| 234 |
+
color_o3d = o3d.geometry.Image(image.astype(np.uint8))
|
| 235 |
+
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
| 236 |
+
color_o3d,
|
| 237 |
+
depth_o3d,
|
| 238 |
+
depth_trunc=max_depth,
|
| 239 |
+
convert_rgb_to_intensity=False,
|
| 240 |
+
depth_scale=1.0,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Create camera intrinsics
|
| 244 |
+
ixt_o3d = o3d.camera.PinholeCameraIntrinsic(
|
| 245 |
+
w, h, ixt[0, 0], ixt[1, 1], ixt[0, 2], ixt[1, 2]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Integrate into volume
|
| 249 |
+
volume.integrate(rgbd, ixt_o3d, ext)
|
| 250 |
+
|
| 251 |
+
# Extract mesh
|
| 252 |
+
mesh = volume.extract_triangle_mesh()
|
| 253 |
+
return mesh
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def sample_points_from_mesh(
|
| 257 |
+
mesh: o3d.geometry.TriangleMesh,
|
| 258 |
+
num_points: int = 1000000,
|
| 259 |
+
) -> o3d.geometry.PointCloud:
|
| 260 |
+
"""
|
| 261 |
+
Uniformly sample points from a triangle mesh.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
mesh: Input triangle mesh
|
| 265 |
+
num_points: Number of points to sample
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Sampled point cloud
|
| 269 |
+
"""
|
| 270 |
+
try:
|
| 271 |
+
pcd = mesh.sample_points_uniformly(number_of_points=num_points)
|
| 272 |
+
# Clamp colors to valid range [0, 1] for Open3D PLY export
|
| 273 |
+
if pcd.has_colors():
|
| 274 |
+
colors = np.asarray(pcd.colors)
|
| 275 |
+
colors = np.clip(colors, 0.0, 1.0)
|
| 276 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 277 |
+
except Exception:
|
| 278 |
+
# Fallback: create random points if mesh is invalid (with fixed seed for reproducibility)
|
| 279 |
+
rng = np.random.default_rng(seed=42)
|
| 280 |
+
points = rng.uniform(-1, 1, size=(num_points, 3))
|
| 281 |
+
pcd = o3d.geometry.PointCloud()
|
| 282 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 283 |
+
return pcd
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# =============================================================================
|
| 287 |
+
# Pose Evaluation
|
| 288 |
+
# =============================================================================
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def build_pair_index(N: int, B: int = 1):
|
| 292 |
+
"""
|
| 293 |
+
Build indices for all possible pairs of frames.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
N: Number of frames
|
| 297 |
+
B: Batch size
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
i1, i2: Indices for all possible pairs
|
| 301 |
+
"""
|
| 302 |
+
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
|
| 303 |
+
i1, i2 = ((i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_])
|
| 304 |
+
return i1, i2
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def compute_pose(pred_se3: torch.Tensor, gt_se3: torch.Tensor) -> Dict:
|
| 308 |
+
"""
|
| 309 |
+
Compute pose estimation metrics between predicted and ground truth trajectories.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
pred_se3: Predicted SE(3) transformations [N, 4, 4]
|
| 313 |
+
gt_se3: Ground truth SE(3) transformations [N, 4, 4]
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Dict with AUC metrics at different thresholds (auc30, auc15, auc05, auc03)
|
| 317 |
+
"""
|
| 318 |
+
pred_se3 = align_to_first_camera(pred_se3)
|
| 319 |
+
gt_se3 = align_to_first_camera(gt_se3)
|
| 320 |
+
|
| 321 |
+
rel_rangle_deg, rel_tangle_deg = se3_to_relative_pose_error(pred_se3, gt_se3, len(pred_se3))
|
| 322 |
+
rError = rel_rangle_deg.cpu().numpy()
|
| 323 |
+
tError = rel_tangle_deg.cpu().numpy()
|
| 324 |
+
|
| 325 |
+
output = Dict()
|
| 326 |
+
output.auc30, _ = calculate_auc_np(rError, tError, max_threshold=30)
|
| 327 |
+
output.auc15, _ = calculate_auc_np(rError, tError, max_threshold=15)
|
| 328 |
+
output.auc05, _ = calculate_auc_np(rError, tError, max_threshold=5)
|
| 329 |
+
output.auc03, _ = calculate_auc_np(rError, tError, max_threshold=3)
|
| 330 |
+
return output
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def align_to_first_camera(camera_poses: torch.Tensor) -> torch.Tensor:
|
| 334 |
+
"""
|
| 335 |
+
Align all camera poses to the first camera's coordinate frame.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
camera_poses: Camera poses as SE3 transformations [N, 4, 4]
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Aligned camera poses [N, 4, 4]
|
| 342 |
+
"""
|
| 343 |
+
first_cam_extrinsic_inv = closed_form_inverse_se3(camera_poses[0][None])
|
| 344 |
+
aligned_poses = torch.matmul(camera_poses, first_cam_extrinsic_inv)
|
| 345 |
+
return aligned_poses
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def rotation_angle(
|
| 349 |
+
rot_gt: torch.Tensor, rot_pred: torch.Tensor, batch_size: int = None, eps: float = 1e-15
|
| 350 |
+
) -> torch.Tensor:
|
| 351 |
+
"""
|
| 352 |
+
Calculate rotation angle error between ground truth and predicted rotations.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
rot_gt: Ground truth rotation matrices
|
| 356 |
+
rot_pred: Predicted rotation matrices
|
| 357 |
+
batch_size: Batch size for reshaping the result
|
| 358 |
+
eps: Small value to avoid numerical issues
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Rotation angle error in degrees
|
| 362 |
+
"""
|
| 363 |
+
q_pred = mat_to_quat(rot_pred)
|
| 364 |
+
q_gt = mat_to_quat(rot_gt)
|
| 365 |
+
|
| 366 |
+
loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
|
| 367 |
+
err_q = torch.arccos(1 - 2 * loss_q)
|
| 368 |
+
|
| 369 |
+
rel_rangle_deg = err_q * 180 / np.pi
|
| 370 |
+
|
| 371 |
+
if batch_size is not None:
|
| 372 |
+
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
|
| 373 |
+
|
| 374 |
+
return rel_rangle_deg
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def translation_angle(
|
| 378 |
+
tvec_gt: torch.Tensor,
|
| 379 |
+
tvec_pred: torch.Tensor,
|
| 380 |
+
batch_size: int = None,
|
| 381 |
+
ambiguity: bool = True,
|
| 382 |
+
) -> torch.Tensor:
|
| 383 |
+
"""
|
| 384 |
+
Calculate translation angle error between ground truth and predicted translations.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
tvec_gt: Ground truth translation vectors
|
| 388 |
+
tvec_pred: Predicted translation vectors
|
| 389 |
+
batch_size: Batch size for reshaping the result
|
| 390 |
+
ambiguity: Whether to handle direction ambiguity
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
Translation angle error in degrees
|
| 394 |
+
"""
|
| 395 |
+
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
|
| 396 |
+
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
|
| 397 |
+
|
| 398 |
+
if ambiguity:
|
| 399 |
+
rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
|
| 400 |
+
|
| 401 |
+
if batch_size is not None:
|
| 402 |
+
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
|
| 403 |
+
|
| 404 |
+
return rel_tangle_deg
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def compare_translation_by_angle(
|
| 408 |
+
t_gt: torch.Tensor, t: torch.Tensor, eps: float = 1e-15, default_err: float = 1e6
|
| 409 |
+
) -> torch.Tensor:
|
| 410 |
+
"""
|
| 411 |
+
Normalize the translation vectors and compute the angle between them.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
t_gt: Ground truth translation vectors
|
| 415 |
+
t: Predicted translation vectors
|
| 416 |
+
eps: Small value to avoid division by zero
|
| 417 |
+
default_err: Default error value for invalid cases
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
Angular error between translation vectors in radians
|
| 421 |
+
"""
|
| 422 |
+
t_norm = torch.norm(t, dim=1, keepdim=True)
|
| 423 |
+
t = t / (t_norm + eps)
|
| 424 |
+
|
| 425 |
+
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
|
| 426 |
+
t_gt = t_gt / (t_gt_norm + eps)
|
| 427 |
+
|
| 428 |
+
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
|
| 429 |
+
err_t = torch.acos(torch.sqrt(1 - loss_t))
|
| 430 |
+
|
| 431 |
+
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
|
| 432 |
+
return err_t
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def calculate_auc_np(
|
| 436 |
+
r_error: np.ndarray, t_error: np.ndarray, max_threshold: int = 30
|
| 437 |
+
) -> tuple:
|
| 438 |
+
"""
|
| 439 |
+
Calculate the Area Under the Curve (AUC) for the given error arrays.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
r_error: Rotation error values in degrees
|
| 443 |
+
t_error: Translation error values in degrees
|
| 444 |
+
max_threshold: Maximum threshold value for binning
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Tuple of (AUC value, normalized histogram)
|
| 448 |
+
"""
|
| 449 |
+
error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
|
| 450 |
+
max_errors = np.max(error_matrix, axis=1)
|
| 451 |
+
bins = np.arange(max_threshold + 1)
|
| 452 |
+
histogram, _ = np.histogram(max_errors, bins=bins)
|
| 453 |
+
num_pairs = float(len(max_errors))
|
| 454 |
+
normalized_histogram = histogram.astype(float) / num_pairs
|
| 455 |
+
return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def se3_to_relative_pose_error(
|
| 459 |
+
pred_se3: torch.Tensor, gt_se3: torch.Tensor, num_frames: int
|
| 460 |
+
) -> tuple:
|
| 461 |
+
"""
|
| 462 |
+
Compute rotation and translation errors between predicted and ground truth poses.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
pred_se3: Predicted SE(3) transformations
|
| 466 |
+
gt_se3: Ground truth SE(3) transformations
|
| 467 |
+
num_frames: Number of frames
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
Tuple of (rotation angle errors, translation angle errors) in degrees
|
| 471 |
+
"""
|
| 472 |
+
pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
|
| 473 |
+
|
| 474 |
+
# Compute relative camera poses between pairs using closed-form inverse
|
| 475 |
+
relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm(gt_se3[pair_idx_i2])
|
| 476 |
+
relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm(pred_se3[pair_idx_i2])
|
| 477 |
+
|
| 478 |
+
# Compute the difference in rotation and translation
|
| 479 |
+
rel_rangle_deg = rotation_angle(relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3])
|
| 480 |
+
rel_tangle_deg = translation_angle(relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3])
|
| 481 |
+
|
| 482 |
+
return rel_rangle_deg, rel_tangle_deg
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def closed_form_inverse_se3(
|
| 486 |
+
se3: torch.Tensor, R: torch.Tensor = None, T: torch.Tensor = None
|
| 487 |
+
) -> torch.Tensor:
|
| 488 |
+
"""
|
| 489 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
| 490 |
+
|
| 491 |
+
Uses closed-form solution instead of torch.inverse() for numerical stability.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
se3: Nx4x4 or Nx3x4 tensor of SE3 matrices
|
| 495 |
+
R: Optional Nx3x3 rotation matrices
|
| 496 |
+
T: Optional Nx3x1 translation vectors
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
Inverted SE3 matrices with same shape as input
|
| 500 |
+
"""
|
| 501 |
+
is_numpy = isinstance(se3, np.ndarray)
|
| 502 |
+
|
| 503 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
| 504 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
| 505 |
+
|
| 506 |
+
if R is None:
|
| 507 |
+
R = se3[:, :3, :3]
|
| 508 |
+
if T is None:
|
| 509 |
+
T = se3[:, :3, 3:]
|
| 510 |
+
|
| 511 |
+
if is_numpy:
|
| 512 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
| 513 |
+
top_right = -np.matmul(R_transposed, T)
|
| 514 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
| 515 |
+
else:
|
| 516 |
+
R_transposed = R.transpose(1, 2)
|
| 517 |
+
top_right = -torch.bmm(R_transposed, T)
|
| 518 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
| 519 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
| 520 |
+
|
| 521 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
| 522 |
+
inverted_matrix[:, :3, 3:] = top_right
|
| 523 |
+
|
| 524 |
+
return inverted_matrix
|
| 525 |
+
|
Depth-Anything-3/src/depth_anything_3/cfg.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Configuration utility functions
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import importlib
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any, Callable, List, Union
|
| 22 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
OmegaConf.register_new_resolver("eval", eval)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
# if eval is not available, we can just pass
|
| 28 |
+
print(f"Error registering eval resolver: {e}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
|
| 32 |
+
"""
|
| 33 |
+
Load a configuration. Will resolve inheritance.
|
| 34 |
+
Supports both file paths and module paths (e.g., depth_anything_3.configs.giant).
|
| 35 |
+
"""
|
| 36 |
+
# Check if path is a module path (contains dots but no slashes and doesn't end with .yaml)
|
| 37 |
+
if "." in path and "/" not in path and not path.endswith(".yaml"):
|
| 38 |
+
# It's a module path, load from package resources
|
| 39 |
+
path_parts = path.split(".")[1:]
|
| 40 |
+
config_path = Path(__file__).resolve().parent
|
| 41 |
+
for part in path_parts:
|
| 42 |
+
config_path = config_path.joinpath(part)
|
| 43 |
+
config_path = config_path.with_suffix(".yaml")
|
| 44 |
+
config = OmegaConf.load(str(config_path))
|
| 45 |
+
else:
|
| 46 |
+
# It's a file path (absolute, relative, or with .yaml extension)
|
| 47 |
+
config = OmegaConf.load(path)
|
| 48 |
+
|
| 49 |
+
if argv is not None:
|
| 50 |
+
config_argv = OmegaConf.from_dotlist(argv)
|
| 51 |
+
config = OmegaConf.merge(config, config_argv)
|
| 52 |
+
config = resolve_recursive(config, resolve_inheritance)
|
| 53 |
+
return config
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def resolve_recursive(
|
| 57 |
+
config: Any,
|
| 58 |
+
resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
|
| 59 |
+
) -> Any:
|
| 60 |
+
config = resolver(config)
|
| 61 |
+
if isinstance(config, DictConfig):
|
| 62 |
+
for k in config.keys():
|
| 63 |
+
v = config.get(k)
|
| 64 |
+
if isinstance(v, (DictConfig, ListConfig)):
|
| 65 |
+
config[k] = resolve_recursive(v, resolver)
|
| 66 |
+
if isinstance(config, ListConfig):
|
| 67 |
+
for i in range(len(config)):
|
| 68 |
+
v = config.get(i)
|
| 69 |
+
if isinstance(v, (DictConfig, ListConfig)):
|
| 70 |
+
config[i] = resolve_recursive(v, resolver)
|
| 71 |
+
return config
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
|
| 75 |
+
"""
|
| 76 |
+
Recursively resolve inheritance if the config contains:
|
| 77 |
+
__inherit__: path/to/parent.yaml or a ListConfig of such paths.
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(config, DictConfig):
|
| 80 |
+
inherit = config.pop("__inherit__", None)
|
| 81 |
+
|
| 82 |
+
if inherit:
|
| 83 |
+
inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
|
| 84 |
+
|
| 85 |
+
parent_config = None
|
| 86 |
+
for parent_path in inherit_list:
|
| 87 |
+
assert isinstance(parent_path, str)
|
| 88 |
+
parent_config = (
|
| 89 |
+
load_config(parent_path)
|
| 90 |
+
if parent_config is None
|
| 91 |
+
else OmegaConf.merge(parent_config, load_config(parent_path))
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if len(config.keys()) > 0:
|
| 95 |
+
config = OmegaConf.merge(parent_config, config)
|
| 96 |
+
else:
|
| 97 |
+
config = parent_config
|
| 98 |
+
return config
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def import_item(path: str, name: str) -> Any:
|
| 102 |
+
"""
|
| 103 |
+
Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
|
| 104 |
+
"""
|
| 105 |
+
return getattr(importlib.import_module(path), name)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def create_object(config: DictConfig) -> Any:
|
| 109 |
+
"""
|
| 110 |
+
Create an object from config.
|
| 111 |
+
The config is expected to contains the following:
|
| 112 |
+
__object__:
|
| 113 |
+
path: path.to.module
|
| 114 |
+
name: MyClass
|
| 115 |
+
args: as_config | as_params (default to as_config)
|
| 116 |
+
"""
|
| 117 |
+
config = DictConfig(config)
|
| 118 |
+
item = import_item(
|
| 119 |
+
path=config.__object__.path,
|
| 120 |
+
name=config.__object__.name,
|
| 121 |
+
)
|
| 122 |
+
args = config.__object__.get("args", "as_config")
|
| 123 |
+
if args == "as_config":
|
| 124 |
+
return item(config)
|
| 125 |
+
if args == "as_params":
|
| 126 |
+
config = OmegaConf.to_object(config)
|
| 127 |
+
config.pop("__object__")
|
| 128 |
+
return item(**config)
|
| 129 |
+
raise NotImplementedError(f"Unknown args type: {args}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_dataset(path: str, *args, **kwargs) -> Any:
|
| 133 |
+
"""
|
| 134 |
+
Create a dataset. Requires the file to contain a "create_dataset" function.
|
| 135 |
+
"""
|
| 136 |
+
return import_item(path, "create_dataset")(*args, **kwargs)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def to_dict_recursive(config_obj):
|
| 140 |
+
if isinstance(config_obj, DictConfig):
|
| 141 |
+
return {k: to_dict_recursive(v) for k, v in config_obj.items()}
|
| 142 |
+
elif isinstance(config_obj, ListConfig):
|
| 143 |
+
return [to_dict_recursive(item) for item in config_obj]
|
| 144 |
+
return config_obj
|
Depth-Anything-3/src/depth_anything_3/cli.py
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: E402
|
| 2 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Refactored Depth Anything 3 CLI
|
| 17 |
+
Clean, modular command-line interface
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import typer
|
| 24 |
+
|
| 25 |
+
from depth_anything_3.services import start_server
|
| 26 |
+
from depth_anything_3.services.gallery import gallery as gallery_main
|
| 27 |
+
from depth_anything_3.services.inference_service import run_inference
|
| 28 |
+
from depth_anything_3.services.input_handlers import (
|
| 29 |
+
ColmapHandler,
|
| 30 |
+
ImageHandler,
|
| 31 |
+
ImagesHandler,
|
| 32 |
+
InputHandler,
|
| 33 |
+
VideoHandler,
|
| 34 |
+
parse_export_feat,
|
| 35 |
+
)
|
| 36 |
+
from depth_anything_3.utils.constants import (
|
| 37 |
+
DEFAULT_EXPORT_DIR,
|
| 38 |
+
DEFAULT_GALLERY_DIR,
|
| 39 |
+
DEFAULT_GRADIO_DIR,
|
| 40 |
+
DEFAULT_MODEL,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 44 |
+
|
| 45 |
+
app = typer.Typer(help="Depth Anything 3 - Video depth estimation CLI", add_completion=False)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ============================================================================
|
| 49 |
+
# Input type detection utilities
|
| 50 |
+
# ============================================================================
|
| 51 |
+
|
| 52 |
+
# Supported file extensions
|
| 53 |
+
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif"}
|
| 54 |
+
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def detect_input_type(input_path: str) -> str:
|
| 58 |
+
"""
|
| 59 |
+
Detect input type from path.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
- "image": Single image file
|
| 63 |
+
- "images": Directory containing images
|
| 64 |
+
- "video": Video file
|
| 65 |
+
- "colmap": COLMAP directory structure
|
| 66 |
+
- "unknown": Cannot determine type
|
| 67 |
+
"""
|
| 68 |
+
if not os.path.exists(input_path):
|
| 69 |
+
return "unknown"
|
| 70 |
+
|
| 71 |
+
# Check if it's a file
|
| 72 |
+
if os.path.isfile(input_path):
|
| 73 |
+
ext = os.path.splitext(input_path)[1].lower()
|
| 74 |
+
if ext in IMAGE_EXTENSIONS:
|
| 75 |
+
return "image"
|
| 76 |
+
elif ext in VIDEO_EXTENSIONS:
|
| 77 |
+
return "video"
|
| 78 |
+
return "unknown"
|
| 79 |
+
|
| 80 |
+
# Check if it's a directory
|
| 81 |
+
if os.path.isdir(input_path):
|
| 82 |
+
# Check for COLMAP structure
|
| 83 |
+
images_dir = os.path.join(input_path, "images")
|
| 84 |
+
sparse_dir = os.path.join(input_path, "sparse")
|
| 85 |
+
|
| 86 |
+
if os.path.isdir(images_dir) and os.path.isdir(sparse_dir):
|
| 87 |
+
return "colmap"
|
| 88 |
+
|
| 89 |
+
# Check if directory contains image files
|
| 90 |
+
for item in os.listdir(input_path):
|
| 91 |
+
item_path = os.path.join(input_path, item)
|
| 92 |
+
if os.path.isfile(item_path):
|
| 93 |
+
ext = os.path.splitext(item)[1].lower()
|
| 94 |
+
if ext in IMAGE_EXTENSIONS:
|
| 95 |
+
return "images"
|
| 96 |
+
|
| 97 |
+
return "unknown"
|
| 98 |
+
|
| 99 |
+
return "unknown"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ============================================================================
|
| 103 |
+
# Common parameters and configuration
|
| 104 |
+
# ============================================================================
|
| 105 |
+
|
| 106 |
+
# ============================================================================
|
| 107 |
+
# Inference commands
|
| 108 |
+
# ============================================================================
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@app.command()
|
| 112 |
+
def auto(
|
| 113 |
+
input_path: str = typer.Argument(
|
| 114 |
+
..., help="Path to input (image, directory, video, or COLMAP)"
|
| 115 |
+
),
|
| 116 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 117 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 118 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 119 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 120 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 121 |
+
backend_url: str = typer.Option(
|
| 122 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 123 |
+
),
|
| 124 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 125 |
+
process_res_method: str = typer.Option(
|
| 126 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 127 |
+
),
|
| 128 |
+
export_feat: str = typer.Option(
|
| 129 |
+
"",
|
| 130 |
+
help="[FEAT_VIS]Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 131 |
+
),
|
| 132 |
+
auto_cleanup: bool = typer.Option(
|
| 133 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 134 |
+
),
|
| 135 |
+
# Video-specific options
|
| 136 |
+
fps: float = typer.Option(1.0, help="[Video] Sampling FPS for frame extraction"),
|
| 137 |
+
# COLMAP-specific options
|
| 138 |
+
sparse_subdir: str = typer.Option(
|
| 139 |
+
"", help="[COLMAP] Sparse reconstruction subdirectory (e.g., '0' for sparse/0/)"
|
| 140 |
+
),
|
| 141 |
+
align_to_input_ext_scale: bool = typer.Option(
|
| 142 |
+
True, help="[COLMAP] Align prediction to input extrinsics scale"
|
| 143 |
+
),
|
| 144 |
+
# Pose estimation options
|
| 145 |
+
use_ray_pose: bool = typer.Option(
|
| 146 |
+
False, help="Use ray-based pose estimation instead of camera decoder"
|
| 147 |
+
),
|
| 148 |
+
ref_view_strategy: str = typer.Option(
|
| 149 |
+
"saddle_balanced",
|
| 150 |
+
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
|
| 151 |
+
),
|
| 152 |
+
# GLB export options
|
| 153 |
+
conf_thresh_percentile: float = typer.Option(
|
| 154 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 155 |
+
),
|
| 156 |
+
num_max_points: int = typer.Option(
|
| 157 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 158 |
+
),
|
| 159 |
+
show_cameras: bool = typer.Option(
|
| 160 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 161 |
+
),
|
| 162 |
+
# Feat_vis export options
|
| 163 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 164 |
+
):
|
| 165 |
+
"""
|
| 166 |
+
Automatically detect input type and run appropriate processing.
|
| 167 |
+
|
| 168 |
+
Supports:
|
| 169 |
+
- Single image file (.jpg, .png, etc.)
|
| 170 |
+
- Directory of images
|
| 171 |
+
- Video file (.mp4, .avi, etc.)
|
| 172 |
+
- COLMAP directory (with 'images' and 'sparse' subdirectories)
|
| 173 |
+
"""
|
| 174 |
+
# Detect input type
|
| 175 |
+
input_type = detect_input_type(input_path)
|
| 176 |
+
|
| 177 |
+
if input_type == "unknown":
|
| 178 |
+
typer.echo(f"❌ Error: Cannot determine input type for: {input_path}", err=True)
|
| 179 |
+
typer.echo("Supported inputs:", err=True)
|
| 180 |
+
typer.echo(" - Single image file (.jpg, .png, etc.)", err=True)
|
| 181 |
+
typer.echo(" - Directory containing images", err=True)
|
| 182 |
+
typer.echo(" - Video file (.mp4, .avi, etc.)", err=True)
|
| 183 |
+
typer.echo(" - COLMAP directory (with 'images/' and 'sparse/' subdirectories)", err=True)
|
| 184 |
+
raise typer.Exit(1)
|
| 185 |
+
|
| 186 |
+
# Display detected type
|
| 187 |
+
typer.echo(f"🔍 Detected input type: {input_type.upper()}")
|
| 188 |
+
typer.echo(f"📁 Input path: {input_path}")
|
| 189 |
+
typer.echo()
|
| 190 |
+
|
| 191 |
+
# Determine backend URL based on use_backend flag
|
| 192 |
+
final_backend_url = backend_url if use_backend else None
|
| 193 |
+
|
| 194 |
+
# Parse export_feat parameter
|
| 195 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 196 |
+
|
| 197 |
+
# Route to appropriate handler
|
| 198 |
+
if input_type == "image":
|
| 199 |
+
typer.echo("Processing single image...")
|
| 200 |
+
# Process input
|
| 201 |
+
image_files = ImageHandler.process(input_path)
|
| 202 |
+
|
| 203 |
+
# Handle export directory
|
| 204 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 205 |
+
|
| 206 |
+
# Run inference
|
| 207 |
+
run_inference(
|
| 208 |
+
image_paths=image_files,
|
| 209 |
+
export_dir=export_dir,
|
| 210 |
+
model_dir=model_dir,
|
| 211 |
+
device=device,
|
| 212 |
+
backend_url=final_backend_url,
|
| 213 |
+
export_format=export_format,
|
| 214 |
+
process_res=process_res,
|
| 215 |
+
process_res_method=process_res_method,
|
| 216 |
+
export_feat_layers=export_feat_layers,
|
| 217 |
+
use_ray_pose=use_ray_pose,
|
| 218 |
+
ref_view_strategy=ref_view_strategy,
|
| 219 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 220 |
+
num_max_points=num_max_points,
|
| 221 |
+
show_cameras=show_cameras,
|
| 222 |
+
feat_vis_fps=feat_vis_fps,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
elif input_type == "images":
|
| 226 |
+
typer.echo("Processing directory of images...")
|
| 227 |
+
# Process input - use default extensions
|
| 228 |
+
image_files = ImagesHandler.process(input_path, "png,jpg,jpeg")
|
| 229 |
+
|
| 230 |
+
# Handle export directory
|
| 231 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 232 |
+
|
| 233 |
+
# Run inference
|
| 234 |
+
run_inference(
|
| 235 |
+
image_paths=image_files,
|
| 236 |
+
export_dir=export_dir,
|
| 237 |
+
model_dir=model_dir,
|
| 238 |
+
device=device,
|
| 239 |
+
backend_url=final_backend_url,
|
| 240 |
+
export_format=export_format,
|
| 241 |
+
process_res=process_res,
|
| 242 |
+
process_res_method=process_res_method,
|
| 243 |
+
export_feat_layers=export_feat_layers,
|
| 244 |
+
use_ray_pose=use_ray_pose,
|
| 245 |
+
ref_view_strategy=ref_view_strategy,
|
| 246 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 247 |
+
num_max_points=num_max_points,
|
| 248 |
+
show_cameras=show_cameras,
|
| 249 |
+
feat_vis_fps=feat_vis_fps,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
elif input_type == "video":
|
| 253 |
+
typer.echo(f"Processing video with FPS={fps}...")
|
| 254 |
+
# Handle export directory
|
| 255 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 256 |
+
|
| 257 |
+
# Process input
|
| 258 |
+
image_files = VideoHandler.process(input_path, export_dir, fps)
|
| 259 |
+
|
| 260 |
+
# Run inference
|
| 261 |
+
run_inference(
|
| 262 |
+
image_paths=image_files,
|
| 263 |
+
export_dir=export_dir,
|
| 264 |
+
model_dir=model_dir,
|
| 265 |
+
device=device,
|
| 266 |
+
backend_url=final_backend_url,
|
| 267 |
+
export_format=export_format,
|
| 268 |
+
process_res=process_res,
|
| 269 |
+
process_res_method=process_res_method,
|
| 270 |
+
export_feat_layers=export_feat_layers,
|
| 271 |
+
use_ray_pose=use_ray_pose,
|
| 272 |
+
ref_view_strategy=ref_view_strategy,
|
| 273 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 274 |
+
num_max_points=num_max_points,
|
| 275 |
+
show_cameras=show_cameras,
|
| 276 |
+
feat_vis_fps=feat_vis_fps,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
elif input_type == "colmap":
|
| 280 |
+
typer.echo(
|
| 281 |
+
f"Processing COLMAP directory (sparse subdirectory: '{sparse_subdir or 'default'}')..."
|
| 282 |
+
)
|
| 283 |
+
# Process input
|
| 284 |
+
image_files, extrinsics, intrinsics = ColmapHandler.process(input_path, sparse_subdir)
|
| 285 |
+
|
| 286 |
+
# Handle export directory
|
| 287 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 288 |
+
|
| 289 |
+
# Run inference
|
| 290 |
+
run_inference(
|
| 291 |
+
image_paths=image_files,
|
| 292 |
+
export_dir=export_dir,
|
| 293 |
+
model_dir=model_dir,
|
| 294 |
+
device=device,
|
| 295 |
+
backend_url=final_backend_url,
|
| 296 |
+
export_format=export_format,
|
| 297 |
+
process_res=process_res,
|
| 298 |
+
process_res_method=process_res_method,
|
| 299 |
+
export_feat_layers=export_feat_layers,
|
| 300 |
+
extrinsics=extrinsics,
|
| 301 |
+
intrinsics=intrinsics,
|
| 302 |
+
align_to_input_ext_scale=align_to_input_ext_scale,
|
| 303 |
+
use_ray_pose=use_ray_pose,
|
| 304 |
+
ref_view_strategy=ref_view_strategy,
|
| 305 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 306 |
+
num_max_points=num_max_points,
|
| 307 |
+
show_cameras=show_cameras,
|
| 308 |
+
feat_vis_fps=feat_vis_fps,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
typer.echo()
|
| 312 |
+
typer.echo("✅ Processing completed successfully!")
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@app.command()
|
| 316 |
+
def image(
|
| 317 |
+
image_path: str = typer.Argument(..., help="Path to input image file"),
|
| 318 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 319 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 320 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 321 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 322 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 323 |
+
backend_url: str = typer.Option(
|
| 324 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 325 |
+
),
|
| 326 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 327 |
+
process_res_method: str = typer.Option(
|
| 328 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 329 |
+
),
|
| 330 |
+
export_feat: str = typer.Option(
|
| 331 |
+
"",
|
| 332 |
+
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 333 |
+
),
|
| 334 |
+
auto_cleanup: bool = typer.Option(
|
| 335 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 336 |
+
),
|
| 337 |
+
# Pose estimation options
|
| 338 |
+
use_ray_pose: bool = typer.Option(
|
| 339 |
+
False, help="Use ray-based pose estimation instead of camera decoder"
|
| 340 |
+
),
|
| 341 |
+
ref_view_strategy: str = typer.Option(
|
| 342 |
+
"saddle_balanced",
|
| 343 |
+
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
|
| 344 |
+
),
|
| 345 |
+
# GLB export options
|
| 346 |
+
conf_thresh_percentile: float = typer.Option(
|
| 347 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 348 |
+
),
|
| 349 |
+
num_max_points: int = typer.Option(
|
| 350 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 351 |
+
),
|
| 352 |
+
show_cameras: bool = typer.Option(
|
| 353 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 354 |
+
),
|
| 355 |
+
# Feat_vis export options
|
| 356 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 357 |
+
):
|
| 358 |
+
"""Run camera pose and depth estimation on a single image."""
|
| 359 |
+
# Process input
|
| 360 |
+
image_files = ImageHandler.process(image_path)
|
| 361 |
+
|
| 362 |
+
# Handle export directory
|
| 363 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 364 |
+
|
| 365 |
+
# Parse export_feat parameter
|
| 366 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 367 |
+
|
| 368 |
+
# Determine backend URL based on use_backend flag
|
| 369 |
+
final_backend_url = backend_url if use_backend else None
|
| 370 |
+
|
| 371 |
+
# Run inference
|
| 372 |
+
run_inference(
|
| 373 |
+
image_paths=image_files,
|
| 374 |
+
export_dir=export_dir,
|
| 375 |
+
model_dir=model_dir,
|
| 376 |
+
device=device,
|
| 377 |
+
backend_url=final_backend_url,
|
| 378 |
+
export_format=export_format,
|
| 379 |
+
process_res=process_res,
|
| 380 |
+
process_res_method=process_res_method,
|
| 381 |
+
export_feat_layers=export_feat_layers,
|
| 382 |
+
use_ray_pose=use_ray_pose,
|
| 383 |
+
reference_view_strategy=reference_view_strategy,
|
| 384 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 385 |
+
num_max_points=num_max_points,
|
| 386 |
+
show_cameras=show_cameras,
|
| 387 |
+
feat_vis_fps=feat_vis_fps,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@app.command()
|
| 392 |
+
def images(
|
| 393 |
+
images_dir: str = typer.Argument(..., help="Path to directory containing input images"),
|
| 394 |
+
image_extensions: str = typer.Option(
|
| 395 |
+
"png,jpg,jpeg", help="Comma-separated image file extensions to process"
|
| 396 |
+
),
|
| 397 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 398 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 399 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 400 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 401 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 402 |
+
backend_url: str = typer.Option(
|
| 403 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 404 |
+
),
|
| 405 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 406 |
+
process_res_method: str = typer.Option(
|
| 407 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 408 |
+
),
|
| 409 |
+
export_feat: str = typer.Option(
|
| 410 |
+
"",
|
| 411 |
+
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 412 |
+
),
|
| 413 |
+
auto_cleanup: bool = typer.Option(
|
| 414 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 415 |
+
),
|
| 416 |
+
# Pose estimation options
|
| 417 |
+
use_ray_pose: bool = typer.Option(
|
| 418 |
+
False, help="Use ray-based pose estimation instead of camera decoder"
|
| 419 |
+
),
|
| 420 |
+
ref_view_strategy: str = typer.Option(
|
| 421 |
+
"saddle_balanced",
|
| 422 |
+
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
|
| 423 |
+
),
|
| 424 |
+
# GLB export options
|
| 425 |
+
conf_thresh_percentile: float = typer.Option(
|
| 426 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 427 |
+
),
|
| 428 |
+
num_max_points: int = typer.Option(
|
| 429 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 430 |
+
),
|
| 431 |
+
show_cameras: bool = typer.Option(
|
| 432 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 433 |
+
),
|
| 434 |
+
# Feat_vis export options
|
| 435 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 436 |
+
):
|
| 437 |
+
"""Run camera pose and depth estimation on a directory of images."""
|
| 438 |
+
# Process input
|
| 439 |
+
image_files = ImagesHandler.process(images_dir, image_extensions)
|
| 440 |
+
|
| 441 |
+
# Handle export directory
|
| 442 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 443 |
+
|
| 444 |
+
# Parse export_feat parameter
|
| 445 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 446 |
+
|
| 447 |
+
# Determine backend URL based on use_backend flag
|
| 448 |
+
final_backend_url = backend_url if use_backend else None
|
| 449 |
+
|
| 450 |
+
# Run inference
|
| 451 |
+
run_inference(
|
| 452 |
+
image_paths=image_files,
|
| 453 |
+
export_dir=export_dir,
|
| 454 |
+
model_dir=model_dir,
|
| 455 |
+
device=device,
|
| 456 |
+
backend_url=final_backend_url,
|
| 457 |
+
export_format=export_format,
|
| 458 |
+
process_res=process_res,
|
| 459 |
+
process_res_method=process_res_method,
|
| 460 |
+
export_feat_layers=export_feat_layers,
|
| 461 |
+
use_ray_pose=use_ray_pose,
|
| 462 |
+
reference_view_strategy=reference_view_strategy,
|
| 463 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 464 |
+
num_max_points=num_max_points,
|
| 465 |
+
show_cameras=show_cameras,
|
| 466 |
+
feat_vis_fps=feat_vis_fps,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@app.command()
|
| 471 |
+
def colmap(
|
| 472 |
+
colmap_dir: str = typer.Argument(
|
| 473 |
+
..., help="Path to COLMAP directory containing 'images' and 'sparse' subdirectories"
|
| 474 |
+
),
|
| 475 |
+
sparse_subdir: str = typer.Option(
|
| 476 |
+
"", help="Sparse reconstruction subdirectory (e.g., '0' for sparse/0/, empty for sparse/)"
|
| 477 |
+
),
|
| 478 |
+
align_to_input_ext_scale: bool = typer.Option(
|
| 479 |
+
True, help="Align prediction to input extrinsics scale"
|
| 480 |
+
),
|
| 481 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 482 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 483 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 484 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 485 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 486 |
+
backend_url: str = typer.Option(
|
| 487 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 488 |
+
),
|
| 489 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 490 |
+
process_res_method: str = typer.Option(
|
| 491 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 492 |
+
),
|
| 493 |
+
export_feat: str = typer.Option(
|
| 494 |
+
"",
|
| 495 |
+
help="Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 496 |
+
),
|
| 497 |
+
auto_cleanup: bool = typer.Option(
|
| 498 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 499 |
+
),
|
| 500 |
+
# Pose estimation options
|
| 501 |
+
use_ray_pose: bool = typer.Option(
|
| 502 |
+
False, help="Use ray-based pose estimation instead of camera decoder"
|
| 503 |
+
),
|
| 504 |
+
ref_view_strategy: str = typer.Option(
|
| 505 |
+
"saddle_balanced",
|
| 506 |
+
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
|
| 507 |
+
),
|
| 508 |
+
# GLB export options
|
| 509 |
+
conf_thresh_percentile: float = typer.Option(
|
| 510 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 511 |
+
),
|
| 512 |
+
num_max_points: int = typer.Option(
|
| 513 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 514 |
+
),
|
| 515 |
+
show_cameras: bool = typer.Option(
|
| 516 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 517 |
+
),
|
| 518 |
+
# Feat_vis export options
|
| 519 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 520 |
+
):
|
| 521 |
+
"""Run pose conditioned depth estimation on COLMAP data."""
|
| 522 |
+
# Process input
|
| 523 |
+
image_files, extrinsics, intrinsics = ColmapHandler.process(colmap_dir, sparse_subdir)
|
| 524 |
+
|
| 525 |
+
# Handle export directory
|
| 526 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 527 |
+
|
| 528 |
+
# Parse export_feat parameter
|
| 529 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 530 |
+
|
| 531 |
+
# Determine backend URL based on use_backend flag
|
| 532 |
+
final_backend_url = backend_url if use_backend else None
|
| 533 |
+
|
| 534 |
+
# Run inference
|
| 535 |
+
run_inference(
|
| 536 |
+
image_paths=image_files,
|
| 537 |
+
export_dir=export_dir,
|
| 538 |
+
model_dir=model_dir,
|
| 539 |
+
device=device,
|
| 540 |
+
backend_url=final_backend_url,
|
| 541 |
+
export_format=export_format,
|
| 542 |
+
process_res=process_res,
|
| 543 |
+
process_res_method=process_res_method,
|
| 544 |
+
export_feat_layers=export_feat_layers,
|
| 545 |
+
extrinsics=extrinsics,
|
| 546 |
+
intrinsics=intrinsics,
|
| 547 |
+
align_to_input_ext_scale=align_to_input_ext_scale,
|
| 548 |
+
use_ray_pose=use_ray_pose,
|
| 549 |
+
reference_view_strategy=reference_view_strategy,
|
| 550 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 551 |
+
num_max_points=num_max_points,
|
| 552 |
+
show_cameras=show_cameras,
|
| 553 |
+
feat_vis_fps=feat_vis_fps,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
@app.command()
|
| 558 |
+
def video(
|
| 559 |
+
video_path: str = typer.Argument(..., help="Path to input video file"),
|
| 560 |
+
fps: float = typer.Option(1.0, help="Sampling FPS for frame extraction"),
|
| 561 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 562 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 563 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 564 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 565 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 566 |
+
backend_url: str = typer.Option(
|
| 567 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 568 |
+
),
|
| 569 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 570 |
+
process_res_method: str = typer.Option(
|
| 571 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 572 |
+
),
|
| 573 |
+
export_feat: str = typer.Option(
|
| 574 |
+
"",
|
| 575 |
+
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 576 |
+
),
|
| 577 |
+
auto_cleanup: bool = typer.Option(
|
| 578 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 579 |
+
),
|
| 580 |
+
# Pose estimation options
|
| 581 |
+
use_ray_pose: bool = typer.Option(
|
| 582 |
+
False, help="Use ray-based pose estimation instead of camera decoder"
|
| 583 |
+
),
|
| 584 |
+
ref_view_strategy: str = typer.Option(
|
| 585 |
+
"saddle_balanced",
|
| 586 |
+
help="Reference view selection strategy: empty, first, middle, saddle_balanced, saddle_sim_range",
|
| 587 |
+
),
|
| 588 |
+
# GLB export options
|
| 589 |
+
conf_thresh_percentile: float = typer.Option(
|
| 590 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 591 |
+
),
|
| 592 |
+
num_max_points: int = typer.Option(
|
| 593 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 594 |
+
),
|
| 595 |
+
show_cameras: bool = typer.Option(
|
| 596 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 597 |
+
),
|
| 598 |
+
# Feat_vis export options
|
| 599 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 600 |
+
):
|
| 601 |
+
"""Run depth estimation on video by extracting frames and processing them."""
|
| 602 |
+
# Handle export directory
|
| 603 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 604 |
+
|
| 605 |
+
# Process input
|
| 606 |
+
image_files = VideoHandler.process(video_path, export_dir, fps)
|
| 607 |
+
|
| 608 |
+
# Parse export_feat parameter
|
| 609 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 610 |
+
|
| 611 |
+
# Determine backend URL based on use_backend flag
|
| 612 |
+
final_backend_url = backend_url if use_backend else None
|
| 613 |
+
|
| 614 |
+
# Run inference
|
| 615 |
+
run_inference(
|
| 616 |
+
image_paths=image_files,
|
| 617 |
+
export_dir=export_dir,
|
| 618 |
+
model_dir=model_dir,
|
| 619 |
+
device=device,
|
| 620 |
+
backend_url=final_backend_url,
|
| 621 |
+
export_format=export_format,
|
| 622 |
+
process_res=process_res,
|
| 623 |
+
process_res_method=process_res_method,
|
| 624 |
+
export_feat_layers=export_feat_layers,
|
| 625 |
+
use_ray_pose=use_ray_pose,
|
| 626 |
+
reference_view_strategy=reference_view_strategy,
|
| 627 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 628 |
+
num_max_points=num_max_points,
|
| 629 |
+
show_cameras=show_cameras,
|
| 630 |
+
feat_vis_fps=feat_vis_fps,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ============================================================================
|
| 635 |
+
# Service management commands
|
| 636 |
+
# ============================================================================
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
@app.command()
|
| 640 |
+
def backend(
|
| 641 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 642 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 643 |
+
host: str = typer.Option("127.0.0.1", help="Host to bind to"),
|
| 644 |
+
port: int = typer.Option(8008, help="Port to bind to"),
|
| 645 |
+
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path (optional)"),
|
| 646 |
+
):
|
| 647 |
+
"""Start model backend service with integrated gallery."""
|
| 648 |
+
typer.echo("=" * 60)
|
| 649 |
+
typer.echo("🚀 Starting Depth Anything 3 Backend Server")
|
| 650 |
+
typer.echo("=" * 60)
|
| 651 |
+
typer.echo(f"Model directory: {model_dir}")
|
| 652 |
+
typer.echo(f"Device: {device}")
|
| 653 |
+
|
| 654 |
+
# Check if gallery directory exists
|
| 655 |
+
if gallery_dir and os.path.exists(gallery_dir):
|
| 656 |
+
typer.echo(f"Gallery directory: {gallery_dir}")
|
| 657 |
+
else:
|
| 658 |
+
gallery_dir = None # Disable gallery if directory doesn't exist
|
| 659 |
+
|
| 660 |
+
typer.echo()
|
| 661 |
+
typer.echo("📡 Server URLs (Ctrl/CMD+Click to open):")
|
| 662 |
+
typer.echo(f" 🏠 Home: http://{host}:{port}")
|
| 663 |
+
typer.echo(f" 📊 Dashboard: http://{host}:{port}/dashboard")
|
| 664 |
+
typer.echo(f" 📈 API Status: http://{host}:{port}/status")
|
| 665 |
+
|
| 666 |
+
if gallery_dir:
|
| 667 |
+
typer.echo(f" 🎨 Gallery: http://{host}:{port}/gallery/")
|
| 668 |
+
|
| 669 |
+
typer.echo("=" * 60)
|
| 670 |
+
|
| 671 |
+
try:
|
| 672 |
+
start_server(model_dir, device, host, port, gallery_dir)
|
| 673 |
+
except KeyboardInterrupt:
|
| 674 |
+
typer.echo("\n👋 Backend server stopped.")
|
| 675 |
+
except Exception as e:
|
| 676 |
+
typer.echo(f"❌ Failed to start backend: {e}")
|
| 677 |
+
raise typer.Exit(1)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
# ============================================================================
|
| 681 |
+
# Application launch commands
|
| 682 |
+
# ============================================================================
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
@app.command()
|
| 686 |
+
def gradio(
|
| 687 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 688 |
+
workspace_dir: str = typer.Option(DEFAULT_GRADIO_DIR, help="Workspace directory path"),
|
| 689 |
+
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path"),
|
| 690 |
+
host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
|
| 691 |
+
port: int = typer.Option(7860, help="Port number to bind to"),
|
| 692 |
+
share: bool = typer.Option(False, help="Create a public link for the app"),
|
| 693 |
+
debug: bool = typer.Option(False, help="Enable debug mode"),
|
| 694 |
+
cache_examples: bool = typer.Option(
|
| 695 |
+
False, help="Pre-cache all example scenes at startup for faster loading"
|
| 696 |
+
),
|
| 697 |
+
cache_gs_tag: str = typer.Option(
|
| 698 |
+
"",
|
| 699 |
+
help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.",
|
| 700 |
+
),
|
| 701 |
+
):
|
| 702 |
+
"""Launch Depth Anything 3 Gradio interactive web application"""
|
| 703 |
+
from depth_anything_3.app.gradio_app import DepthAnything3App
|
| 704 |
+
|
| 705 |
+
# Create necessary directories
|
| 706 |
+
os.makedirs(workspace_dir, exist_ok=True)
|
| 707 |
+
os.makedirs(gallery_dir, exist_ok=True)
|
| 708 |
+
|
| 709 |
+
typer.echo("Launching Depth Anything 3 Gradio application...")
|
| 710 |
+
typer.echo(f"Model directory: {model_dir}")
|
| 711 |
+
typer.echo(f"Workspace directory: {workspace_dir}")
|
| 712 |
+
typer.echo(f"Gallery directory: {gallery_dir}")
|
| 713 |
+
typer.echo(f"Host: {host}")
|
| 714 |
+
typer.echo(f"Port: {port}")
|
| 715 |
+
typer.echo(f"Share: {share}")
|
| 716 |
+
typer.echo(f"Debug mode: {debug}")
|
| 717 |
+
typer.echo(f"Cache examples: {cache_examples}")
|
| 718 |
+
if cache_examples:
|
| 719 |
+
if cache_gs_tag:
|
| 720 |
+
typer.echo(
|
| 721 |
+
f"Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)"
|
| 722 |
+
)
|
| 723 |
+
else:
|
| 724 |
+
typer.echo(f"Cache GS Tag: None (all scenes will use low-res only)")
|
| 725 |
+
|
| 726 |
+
try:
|
| 727 |
+
# Initialize and launch application
|
| 728 |
+
app = DepthAnything3App(
|
| 729 |
+
model_dir=model_dir, workspace_dir=workspace_dir, gallery_dir=gallery_dir
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Pre-cache examples if requested
|
| 733 |
+
if cache_examples:
|
| 734 |
+
typer.echo("\n" + "=" * 60)
|
| 735 |
+
typer.echo("Pre-caching mode enabled")
|
| 736 |
+
if cache_gs_tag:
|
| 737 |
+
typer.echo(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
|
| 738 |
+
typer.echo(f"Other scenes will use LOW-RES only")
|
| 739 |
+
else:
|
| 740 |
+
typer.echo(f"All scenes will use LOW-RES only")
|
| 741 |
+
typer.echo("=" * 60)
|
| 742 |
+
app.cache_examples(
|
| 743 |
+
show_cam=True,
|
| 744 |
+
filter_black_bg=False,
|
| 745 |
+
filter_white_bg=False,
|
| 746 |
+
save_percentage=20.0,
|
| 747 |
+
num_max_points=1000,
|
| 748 |
+
cache_gs_tag=cache_gs_tag,
|
| 749 |
+
gs_trj_mode="smooth",
|
| 750 |
+
gs_video_quality="low",
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Prepare launch arguments
|
| 754 |
+
launch_kwargs = {"share": share, "debug": debug}
|
| 755 |
+
|
| 756 |
+
app.launch(host=host, port=port, **launch_kwargs)
|
| 757 |
+
|
| 758 |
+
except KeyboardInterrupt:
|
| 759 |
+
typer.echo("\nGradio application stopped.")
|
| 760 |
+
except Exception as e:
|
| 761 |
+
typer.echo(f"Failed to launch Gradio application: {e}")
|
| 762 |
+
raise typer.Exit(1)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
@app.command()
|
| 766 |
+
def gallery(
|
| 767 |
+
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery root directory"),
|
| 768 |
+
host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
|
| 769 |
+
port: int = typer.Option(8007, help="Port number to bind to"),
|
| 770 |
+
open_browser: bool = typer.Option(False, help="Open browser after launch"),
|
| 771 |
+
):
|
| 772 |
+
"""Launch Depth Anything 3 Gallery server"""
|
| 773 |
+
|
| 774 |
+
# Validate gallery directory
|
| 775 |
+
if not os.path.exists(gallery_dir):
|
| 776 |
+
raise typer.BadParameter(f"Gallery directory not found: {gallery_dir}")
|
| 777 |
+
|
| 778 |
+
typer.echo("Launching Depth Anything 3 Gallery server...")
|
| 779 |
+
typer.echo(f"Gallery directory: {gallery_dir}")
|
| 780 |
+
typer.echo(f"Host: {host}")
|
| 781 |
+
typer.echo(f"Port: {port}")
|
| 782 |
+
typer.echo(f"Auto-open browser: {open_browser}")
|
| 783 |
+
|
| 784 |
+
try:
|
| 785 |
+
# Set command line arguments
|
| 786 |
+
import sys
|
| 787 |
+
|
| 788 |
+
sys.argv = ["gallery", "--dir", gallery_dir, "--host", host, "--port", str(port)]
|
| 789 |
+
if open_browser:
|
| 790 |
+
sys.argv.append("--open")
|
| 791 |
+
|
| 792 |
+
# Launch gallery server
|
| 793 |
+
gallery_main()
|
| 794 |
+
|
| 795 |
+
except KeyboardInterrupt:
|
| 796 |
+
typer.echo("\nGallery server stopped.")
|
| 797 |
+
except Exception as e:
|
| 798 |
+
typer.echo(f"Failed to launch Gallery server: {e}")
|
| 799 |
+
raise typer.Exit(1)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
if __name__ == "__main__":
|
| 803 |
+
app()
|
Depth-Anything-3/src/depth_anything_3/configs/da3-base.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitb
|
| 13 |
+
out_layers: [5, 7, 9, 11]
|
| 14 |
+
alt_start: 4
|
| 15 |
+
qknorm_start: 4
|
| 16 |
+
rope_start: 4
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 1536
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 128
|
| 28 |
+
out_channels: &head_out_channels [96, 192, 384, 768]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 768
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 1536
|
Depth-Anything-3/src/depth_anything_3/configs/da3-giant.yaml
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitg
|
| 13 |
+
out_layers: [19, 27, 33, 39]
|
| 14 |
+
alt_start: 13
|
| 15 |
+
qknorm_start: 13
|
| 16 |
+
rope_start: 13
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 3072
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 256
|
| 28 |
+
out_channels: &head_out_channels [256, 512, 1024, 1024]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 1536
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 3072
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
gs_head:
|
| 49 |
+
__object__:
|
| 50 |
+
path: depth_anything_3.model.gsdpt
|
| 51 |
+
name: GSDPT
|
| 52 |
+
args: as_params
|
| 53 |
+
|
| 54 |
+
dim_in: *head_dim_in
|
| 55 |
+
output_dim: 38 # should align with gs_adapter's setting, for gs params
|
| 56 |
+
features: *head_features
|
| 57 |
+
out_channels: *head_out_channels
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
gs_adapter:
|
| 61 |
+
__object__:
|
| 62 |
+
path: depth_anything_3.model.gs_adapter
|
| 63 |
+
name: GaussianAdapter
|
| 64 |
+
args: as_params
|
| 65 |
+
|
| 66 |
+
sh_degree: 2
|
| 67 |
+
pred_color: false # predict SH coefficient if false
|
| 68 |
+
pred_offset_depth: true
|
| 69 |
+
pred_offset_xy: true
|
| 70 |
+
gaussian_scale_min: 1e-5
|
| 71 |
+
gaussian_scale_max: 30.0
|
Depth-Anything-3/src/depth_anything_3/configs/da3-large.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitl
|
| 13 |
+
out_layers: [11, 15, 19, 23]
|
| 14 |
+
alt_start: 8
|
| 15 |
+
qknorm_start: 8
|
| 16 |
+
rope_start: 8
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 2048
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 256
|
| 28 |
+
out_channels: &head_out_channels [256, 512, 1024, 1024]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 1024
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 2048
|
Depth-Anything-3/src/depth_anything_3/configs/da3-small.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vits
|
| 13 |
+
out_layers: [5, 7, 9, 11]
|
| 14 |
+
alt_start: 4
|
| 15 |
+
qknorm_start: 4
|
| 16 |
+
rope_start: 4
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 768
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 64
|
| 28 |
+
out_channels: &head_out_channels [48, 96, 192, 384]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 384
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 768
|
Depth-Anything-3/src/depth_anything_3/configs/da3metric-large.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitl
|
| 13 |
+
out_layers: [4, 11, 17, 23]
|
| 14 |
+
alt_start: -1 # -1 means disable
|
| 15 |
+
qknorm_start: -1
|
| 16 |
+
rope_start: -1
|
| 17 |
+
cat_token: False
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dpt
|
| 22 |
+
name: DPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: 1024
|
| 26 |
+
output_dim: 1
|
| 27 |
+
features: 256
|
| 28 |
+
out_channels: [256, 512, 1024, 1024]
|
Depth-Anything-3/src/depth_anything_3/configs/da3mono-large.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitl
|
| 13 |
+
out_layers: [4, 11, 17, 23]
|
| 14 |
+
alt_start: -1 # -1 means disable
|
| 15 |
+
qknorm_start: -1
|
| 16 |
+
rope_start: -1
|
| 17 |
+
cat_token: False
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dpt
|
| 22 |
+
name: DPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: 1024
|
| 26 |
+
output_dim: 1
|
| 27 |
+
features: 256
|
| 28 |
+
out_channels: [256, 512, 1024, 1024]
|
Depth-Anything-3/src/depth_anything_3/configs/da3nested-giant-large.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: NestedDepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
anyview:
|
| 7 |
+
__inherit__: depth_anything_3.configs.da3-giant
|
| 8 |
+
|
| 9 |
+
metric:
|
| 10 |
+
__inherit__: depth_anything_3.configs.da3metric-large
|
Depth-Anything-3/src/depth_anything_3/model/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net
|
| 16 |
+
|
| 17 |
+
__export__ = [
|
| 18 |
+
NestedDepthAnything3Net,
|
| 19 |
+
DepthAnything3Net,
|
| 20 |
+
]
|
Depth-Anything-3/src/depth_anything_3/model/cam_dec.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CameraDec(nn.Module):
|
| 20 |
+
def __init__(self, dim_in=1536):
|
| 21 |
+
super().__init__()
|
| 22 |
+
output_dim = dim_in
|
| 23 |
+
self.backbone = nn.Sequential(
|
| 24 |
+
nn.Linear(output_dim, output_dim),
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
nn.Linear(output_dim, output_dim),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
)
|
| 29 |
+
self.fc_t = nn.Linear(output_dim, 3)
|
| 30 |
+
self.fc_qvec = nn.Linear(output_dim, 4)
|
| 31 |
+
self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU())
|
| 32 |
+
|
| 33 |
+
def forward(self, feat, camera_encoding=None, *args, **kwargs):
|
| 34 |
+
B, N = feat.shape[:2]
|
| 35 |
+
feat = feat.reshape(B * N, -1)
|
| 36 |
+
feat = self.backbone(feat)
|
| 37 |
+
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
|
| 38 |
+
if camera_encoding is None:
|
| 39 |
+
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
|
| 40 |
+
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
|
| 41 |
+
else:
|
| 42 |
+
out_qvec = camera_encoding[..., 3:7]
|
| 43 |
+
out_fov = camera_encoding[..., -2:]
|
| 44 |
+
pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1)
|
| 45 |
+
return pose_enc
|