Diffusers
Safetensors
zeyuren2002's picture
Add files using upload-large-folder tool
ea3c0ad verified
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long)
import bisect
import glob
import os
import numpy as np
import trimesh
from loop_utils.alignment_torch import robust_weighted_estimate_sim3_torch
from loop_utils.alignment_triton import robust_weighted_estimate_sim3_triton
from numba import njit
from sklearn.linear_model import LinearRegression, RANSACRegressor
def accumulate_sim3_transforms(transforms):
"""
Accumulate adjacent SIM(3) transforms into transforms
from the initial frame to each subsequent frame.
Args:
transforms: list, each element is a tuple (R, s, t)
R: 3x3 rotation matrix (np.array)
s: scale factor (scalar)
t: 3x1 translation vector (np.array)
Returns:
Cumulative transforms list, each element is (R_cum, s_cum, t_cum)
representing the transform from frame 0 to frame k
"""
if not transforms:
return []
cumulative_transforms = [transforms[0]]
for i in range(1, len(transforms)):
s_cum_prev, R_cum_prev, t_cum_prev = cumulative_transforms[i - 1]
s_next, R_next, t_next = transforms[i]
R_cum_new = R_cum_prev @ R_next
s_cum_new = s_cum_prev * s_next
t_cum_new = s_cum_prev * (R_cum_prev @ t_next) + t_cum_prev
cumulative_transforms.append((s_cum_new, R_cum_new, t_cum_new))
return cumulative_transforms
def estimate_sim3(source_points, target_points):
mu_src = np.mean(source_points, axis=0)
mu_tgt = np.mean(target_points, axis=0)
src_centered = source_points - mu_src
tgt_centered = target_points - mu_tgt
scale_src = np.sqrt((src_centered**2).sum(axis=1).mean())
scale_tgt = np.sqrt((tgt_centered**2).sum(axis=1).mean())
s = scale_tgt / scale_src
src_scaled = src_centered * s
H = src_scaled.T @ tgt_centered
U, _, Vt = np.linalg.svd(H)
R = Vt.T @ U.T
if np.linalg.det(R) < 0:
Vt[2, :] *= -1
R = Vt.T @ U.T
t = mu_tgt - s * R @ mu_src
return s, R, t
def align_point_maps(point_map1, conf1, point_map2, conf2, conf_threshold):
"""point_map2 -> point_map1"""
b1, _, _, _ = point_map1.shape
b2, _, _, _ = point_map2.shape
b = min(b1, b2)
aligned_points1 = []
aligned_points2 = []
for i in range(b):
mask1 = conf1[i] > conf_threshold
mask2 = conf2[i] > conf_threshold
valid_mask = mask1 & mask2
idx = np.where(valid_mask)
if len(idx[0]) == 0:
continue
pts1 = point_map1[i][idx]
pts2 = point_map2[i][idx]
aligned_points1.append(pts1)
aligned_points2.append(pts2)
if len(aligned_points1) == 0:
raise ValueError("No matching point pairs were found!")
all_pts1 = np.concatenate(aligned_points1, axis=0)
all_pts2 = np.concatenate(aligned_points2, axis=0)
print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
s, R, t = estimate_sim3(all_pts2, all_pts1)
mean_error = compute_alignment_error(
point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
)
print(f"Mean error: {mean_error}")
return s, R, t
def apply_sim3(points, s, R, t):
return (s * (R @ points.T)).T + t
def apply_sim3_direct(point_maps, s, R, t):
# point_maps: (b, h, w, 3) -> (b, h, w, 3, 1)
point_maps_expanded = point_maps[..., np.newaxis] # (b, h, w, 3, 1)
# R: (3, 3) -> (b, h, w, 3, 1) = (3, 3) @ (3, 1)
rotated = np.matmul(R, point_maps_expanded) # (b, h, w, 3, 1)
rotated = rotated.squeeze(-1) # (b, h, w, 3)
transformed = s * rotated + t # (b, h, w, 3)
return transformed
def compute_alignment_error(point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t):
"""
Compute the average point alignment error (using only original inputs)
Args:
point_map1: target point map (b, h, w, 3)
conf1: target confidence map (b, h, w)
point_map2: source point map (b, h, w, 3)
conf2: source confidence map (b, h, w)
conf_threshold: confidence threshold
s, R, t: transformation parameters
"""
b1, h1, w1, _ = point_map1.shape
b2, h2, w2, _ = point_map2.shape
b = min(b1, b2)
h = min(h1, h2)
w = min(w1, w2)
target_points = []
source_points = []
for i in range(b):
mask1 = conf1[i, :h, :w] > conf_threshold
mask2 = conf2[i, :h, :w] > conf_threshold
valid_mask = mask1 & mask2
idx = np.where(valid_mask)
if len(idx[0]) == 0:
continue
t_pts = point_map1[i, :h, :w][idx]
s_pts = point_map2[i, :h, :w][idx]
target_points.append(t_pts)
source_points.append(s_pts)
if len(target_points) == 0:
print("Warning: No matching point pairs found for error calculation")
return np.nan
all_target = np.concatenate(target_points, axis=0)
all_source = np.concatenate(source_points, axis=0)
transformed = (s * (R @ all_source.T)).T + t
errors = np.linalg.norm(transformed - all_target, axis=1)
mean_error = np.mean(errors)
std_error = np.std(errors)
median_error = np.median(errors)
max_error = np.max(errors)
print(
f"Alignment error statistics [using {len(errors)} points]: "
f"mean={mean_error:.4f}, std={std_error:.4f}, "
f"median={median_error:.4f}, max={max_error:.4f}"
)
return mean_error
def save_confident_pointcloud(
points, colors, confs, output_path, conf_threshold, sample_ratio=1.0
):
"""
Filter points based on confidence threshold
and save as PLY file, with optional random sampling ratio.
Args:
- points: np.ndarray, shape (H, W, 3) or (N, 3)
- colors: np.ndarray, shape (H, W, 3) or (N, 3)
- confs: np.ndarray, shape (H, W) or (N,)
- output_path: str, output PLY file path
- conf_threshold: float, confidence threshold for point filtering
- sample_ratio: float, sampling ratio (0 < sample_ratio <= 1.0)
"""
points = points.reshape(-1, 3).astype(np.float32, copy=False)
colors = colors.reshape(-1, 3).astype(np.uint8, copy=False)
confs = confs.reshape(-1).astype(np.float32, copy=False)
conf_mask = (confs >= conf_threshold) & (confs > 1e-5)
points = points[conf_mask]
colors = colors[conf_mask]
if 0 < sample_ratio < 1.0 and len(points) > 0:
num_samples = int(len(points) * sample_ratio)
indices = np.random.choice(len(points), num_samples, replace=False)
points = points[indices]
colors = colors[indices]
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
print(f"shape of sampled point: {points.shape}")
trimesh.PointCloud(points, colors=colors).export(output_path)
print(f"Saved point cloud with {len(points)} points to {output_path}")
def save_confident_pointcloud_batch(
points, colors, confs, output_path, conf_threshold, sample_ratio=1.0, batch_size=1000000
):
"""
- points: np.ndarray, (b, H, W, 3) / (N, 3)
- colors: np.ndarray, (b, H, W, 3) / (N, 3)
- confs: np.ndarray, (b, H, W) / (N,)
- output_path: str
- conf_threshold: float,
- sample_ratio: float (0 < sample_ratio <= 1.0)
- batch_size: int
"""
if points.ndim == 2:
b = 1
points = points[np.newaxis, ...]
colors = colors[np.newaxis, ...]
confs = confs[np.newaxis, ...]
elif points.ndim == 4:
b = points.shape[0]
else:
raise ValueError("Unsupported points dimension. Must be 2 (N,3) or 4 (b,H,W,3)")
total_valid = 0
for i in range(b):
cfs = confs[i].reshape(-1)
total_valid += np.count_nonzero((cfs >= conf_threshold) & (cfs > 1e-5))
num_samples = int(total_valid * sample_ratio) if sample_ratio < 1.0 else total_valid
if num_samples == 0:
save_ply(np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8), output_path)
return
if sample_ratio == 1.0:
with open(output_path, "wb") as f:
write_ply_header(f, num_samples)
for i in range(b):
pts = points[i].reshape(-1, 3).astype(np.float32)
cls = colors[i].reshape(-1, 3).astype(np.uint8)
cfs = confs[i].reshape(-1).astype(np.float32)
mask = (cfs >= conf_threshold) & (cfs > 1e-5)
valid_pts = pts[mask]
valid_cls = cls[mask]
for j in range(0, len(valid_pts), batch_size):
batch_pts = valid_pts[j : j + batch_size]
batch_cls = valid_cls[j : j + batch_size]
write_ply_batch(f, batch_pts, batch_cls)
else:
reservoir_pts = np.zeros((num_samples, 3), dtype=np.float32)
reservoir_clr = np.zeros((num_samples, 3), dtype=np.uint8)
count = 0
for i in range(b):
pts = points[i].reshape(-1, 3).astype(np.float32)
cls = colors[i].reshape(-1, 3).astype(np.uint8)
cfs = confs[i].reshape(-1).astype(np.float32)
mask = (cfs >= conf_threshold) & (cfs > 1e-5)
valid_pts = pts[mask]
valid_cls = cls[mask]
n_valid = len(valid_pts)
if count < num_samples:
fill_count = min(num_samples - count, n_valid)
reservoir_pts[count : count + fill_count] = valid_pts[:fill_count]
reservoir_clr[count : count + fill_count] = valid_cls[:fill_count]
count += fill_count
if fill_count < n_valid:
remaining_pts = valid_pts[fill_count:]
remaining_cls = valid_cls[fill_count:]
count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
remaining_pts, remaining_cls, count, reservoir_pts, reservoir_clr
)
else:
count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling(
valid_pts, valid_cls, count, reservoir_pts, reservoir_clr
)
save_ply(reservoir_pts, reservoir_clr, output_path)
""" The following function is deprecated"""
# def vectorized_reservoir_sampling(new_pts, new_cls, current_count, reservoir_pts, reservoir_clr):
# """
# - new_pts: (M, 3)
# - new_cls: (M, 3)
# - current_count
# - reservoir_pts: (K, 3)
# - reservoir_clr: (K, 3)
# """
# k = len(reservoir_pts)
# n_new = len(new_pts)
# rand_indices = np.random.randint(0, current_count + n_new, size=n_new)
# replace_mask = rand_indices < k
# replace_indices = rand_indices[replace_mask]
# replace_pts = new_pts[replace_mask]
# replace_cls = new_cls[replace_mask]
# reservoir_pts[replace_indices] = replace_pts
# reservoir_clr[replace_indices] = replace_cls
# return current_count + n_new, reservoir_pts, reservoir_clr
"""
Function `vectorized_reservoir_sampling` is not mathematically accurate in sampling.
This leads to inconsistent density in the downsampled point clouds.
The `optimized_vectorized_reservoir_sampling` function has fixed this bug.
Special thanks to @Horace89 for the detailed analysis and code assistance.
See https://github.com/DengKaiCQ/VGGT-Long/issues/28 for details
"""
def optimized_vectorized_reservoir_sampling(
new_points: np.ndarray,
new_colors: np.ndarray,
current_count: int,
reservoir_points: np.ndarray,
reservoir_colors: np.ndarray,
) -> tuple[int, np.ndarray, np.ndarray]:
"""
Optimized vectorized reservoir sampling with batch probability calculations.
This maintains mathematical correctness while improving performance through
vectorized operations where possible.
Args:
new_points: New point coordinates to consider, shape (M, 3)
new_colors: New point colors to consider, shape (M, 3)
current_count: Number of elements seen so far
reservoir_points: Current reservoir of sampled points, shape (K, 3)
reservoir_colors: Current reservoir of sampled colors, shape (K, 3)
Returns:
Tuple of (updated_count, updated_reservoir_points, updated_reservoir_colors)
"""
random_gen = np.random
reservoir_size = len(reservoir_points)
num_new_points = len(new_points)
if num_new_points == 0:
return current_count, reservoir_points, reservoir_colors
# Calculate sequential indices for each new point
point_indices = np.arange(current_count + 1, current_count + num_new_points + 1)
# Generate random numbers for each point
random_values = random_gen.randint(0, point_indices, size=num_new_points)
# Determine which points should replace reservoir elements
replacement_mask = random_values < reservoir_size
replacement_positions = random_values[replacement_mask]
# Apply replacements
if np.any(replacement_mask):
points_to_replace = new_points[replacement_mask]
colors_to_replace = new_colors[replacement_mask]
reservoir_points[replacement_positions] = points_to_replace
reservoir_colors[replacement_positions] = colors_to_replace
return current_count + num_new_points, reservoir_points, reservoir_colors
def write_ply_header(f, num_vertices):
header = [
"ply",
"format binary_little_endian 1.0",
f"element vertex {num_vertices}",
"property float x",
"property float y",
"property float z",
"property uchar red",
"property uchar green",
"property uchar blue",
"end_header",
]
f.write("\n".join(header).encode() + b"\n")
def write_ply_batch(f, points, colors):
structured = np.zeros(
len(points),
dtype=[
("x", np.float32),
("y", np.float32),
("z", np.float32),
("red", np.uint8),
("green", np.uint8),
("blue", np.uint8),
],
)
structured["x"] = points[:, 0]
structured["y"] = points[:, 1]
structured["z"] = points[:, 2]
structured["red"] = colors[:, 0]
structured["green"] = colors[:, 1]
structured["blue"] = colors[:, 2]
f.write(structured.tobytes())
def save_ply(points, colors, filename):
with open(filename, "wb") as f:
write_ply_header(f, len(points))
write_ply_batch(f, points, colors)
def find_chunk_index(chunks, idx):
"""
Find the 0-based chunk index that contains the given index idx.
chunks: List of (begin_idx, end_idx).
idx: The index to search for.
Returns the 0-based chunk index.
"""
starts = [chunk[0] for chunk in chunks]
pos = bisect.bisect_right(starts, idx) - 1 # Find position of idx in starts
if pos < 0 or pos >= len(chunks):
raise ValueError(f"Index {idx} not found in any chunk")
chunk_begin, chunk_end = chunks[pos]
if idx < chunk_begin or idx > chunk_end:
raise ValueError(f"Index {idx} not found in any chunk")
return pos
def get_frame_range(chunk, idx, half_window=10):
"""
Calculate the frame range centered at idx with half_window
frames on each side within chunk boundaries.
If near boundaries, take 2 * half_window frames starting from the boundary.
chunk: (begin_idx, end_idx).
idx: Center index.
half_window: Number of frames to take on each side of center index.
Returns (start, end).
"""
begin, end = chunk
window_size = 2 * half_window
if idx - half_window < begin:
start = begin
end_candidate = begin + window_size
end = min(end, end_candidate)
elif idx + half_window > end:
end_candidate = end
start_candidate = end - window_size
start = max(begin, start_candidate)
else:
start = idx - half_window
end = idx + half_window
return (start, end)
def process_loop_list(chunk_index, loop_list, half_window=10):
"""
Process loop_list and return chunk indices and frame ranges for each (idx1, idx2) pair.
chunk_index: List of (begin_idx, end_idx) tuples.
loop_list: List of (idx1, idx2) tuples.
half_window: Number of frames to take on each side of center index (default 10).
Returns list of (chunk_idx1, range1, chunk_idx2, range2) tuples where:
- chunk_idx1, chunk_idx2: Chunk indices (1-based).
- range1, range2: Frame range tuples (start, end).
"""
results = []
for idx1, idx2 in loop_list:
try:
chunk_idx1_0based = find_chunk_index(chunk_index, idx1)
chunk1 = chunk_index[chunk_idx1_0based]
range1 = get_frame_range(chunk1, idx1, half_window)
chunk_idx2_0based = find_chunk_index(chunk_index, idx2)
chunk2 = chunk_index[chunk_idx2_0based]
range2 = get_frame_range(chunk2, idx2, half_window)
result = (chunk_idx1_0based, range1, chunk_idx2_0based, range2)
results.append(result)
except ValueError as e:
print(f"Skipping pair ({idx1}, {idx2}): {e}")
return results
def compute_sim3_ab(S_a, S_b):
s_a, R_a, T_a = S_a
s_b, R_b, T_b = S_b
s_ab = s_b / s_a
R_ab = R_b @ R_a.T
T_ab = T_b - s_ab * (R_ab @ T_a)
return (s_ab, R_ab, T_ab)
def merge_ply_files(input_dir, output_path):
"""
Merge all PLY files in a directory into one file (without loading into memory)
Args:
- input_dir: Input directory containing multiple '{idx}_pcd.ply' files
- output_path: Output file path (e.g., 'combined.ply')
"""
print("Merging PLY files...")
input_files = sorted(glob.glob(os.path.join(input_dir, "*_pcd.ply")))
if not input_files:
print("No PLY files found")
return
idx_file = 0
len(input_files)
total_vertices = 0
for file in input_files: # Count total vertices
with open(file, "rb") as f:
for line in f:
if line.startswith(b"element vertex"):
vertex_count = int(line.split()[-1])
total_vertices += vertex_count
elif line.startswith(b"end_header"):
break
with open(output_path, "wb") as out_f:
# Write new header
out_f.write(b"ply\n")
out_f.write(b"format binary_little_endian 1.0\n")
out_f.write(f"element vertex {total_vertices}\n".encode())
out_f.write(b"property float x\n")
out_f.write(b"property float y\n")
out_f.write(b"property float z\n")
out_f.write(b"property uchar red\n")
out_f.write(b"property uchar green\n")
out_f.write(b"property uchar blue\n")
out_f.write(b"end_header\n")
for file in input_files:
print(f"Processing {idx_file}/{len(input_files)}: {file}")
idx_file += 1
with open(file, "rb") as in_f:
# Skip the head
in_header = True
while in_header:
line = in_f.readline()
if line.startswith(b"end_header"):
in_header = False
data = in_f.read()
out_f.write(data)
print(f"Merge completed! Total points: {total_vertices}")
print(f"Output file: {output_path}")
def weighted_estimate_se3(source_points, target_points, weights):
"""
source_points: (Nx3)
target_points: (Nx3)
:weights: (N,) [0,1]
"""
total_weight = np.sum(weights)
if total_weight < 1e-6:
raise ValueError("Total weight too small for meaningful estimation")
normalized_weights = weights / total_weight
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
src_centered = source_points - mu_src
tgt_centered = target_points - mu_tgt
weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
H = weighted_src.T @ weighted_tgt
U, _, Vt = np.linalg.svd(H)
R = Vt.T @ U.T
if np.linalg.det(R) < 0:
Vt[2, :] *= -1
R = Vt.T @ U.T
t = mu_tgt - R @ mu_src
return 1.0, R, t
def weighted_estimate_sim3(source_points, target_points, weights):
"""
source_points: (Nx3)
target_points: (Nx3)
:weights: (N,) [0,1]
"""
total_weight = np.sum(weights)
if total_weight < 1e-6:
raise ValueError("Total weight too small for meaningful estimation")
normalized_weights = weights / total_weight
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
src_centered = source_points - mu_src
tgt_centered = target_points - mu_tgt
scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
s = scale_tgt / scale_src
weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
H = weighted_src.T @ weighted_tgt
U, _, Vt = np.linalg.svd(H)
R = Vt.T @ U.T
if np.linalg.det(R) < 0:
Vt[2, :] *= -1
R = Vt.T @ U.T
t = mu_tgt - s * R @ mu_src
return s, R, t
def huber_loss(r, delta):
abs_r = np.abs(r)
return np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
def robust_weighted_estimate_sim3(
src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
):
"""
src: (Nx3)
tgt: (Nx3)
init_weights: (N,)
"""
if align_method == "sim3":
s, R, t = weighted_estimate_sim3(src, tgt, init_weights)
elif align_method == "se3" or align_method == "scale+se3":
s, R, t = weighted_estimate_se3(src, tgt, init_weights)
prev_error = float("inf")
for iter in range(max_iters):
transformed = s * (src @ R.T) + t
residuals = np.linalg.norm(tgt - transformed, axis=1) # (N,)
print(f"Residuals: {np.mean(residuals)}")
abs_res = np.abs(residuals)
huber_weights = np.ones_like(residuals)
large_res_mask = abs_res > delta
huber_weights[large_res_mask] = delta / abs_res[large_res_mask]
combined_weights = init_weights * huber_weights
combined_weights /= np.sum(combined_weights) + 1e-12
if align_method == "se3":
s_new, R_new, t_new = weighted_estimate_se3(src, tgt, combined_weights)
elif align_method == "sim3" or align_method == "scale+se3":
s_new, R_new, t_new = weighted_estimate_sim3(src, tgt, combined_weights)
param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
current_error = np.sum(huber_loss(residuals, delta) * init_weights)
if (param_change < tol and rot_angle < np.radians(0.1)) or (
abs(prev_error - current_error) < tol * prev_error
):
break
s, R, t = s_new, R_new, t_new
prev_error = current_error
return s, R, t
# ===== Speed Up Begin =====
@njit(cache=True)
def _weighted_estimate_se3_numba(source_points, target_points, weights):
# Ensure float32
source_points = source_points.astype(np.float32)
target_points = target_points.astype(np.float32)
weights = weights.astype(np.float32)
total_weight = np.sum(weights)
if total_weight < 1e-6:
return (
1.0,
np.zeros(3, dtype=np.float32),
np.zeros(3, dtype=np.float32),
np.zeros((3, 3), dtype=np.float32),
)
normalized_weights = weights / total_weight
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
src_centered = source_points - mu_src
tgt_centered = target_points - mu_tgt
weighted_src = src_centered * np.sqrt(normalized_weights)[:, None]
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
H = weighted_src.T @ weighted_tgt
return 1.0, mu_src, mu_tgt, H
@njit(cache=True)
def _weighted_estimate_sim3_numba(source_points, target_points, weights):
# Ensure float32
source_points = source_points.astype(np.float32)
target_points = target_points.astype(np.float32)
weights = weights.astype(np.float32)
total_weight = np.sum(weights)
if total_weight < 1e-6:
return (
-1.0,
np.zeros(3, dtype=np.float32),
np.zeros(3, dtype=np.float32),
np.zeros((3, 3), dtype=np.float32),
)
normalized_weights = weights / total_weight
mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0)
mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0)
src_centered = source_points - mu_src
tgt_centered = target_points - mu_tgt
scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1)))
scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1)))
s = scale_tgt / scale_src
weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None]
weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None]
H = weighted_src.T @ weighted_tgt
return s, mu_src, mu_tgt, H
def weighted_estimate_sim3_numba(source_points, target_points, weights, align_method="sim3"):
if align_method == "sim3":
s, mu_src, mu_tgt, H = _weighted_estimate_sim3_numba(source_points, target_points, weights)
elif align_method == "se3" or align_method == "scale+se3":
s, mu_src, mu_tgt, H = _weighted_estimate_se3_numba(source_points, target_points, weights)
if s < 0:
raise ValueError("Total weight too small for meaningful estimation")
# Ensure float32
H = H.astype(np.float32)
U, _, Vt = np.linalg.svd(H.astype(np.float32)) # float32 SVD
R = Vt.T @ U.T
if np.linalg.det(R) < 0:
Vt[2, :] *= -1
R = Vt.T @ U.T
if align_method == "se3" or align_method == "scale+se3":
t = mu_tgt - R @ mu_src
else:
t = mu_tgt - s * R @ mu_src
return s, R, t
@njit(cache=True)
def huber_loss_numba(r, delta):
r = r.astype(np.float32)
delta = np.float32(delta)
abs_r = np.abs(r)
result = np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta))
return result.astype(np.float32)
@njit(cache=True)
def compute_residuals_numba(tgt, transformed):
residuals = np.empty(tgt.shape[0], dtype=np.float32)
for i in range(tgt.shape[0]):
diff = tgt[i] - transformed[i]
residuals[i] = np.sqrt(np.sum(diff**2))
return residuals
@njit(cache=True)
def compute_huber_weights_numba(residuals, delta):
weights = np.ones(residuals.shape, dtype=np.float32)
for i in range(residuals.shape[0]):
r = residuals[i]
if r > delta:
weights[i] = delta / r
return weights
@njit(cache=True)
def apply_transformation_numba(src, s, R, t):
transformed = np.empty_like(src)
for i in range(src.shape[0]):
p = src[i]
transformed[i] = s * (R @ p) + t
return transformed
def robust_weighted_estimate_sim3_numba(
src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3"
):
src = src.astype(np.float32)
tgt = tgt.astype(np.float32)
init_weights = init_weights.astype(np.float32)
s, R, t = weighted_estimate_sim3_numba(src, tgt, init_weights, align_method=align_method)
prev_error = float("inf")
for iter in range(max_iters):
transformed = apply_transformation_numba(src, s, R, t)
residuals = compute_residuals_numba(tgt, transformed)
print(f"Residuals: {np.mean(residuals)}")
huber_weights = compute_huber_weights_numba(residuals, delta)
combined_weights = init_weights * huber_weights
combined_weights /= np.sum(combined_weights) + 1e-12
s_new, R_new, t_new = weighted_estimate_sim3_numba(
src, tgt, combined_weights, align_method=align_method
)
param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t)
rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2)))
current_error = np.sum(huber_loss_numba(residuals, delta) * init_weights)
if (param_change < tol and rot_angle < np.radians(0.1)) or (
abs(prev_error - current_error) < tol * prev_error
):
break
s, R, t = s_new, R_new, t_new
prev_error = current_error
return s, R, t
def warmup_numba():
print("\nWarming up Numba JIT-compiled functions...")
src = np.random.randn(50000, 3).astype(np.float32)
tgt = np.random.randn(50000, 3).astype(np.float32)
weights = np.ones(50000, dtype=np.float32)
residuals = np.abs(np.random.randn(50000).astype(np.float32))
R = np.eye(3, dtype=np.float32)
t = np.zeros(3, dtype=np.float32)
s = np.float32(1.0)
delta = np.float32(1.0)
try:
_ = _weighted_estimate_sim3_numba(src, tgt, weights)
print(" - _weighted_estimate_sim3_numba warmed up.")
except Exception as e:
print(" ! Failed to warm up _weighted_estimate_sim3_numba:", e)
try:
_ = _weighted_estimate_se3_numba(src, tgt, weights)
print(" - _weighted_estimate_se3_numba warmed up.")
except Exception as e:
print(" ! Failed to warm up _weighted_estimate_se3_numba:", e)
try:
_ = huber_loss_numba(residuals, delta)
print(" - huber_loss_numba warmed up.")
except Exception as e:
print(" ! Failed to warm up huber_loss_numba:", e)
try:
_ = compute_huber_weights_numba(residuals, delta)
print(" - compute_huber_weights_numba warmed up.")
except Exception as e:
print(" ! Failed to warm up compute_huber_weights_numba:", e)
try:
_ = compute_residuals_numba(tgt, src)
print(" - compute_residuals_numba warmed up.")
except Exception as e:
print(" ! Failed to warm up compute_residuals_numba:", e)
try:
_ = apply_transformation_numba(src, s, R, t)
print(" - apply_transformation_numba warmed up.")
except Exception as e:
print(" ! Failed to warm up apply_transformation_numba:", e)
print("Numba warm-up complete.\n")
# ===== Speed Up End =====
# ===== Scale precompute begin =====
def compute_scale_ransac(
depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, max_samples=10000
):
"""
Args:
depth1: (n1, h, w)
depth2: (n2, h, w)
conf1: (n1, h, w)
conf2: (n2, h, w)
"""
depth1_flat = depth1.reshape(-1)
depth2_flat = depth2.reshape(-1)
conf1_flat = conf1.reshape(-1)
conf2_flat = conf2.reshape(-1)
conf_threshold = max(
np.median(conf1_flat) * conf_threshold_ratio,
np.median(conf2_flat) * conf_threshold_ratio,
1e-6,
)
valid_mask = (
(conf1_flat > conf_threshold)
& (conf2_flat > conf_threshold)
& (depth1_flat > 1e-3)
& (depth2_flat > 1e-3)
& (depth1_flat < 100)
& (depth2_flat < 100)
)
if np.sum(valid_mask) < 100:
print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
return 1.0, 0.0
valid_depth1 = depth1_flat[valid_mask]
valid_depth2 = depth2_flat[valid_mask]
if len(valid_depth1) > max_samples:
indices = np.random.choice(len(valid_depth1), max_samples, replace=False)
valid_depth1 = valid_depth1[indices]
valid_depth2 = valid_depth2[indices]
X = valid_depth2.reshape(-1, 1)
y = valid_depth1
base_estimator = LinearRegression(fit_intercept=False)
ransac = RANSACRegressor(
estimator=base_estimator,
max_trials=1000,
min_samples=max(10, len(X) // 100),
residual_threshold=0.1,
random_state=42,
)
ransac.fit(X, y)
scale_factor = ransac.estimator_.coef_[0]
inlier_mask = ransac.inlier_mask_
inlier_ratio = np.sum(inlier_mask) / len(inlier_mask)
print(f"RANSAC scale: {scale_factor:.6f}, inlier ratio: {inlier_ratio:.4f}")
if 0.1 < scale_factor < 10.0:
return scale_factor, inlier_ratio
else:
print(f"Warning: Unreasonable scale {scale_factor}, using 1.0")
return 1.0, inlier_ratio
def compute_scale_weighted(
depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, weight_power=2.0, robust_quantile=0.9
):
"""
Args:
depth1: (n1, h, w)
depth2: (n2, h, w)
conf1: (n1, h, w)
conf2: (n2, h, w)
"""
depth1_flat = depth1.reshape(-1)
depth2_flat = depth2.reshape(-1)
conf1_flat = conf1.reshape(-1)
conf2_flat = conf2.reshape(-1)
conf_threshold = max(
np.median(conf1_flat) * conf_threshold_ratio,
np.median(conf2_flat) * conf_threshold_ratio,
1e-6,
)
valid_mask = (
(conf1_flat > conf_threshold)
& (conf2_flat > conf_threshold)
& (depth1_flat > 1e-3)
& (depth2_flat > 1e-3)
& (depth1_flat < 100)
& (depth2_flat < 100)
)
if np.sum(valid_mask) < 100:
print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0")
return 1.0, 0.0
valid_depth1 = depth1_flat[valid_mask]
valid_depth2 = depth2_flat[valid_mask]
valid_conf1 = conf1_flat[valid_mask]
valid_conf2 = conf2_flat[valid_mask]
combined_weights = (valid_conf1 * valid_conf2) ** weight_power
combined_weights = combined_weights / (np.sum(combined_weights) + 1e-8)
ratios = valid_depth1 / (valid_depth2 + 1e-8)
sorted_indices = np.argsort(ratios)
sorted_ratios = ratios[sorted_indices]
sorted_weights = combined_weights[sorted_indices]
cumulative_weights = np.cumsum(sorted_weights)
median_idx = np.searchsorted(cumulative_weights, 0.5)
scale_median = sorted_ratios[median_idx] if median_idx < len(sorted_ratios) else 1.0
quantile_idx = np.searchsorted(cumulative_weights, robust_quantile)
scale_quantile = (
sorted_ratios[quantile_idx] if quantile_idx < len(sorted_ratios) else scale_median
)
weight_entropy = -np.sum(combined_weights * np.log(combined_weights + 1e-8))
max_entropy = np.log(len(combined_weights))
confidence_score = 1.0 - (weight_entropy / max_entropy) if max_entropy > 0 else 0.0
print(f"Weighted scale: {scale_quantile:.6f}, confidence: {confidence_score:.4f}")
if 0.1 < scale_quantile < 10.0:
return scale_quantile, confidence_score
else:
print(f"Warning: Unreasonable scale {scale_quantile}, using 1.0")
return 1.0, confidence_score
def compute_chunk_scale_advanced(depth1, depth2, conf1, conf2, method="auto"):
"""
method: 'auto', 'ransac', 'weighted'
"""
if method == "ransac":
scale, score = compute_scale_ransac(depth1, depth2, conf1, conf2)
return scale, score, "ransac"
elif method == "weighted":
scale, score = compute_scale_weighted(depth1, depth2, conf1, conf2)
return scale, score, "weighted"
elif method == "auto":
scale_ransac, inlier_ratio = compute_scale_ransac(depth1, depth2, conf1, conf2)
scale_weighted, conf_score = compute_scale_weighted(depth1, depth2, conf1, conf2)
ransac_quality = inlier_ratio
weighted_quality = conf_score
print(f"RANSAC quality: {ransac_quality:.4f}, Weighted quality: {weighted_quality:.4f}")
if ransac_quality > 0.7 and weighted_quality > 0.7:
# both method are good, we take both of them by average
final_scale = (scale_ransac + scale_weighted) / 2
final_method = "average"
elif ransac_quality > weighted_quality:
final_scale = scale_ransac
final_method = "ransac"
else:
final_scale = scale_weighted
final_method = "weighted"
final_quality = max(ransac_quality, weighted_quality)
return final_scale, final_quality, final_method
def precompute_scale_chunks_with_depth(
chunk1_depth, chunk1_conf, chunk2_depth, chunk2_conf, method="auto"
):
"""
Args:
chunk1_depth: (n1, h, w)
chunk1_conf: (n1, h, w)
chunk2_depth: (n2, h, w)
chunk2_conf: (n2, h, w)
method: 'auto', 'ransac', 'weighted'
"""
scale_factor, quality_score, method_used = compute_chunk_scale_advanced(
chunk1_depth, chunk2_depth, chunk1_conf, chunk2_conf, method
)
print(f"Final scale: {scale_factor:.6f}, quality: {quality_score:.4f}, method: {method_used}")
return scale_factor, quality_score, method_used
# ===== Scale precompute end =====
def weighted_align_point_maps(
point_map1, conf1, point_map2, conf2, conf_threshold, config, precompute_scale=None
):
"""point_map2 -> point_map1"""
b1, _, _, _ = point_map1.shape
b2, _, _, _ = point_map2.shape
b = min(b1, b2)
if precompute_scale is not None: # meaning we are using align method 'scale+se3'
point_map2 *= precompute_scale
aligned_points1 = []
aligned_points2 = []
confidence_weights = []
for i in range(b):
mask1 = conf1[i] > conf_threshold
mask2 = conf2[i] > conf_threshold
valid_mask = mask1 & mask2
idx = np.where(valid_mask)
if len(idx[0]) == 0:
continue
pts1 = point_map1[i][idx]
pts2 = point_map2[i][idx]
combined_conf = np.sqrt(conf1[i][idx] * conf2[i][idx])
aligned_points1.append(pts1)
aligned_points2.append(pts2)
confidence_weights.append(combined_conf)
if len(aligned_points1) == 0:
raise ValueError("No matching point pairs were found!")
all_pts1 = np.concatenate(aligned_points1, axis=0)
all_pts2 = np.concatenate(aligned_points2, axis=0)
all_weights = np.concatenate(confidence_weights, axis=0)
print(f"The number of corresponding points matched: {all_pts1.shape[0]}")
if config["Model"]["align_lib"] == "numba":
s, R, t = robust_weighted_estimate_sim3_numba(
all_pts2,
all_pts1,
all_weights,
delta=config["Model"]["IRLS"]["delta"],
max_iters=config["Model"]["IRLS"]["max_iters"],
tol=eval(config["Model"]["IRLS"]["tol"]),
align_method=config["Model"]["align_method"],
)
elif config["Model"]["align_lib"] == "numpy": # numpy
s, R, t = robust_weighted_estimate_sim3(
all_pts2,
all_pts1,
all_weights,
delta=config["Model"]["IRLS"]["delta"],
max_iters=config["Model"]["IRLS"]["max_iters"],
tol=eval(config["Model"]["IRLS"]["tol"]),
align_method=config["Model"]["align_method"],
)
elif config["Model"]["align_lib"] == "torch": # torch
s, R, t = robust_weighted_estimate_sim3_torch(
all_pts2,
all_pts1,
all_weights,
delta=config["Model"]["IRLS"]["delta"],
max_iters=config["Model"]["IRLS"]["max_iters"],
tol=eval(config["Model"]["IRLS"]["tol"]),
align_method=config["Model"]["align_method"],
)
elif config["Model"]["align_lib"] == "triton": # triton
s, R, t = robust_weighted_estimate_sim3_triton(
all_pts2,
all_pts1,
all_weights,
delta=config["Model"]["IRLS"]["delta"],
max_iters=config["Model"]["IRLS"]["max_iters"],
tol=eval(config["Model"]["IRLS"]["tol"]),
align_method=config["Model"]["align_method"],
)
else:
raise ValueError(f"Unknown align_lib: {config['Model']['align_lib']}")
if precompute_scale is not None: # meaning we are using align method 'scale+se3'
# we need this precompute_scale for loop align
s = precompute_scale
mean_error = compute_alignment_error(
point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t
)
print(f"Mean error: {mean_error}")
return s, R, t