Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Quick visualization wrapper for GCT predictions using Viser. | |
| """ | |
| import time | |
| import threading | |
| from typing import List, Optional | |
| import numpy as np | |
| import viser | |
| import viser.transforms as tf | |
| from tqdm.auto import tqdm | |
| from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map | |
| from lingbot_map.vis.sky_segmentation import apply_sky_segmentation | |
| def viser_wrapper( | |
| pred_dict: dict, | |
| port: int = 8080, | |
| init_conf_threshold: float = 50.0, | |
| use_point_map: bool = False, | |
| background_mode: bool = False, | |
| mask_sky: bool = False, | |
| image_folder: Optional[str] = None, | |
| ): | |
| """ | |
| Visualize predicted 3D points and camera poses with viser. | |
| This is a simplified wrapper for quick visualization without the full | |
| PointCloudViewer controls. | |
| Args: | |
| pred_dict: Dictionary containing predictions with keys: | |
| - images: (S, 3, H, W) - Input images | |
| - world_points: (S, H, W, 3) | |
| - world_points_conf: (S, H, W) | |
| - depth: (S, H, W, 1) | |
| - depth_conf: (S, H, W) | |
| - extrinsic: (S, 3, 4) | |
| - intrinsic: (S, 3, 3) | |
| port: Port number for the viser server | |
| init_conf_threshold: Initial percentage of low-confidence points to filter out | |
| use_point_map: Whether to visualize world_points or use depth-based points | |
| background_mode: Whether to run the server in background thread | |
| mask_sky: Whether to apply sky segmentation to filter out sky points | |
| image_folder: Path to the folder containing input images (for sky segmentation) | |
| Returns: | |
| viser.ViserServer: The viser server instance | |
| """ | |
| print(f"Starting viser server on port {port}") | |
| server = viser.ViserServer(host="0.0.0.0", port=port) | |
| server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") | |
| # Unpack prediction dict | |
| images = pred_dict["images"] # (S, 3, H, W) | |
| world_points_map = pred_dict["world_points"] # (S, H, W, 3) | |
| conf_map = pred_dict["world_points_conf"] # (S, H, W) | |
| depth_map = pred_dict["depth"] # (S, H, W, 1) | |
| depth_conf = pred_dict["depth_conf"] # (S, H, W) | |
| extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) | |
| intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) | |
| # Compute world points from depth if not using the precomputed point map | |
| if not use_point_map: | |
| world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) | |
| conf = depth_conf | |
| else: | |
| world_points = world_points_map | |
| conf = conf_map | |
| # Apply sky segmentation if enabled | |
| if mask_sky and image_folder is not None: | |
| conf = apply_sky_segmentation(conf, image_folder) | |
| # Convert images from (S, 3, H, W) to (S, H, W, 3) | |
| colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) | |
| shape = world_points.shape | |
| S: int = shape[0] | |
| H: int = shape[1] | |
| W: int = shape[2] | |
| # Flatten | |
| points = world_points.reshape(-1, 3) | |
| colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) | |
| conf_flat = conf.reshape(-1) | |
| # Random sample points if too many | |
| indices = None | |
| if points.shape[0] > 6000000: | |
| print(f"Too many points ({points.shape[0]}), randomly sampling 6M points") | |
| indices = np.random.choice(points.shape[0], size=6000000, replace=False) | |
| points = points[indices] | |
| colors_flat = colors_flat[indices] | |
| conf_flat = conf_flat[indices] | |
| cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) | |
| cam_to_world = cam_to_world_mat[:, :3, :] | |
| # Compute scene center and recenter | |
| scene_center = np.mean(points, axis=0) | |
| points_centered = points - scene_center | |
| cam_to_world[..., -1] -= scene_center | |
| # Store frame indices for filtering | |
| frame_indices = ( | |
| np.repeat(np.arange(S), H * W)[indices] | |
| if indices is not None | |
| else np.repeat(np.arange(S), H * W) | |
| ) | |
| # Build the viser GUI | |
| gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True) | |
| gui_points_conf = server.gui.add_slider( | |
| "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold | |
| ) | |
| gui_frame_selector = server.gui.add_dropdown( | |
| "Show Points from Frames", | |
| options=["All"] + [str(i) for i in range(S)], | |
| initial_value="All" | |
| ) | |
| # Create the main point cloud | |
| init_threshold_val = np.percentile(conf_flat, init_conf_threshold) | |
| init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1) | |
| point_cloud = server.scene.add_point_cloud( | |
| name="viser_pcd", | |
| points=points_centered[init_conf_mask], | |
| colors=colors_flat[init_conf_mask], | |
| point_size=0.0005, | |
| point_shape="circle", | |
| ) | |
| frames: List[viser.FrameHandle] = [] | |
| frustums: List[viser.CameraFrustumHandle] = [] | |
| def visualize_frames(extrinsics, images_: np.ndarray) -> None: | |
| """Add camera frames and frustums to the scene.""" | |
| for f in frames: | |
| f.remove() | |
| frames.clear() | |
| for fr in frustums: | |
| fr.remove() | |
| frustums.clear() | |
| def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: | |
| def _(_) -> None: | |
| for client in server.get_clients().values(): | |
| client.camera.wxyz = frame.wxyz | |
| client.camera.position = frame.position | |
| for img_id in tqdm(range(S)): | |
| cam2world_3x4 = extrinsics[img_id] | |
| T_world_camera = tf.SE3.from_matrix(cam2world_3x4) | |
| frame_axis = server.scene.add_frame( | |
| f"frame_{img_id}", | |
| wxyz=T_world_camera.rotation().wxyz, | |
| position=T_world_camera.translation(), | |
| axes_length=0.05, | |
| axes_radius=0.002, | |
| origin_radius=0.002, | |
| ) | |
| frames.append(frame_axis) | |
| img = images_[img_id] | |
| img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) | |
| h, w = img.shape[:2] | |
| fy = 1.1 * h | |
| fov = 2 * np.arctan2(h / 2, fy) | |
| frustum_cam = server.scene.add_camera_frustum( | |
| f"frame_{img_id}/frustum", | |
| fov=fov, | |
| aspect=w / h, | |
| scale=0.05, | |
| image=img, | |
| line_width=1.0 | |
| ) | |
| frustums.append(frustum_cam) | |
| attach_callback(frustum_cam, frame_axis) | |
| def update_point_cloud() -> None: | |
| """Update point cloud based on current GUI selections.""" | |
| current_percentage = gui_points_conf.value | |
| threshold_val = np.percentile(conf_flat, current_percentage) | |
| print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%") | |
| conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5) | |
| if gui_frame_selector.value == "All": | |
| frame_mask = np.ones_like(conf_mask, dtype=bool) | |
| else: | |
| selected_idx = int(gui_frame_selector.value) | |
| frame_mask = frame_indices == selected_idx | |
| combined_mask = conf_mask & frame_mask | |
| point_cloud.points = points_centered[combined_mask] | |
| point_cloud.colors = colors_flat[combined_mask] | |
| def _(_) -> None: | |
| update_point_cloud() | |
| def _(_) -> None: | |
| update_point_cloud() | |
| def _(_) -> None: | |
| for f in frames: | |
| f.visible = gui_show_frames.value | |
| for fr in frustums: | |
| fr.visible = gui_show_frames.value | |
| # Add camera frames | |
| import torch | |
| if torch.is_tensor(cam_to_world): | |
| cam_to_world_np = cam_to_world.cpu().numpy() | |
| else: | |
| cam_to_world_np = cam_to_world | |
| visualize_frames(cam_to_world_np, images) | |
| print("Starting viser server...") | |
| if background_mode: | |
| def server_loop(): | |
| while True: | |
| time.sleep(0.001) | |
| thread = threading.Thread(target=server_loop, daemon=True) | |
| thread.start() | |
| else: | |
| while True: | |
| time.sleep(0.01) | |
| return server | |