Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Interactive 3D Point Cloud Viewer using Viser. | |
| This module provides the PointCloudViewer class for visualizing 3D reconstruction results, | |
| including point clouds, camera poses, and animated playback. | |
| """ | |
| import os | |
| import time | |
| import threading | |
| import subprocess | |
| import tempfile | |
| import shutil | |
| from typing import List, Optional, Dict, Any, Tuple | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import matplotlib.cm as cm | |
| from tqdm.auto import tqdm | |
| import viser | |
| import viser.transforms as tf | |
| from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map | |
| from lingbot_map.vis.utils import CameraState | |
| from lingbot_map.vis.sky_segmentation import apply_sky_segmentation | |
| class PointCloudViewer: | |
| """ | |
| Interactive 3D point cloud viewer with camera visualization. | |
| Features: | |
| - Point cloud visualization with confidence-based filtering | |
| - Camera frustum visualization with gradient colors | |
| - Frame-by-frame playback animation (3D/4D modes) | |
| - Range-based and recent-N-frames visualization modes | |
| - Video export with FFmpeg | |
| Args: | |
| model: Optional model for interactive inference | |
| state_args: Optional state arguments | |
| pc_list: List of point clouds per frame | |
| color_list: List of colors per frame | |
| conf_list: List of confidence scores per frame | |
| cam_dict: Camera dictionary with focal, pp, R, t | |
| image_mask: Optional image mask | |
| edge_color_list: Optional edge colors | |
| device: Device for computation | |
| port: Viser server port | |
| show_camera: Whether to show camera frustums | |
| vis_threshold: Visibility threshold for filtering | |
| size: Image size | |
| downsample_factor: Point cloud downsample factor | |
| point_size: Initial point size | |
| pred_dict: Prediction dictionary (alternative to pc_list/color_list/conf_list) | |
| init_conf_threshold: Initial confidence threshold percentage | |
| use_point_map: Use point map instead of depth-based points | |
| mask_sky: Apply sky segmentation | |
| image_folder: Path to image folder (for sky segmentation) | |
| """ | |
| def __init__( | |
| self, | |
| model=None, | |
| state_args=None, | |
| pc_list=None, | |
| color_list=None, | |
| conf_list=None, | |
| cam_dict=None, | |
| image_mask=None, | |
| edge_color_list=None, | |
| device: str = "cpu", | |
| port: int = 8080, | |
| show_camera: bool = True, | |
| vis_threshold: float = 1.0, | |
| size: int = 512, | |
| downsample_factor: int = 10, | |
| point_size: float = 0.00001, | |
| pred_dict: Optional[Dict] = None, | |
| init_conf_threshold: float = 50.0, | |
| use_point_map: bool = False, | |
| mask_sky: bool = False, | |
| image_folder: Optional[str] = None, | |
| sky_mask_dir: Optional[str] = None, | |
| sky_mask_visualization_dir: Optional[str] = None, | |
| depth_stride: int = 1, | |
| ): | |
| self.model = model | |
| self.size = size | |
| self.state_args = state_args | |
| self.server = viser.ViserServer(host="0.0.0.0", port=port) | |
| self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") | |
| self.device = device | |
| self.conf_list = conf_list | |
| self.vis_threshold = vis_threshold | |
| self.point_size = point_size | |
| self.tt = lambda x: torch.from_numpy(x).float().to(device) | |
| # Process the prediction dictionary to create pc_list, color_list, conf_list | |
| if pred_dict is not None: | |
| pc_list, color_list, conf_list, cam_dict = self._process_pred_dict( | |
| pred_dict, use_point_map, mask_sky, image_folder, | |
| sky_mask_dir=sky_mask_dir, | |
| sky_mask_visualization_dir=sky_mask_visualization_dir, | |
| depth_stride=depth_stride, | |
| ) | |
| else: | |
| self.original_images = [] | |
| self.pcs, self.all_steps = self.read_data( | |
| pc_list, color_list, conf_list, edge_color_list | |
| ) | |
| self.cam_dict = cam_dict | |
| self.num_frames = len(self.all_steps) | |
| self.image_mask = image_mask | |
| self.show_camera = show_camera | |
| self.on_replay = False | |
| self.vis_pts_list = [] | |
| self.traj_list = [] | |
| self.orig_img_list = [x[0] for x in color_list if len(x) > 0] if color_list else [] | |
| self.via_points = [] | |
| self._setup_gui() | |
| self.server.on_client_connect(self._connect_client) | |
| def _process_pred_dict( | |
| self, | |
| pred_dict: Dict, | |
| use_point_map: bool, | |
| mask_sky: bool, | |
| image_folder: Optional[str], | |
| sky_mask_dir: Optional[str] = None, | |
| sky_mask_visualization_dir: Optional[str] = None, | |
| depth_stride: int = 1, | |
| ) -> Tuple[List, List, List, Dict]: | |
| """Process prediction dictionary to extract visualization data. | |
| Args: | |
| pred_dict: Model prediction dictionary. | |
| use_point_map: Use point map instead of depth-based projection. | |
| mask_sky: Apply sky segmentation to filter sky points. | |
| image_folder: Path to images for sky segmentation. | |
| sky_mask_dir: Directory for cached sky masks. | |
| sky_mask_visualization_dir: Directory for sky mask visualization images. | |
| depth_stride: Only project depth to point cloud every N frames. | |
| Frames not projected will have empty point clouds but still | |
| show camera frustums and images. 1 = every frame (default). | |
| """ | |
| images = pred_dict["images"] # (S, 3, H, W) | |
| depth_map = pred_dict.get("depth") # (S, H, W, 1) | |
| depth_conf = pred_dict.get("depth_conf") # (S, H, W) | |
| extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) | |
| intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) | |
| # Compute world points from depth if not using the precomputed point map | |
| if not use_point_map: | |
| world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) | |
| conf = depth_conf | |
| else: | |
| world_points = pred_dict["world_points"] # (S, H, W, 3) | |
| conf = pred_dict.get("world_points_conf", depth_conf) # (S, H, W) | |
| # Apply sky segmentation if enabled | |
| if mask_sky: | |
| conf = apply_sky_segmentation( | |
| conf, image_folder=image_folder, images=images, | |
| sky_mask_dir=sky_mask_dir, | |
| sky_mask_visualization_dir=sky_mask_visualization_dir, | |
| ) | |
| # Convert images from (S, 3, H, W) to (S, H, W, 3) | |
| colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) | |
| S = world_points.shape[0] | |
| # Store original images for camera frustum display | |
| self.original_images = [] | |
| for i in range(S): | |
| img = images[i] # shape (3, H, W) | |
| img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) | |
| self.original_images.append(img) | |
| # Create lists - apply depth_stride to skip frames for point projection | |
| H, W = world_points.shape[1], world_points.shape[2] | |
| pc_list = [] | |
| color_list = [] | |
| conf_list = [] | |
| skipped = 0 | |
| for i in range(S): | |
| if depth_stride > 1 and i % depth_stride != 0: | |
| # Empty point cloud for skipped frames | |
| pc_list.append(np.zeros((0, 0, 3), dtype=np.float32)) | |
| color_list.append(np.zeros((0, 0, 3), dtype=np.float32)) | |
| conf_list.append(np.zeros((0, 0), dtype=np.float32)) | |
| skipped += 1 | |
| else: | |
| pc_list.append(world_points[i]) | |
| color_list.append(colors[i]) | |
| if conf is not None: | |
| conf_list.append(conf[i]) | |
| else: | |
| conf_list.append(np.ones(world_points[i].shape[:2], dtype=np.float32)) | |
| if depth_stride > 1: | |
| print(f' depth_stride={depth_stride}: projecting {S - skipped}/{S} frames, skipping {skipped}') | |
| # Create camera dictionary (all frames keep cameras) | |
| cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) | |
| cam_dict = { | |
| "focal": [intrinsics_cam[i, 0, 0] for i in range(S)], | |
| "pp": [(intrinsics_cam[i, 0, 2], intrinsics_cam[i, 1, 2]) for i in range(S)], | |
| "R": [cam_to_world_mat[i, :3, :3] for i in range(S)], | |
| "t": [cam_to_world_mat[i, :3, 3] for i in range(S)], | |
| } | |
| return pc_list, color_list, conf_list, cam_dict | |
| def _compute_scene_center_and_scale(self) -> Tuple[np.ndarray, float]: | |
| """Compute scene center and scale from camera positions and point clouds. | |
| Returns: | |
| Tuple of (center as 3D array, scale as float distance). | |
| """ | |
| # Use camera positions as primary reference (more reliable than noisy points) | |
| if self.cam_dict is not None and "t" in self.cam_dict: | |
| cam_positions = np.array([self.cam_dict["t"][s] for s in self.all_steps]) | |
| center = np.mean(cam_positions, axis=0) | |
| if len(cam_positions) > 1: | |
| extent = np.ptp(cam_positions, axis=0) # range per axis | |
| scale = np.linalg.norm(extent) | |
| else: | |
| scale = 1.0 | |
| else: | |
| # Fallback: use point cloud data | |
| all_pts = [] | |
| for step in self.all_steps: | |
| pc = self.pcs[step]["pc"].reshape(-1, 3) | |
| # subsample for speed | |
| if len(pc) > 1000: | |
| pc = pc[::len(pc) // 1000] | |
| all_pts.append(pc) | |
| all_pts = np.concatenate(all_pts, axis=0) | |
| center = np.median(all_pts, axis=0) | |
| extent = np.percentile(all_pts, 95, axis=0) - np.percentile(all_pts, 5, axis=0) | |
| scale = np.linalg.norm(extent) | |
| return center, max(scale, 0.1) | |
| def _reset_view_to_direction( | |
| self, | |
| direction: np.ndarray, | |
| up: np.ndarray = np.array([0.0, -1.0, 0.0]), | |
| distance_scale: float = 1.5, | |
| smooth: bool = True, | |
| ): | |
| """Reset the viewer camera to look at scene center from a given direction. | |
| Args: | |
| direction: Unit vector pointing FROM the scene center TO the camera. | |
| up: Up vector for the camera. | |
| distance_scale: Multiplier on scene scale for camera distance. | |
| smooth: Whether to smoothly transition. | |
| """ | |
| center, scale = self._compute_scene_center_and_scale() | |
| distance = scale * distance_scale | |
| position = center + direction * distance | |
| for client in self.server.get_clients().values(): | |
| if smooth: | |
| self._smooth_camera_transition( | |
| client, | |
| target_position=position, | |
| target_look_at=center, | |
| target_up=up, | |
| duration=0.4, | |
| ) | |
| else: | |
| client.camera.up_direction = tuple(up) | |
| client.camera.position = tuple(position) | |
| client.camera.look_at = tuple(center) | |
| def _setup_gui(self): | |
| """Setup GUI controls.""" | |
| gui_reset_up = self.server.gui.add_button( | |
| "Reset up direction", | |
| hint="Set the camera control 'up' direction to the current camera's 'up'.", | |
| ) | |
| def _(event: viser.GuiEvent) -> None: | |
| client = event.client | |
| assert client is not None | |
| client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( | |
| [0.0, -1.0, 0.0] | |
| ) | |
| # Video frame display controls — kept at top so the current frame is always visible | |
| with self.server.gui.add_folder("Video Display"): | |
| self.show_video_checkbox = self.server.gui.add_checkbox("Show Current Frame", initial_value=True) | |
| if hasattr(self, 'original_images') and len(self.original_images) > 0: | |
| self.current_frame_image = self.server.gui.add_image( | |
| self.original_images[0], label="Current Frame" | |
| ) | |
| else: | |
| self.current_frame_image = None | |
| # Preset view direction buttons | |
| with self.server.gui.add_folder("Reset View Direction"): | |
| btn_look_at_center = self.server.gui.add_button( | |
| "Look At Scene Center", | |
| hint="Reset orbit center to the scene center (fixes orbit after dragging).", | |
| ) | |
| btn_overview = self.server.gui.add_button( | |
| "Overview", | |
| hint="Reset to a 3/4 overview of the scene.", | |
| ) | |
| btn_front = self.server.gui.add_button( | |
| "Front (+Z)", | |
| hint="View scene from the front.", | |
| ) | |
| btn_back = self.server.gui.add_button( | |
| "Back (-Z)", | |
| hint="View scene from the back.", | |
| ) | |
| btn_top = self.server.gui.add_button( | |
| "Top (-Y)", | |
| hint="View scene from above (bird's eye).", | |
| ) | |
| btn_left = self.server.gui.add_button( | |
| "Left (-X)", | |
| hint="View scene from the left.", | |
| ) | |
| btn_right = self.server.gui.add_button( | |
| "Right (+X)", | |
| hint="View scene from the right.", | |
| ) | |
| btn_first_cam = self.server.gui.add_button( | |
| "First Camera", | |
| hint="Reset to the first camera's viewpoint.", | |
| ) | |
| def _(_) -> None: | |
| center, _ = self._compute_scene_center_and_scale() | |
| for client in self.server.get_clients().values(): | |
| client.camera.look_at = tuple(center) | |
| def _(_) -> None: | |
| d = np.array([0.5, -0.6, 0.6]) | |
| self._reset_view_to_direction(d / np.linalg.norm(d)) | |
| def _(_) -> None: | |
| self._reset_view_to_direction(np.array([0.0, 0.0, 1.0])) | |
| def _(_) -> None: | |
| self._reset_view_to_direction(np.array([0.0, 0.0, -1.0])) | |
| def _(_) -> None: | |
| self._reset_view_to_direction( | |
| np.array([0.0, -1.0, 0.0]), | |
| up=np.array([0.0, 0.0, 1.0]), | |
| ) | |
| def _(_) -> None: | |
| self._reset_view_to_direction(np.array([-1.0, 0.0, 0.0])) | |
| def _(_) -> None: | |
| self._reset_view_to_direction(np.array([1.0, 0.0, 0.0])) | |
| def _(_) -> None: | |
| self._move_to_camera(0, smooth=True) | |
| button3 = self.server.gui.add_button("4D (Only Show Current Frame)") | |
| button4 = self.server.gui.add_button("3D (Show All Frames)") | |
| self.is_render = False | |
| self.fourd = False | |
| def _(event: viser.GuiEvent) -> None: | |
| self.fourd = True | |
| def _(event: viser.GuiEvent) -> None: | |
| self.fourd = False | |
| self.focal_slider = self.server.gui.add_slider( | |
| "Focal Length", min=0.1, max=99999, step=1, initial_value=533 | |
| ) | |
| self.psize_slider = self.server.gui.add_slider( | |
| "Point Size", min=0.00001, max=0.1, step=0.00001, initial_value=self.point_size | |
| ) | |
| self.camsize_slider = self.server.gui.add_slider( | |
| "Camera Size", min=0.01, max=0.5, step=0.01, initial_value=0.1 | |
| ) | |
| self.downsample_slider = self.server.gui.add_slider( | |
| "Downsample Factor", min=1, max=1000, step=1, initial_value=10 | |
| ) | |
| self.show_camera_checkbox = self.server.gui.add_checkbox( | |
| "Show Camera", initial_value=self.show_camera | |
| ) | |
| self.vis_threshold_slider = self.server.gui.add_slider( | |
| "Visibility Threshold", min=1.0, max=5.0, step=0.01, | |
| initial_value=self.vis_threshold, | |
| ) | |
| self.camera_downsample_slider = self.server.gui.add_slider( | |
| "Camera Downsample Factor", min=1, max=50, step=1, initial_value=1 | |
| ) | |
| # Screenshot controls | |
| with self.server.gui.add_folder("Screenshot"): | |
| self.screenshot_button = self.server.gui.add_button("Take Screenshot") | |
| self.screenshot_resolution = self.server.gui.add_dropdown( | |
| "Resolution", | |
| options=["1920x1080", "2560x1440", "3840x2160", "Current"], | |
| initial_value="1920x1080", | |
| ) | |
| self.screenshot_path = self.server.gui.add_text( | |
| "Save Path", initial_value="screenshot.png" | |
| ) | |
| self.screenshot_status = self.server.gui.add_text( | |
| "Status", initial_value="Ready" | |
| ) | |
| def _(event: viser.GuiEvent) -> None: | |
| self._take_screenshot(event.client) | |
| # GLB export controls | |
| with self.server.gui.add_folder("Export GLB"): | |
| self.glb_output_path = self.server.gui.add_text( | |
| "Output Path", initial_value="export.glb" | |
| ) | |
| self.glb_show_cam_checkbox = self.server.gui.add_checkbox( | |
| "Include Cameras", initial_value=True, | |
| ) | |
| self.glb_cam_scale_slider = self.server.gui.add_slider( | |
| "Camera Scale", min=0.01, max=5.0, step=0.01, initial_value=1.0, | |
| hint="Scale factor for camera size in GLB.", | |
| ) | |
| self.glb_frustum_thickness_slider = self.server.gui.add_slider( | |
| "Frustum Thickness", min=1.0, max=10.0, step=0.5, initial_value=3.0, | |
| hint="Thickness multiplier for camera frustum edges.", | |
| ) | |
| self.glb_trajectory_checkbox = self.server.gui.add_checkbox( | |
| "Show Trajectory", initial_value=True, | |
| hint="Connect cameras with a trajectory line.", | |
| ) | |
| self.glb_trajectory_radius_slider = self.server.gui.add_slider( | |
| "Trajectory Radius", min=0.001, max=0.05, step=0.001, initial_value=0.005, | |
| hint="Radius of the trajectory tube.", | |
| ) | |
| self.glb_mode_dropdown = self.server.gui.add_dropdown( | |
| "Export Mode", | |
| options=["Points", "Spheres"], | |
| initial_value="Points", | |
| hint="Points: raw (fast). Spheres: each point becomes a small sphere (prettier, slower).", | |
| ) | |
| self.glb_sphere_radius_slider = self.server.gui.add_slider( | |
| "Sphere Radius", min=0.001, max=0.1, step=0.001, initial_value=0.005, | |
| hint="Radius of each sphere in Spheres mode.", | |
| disabled=True, | |
| ) | |
| self.glb_max_sphere_pts_slider = self.server.gui.add_slider( | |
| "Max Sphere Points", min=10000, max=500000, step=10000, initial_value=100000, | |
| hint="Cap point count for Spheres mode to keep file size manageable.", | |
| disabled=True, | |
| ) | |
| self.glb_opacity_slider = self.server.gui.add_slider( | |
| "Opacity", min=0.0, max=1.0, step=0.05, initial_value=1.0, | |
| hint="Point/sphere opacity (alpha). <1.0 = semi-transparent.", | |
| ) | |
| self.glb_saturation_slider = self.server.gui.add_slider( | |
| "Saturation Boost", min=0.0, max=2.0, step=0.1, initial_value=1.0, | |
| hint="Color saturation multiplier. >1 = more vivid, <1 = washed out.", | |
| ) | |
| self.glb_brightness_slider = self.server.gui.add_slider( | |
| "Brightness Boost", min=0.5, max=2.0, step=0.1, initial_value=1.0, | |
| hint="Color brightness multiplier.", | |
| ) | |
| self.glb_export_button = self.server.gui.add_button( | |
| "Export GLB", | |
| hint="Export current filtered point clouds and cameras as GLB.", | |
| ) | |
| self.glb_status = self.server.gui.add_text("Status", initial_value="Ready") | |
| def _(_) -> None: | |
| is_sphere = self.glb_mode_dropdown.value == "Spheres" | |
| self.glb_sphere_radius_slider.disabled = not is_sphere | |
| self.glb_max_sphere_pts_slider.disabled = not is_sphere | |
| def _(_) -> None: | |
| self._export_glb() | |
| # Video saving controls | |
| with self.server.gui.add_folder("Video Saving"): | |
| self.save_video_button = self.server.gui.add_button("Save Video", disabled=False) | |
| self.video_output_path = self.server.gui.add_text("Output Path", initial_value="output_pointcloud.mp4") | |
| self.video_save_fps = self.server.gui.add_slider("Video FPS", min=10, max=60, step=1, initial_value=30) | |
| self.video_resolution = self.server.gui.add_dropdown( | |
| "Resolution", options=["1920x1080", "1280x720", "3840x2160"], initial_value="1920x1080" | |
| ) | |
| self.save_original_video_checkbox = self.server.gui.add_checkbox("Also Save Original Video", initial_value=True) | |
| self.video_status = self.server.gui.add_text("Status", initial_value="Ready to save") | |
| def _(_) -> None: | |
| self.save_video( | |
| output_path=self.video_output_path.value, | |
| fps=self.video_save_fps.value, | |
| resolution=self.video_resolution.value, | |
| save_original_video=self.save_original_video_checkbox.value | |
| ) | |
| def _(_) -> None: | |
| if self.current_frame_image is not None: | |
| self.current_frame_image.visible = self.show_video_checkbox.value | |
| self.pc_handles = [] | |
| self.cam_handles = [] | |
| def _(_) -> None: | |
| for handle in self.pc_handles: | |
| handle.point_size = self.psize_slider.value | |
| def _(_) -> None: | |
| for handle in self.cam_handles: | |
| handle.scale = self.camsize_slider.value | |
| handle.line_thickness = 0.03 * handle.scale | |
| def _(_) -> None: | |
| self._regenerate_point_clouds() | |
| def _(_) -> None: | |
| self.show_camera = self.show_camera_checkbox.value | |
| if self.show_camera: | |
| self._regenerate_cameras() | |
| else: | |
| for handle in self.cam_handles: | |
| handle.visible = False | |
| def _(_) -> None: | |
| self.vis_threshold = self.vis_threshold_slider.value | |
| self._regenerate_point_clouds() | |
| def _(_) -> None: | |
| self._regenerate_cameras() | |
| def _regenerate_point_clouds(self): | |
| """Regenerate all point clouds with current settings.""" | |
| if not hasattr(self, 'frame_nodes'): | |
| return | |
| for handle in self.pc_handles: | |
| try: | |
| handle.remove() | |
| except (KeyError, AttributeError): | |
| pass | |
| self.pc_handles.clear() | |
| self.vis_pts_list.clear() | |
| for i, step in enumerate(self.all_steps): | |
| pc = self.pcs[step]["pc"] | |
| color = self.pcs[step]["color"] | |
| conf = self.pcs[step]["conf"] | |
| edge_color = self.pcs[step].get("edge_color", None) | |
| pred_pts, pc_color = self.parse_pc_data( | |
| pc, color, conf, edge_color, set_border_color=True, | |
| downsample_factor=self.downsample_slider.value | |
| ) | |
| self.vis_pts_list.append(pred_pts) | |
| handle = self.server.scene.add_point_cloud( | |
| name=f"/frames/{step}/pred_pts", | |
| points=pred_pts, | |
| colors=pc_color, | |
| point_size=self.psize_slider.value, | |
| ) | |
| self.pc_handles.append(handle) | |
| def _regenerate_cameras(self): | |
| """Regenerate camera visualizations with current settings.""" | |
| if not hasattr(self, 'frame_nodes'): | |
| return | |
| for handle in self.cam_handles: | |
| try: | |
| handle.remove() | |
| except (KeyError, AttributeError): | |
| pass | |
| self.cam_handles.clear() | |
| if self.show_camera: | |
| downsample_factor = int(self.camera_downsample_slider.value) | |
| for i, step in enumerate(self.all_steps): | |
| if i % downsample_factor == 0: | |
| self.add_camera(step) | |
| def _export_glb(self): | |
| """Export current filtered point clouds and cameras as a GLB file.""" | |
| try: | |
| import trimesh | |
| except ImportError: | |
| self.glb_status.value = "Error: pip install trimesh" | |
| return | |
| self.glb_status.value = "Collecting points..." | |
| print("Exporting GLB...") | |
| # Collect all currently visible, filtered points and colors | |
| all_points = [] | |
| all_colors = [] | |
| for step in self.all_steps: | |
| pc = self.pcs[step]["pc"] | |
| color = self.pcs[step]["color"] | |
| conf = self.pcs[step]["conf"] | |
| edge_color = self.pcs[step].get("edge_color", None) | |
| pts, cols = self.parse_pc_data( | |
| pc, color, conf, edge_color, set_border_color=False, | |
| downsample_factor=self.downsample_slider.value, | |
| ) | |
| if len(pts) > 0: | |
| all_points.append(pts) | |
| if cols.dtype != np.uint8: | |
| cols = (np.clip(cols, 0, 1) * 255).astype(np.uint8) | |
| all_colors.append(cols) | |
| if not all_points: | |
| self.glb_status.value = "Error: no points to export" | |
| return | |
| vertices = np.concatenate(all_points, axis=0) | |
| colors_rgb = np.concatenate(all_colors, axis=0) | |
| # --- Color enhancement --- | |
| colors_float = colors_rgb.astype(np.float32) / 255.0 | |
| sat_boost = self.glb_saturation_slider.value | |
| if sat_boost != 1.0: | |
| gray = colors_float.mean(axis=1, keepdims=True) | |
| colors_float = gray + sat_boost * (colors_float - gray) | |
| bri_boost = self.glb_brightness_slider.value | |
| if bri_boost != 1.0: | |
| colors_float = colors_float * bri_boost | |
| colors_float = np.clip(colors_float, 0.0, 1.0) | |
| # --- Opacity --- | |
| # Simulate opacity by blending colors toward white (works in all viewers). | |
| # For Spheres mode, also set true alpha for viewers that support it. | |
| alpha = self.glb_opacity_slider.value | |
| if alpha < 1.0: | |
| bg = np.ones_like(colors_float) # white background | |
| colors_float = colors_float * alpha + bg * (1.0 - alpha) | |
| colors_float = np.clip(colors_float, 0.0, 1.0) | |
| colors_u8 = (colors_float * 255).astype(np.uint8) | |
| colors_rgba = np.concatenate([ | |
| colors_u8, | |
| np.full((len(colors_u8), 1), int(alpha * 255), dtype=np.uint8), | |
| ], axis=1) # (N, 4) | |
| # Compute scene scale for camera sizing | |
| lo = np.percentile(vertices, 5, axis=0) | |
| hi = np.percentile(vertices, 95, axis=0) | |
| scene_scale = max(np.linalg.norm(hi - lo), 0.1) | |
| scene_3d = trimesh.Scene() | |
| # --- Export mode --- | |
| export_mode = self.glb_mode_dropdown.value | |
| if export_mode == "Spheres": | |
| self.glb_status.value = "Building spheres..." | |
| max_pts = int(self.glb_max_sphere_pts_slider.value) | |
| radius = self.glb_sphere_radius_slider.value | |
| # Subsample if too many points | |
| if len(vertices) > max_pts: | |
| idx = np.random.choice(len(vertices), max_pts, replace=False) | |
| idx.sort() | |
| vertices = vertices[idx] | |
| colors_rgba = colors_rgba[idx] | |
| sphere_template = trimesh.creation.icosphere(subdivisions=1, radius=radius) | |
| n_verts_per = len(sphere_template.vertices) | |
| n_faces_per = len(sphere_template.faces) | |
| all_verts = np.empty((len(vertices) * n_verts_per, 3), dtype=np.float32) | |
| all_faces = np.empty((len(vertices) * n_faces_per, 3), dtype=np.int64) | |
| all_face_colors = np.empty((len(vertices) * n_faces_per, 4), dtype=np.uint8) | |
| for i, (pt, rgba) in enumerate(zip(vertices, colors_rgba)): | |
| v_off = i * n_verts_per | |
| f_off = i * n_faces_per | |
| all_verts[v_off:v_off + n_verts_per] = sphere_template.vertices + pt | |
| all_faces[f_off:f_off + n_faces_per] = sphere_template.faces + v_off | |
| all_face_colors[f_off:f_off + n_faces_per] = rgba | |
| mesh = trimesh.Trimesh(vertices=all_verts, faces=all_faces) | |
| mesh.visual.face_colors = all_face_colors | |
| # Enable alpha blending in glTF material for true transparency | |
| if alpha < 1.0: | |
| mesh.visual.material.alphaMode = 'BLEND' | |
| scene_3d.add_geometry(mesh) | |
| print(f"Spheres mode: {len(vertices):,} spheres, {len(all_faces):,} faces") | |
| else: | |
| # Points mode (GLB viewers ignore alpha on points, so use blended RGB) | |
| scene_3d.add_geometry(trimesh.PointCloud(vertices=vertices, colors=colors_u8)) | |
| # Add cameras and trajectory | |
| if self.glb_show_cam_checkbox.value and self.cam_dict is not None: | |
| from lingbot_map.vis.glb_export import integrate_camera_into_scene | |
| import matplotlib | |
| colormap = matplotlib.colormaps.get_cmap("gist_rainbow") | |
| num_cameras = len(self.all_steps) | |
| cam_positions = [] | |
| frustum_thickness = self.glb_frustum_thickness_slider.value | |
| effective_cam_scale = scene_scale * self.glb_cam_scale_slider.value | |
| for i, step in enumerate(self.all_steps): | |
| R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3) | |
| t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3) | |
| c2w = np.eye(4) | |
| c2w[:3, :3] = R | |
| c2w[:3, 3] = t | |
| cam_positions.append(np.array(t, dtype=np.float64)) | |
| rgba_c = colormap(i / max(num_cameras - 1, 1)) | |
| cam_color = tuple(int(255 * x) for x in rgba_c[:3]) | |
| integrate_camera_into_scene( | |
| scene_3d, c2w, cam_color, | |
| effective_cam_scale, | |
| frustum_thickness=frustum_thickness, | |
| ) | |
| # Add trajectory line as a tube connecting camera positions | |
| if self.glb_trajectory_checkbox.value and len(cam_positions) >= 2: | |
| traj_pts = np.array(cam_positions) | |
| traj_radius = self.glb_trajectory_radius_slider.value * self.glb_cam_scale_slider.value | |
| traj_mesh = self._build_trajectory_tube( | |
| traj_pts, traj_radius, colormap, num_cameras | |
| ) | |
| if traj_mesh is not None: | |
| scene_3d.add_geometry(traj_mesh) | |
| # Align scene using first camera extrinsic | |
| if self.cam_dict is not None and len(self.all_steps) > 0: | |
| from lingbot_map.vis.glb_export import apply_scene_alignment | |
| step0 = self.all_steps[0] | |
| R0 = self.cam_dict["R"][step0] if "R" in self.cam_dict else np.eye(3) | |
| t0 = self.cam_dict["t"][step0] if "t" in self.cam_dict else np.zeros(3) | |
| c2w_0 = np.eye(4) | |
| c2w_0[:3, :3] = R0 | |
| c2w_0[:3, 3] = t0 | |
| w2c_0 = np.linalg.inv(c2w_0) | |
| extrinsics = np.expand_dims(w2c_0, 0) | |
| scene_3d = apply_scene_alignment(scene_3d, extrinsics) | |
| output_path = self.glb_output_path.value | |
| scene_3d.export(output_path) | |
| n_pts = len(vertices) | |
| mode_str = f"spheres r={self.glb_sphere_radius_slider.value}" if export_mode == "Spheres" else "points" | |
| self.glb_status.value = f"Saved: {output_path} ({n_pts:,} {mode_str})" | |
| print(f"GLB exported to {output_path} ({n_pts:,} {mode_str})") | |
| def _build_trajectory_tube(positions, radius, colormap, num_cameras): | |
| """Build a tube mesh following camera trajectory with per-segment color. | |
| Args: | |
| positions: (N, 3) camera positions. | |
| radius: Tube radius. | |
| colormap: Matplotlib colormap for gradient coloring. | |
| num_cameras: Total number of cameras (for color normalization). | |
| Returns: | |
| trimesh.Trimesh or None. | |
| """ | |
| import trimesh | |
| segments = [] | |
| for i in range(len(positions) - 1): | |
| p0, p1 = positions[i], positions[i + 1] | |
| seg_len = np.linalg.norm(p1 - p0) | |
| if seg_len < 1e-8: | |
| continue | |
| # Create cylinder along Z, then transform | |
| cyl = trimesh.creation.cylinder(radius=radius, height=seg_len, sections=8) | |
| # Direction vector | |
| direction = (p1 - p0) / seg_len | |
| mid = (p0 + p1) / 2.0 | |
| # Build rotation: default cylinder is along Z | |
| z_axis = np.array([0.0, 0.0, 1.0]) | |
| v = np.cross(z_axis, direction) | |
| c = np.dot(z_axis, direction) | |
| if np.linalg.norm(v) < 1e-8: | |
| rot = np.eye(3) if c > 0 else np.diag([1, -1, -1]) | |
| else: | |
| vx = np.array([[0, -v[2], v[1]], | |
| [v[2], 0, -v[0]], | |
| [-v[1], v[0], 0]]) | |
| rot = np.eye(3) + vx + vx @ vx / (1.0 + c) | |
| transform = np.eye(4) | |
| transform[:3, :3] = rot | |
| transform[:3, 3] = mid | |
| cyl.apply_transform(transform) | |
| # Color: midpoint index | |
| t_color = (i + 0.5) / max(num_cameras - 1, 1) | |
| rgba = colormap(t_color) | |
| color_rgb = tuple(int(255 * x) for x in rgba[:3]) | |
| cyl.visual.face_colors[:, :3] = color_rgb | |
| segments.append(cyl) | |
| if not segments: | |
| return None | |
| return trimesh.util.concatenate(segments) | |
| def update_frame_visibility(self): | |
| """Show all frames up to the current timestep (or only the current one in 4D mode).""" | |
| if not hasattr(self, 'frame_nodes') or not hasattr(self, 'gui_timestep'): | |
| return | |
| current_timestep = self.gui_timestep.value | |
| for i, frame_node in enumerate(self.frame_nodes): | |
| frame_node.visible = ( | |
| i <= current_timestep if not self.fourd else i == current_timestep | |
| ) | |
| def _move_to_camera(self, frame_idx: int, smooth: bool = True): | |
| """Move viewer camera to match reconstructed camera at given frame.""" | |
| if self.cam_dict is None: | |
| return | |
| step = self.all_steps[frame_idx] if frame_idx < len(self.all_steps) else self.all_steps[-1] | |
| R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3) | |
| t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3) | |
| focal = self.cam_dict["focal"][step] if "focal" in self.cam_dict else 1.0 | |
| pp = self.cam_dict["pp"][step] if "pp" in self.cam_dict else (1.0, 1.0) | |
| offset = 0.5 | |
| viewing_dir = R[:, 2] # camera Z axis in world frame | |
| position = t - viewing_dir * offset | |
| look_at = t + viewing_dir * 0.5 # look slightly ahead of camera | |
| fov = 2 * np.arctan(pp[0] / focal) | |
| up = -R[:, 1] # camera -Y axis in world frame | |
| for client in self.server.get_clients().values(): | |
| if smooth: | |
| self._smooth_camera_transition( | |
| client, | |
| target_position=position, | |
| target_look_at=look_at, | |
| target_up=up, | |
| target_fov=fov, | |
| duration=0.3, | |
| ) | |
| else: | |
| client.camera.up_direction = tuple(up) | |
| client.camera.position = tuple(position) | |
| client.camera.look_at = tuple(look_at) | |
| if fov is not None: | |
| client.camera.fov = fov | |
| def _smooth_camera_transition( | |
| self, | |
| client, | |
| target_position, | |
| target_look_at=None, | |
| target_up=None, | |
| target_fov=None, | |
| duration=0.3, | |
| ): | |
| """Smoothly transition camera to target pose using look_at based control. | |
| Args: | |
| client: Viser client handle. | |
| target_position: Target camera position (3,). | |
| target_look_at: Target look-at point (3,). If None, keeps current. | |
| target_up: Target up direction (3,). If None, keeps current. | |
| target_fov: Target FOV. If None, keeps current. | |
| duration: Transition duration in seconds. | |
| """ | |
| def interpolate(): | |
| num_steps = 15 | |
| dt = duration / num_steps | |
| start_position = np.array(client.camera.position, dtype=np.float64) | |
| start_look_at = np.array(client.camera.look_at, dtype=np.float64) | |
| start_fov = client.camera.fov | |
| end_position = np.asarray(target_position, dtype=np.float64) | |
| end_look_at = np.asarray(target_look_at, dtype=np.float64) if target_look_at is not None else start_look_at | |
| # Set up direction once at the start (not interpolated to avoid flicker) | |
| if target_up is not None: | |
| client.camera.up_direction = tuple(np.asarray(target_up, dtype=np.float64)) | |
| for i in range(num_steps + 1): | |
| alpha = i / num_steps | |
| # Smooth ease-in-out | |
| alpha_smooth = alpha * alpha * (3 - 2 * alpha) | |
| interp_pos = start_position + (end_position - start_position) * alpha_smooth | |
| interp_look = start_look_at + (end_look_at - start_look_at) * alpha_smooth | |
| # Set position first (this auto-moves look_at), then override look_at | |
| client.camera.position = tuple(interp_pos) | |
| client.camera.look_at = tuple(interp_look) | |
| if target_fov is not None: | |
| interp_fov = start_fov + (target_fov - start_fov) * alpha_smooth | |
| client.camera.fov = interp_fov | |
| time.sleep(dt) | |
| thread = threading.Thread(target=interpolate, daemon=True) | |
| thread.start() | |
| def _slerp(self, q1, q2, t): | |
| """Spherical linear interpolation between quaternions.""" | |
| dot = np.dot(q1, q2) | |
| if abs(dot) > 0.9995: | |
| result = q1 + t * (q2 - q1) | |
| return result / np.linalg.norm(result) | |
| dot = np.clip(dot, -1.0, 1.0) | |
| theta_0 = np.arccos(dot) | |
| theta = theta_0 * t | |
| q2_orthogonal = q2 - q1 * dot | |
| q2_orthogonal = q2_orthogonal / np.linalg.norm(q2_orthogonal) | |
| return q1 * np.cos(theta) + q2_orthogonal * np.sin(theta) | |
| def get_camera_state(self, client: viser.ClientHandle) -> CameraState: | |
| """Get current camera state from client.""" | |
| camera = client.camera | |
| c2w = np.concatenate([ | |
| np.concatenate([tf.SO3(camera.wxyz).as_matrix(), camera.position[:, None]], 1), | |
| [[0, 0, 0, 1]], | |
| ], 0) | |
| return CameraState(fov=camera.fov, aspect=camera.aspect, c2w=c2w) | |
| def generate_pseudo_intrinsics(h: int, w: int) -> np.ndarray: | |
| """Generate pseudo intrinsics from image size.""" | |
| focal = (h**2 + w**2) ** 0.5 | |
| return np.array([[focal, 0, w // 2], [0, focal, h // 2], [0, 0, 1]]).astype(np.float32) | |
| def _connect_client(self, client: viser.ClientHandle): | |
| """Setup client connection callbacks.""" | |
| wxyz_panel = client.gui.add_text("wxyz:", f"{client.camera.wxyz}") | |
| position_panel = client.gui.add_text("position:", f"{client.camera.position}") | |
| fov_panel = client.gui.add_text( | |
| "fov:", f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}" | |
| ) | |
| aspect_panel = client.gui.add_text("aspect:", "1.0") | |
| def _(_: viser.CameraHandle): | |
| with self.server.atomic(): | |
| wxyz_panel.value = f"{client.camera.wxyz}" | |
| position_panel.value = f"{client.camera.position}" | |
| fov_panel.value = f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}" | |
| aspect_panel.value = "1.0" | |
| def set_color_border(image, border_width=5, color=[1, 0, 0]): | |
| """Add colored border to image.""" | |
| image[:border_width, :, 0] = color[0] | |
| image[:border_width, :, 1] = color[1] | |
| image[:border_width, :, 2] = color[2] | |
| image[-border_width:, :, 0] = color[0] | |
| image[-border_width:, :, 1] = color[1] | |
| image[-border_width:, :, 2] = color[2] | |
| image[:, :border_width, 0] = color[0] | |
| image[:, :border_width, 1] = color[1] | |
| image[:, :border_width, 2] = color[2] | |
| image[:, -border_width:, 0] = color[0] | |
| image[:, -border_width:, 1] = color[1] | |
| image[:, -border_width:, 2] = color[2] | |
| return image | |
| def read_data(self, pc_list, color_list, conf_list, edge_color_list=None): | |
| """Read and organize point cloud data.""" | |
| pcs = {} | |
| step_list = [] | |
| for i, pc in enumerate(pc_list): | |
| step = i | |
| pcs.update({ | |
| step: { | |
| "pc": pc, | |
| "color": color_list[i], | |
| "conf": conf_list[i], | |
| "edge_color": ( | |
| None if edge_color_list is None or edge_color_list[i] is None | |
| else edge_color_list[i] | |
| ), | |
| } | |
| }) | |
| step_list.append(step) | |
| # Generate camera gradient colors | |
| num_cameras = len(pc_list) | |
| if num_cameras > 1: | |
| normalized_indices = np.array(list(range(num_cameras))) / (num_cameras - 1) | |
| else: | |
| normalized_indices = np.array([0.0]) | |
| cmap = cm.get_cmap('viridis') | |
| self.camera_colors = cmap(normalized_indices) | |
| return pcs, step_list | |
| def parse_pc_data( | |
| self, | |
| pc, | |
| color, | |
| conf=None, | |
| edge_color=[0.251, 0.702, 0.902], | |
| set_border_color=False, | |
| downsample_factor=1, | |
| ): | |
| """Parse and filter point cloud data.""" | |
| pred_pts = pc.reshape(-1, 3) | |
| if set_border_color and edge_color is not None: | |
| color = self.set_color_border(color[0], color=edge_color) | |
| if np.isnan(color).any(): | |
| color = np.zeros((pred_pts.shape[0], 3)) | |
| color[:, 2] = 1 | |
| else: | |
| color = color.reshape(-1, 3) | |
| # Remove NaN / Inf points | |
| valid = np.isfinite(pred_pts).all(axis=1) | |
| if not valid.all(): | |
| pred_pts = pred_pts[valid] | |
| color = color[valid] | |
| if conf is not None: | |
| conf = conf.reshape(-1)[valid] | |
| # Confidence threshold filter | |
| if conf is not None: | |
| conf_flat = conf.reshape(-1) if conf.ndim > 1 else conf | |
| mask = conf_flat > self.vis_threshold | |
| pred_pts = pred_pts[mask] | |
| color = color[mask] | |
| if len(pred_pts) == 0: | |
| return pred_pts, color | |
| # Downsample | |
| if downsample_factor > 1 and len(pred_pts) > 0: | |
| indices = np.arange(0, len(pred_pts), downsample_factor) | |
| pred_pts = pred_pts[indices] | |
| color = color[indices] | |
| return pred_pts, color | |
| def add_pc(self, step): | |
| """Add point cloud for a frame.""" | |
| pc = self.pcs[step]["pc"] | |
| color = self.pcs[step]["color"] | |
| conf = self.pcs[step]["conf"] | |
| edge_color = self.pcs[step].get("edge_color", None) | |
| pred_pts, color = self.parse_pc_data( | |
| pc, color, conf, edge_color, set_border_color=True, | |
| downsample_factor=self.downsample_slider.value | |
| ) | |
| self.vis_pts_list.append(pred_pts) | |
| self.pc_handles.append( | |
| self.server.scene.add_point_cloud( | |
| name=f"/frames/{step}/pred_pts", | |
| points=pred_pts, | |
| colors=color, | |
| point_size=self.psize_slider.value, | |
| ) | |
| ) | |
| def add_camera(self, step): | |
| """Add camera visualization for a frame.""" | |
| cam = self.cam_dict | |
| focal = cam["focal"][step] if cam and "focal" in cam else 1.0 | |
| pp = cam["pp"][step] if cam and "pp" in cam else (1.0, 1.0) | |
| R = cam["R"][step] if cam and "R" in cam else np.eye(3) | |
| t = cam["t"][step] if cam and "t" in cam else np.zeros(3) | |
| q = tf.SO3.from_matrix(R).wxyz | |
| fov = 2 * np.arctan(pp[0] / focal) | |
| aspect = pp[0] / pp[1] | |
| self.traj_list.append((q, t)) | |
| step_index = self.all_steps.index(step) if step in self.all_steps else 0 | |
| camera_color = self.camera_colors[step_index] | |
| camera_color_rgb = tuple((camera_color[:3] * 255).astype(int)) | |
| self.server.scene.add_frame( | |
| f"/frames/{step}/camera_frame", | |
| wxyz=q, | |
| position=t, | |
| axes_length=0.05, | |
| axes_radius=0.002, | |
| origin_radius=0.002, | |
| ) | |
| frustum_handle = self.server.scene.add_camera_frustum( | |
| name=f"/frames/{step}/camera", | |
| fov=fov, | |
| aspect=aspect, | |
| wxyz=q, | |
| position=t, | |
| scale=0.03, | |
| color=camera_color_rgb, | |
| ) | |
| def _(event) -> None: | |
| look_at_pt = t + R[:, 2] * 0.5 # look ahead along camera Z | |
| up_dir = -R[:, 1] | |
| for client in self.server.get_clients().values(): | |
| client.camera.up_direction = tuple(up_dir) | |
| client.camera.position = tuple(t) | |
| client.camera.look_at = tuple(look_at_pt) | |
| self.cam_handles.append(frustum_handle) | |
| def animate(self): | |
| """Setup and run animation controls.""" | |
| with self.server.gui.add_folder("Playback"): | |
| self.gui_timestep = self.server.gui.add_slider( | |
| "Train Step", min=0, max=self.num_frames - 1, step=1, initial_value=0, disabled=False | |
| ) | |
| gui_next_frame = self.server.gui.add_button("Next Step", disabled=False) | |
| gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False) | |
| gui_playing = self.server.gui.add_checkbox("Playing", True) | |
| gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=20) | |
| gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60")) | |
| def _(_) -> None: | |
| self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames | |
| def _(_) -> None: | |
| self.gui_timestep.value = (self.gui_timestep.value - 1) % self.num_frames | |
| def _(_) -> None: | |
| self.gui_timestep.disabled = gui_playing.value | |
| gui_next_frame.disabled = gui_playing.value | |
| gui_prev_frame.disabled = gui_playing.value | |
| def _(_) -> None: | |
| gui_framerate.value = int(gui_framerate_options.value) | |
| prev_timestep = self.gui_timestep.value | |
| def _(_) -> None: | |
| nonlocal prev_timestep | |
| current_timestep = self.gui_timestep.value | |
| if self.current_frame_image is not None and hasattr(self, 'original_images'): | |
| if current_timestep < len(self.original_images): | |
| self.current_frame_image.image = self.original_images[current_timestep] | |
| with self.server.atomic(): | |
| self.frame_nodes[current_timestep].visible = True | |
| self.frame_nodes[prev_timestep].visible = False | |
| self.server.flush() | |
| prev_timestep = current_timestep | |
| self.server.scene.add_frame("/frames", show_axes=False) | |
| self.frame_nodes = [] | |
| for i in range(self.num_frames): | |
| step = self.all_steps[i] | |
| self.frame_nodes.append( | |
| self.server.scene.add_frame(f"/frames/{step}", show_axes=False) | |
| ) | |
| self.add_pc(step) | |
| if self.show_camera: | |
| downsample_factor = int(self.camera_downsample_slider.value) | |
| if i % downsample_factor == 0: | |
| self.add_camera(step) | |
| prev_timestep = self.gui_timestep.value | |
| while True: | |
| if self.on_replay: | |
| pass | |
| else: | |
| if gui_playing.value: | |
| self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames | |
| self.update_frame_visibility() | |
| time.sleep(1.0 / gui_framerate.value) | |
| def _take_screenshot(self, client: Optional[Any] = None): | |
| """Capture a screenshot from the current view and save to file. | |
| Args: | |
| client: The viser client that triggered the action. If None, | |
| uses the first connected client. | |
| """ | |
| output_path = self.screenshot_path.value | |
| res_str = self.screenshot_resolution.value | |
| # Resolve client | |
| if client is None: | |
| clients = list(self.server.get_clients().values()) | |
| if not clients: | |
| self.screenshot_status.value = "Error: no client connected" | |
| return | |
| client = clients[0] | |
| try: | |
| self.screenshot_status.value = "Capturing..." | |
| if res_str == "Current": | |
| # Use default render size | |
| width, height = 1920, 1080 | |
| else: | |
| width, height = map(int, res_str.split("x")) | |
| render = client.camera.get_render(height=height, width=width) | |
| if render is not None: | |
| frame = np.array(render) | |
| if frame.shape[2] == 4: | |
| frame = frame[:, :, :3] | |
| frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(output_path, frame_bgr) | |
| self.screenshot_status.value = f"Saved: {output_path}" | |
| print(f"Screenshot saved to {output_path} ({width}x{height})") | |
| else: | |
| self.screenshot_status.value = "Error: render returned None" | |
| print("Screenshot failed: render returned None") | |
| except Exception as e: | |
| self.screenshot_status.value = f"Error: {e}" | |
| print(f"Screenshot error: {e}") | |
| def save_video( | |
| self, | |
| output_path: str = "output_pointcloud.mp4", | |
| fps: int = 30, | |
| resolution: str = "1920x1080", | |
| save_original_video: bool = True | |
| ): | |
| """Save point cloud animation as video.""" | |
| try: | |
| if hasattr(self, 'video_status'): | |
| self.video_status.value = "Saving video..." | |
| print(f"Saving video to {output_path}...") | |
| width, height = map(int, resolution.split('x')) | |
| temp_dir = tempfile.mkdtemp(prefix="viser_video_") | |
| print(f"Temporary directory: {temp_dir}") | |
| print("Waiting for client connection...") | |
| timeout = 10 | |
| start_time = time.time() | |
| while len(self.server.get_clients()) == 0: | |
| time.sleep(0.1) | |
| if time.time() - start_time > timeout: | |
| raise RuntimeError("No client connected. Please open the visualization in a browser first.") | |
| print("Client connected. Starting to render frames...") | |
| clients = list(self.server.get_clients().values()) | |
| client = clients[0] | |
| if not hasattr(self, 'gui_timestep'): | |
| raise RuntimeError("Animation not initialized. Please ensure animate() is called before save_video().") | |
| for i in tqdm(range(self.num_frames), desc="Rendering frames"): | |
| self.gui_timestep.value = i | |
| time.sleep(0.1) | |
| try: | |
| screenshot = client.camera.get_render(height=height, width=width) | |
| if screenshot is not None: | |
| frame = np.array(screenshot) | |
| if frame.shape[2] == 4: | |
| frame = frame[:, :, :3] | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") | |
| cv2.imwrite(frame_path, frame) | |
| else: | |
| frame = self._render_frame_fallback(i, width, height) | |
| frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") | |
| cv2.imwrite(frame_path, frame) | |
| except Exception as e: | |
| print(f"Warning: Error capturing frame {i}: {e}, using fallback") | |
| frame = self._render_frame_fallback(i, width, height) | |
| frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") | |
| cv2.imwrite(frame_path, frame) | |
| print("Encoding video with ffmpeg...") | |
| ffmpeg_cmd = [ | |
| 'ffmpeg', '-y', '-framerate', str(fps), | |
| '-i', os.path.join(temp_dir, 'frame_%06d.png'), | |
| '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18', | |
| output_path | |
| ] | |
| result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True) | |
| if result.returncode == 0: | |
| print(f"Point cloud video saved successfully to {output_path}") | |
| if hasattr(self, 'video_status'): | |
| self.video_status.value = f"Saved to {output_path}" | |
| else: | |
| print(f"FFmpeg error: {result.stderr}") | |
| if hasattr(self, 'video_status'): | |
| self.video_status.value = "Error: FFmpeg failed" | |
| if save_original_video and hasattr(self, 'original_images') and len(self.original_images) > 0: | |
| self._save_original_video(output_path, fps, width, height) | |
| shutil.rmtree(temp_dir) | |
| print("Temporary files cleaned up") | |
| except Exception as e: | |
| print(f"Error saving video: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if hasattr(self, 'video_status'): | |
| self.video_status.value = f"Error: {str(e)}" | |
| def _save_original_video(self, pointcloud_video_path: str, fps: int, width: int, height: int): | |
| """Save original images as video.""" | |
| base_path = os.path.splitext(pointcloud_video_path)[0] | |
| original_video_path = f"{base_path}_original.mp4" | |
| print(f"Saving original images video to {original_video_path}...") | |
| try: | |
| temp_dir = tempfile.mkdtemp(prefix="original_video_") | |
| for i, img in enumerate(tqdm(self.original_images, desc="Saving original frames")): | |
| frame = cv2.resize(img, (width, height)) | |
| if len(frame.shape) == 3 and frame.shape[2] == 3: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") | |
| cv2.imwrite(frame_path, frame) | |
| print("Encoding original video with ffmpeg...") | |
| ffmpeg_cmd = [ | |
| 'ffmpeg', '-y', '-framerate', str(fps), | |
| '-i', os.path.join(temp_dir, 'frame_%06d.png'), | |
| '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18', | |
| original_video_path | |
| ] | |
| result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True) | |
| if result.returncode == 0: | |
| print(f"Original video saved successfully to {original_video_path}") | |
| else: | |
| print(f"FFmpeg error for original video: {result.stderr}") | |
| shutil.rmtree(temp_dir) | |
| except Exception as e: | |
| print(f"Error saving original video: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def _render_frame_fallback(self, frame_idx: int, width: int, height: int) -> np.ndarray: | |
| """Fallback rendering when screenshot capture fails.""" | |
| if hasattr(self, 'original_images') and frame_idx < len(self.original_images): | |
| frame = self.original_images[frame_idx].copy() | |
| frame = cv2.resize(frame, (width, height)) | |
| cv2.putText(frame, f"Frame {frame_idx}", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |
| return frame | |
| else: | |
| frame = np.zeros((height, width, 3), dtype=np.uint8) | |
| cv2.putText(frame, f"Frame {frame_idx} - No render available", | |
| (width//4, height//2), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| return frame | |
| def run(self, background_mode: bool = False): | |
| """Run the viewer.""" | |
| self.animate() | |
| if background_mode: | |
| def server_loop(): | |
| while True: | |
| time.sleep(0.001) | |
| thread = threading.Thread(target=server_loop, daemon=True) | |
| thread.start() | |
| else: | |
| while True: | |
| time.sleep(10.0) | |