# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long) import bisect import glob import os import numpy as np import trimesh from loop_utils.alignment_torch import robust_weighted_estimate_sim3_torch from loop_utils.alignment_triton import robust_weighted_estimate_sim3_triton from numba import njit from sklearn.linear_model import LinearRegression, RANSACRegressor def accumulate_sim3_transforms(transforms): """ Accumulate adjacent SIM(3) transforms into transforms from the initial frame to each subsequent frame. Args: transforms: list, each element is a tuple (R, s, t) R: 3x3 rotation matrix (np.array) s: scale factor (scalar) t: 3x1 translation vector (np.array) Returns: Cumulative transforms list, each element is (R_cum, s_cum, t_cum) representing the transform from frame 0 to frame k """ if not transforms: return [] cumulative_transforms = [transforms[0]] for i in range(1, len(transforms)): s_cum_prev, R_cum_prev, t_cum_prev = cumulative_transforms[i - 1] s_next, R_next, t_next = transforms[i] R_cum_new = R_cum_prev @ R_next s_cum_new = s_cum_prev * s_next t_cum_new = s_cum_prev * (R_cum_prev @ t_next) + t_cum_prev cumulative_transforms.append((s_cum_new, R_cum_new, t_cum_new)) return cumulative_transforms def estimate_sim3(source_points, target_points): mu_src = np.mean(source_points, axis=0) mu_tgt = np.mean(target_points, axis=0) src_centered = source_points - mu_src tgt_centered = target_points - mu_tgt scale_src = np.sqrt((src_centered**2).sum(axis=1).mean()) scale_tgt = np.sqrt((tgt_centered**2).sum(axis=1).mean()) s = scale_tgt / scale_src src_scaled = src_centered * s H = src_scaled.T @ tgt_centered U, _, Vt = np.linalg.svd(H) R = Vt.T @ U.T if np.linalg.det(R) < 0: Vt[2, :] *= -1 R = Vt.T @ U.T t = mu_tgt - s * R @ mu_src return s, R, t def align_point_maps(point_map1, conf1, point_map2, conf2, conf_threshold): """point_map2 -> point_map1""" b1, _, _, _ = point_map1.shape b2, _, _, _ = point_map2.shape b = min(b1, b2) aligned_points1 = [] aligned_points2 = [] for i in range(b): mask1 = conf1[i] > conf_threshold mask2 = conf2[i] > conf_threshold valid_mask = mask1 & mask2 idx = np.where(valid_mask) if len(idx[0]) == 0: continue pts1 = point_map1[i][idx] pts2 = point_map2[i][idx] aligned_points1.append(pts1) aligned_points2.append(pts2) if len(aligned_points1) == 0: raise ValueError("No matching point pairs were found!") all_pts1 = np.concatenate(aligned_points1, axis=0) all_pts2 = np.concatenate(aligned_points2, axis=0) print(f"The number of corresponding points matched: {all_pts1.shape[0]}") s, R, t = estimate_sim3(all_pts2, all_pts1) mean_error = compute_alignment_error( point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t ) print(f"Mean error: {mean_error}") return s, R, t def apply_sim3(points, s, R, t): return (s * (R @ points.T)).T + t def apply_sim3_direct(point_maps, s, R, t): # point_maps: (b, h, w, 3) -> (b, h, w, 3, 1) point_maps_expanded = point_maps[..., np.newaxis] # (b, h, w, 3, 1) # R: (3, 3) -> (b, h, w, 3, 1) = (3, 3) @ (3, 1) rotated = np.matmul(R, point_maps_expanded) # (b, h, w, 3, 1) rotated = rotated.squeeze(-1) # (b, h, w, 3) transformed = s * rotated + t # (b, h, w, 3) return transformed def compute_alignment_error(point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t): """ Compute the average point alignment error (using only original inputs) Args: point_map1: target point map (b, h, w, 3) conf1: target confidence map (b, h, w) point_map2: source point map (b, h, w, 3) conf2: source confidence map (b, h, w) conf_threshold: confidence threshold s, R, t: transformation parameters """ b1, h1, w1, _ = point_map1.shape b2, h2, w2, _ = point_map2.shape b = min(b1, b2) h = min(h1, h2) w = min(w1, w2) target_points = [] source_points = [] for i in range(b): mask1 = conf1[i, :h, :w] > conf_threshold mask2 = conf2[i, :h, :w] > conf_threshold valid_mask = mask1 & mask2 idx = np.where(valid_mask) if len(idx[0]) == 0: continue t_pts = point_map1[i, :h, :w][idx] s_pts = point_map2[i, :h, :w][idx] target_points.append(t_pts) source_points.append(s_pts) if len(target_points) == 0: print("Warning: No matching point pairs found for error calculation") return np.nan all_target = np.concatenate(target_points, axis=0) all_source = np.concatenate(source_points, axis=0) transformed = (s * (R @ all_source.T)).T + t errors = np.linalg.norm(transformed - all_target, axis=1) mean_error = np.mean(errors) std_error = np.std(errors) median_error = np.median(errors) max_error = np.max(errors) print( f"Alignment error statistics [using {len(errors)} points]: " f"mean={mean_error:.4f}, std={std_error:.4f}, " f"median={median_error:.4f}, max={max_error:.4f}" ) return mean_error def save_confident_pointcloud( points, colors, confs, output_path, conf_threshold, sample_ratio=1.0 ): """ Filter points based on confidence threshold and save as PLY file, with optional random sampling ratio. Args: - points: np.ndarray, shape (H, W, 3) or (N, 3) - colors: np.ndarray, shape (H, W, 3) or (N, 3) - confs: np.ndarray, shape (H, W) or (N,) - output_path: str, output PLY file path - conf_threshold: float, confidence threshold for point filtering - sample_ratio: float, sampling ratio (0 < sample_ratio <= 1.0) """ points = points.reshape(-1, 3).astype(np.float32, copy=False) colors = colors.reshape(-1, 3).astype(np.uint8, copy=False) confs = confs.reshape(-1).astype(np.float32, copy=False) conf_mask = (confs >= conf_threshold) & (confs > 1e-5) points = points[conf_mask] colors = colors[conf_mask] if 0 < sample_ratio < 1.0 and len(points) > 0: num_samples = int(len(points) * sample_ratio) indices = np.random.choice(len(points), num_samples, replace=False) points = points[indices] colors = colors[indices] os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) print(f"shape of sampled point: {points.shape}") trimesh.PointCloud(points, colors=colors).export(output_path) print(f"Saved point cloud with {len(points)} points to {output_path}") def save_confident_pointcloud_batch( points, colors, confs, output_path, conf_threshold, sample_ratio=1.0, batch_size=1000000 ): """ - points: np.ndarray, (b, H, W, 3) / (N, 3) - colors: np.ndarray, (b, H, W, 3) / (N, 3) - confs: np.ndarray, (b, H, W) / (N,) - output_path: str - conf_threshold: float, - sample_ratio: float (0 < sample_ratio <= 1.0) - batch_size: int """ if points.ndim == 2: b = 1 points = points[np.newaxis, ...] colors = colors[np.newaxis, ...] confs = confs[np.newaxis, ...] elif points.ndim == 4: b = points.shape[0] else: raise ValueError("Unsupported points dimension. Must be 2 (N,3) or 4 (b,H,W,3)") total_valid = 0 for i in range(b): cfs = confs[i].reshape(-1) total_valid += np.count_nonzero((cfs >= conf_threshold) & (cfs > 1e-5)) num_samples = int(total_valid * sample_ratio) if sample_ratio < 1.0 else total_valid if num_samples == 0: save_ply(np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8), output_path) return if sample_ratio == 1.0: with open(output_path, "wb") as f: write_ply_header(f, num_samples) for i in range(b): pts = points[i].reshape(-1, 3).astype(np.float32) cls = colors[i].reshape(-1, 3).astype(np.uint8) cfs = confs[i].reshape(-1).astype(np.float32) mask = (cfs >= conf_threshold) & (cfs > 1e-5) valid_pts = pts[mask] valid_cls = cls[mask] for j in range(0, len(valid_pts), batch_size): batch_pts = valid_pts[j : j + batch_size] batch_cls = valid_cls[j : j + batch_size] write_ply_batch(f, batch_pts, batch_cls) else: reservoir_pts = np.zeros((num_samples, 3), dtype=np.float32) reservoir_clr = np.zeros((num_samples, 3), dtype=np.uint8) count = 0 for i in range(b): pts = points[i].reshape(-1, 3).astype(np.float32) cls = colors[i].reshape(-1, 3).astype(np.uint8) cfs = confs[i].reshape(-1).astype(np.float32) mask = (cfs >= conf_threshold) & (cfs > 1e-5) valid_pts = pts[mask] valid_cls = cls[mask] n_valid = len(valid_pts) if count < num_samples: fill_count = min(num_samples - count, n_valid) reservoir_pts[count : count + fill_count] = valid_pts[:fill_count] reservoir_clr[count : count + fill_count] = valid_cls[:fill_count] count += fill_count if fill_count < n_valid: remaining_pts = valid_pts[fill_count:] remaining_cls = valid_cls[fill_count:] count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling( remaining_pts, remaining_cls, count, reservoir_pts, reservoir_clr ) else: count, reservoir_pts, reservoir_clr = optimized_vectorized_reservoir_sampling( valid_pts, valid_cls, count, reservoir_pts, reservoir_clr ) save_ply(reservoir_pts, reservoir_clr, output_path) """ The following function is deprecated""" # def vectorized_reservoir_sampling(new_pts, new_cls, current_count, reservoir_pts, reservoir_clr): # """ # - new_pts: (M, 3) # - new_cls: (M, 3) # - current_count # - reservoir_pts: (K, 3) # - reservoir_clr: (K, 3) # """ # k = len(reservoir_pts) # n_new = len(new_pts) # rand_indices = np.random.randint(0, current_count + n_new, size=n_new) # replace_mask = rand_indices < k # replace_indices = rand_indices[replace_mask] # replace_pts = new_pts[replace_mask] # replace_cls = new_cls[replace_mask] # reservoir_pts[replace_indices] = replace_pts # reservoir_clr[replace_indices] = replace_cls # return current_count + n_new, reservoir_pts, reservoir_clr """ Function `vectorized_reservoir_sampling` is not mathematically accurate in sampling. This leads to inconsistent density in the downsampled point clouds. The `optimized_vectorized_reservoir_sampling` function has fixed this bug. Special thanks to @Horace89 for the detailed analysis and code assistance. See https://github.com/DengKaiCQ/VGGT-Long/issues/28 for details """ def optimized_vectorized_reservoir_sampling( new_points: np.ndarray, new_colors: np.ndarray, current_count: int, reservoir_points: np.ndarray, reservoir_colors: np.ndarray, ) -> tuple[int, np.ndarray, np.ndarray]: """ Optimized vectorized reservoir sampling with batch probability calculations. This maintains mathematical correctness while improving performance through vectorized operations where possible. Args: new_points: New point coordinates to consider, shape (M, 3) new_colors: New point colors to consider, shape (M, 3) current_count: Number of elements seen so far reservoir_points: Current reservoir of sampled points, shape (K, 3) reservoir_colors: Current reservoir of sampled colors, shape (K, 3) Returns: Tuple of (updated_count, updated_reservoir_points, updated_reservoir_colors) """ random_gen = np.random reservoir_size = len(reservoir_points) num_new_points = len(new_points) if num_new_points == 0: return current_count, reservoir_points, reservoir_colors # Calculate sequential indices for each new point point_indices = np.arange(current_count + 1, current_count + num_new_points + 1) # Generate random numbers for each point random_values = random_gen.randint(0, point_indices, size=num_new_points) # Determine which points should replace reservoir elements replacement_mask = random_values < reservoir_size replacement_positions = random_values[replacement_mask] # Apply replacements if np.any(replacement_mask): points_to_replace = new_points[replacement_mask] colors_to_replace = new_colors[replacement_mask] reservoir_points[replacement_positions] = points_to_replace reservoir_colors[replacement_positions] = colors_to_replace return current_count + num_new_points, reservoir_points, reservoir_colors def write_ply_header(f, num_vertices): header = [ "ply", "format binary_little_endian 1.0", f"element vertex {num_vertices}", "property float x", "property float y", "property float z", "property uchar red", "property uchar green", "property uchar blue", "end_header", ] f.write("\n".join(header).encode() + b"\n") def write_ply_batch(f, points, colors): structured = np.zeros( len(points), dtype=[ ("x", np.float32), ("y", np.float32), ("z", np.float32), ("red", np.uint8), ("green", np.uint8), ("blue", np.uint8), ], ) structured["x"] = points[:, 0] structured["y"] = points[:, 1] structured["z"] = points[:, 2] structured["red"] = colors[:, 0] structured["green"] = colors[:, 1] structured["blue"] = colors[:, 2] f.write(structured.tobytes()) def save_ply(points, colors, filename): with open(filename, "wb") as f: write_ply_header(f, len(points)) write_ply_batch(f, points, colors) def find_chunk_index(chunks, idx): """ Find the 0-based chunk index that contains the given index idx. chunks: List of (begin_idx, end_idx). idx: The index to search for. Returns the 0-based chunk index. """ starts = [chunk[0] for chunk in chunks] pos = bisect.bisect_right(starts, idx) - 1 # Find position of idx in starts if pos < 0 or pos >= len(chunks): raise ValueError(f"Index {idx} not found in any chunk") chunk_begin, chunk_end = chunks[pos] if idx < chunk_begin or idx > chunk_end: raise ValueError(f"Index {idx} not found in any chunk") return pos def get_frame_range(chunk, idx, half_window=10): """ Calculate the frame range centered at idx with half_window frames on each side within chunk boundaries. If near boundaries, take 2 * half_window frames starting from the boundary. chunk: (begin_idx, end_idx). idx: Center index. half_window: Number of frames to take on each side of center index. Returns (start, end). """ begin, end = chunk window_size = 2 * half_window if idx - half_window < begin: start = begin end_candidate = begin + window_size end = min(end, end_candidate) elif idx + half_window > end: end_candidate = end start_candidate = end - window_size start = max(begin, start_candidate) else: start = idx - half_window end = idx + half_window return (start, end) def process_loop_list(chunk_index, loop_list, half_window=10): """ Process loop_list and return chunk indices and frame ranges for each (idx1, idx2) pair. chunk_index: List of (begin_idx, end_idx) tuples. loop_list: List of (idx1, idx2) tuples. half_window: Number of frames to take on each side of center index (default 10). Returns list of (chunk_idx1, range1, chunk_idx2, range2) tuples where: - chunk_idx1, chunk_idx2: Chunk indices (1-based). - range1, range2: Frame range tuples (start, end). """ results = [] for idx1, idx2 in loop_list: try: chunk_idx1_0based = find_chunk_index(chunk_index, idx1) chunk1 = chunk_index[chunk_idx1_0based] range1 = get_frame_range(chunk1, idx1, half_window) chunk_idx2_0based = find_chunk_index(chunk_index, idx2) chunk2 = chunk_index[chunk_idx2_0based] range2 = get_frame_range(chunk2, idx2, half_window) result = (chunk_idx1_0based, range1, chunk_idx2_0based, range2) results.append(result) except ValueError as e: print(f"Skipping pair ({idx1}, {idx2}): {e}") return results def compute_sim3_ab(S_a, S_b): s_a, R_a, T_a = S_a s_b, R_b, T_b = S_b s_ab = s_b / s_a R_ab = R_b @ R_a.T T_ab = T_b - s_ab * (R_ab @ T_a) return (s_ab, R_ab, T_ab) def merge_ply_files(input_dir, output_path): """ Merge all PLY files in a directory into one file (without loading into memory) Args: - input_dir: Input directory containing multiple '{idx}_pcd.ply' files - output_path: Output file path (e.g., 'combined.ply') """ print("Merging PLY files...") input_files = sorted(glob.glob(os.path.join(input_dir, "*_pcd.ply"))) if not input_files: print("No PLY files found") return idx_file = 0 len(input_files) total_vertices = 0 for file in input_files: # Count total vertices with open(file, "rb") as f: for line in f: if line.startswith(b"element vertex"): vertex_count = int(line.split()[-1]) total_vertices += vertex_count elif line.startswith(b"end_header"): break with open(output_path, "wb") as out_f: # Write new header out_f.write(b"ply\n") out_f.write(b"format binary_little_endian 1.0\n") out_f.write(f"element vertex {total_vertices}\n".encode()) out_f.write(b"property float x\n") out_f.write(b"property float y\n") out_f.write(b"property float z\n") out_f.write(b"property uchar red\n") out_f.write(b"property uchar green\n") out_f.write(b"property uchar blue\n") out_f.write(b"end_header\n") for file in input_files: print(f"Processing {idx_file}/{len(input_files)}: {file}") idx_file += 1 with open(file, "rb") as in_f: # Skip the head in_header = True while in_header: line = in_f.readline() if line.startswith(b"end_header"): in_header = False data = in_f.read() out_f.write(data) print(f"Merge completed! Total points: {total_vertices}") print(f"Output file: {output_path}") def weighted_estimate_se3(source_points, target_points, weights): """ source_points: (Nx3) target_points: (Nx3) :weights: (N,) [0,1] """ total_weight = np.sum(weights) if total_weight < 1e-6: raise ValueError("Total weight too small for meaningful estimation") normalized_weights = weights / total_weight mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0) mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0) src_centered = source_points - mu_src tgt_centered = target_points - mu_tgt weighted_src = src_centered * np.sqrt(normalized_weights)[:, None] weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None] H = weighted_src.T @ weighted_tgt U, _, Vt = np.linalg.svd(H) R = Vt.T @ U.T if np.linalg.det(R) < 0: Vt[2, :] *= -1 R = Vt.T @ U.T t = mu_tgt - R @ mu_src return 1.0, R, t def weighted_estimate_sim3(source_points, target_points, weights): """ source_points: (Nx3) target_points: (Nx3) :weights: (N,) [0,1] """ total_weight = np.sum(weights) if total_weight < 1e-6: raise ValueError("Total weight too small for meaningful estimation") normalized_weights = weights / total_weight mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0) mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0) src_centered = source_points - mu_src tgt_centered = target_points - mu_tgt scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1))) scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1))) s = scale_tgt / scale_src weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None] weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None] H = weighted_src.T @ weighted_tgt U, _, Vt = np.linalg.svd(H) R = Vt.T @ U.T if np.linalg.det(R) < 0: Vt[2, :] *= -1 R = Vt.T @ U.T t = mu_tgt - s * R @ mu_src return s, R, t def huber_loss(r, delta): abs_r = np.abs(r) return np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta)) def robust_weighted_estimate_sim3( src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3" ): """ src: (Nx3) tgt: (Nx3) init_weights: (N,) """ if align_method == "sim3": s, R, t = weighted_estimate_sim3(src, tgt, init_weights) elif align_method == "se3" or align_method == "scale+se3": s, R, t = weighted_estimate_se3(src, tgt, init_weights) prev_error = float("inf") for iter in range(max_iters): transformed = s * (src @ R.T) + t residuals = np.linalg.norm(tgt - transformed, axis=1) # (N,) print(f"Residuals: {np.mean(residuals)}") abs_res = np.abs(residuals) huber_weights = np.ones_like(residuals) large_res_mask = abs_res > delta huber_weights[large_res_mask] = delta / abs_res[large_res_mask] combined_weights = init_weights * huber_weights combined_weights /= np.sum(combined_weights) + 1e-12 if align_method == "se3": s_new, R_new, t_new = weighted_estimate_se3(src, tgt, combined_weights) elif align_method == "sim3" or align_method == "scale+se3": s_new, R_new, t_new = weighted_estimate_sim3(src, tgt, combined_weights) param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t) rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2))) current_error = np.sum(huber_loss(residuals, delta) * init_weights) if (param_change < tol and rot_angle < np.radians(0.1)) or ( abs(prev_error - current_error) < tol * prev_error ): break s, R, t = s_new, R_new, t_new prev_error = current_error return s, R, t # ===== Speed Up Begin ===== @njit(cache=True) def _weighted_estimate_se3_numba(source_points, target_points, weights): # Ensure float32 source_points = source_points.astype(np.float32) target_points = target_points.astype(np.float32) weights = weights.astype(np.float32) total_weight = np.sum(weights) if total_weight < 1e-6: return ( 1.0, np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32), np.zeros((3, 3), dtype=np.float32), ) normalized_weights = weights / total_weight mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0) mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0) src_centered = source_points - mu_src tgt_centered = target_points - mu_tgt weighted_src = src_centered * np.sqrt(normalized_weights)[:, None] weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None] H = weighted_src.T @ weighted_tgt return 1.0, mu_src, mu_tgt, H @njit(cache=True) def _weighted_estimate_sim3_numba(source_points, target_points, weights): # Ensure float32 source_points = source_points.astype(np.float32) target_points = target_points.astype(np.float32) weights = weights.astype(np.float32) total_weight = np.sum(weights) if total_weight < 1e-6: return ( -1.0, np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32), np.zeros((3, 3), dtype=np.float32), ) normalized_weights = weights / total_weight mu_src = np.sum(normalized_weights[:, None] * source_points, axis=0) mu_tgt = np.sum(normalized_weights[:, None] * target_points, axis=0) src_centered = source_points - mu_src tgt_centered = target_points - mu_tgt scale_src = np.sqrt(np.sum(normalized_weights * np.sum(src_centered**2, axis=1))) scale_tgt = np.sqrt(np.sum(normalized_weights * np.sum(tgt_centered**2, axis=1))) s = scale_tgt / scale_src weighted_src = (s * src_centered) * np.sqrt(normalized_weights)[:, None] weighted_tgt = tgt_centered * np.sqrt(normalized_weights)[:, None] H = weighted_src.T @ weighted_tgt return s, mu_src, mu_tgt, H def weighted_estimate_sim3_numba(source_points, target_points, weights, align_method="sim3"): if align_method == "sim3": s, mu_src, mu_tgt, H = _weighted_estimate_sim3_numba(source_points, target_points, weights) elif align_method == "se3" or align_method == "scale+se3": s, mu_src, mu_tgt, H = _weighted_estimate_se3_numba(source_points, target_points, weights) if s < 0: raise ValueError("Total weight too small for meaningful estimation") # Ensure float32 H = H.astype(np.float32) U, _, Vt = np.linalg.svd(H.astype(np.float32)) # float32 SVD R = Vt.T @ U.T if np.linalg.det(R) < 0: Vt[2, :] *= -1 R = Vt.T @ U.T if align_method == "se3" or align_method == "scale+se3": t = mu_tgt - R @ mu_src else: t = mu_tgt - s * R @ mu_src return s, R, t @njit(cache=True) def huber_loss_numba(r, delta): r = r.astype(np.float32) delta = np.float32(delta) abs_r = np.abs(r) result = np.where(abs_r <= delta, 0.5 * r**2, delta * (abs_r - 0.5 * delta)) return result.astype(np.float32) @njit(cache=True) def compute_residuals_numba(tgt, transformed): residuals = np.empty(tgt.shape[0], dtype=np.float32) for i in range(tgt.shape[0]): diff = tgt[i] - transformed[i] residuals[i] = np.sqrt(np.sum(diff**2)) return residuals @njit(cache=True) def compute_huber_weights_numba(residuals, delta): weights = np.ones(residuals.shape, dtype=np.float32) for i in range(residuals.shape[0]): r = residuals[i] if r > delta: weights[i] = delta / r return weights @njit(cache=True) def apply_transformation_numba(src, s, R, t): transformed = np.empty_like(src) for i in range(src.shape[0]): p = src[i] transformed[i] = s * (R @ p) + t return transformed def robust_weighted_estimate_sim3_numba( src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3" ): src = src.astype(np.float32) tgt = tgt.astype(np.float32) init_weights = init_weights.astype(np.float32) s, R, t = weighted_estimate_sim3_numba(src, tgt, init_weights, align_method=align_method) prev_error = float("inf") for iter in range(max_iters): transformed = apply_transformation_numba(src, s, R, t) residuals = compute_residuals_numba(tgt, transformed) print(f"Residuals: {np.mean(residuals)}") huber_weights = compute_huber_weights_numba(residuals, delta) combined_weights = init_weights * huber_weights combined_weights /= np.sum(combined_weights) + 1e-12 s_new, R_new, t_new = weighted_estimate_sim3_numba( src, tgt, combined_weights, align_method=align_method ) param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t) rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2))) current_error = np.sum(huber_loss_numba(residuals, delta) * init_weights) if (param_change < tol and rot_angle < np.radians(0.1)) or ( abs(prev_error - current_error) < tol * prev_error ): break s, R, t = s_new, R_new, t_new prev_error = current_error return s, R, t def warmup_numba(): print("\nWarming up Numba JIT-compiled functions...") src = np.random.randn(50000, 3).astype(np.float32) tgt = np.random.randn(50000, 3).astype(np.float32) weights = np.ones(50000, dtype=np.float32) residuals = np.abs(np.random.randn(50000).astype(np.float32)) R = np.eye(3, dtype=np.float32) t = np.zeros(3, dtype=np.float32) s = np.float32(1.0) delta = np.float32(1.0) try: _ = _weighted_estimate_sim3_numba(src, tgt, weights) print(" - _weighted_estimate_sim3_numba warmed up.") except Exception as e: print(" ! Failed to warm up _weighted_estimate_sim3_numba:", e) try: _ = _weighted_estimate_se3_numba(src, tgt, weights) print(" - _weighted_estimate_se3_numba warmed up.") except Exception as e: print(" ! Failed to warm up _weighted_estimate_se3_numba:", e) try: _ = huber_loss_numba(residuals, delta) print(" - huber_loss_numba warmed up.") except Exception as e: print(" ! Failed to warm up huber_loss_numba:", e) try: _ = compute_huber_weights_numba(residuals, delta) print(" - compute_huber_weights_numba warmed up.") except Exception as e: print(" ! Failed to warm up compute_huber_weights_numba:", e) try: _ = compute_residuals_numba(tgt, src) print(" - compute_residuals_numba warmed up.") except Exception as e: print(" ! Failed to warm up compute_residuals_numba:", e) try: _ = apply_transformation_numba(src, s, R, t) print(" - apply_transformation_numba warmed up.") except Exception as e: print(" ! Failed to warm up apply_transformation_numba:", e) print("Numba warm-up complete.\n") # ===== Speed Up End ===== # ===== Scale precompute begin ===== def compute_scale_ransac( depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, max_samples=10000 ): """ Args: depth1: (n1, h, w) depth2: (n2, h, w) conf1: (n1, h, w) conf2: (n2, h, w) """ depth1_flat = depth1.reshape(-1) depth2_flat = depth2.reshape(-1) conf1_flat = conf1.reshape(-1) conf2_flat = conf2.reshape(-1) conf_threshold = max( np.median(conf1_flat) * conf_threshold_ratio, np.median(conf2_flat) * conf_threshold_ratio, 1e-6, ) valid_mask = ( (conf1_flat > conf_threshold) & (conf2_flat > conf_threshold) & (depth1_flat > 1e-3) & (depth2_flat > 1e-3) & (depth1_flat < 100) & (depth2_flat < 100) ) if np.sum(valid_mask) < 100: print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0") return 1.0, 0.0 valid_depth1 = depth1_flat[valid_mask] valid_depth2 = depth2_flat[valid_mask] if len(valid_depth1) > max_samples: indices = np.random.choice(len(valid_depth1), max_samples, replace=False) valid_depth1 = valid_depth1[indices] valid_depth2 = valid_depth2[indices] X = valid_depth2.reshape(-1, 1) y = valid_depth1 base_estimator = LinearRegression(fit_intercept=False) ransac = RANSACRegressor( estimator=base_estimator, max_trials=1000, min_samples=max(10, len(X) // 100), residual_threshold=0.1, random_state=42, ) ransac.fit(X, y) scale_factor = ransac.estimator_.coef_[0] inlier_mask = ransac.inlier_mask_ inlier_ratio = np.sum(inlier_mask) / len(inlier_mask) print(f"RANSAC scale: {scale_factor:.6f}, inlier ratio: {inlier_ratio:.4f}") if 0.1 < scale_factor < 10.0: return scale_factor, inlier_ratio else: print(f"Warning: Unreasonable scale {scale_factor}, using 1.0") return 1.0, inlier_ratio def compute_scale_weighted( depth1, depth2, conf1, conf2, conf_threshold_ratio=0.1, weight_power=2.0, robust_quantile=0.9 ): """ Args: depth1: (n1, h, w) depth2: (n2, h, w) conf1: (n1, h, w) conf2: (n2, h, w) """ depth1_flat = depth1.reshape(-1) depth2_flat = depth2.reshape(-1) conf1_flat = conf1.reshape(-1) conf2_flat = conf2.reshape(-1) conf_threshold = max( np.median(conf1_flat) * conf_threshold_ratio, np.median(conf2_flat) * conf_threshold_ratio, 1e-6, ) valid_mask = ( (conf1_flat > conf_threshold) & (conf2_flat > conf_threshold) & (depth1_flat > 1e-3) & (depth2_flat > 1e-3) & (depth1_flat < 100) & (depth2_flat < 100) ) if np.sum(valid_mask) < 100: print(f"Warning: Only {np.sum(valid_mask)} valid points, using default scale 1.0") return 1.0, 0.0 valid_depth1 = depth1_flat[valid_mask] valid_depth2 = depth2_flat[valid_mask] valid_conf1 = conf1_flat[valid_mask] valid_conf2 = conf2_flat[valid_mask] combined_weights = (valid_conf1 * valid_conf2) ** weight_power combined_weights = combined_weights / (np.sum(combined_weights) + 1e-8) ratios = valid_depth1 / (valid_depth2 + 1e-8) sorted_indices = np.argsort(ratios) sorted_ratios = ratios[sorted_indices] sorted_weights = combined_weights[sorted_indices] cumulative_weights = np.cumsum(sorted_weights) median_idx = np.searchsorted(cumulative_weights, 0.5) scale_median = sorted_ratios[median_idx] if median_idx < len(sorted_ratios) else 1.0 quantile_idx = np.searchsorted(cumulative_weights, robust_quantile) scale_quantile = ( sorted_ratios[quantile_idx] if quantile_idx < len(sorted_ratios) else scale_median ) weight_entropy = -np.sum(combined_weights * np.log(combined_weights + 1e-8)) max_entropy = np.log(len(combined_weights)) confidence_score = 1.0 - (weight_entropy / max_entropy) if max_entropy > 0 else 0.0 print(f"Weighted scale: {scale_quantile:.6f}, confidence: {confidence_score:.4f}") if 0.1 < scale_quantile < 10.0: return scale_quantile, confidence_score else: print(f"Warning: Unreasonable scale {scale_quantile}, using 1.0") return 1.0, confidence_score def compute_chunk_scale_advanced(depth1, depth2, conf1, conf2, method="auto"): """ method: 'auto', 'ransac', 'weighted' """ if method == "ransac": scale, score = compute_scale_ransac(depth1, depth2, conf1, conf2) return scale, score, "ransac" elif method == "weighted": scale, score = compute_scale_weighted(depth1, depth2, conf1, conf2) return scale, score, "weighted" elif method == "auto": scale_ransac, inlier_ratio = compute_scale_ransac(depth1, depth2, conf1, conf2) scale_weighted, conf_score = compute_scale_weighted(depth1, depth2, conf1, conf2) ransac_quality = inlier_ratio weighted_quality = conf_score print(f"RANSAC quality: {ransac_quality:.4f}, Weighted quality: {weighted_quality:.4f}") if ransac_quality > 0.7 and weighted_quality > 0.7: # both method are good, we take both of them by average final_scale = (scale_ransac + scale_weighted) / 2 final_method = "average" elif ransac_quality > weighted_quality: final_scale = scale_ransac final_method = "ransac" else: final_scale = scale_weighted final_method = "weighted" final_quality = max(ransac_quality, weighted_quality) return final_scale, final_quality, final_method def precompute_scale_chunks_with_depth( chunk1_depth, chunk1_conf, chunk2_depth, chunk2_conf, method="auto" ): """ Args: chunk1_depth: (n1, h, w) chunk1_conf: (n1, h, w) chunk2_depth: (n2, h, w) chunk2_conf: (n2, h, w) method: 'auto', 'ransac', 'weighted' """ scale_factor, quality_score, method_used = compute_chunk_scale_advanced( chunk1_depth, chunk2_depth, chunk1_conf, chunk2_conf, method ) print(f"Final scale: {scale_factor:.6f}, quality: {quality_score:.4f}, method: {method_used}") return scale_factor, quality_score, method_used # ===== Scale precompute end ===== def weighted_align_point_maps( point_map1, conf1, point_map2, conf2, conf_threshold, config, precompute_scale=None ): """point_map2 -> point_map1""" b1, _, _, _ = point_map1.shape b2, _, _, _ = point_map2.shape b = min(b1, b2) if precompute_scale is not None: # meaning we are using align method 'scale+se3' point_map2 *= precompute_scale aligned_points1 = [] aligned_points2 = [] confidence_weights = [] for i in range(b): mask1 = conf1[i] > conf_threshold mask2 = conf2[i] > conf_threshold valid_mask = mask1 & mask2 idx = np.where(valid_mask) if len(idx[0]) == 0: continue pts1 = point_map1[i][idx] pts2 = point_map2[i][idx] combined_conf = np.sqrt(conf1[i][idx] * conf2[i][idx]) aligned_points1.append(pts1) aligned_points2.append(pts2) confidence_weights.append(combined_conf) if len(aligned_points1) == 0: raise ValueError("No matching point pairs were found!") all_pts1 = np.concatenate(aligned_points1, axis=0) all_pts2 = np.concatenate(aligned_points2, axis=0) all_weights = np.concatenate(confidence_weights, axis=0) print(f"The number of corresponding points matched: {all_pts1.shape[0]}") if config["Model"]["align_lib"] == "numba": s, R, t = robust_weighted_estimate_sim3_numba( all_pts2, all_pts1, all_weights, delta=config["Model"]["IRLS"]["delta"], max_iters=config["Model"]["IRLS"]["max_iters"], tol=eval(config["Model"]["IRLS"]["tol"]), align_method=config["Model"]["align_method"], ) elif config["Model"]["align_lib"] == "numpy": # numpy s, R, t = robust_weighted_estimate_sim3( all_pts2, all_pts1, all_weights, delta=config["Model"]["IRLS"]["delta"], max_iters=config["Model"]["IRLS"]["max_iters"], tol=eval(config["Model"]["IRLS"]["tol"]), align_method=config["Model"]["align_method"], ) elif config["Model"]["align_lib"] == "torch": # torch s, R, t = robust_weighted_estimate_sim3_torch( all_pts2, all_pts1, all_weights, delta=config["Model"]["IRLS"]["delta"], max_iters=config["Model"]["IRLS"]["max_iters"], tol=eval(config["Model"]["IRLS"]["tol"]), align_method=config["Model"]["align_method"], ) elif config["Model"]["align_lib"] == "triton": # triton s, R, t = robust_weighted_estimate_sim3_triton( all_pts2, all_pts1, all_weights, delta=config["Model"]["IRLS"]["delta"], max_iters=config["Model"]["IRLS"]["max_iters"], tol=eval(config["Model"]["IRLS"]["tol"]), align_method=config["Model"]["align_method"], ) else: raise ValueError(f"Unknown align_lib: {config['Model']['align_lib']}") if precompute_scale is not None: # meaning we are using align method 'scale+se3' # we need this precompute_scale for loop align s = precompute_scale mean_error = compute_alignment_error( point_map1, conf1, point_map2, conf2, conf_threshold, s, R, t ) print(f"Mean error: {mean_error}") return s, R, t