Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import tempfile | |
| import trimesh | |
| from realm.model_factory import REALM_creator | |
| from realm.utils.vis import VisMast3r, voxel_to_rgb_image | |
| # --------------------------------------------------------------------------- | |
| # Global State & Initialization | |
| # --------------------------------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _MODEL_CACHE = {} | |
| def get_model(task_name: str): | |
| """Fetch the model from cache, or build it if it hasn't been loaded yet.""" | |
| config_map = { | |
| "Depth": "depth", | |
| "Segmentation": "segmentation", | |
| "Point Cloud & Correspondences": "mast3r" | |
| } | |
| config_name = config_map.get(task_name) | |
| if not config_name: | |
| raise ValueError(f"Unknown task: {task_name}") | |
| if config_name not in _MODEL_CACHE: | |
| print(f"Loading REALM model for '{config_name}'...") | |
| config_path = Path("configs") / f"{config_name}.yaml" | |
| model = REALM_creator(config_path).to(device) | |
| model.eval() | |
| _MODEL_CACHE[config_name] = model | |
| return _MODEL_CACHE[config_name] | |
| # --------------------------------------------------------------------------- | |
| # Preprocessing & Visualization Helpers | |
| # --------------------------------------------------------------------------- | |
| def image_to_normalized_tensor(img_rgb: np.ndarray) -> torch.Tensor: | |
| """Convert an RGB numpy array [0, 255] to a normalized PyTorch tensor.""" | |
| img_tensor = torch.from_numpy(img_rgb).float().permute(2, 0, 1) / 255.0 | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| return (img_tensor - mean) / std | |
| def parse_event_file(file_path: str) -> torch.Tensor: | |
| """Load a .npz or .npy event voxel and resize it to [1, C, 448, 448].""" | |
| if file_path.endswith('.npz'): | |
| with np.load(file_path) as data: | |
| voxel_np = data[data.files[0]] | |
| else: | |
| voxel_np = np.load(file_path) | |
| ev_tensor = torch.from_numpy(voxel_np).float() | |
| if ev_tensor.ndim == 3: | |
| ev_tensor = ev_tensor.unsqueeze(0) | |
| ev_tensor = F.interpolate(ev_tensor, size=(448, 448), mode="bilinear", align_corners=False) | |
| return ev_tensor | |
| def get_input_data(file_obj): | |
| """ | |
| Parses the upload and returns a tuple: (Model Tensor, Display Image [H, W, 3] uint8) | |
| """ | |
| file_path = file_obj.name | |
| ext = Path(file_path).suffix.lower() | |
| if ext in ['.npz', '.npy']: | |
| ev_tensor = parse_event_file(file_path) | |
| # Convert the event voxel grid into an RGB display image for visualization | |
| disp_img = voxel_to_rgb_image(ev_tensor.squeeze(0), 1.0) | |
| if isinstance(disp_img, torch.Tensor): | |
| disp_img = disp_img.cpu().numpy() | |
| if disp_img.max() <= 1.0: | |
| disp_img = disp_img * 255.0 | |
| disp_img = np.clip(disp_img, 0, 255).astype(np.uint8) | |
| return ev_tensor.to(device), disp_img | |
| else: | |
| img_bgr = cv2.imread(file_path) | |
| if img_bgr is None: | |
| raise gr.Error(f"Could not read the uploaded image: {Path(file_path).name}") | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| img_resized = cv2.resize(img_rgb, (448, 448)) | |
| tensor = image_to_normalized_tensor(img_resized).unsqueeze(0).to(device) | |
| return tensor, img_resized | |
| def vis_depth(depth_map: np.ndarray, cmap_name: str = "magma") -> np.ndarray: | |
| """Visualise a depth map with logarithmic scaling.""" | |
| valid = (depth_map > 1e-6) & np.isfinite(depth_map) | |
| if valid.sum() == 0: | |
| return np.zeros((*depth_map.shape, 3), dtype=np.uint8) | |
| log_depth = np.zeros_like(depth_map) | |
| log_depth[valid] = np.log(depth_map[valid]) | |
| log_min = np.percentile(log_depth[valid], 2) | |
| log_max = np.percentile(log_depth[valid], 95) | |
| log_clipped = np.clip(log_depth, log_min, log_max) | |
| norm = np.zeros_like(depth_map) | |
| if log_max > log_min: | |
| norm[valid] = (log_clipped[valid] - log_min) / (log_max - log_min) | |
| else: | |
| norm[valid] = 0.5 | |
| colored = plt.get_cmap(cmap_name)(norm)[:, :, :3] | |
| colored[~valid] = 0.0 | |
| return (colored * 255).astype(np.uint8) | |
| def build_seg_color_map() -> np.ndarray: | |
| """Hardcoded 11-class Cityscapes color map.""" | |
| color_map = np.zeros((256, 3), dtype=np.uint8) | |
| mapping = { | |
| 0: (70, 130, 180), 1: (70, 70, 70), 2: (190, 153, 153), | |
| 3: (220, 20, 60), 4: (153, 153, 153), 5: (128, 64, 128), | |
| 6: (244, 35, 232), 7: (107, 142, 35), 8: (0, 0, 142), | |
| 9: (102, 102, 156), 10: (250, 170, 30) | |
| } | |
| for train_id, rgb in mapping.items(): | |
| color_map[train_id] = rgb | |
| return color_map | |
| SEG_COLOR_MAP = build_seg_color_map() | |
| def decode_seg_map(label_mask: np.ndarray) -> np.ndarray: | |
| """Map predicted label IDs to RGB colors. Void (255) is rendered black.""" | |
| safe_mask = label_mask.copy() | |
| safe_mask[safe_mask == 255] = 0 | |
| rgb = SEG_COLOR_MAP[safe_mask].copy() | |
| rgb[label_mask == 255] = 0 | |
| return rgb.astype(np.uint8) | |
| # --------------------------------------------------------------------------- | |
| # Task-Specific Inference | |
| # --------------------------------------------------------------------------- | |
| def run_depth(model, inp_tensor, disp_img): | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): | |
| pred_padded = model(inp_tensor, {"upsample": True, "H": 448, "W": 448}) | |
| pred_np = pred_padded.squeeze().float().cpu().numpy() | |
| depth_vis = vis_depth(pred_np, "magma") | |
| # Stitch the input and the prediction side-by-side | |
| combined_vis = np.hstack((disp_img, depth_vis)) | |
| return combined_vis | |
| def run_segmentation(model, inp_tensor, disp_img): | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): | |
| pred_out = model(inp_tensor, {"upsample": True, "H": 448, "W": 448}) | |
| pred_mask = torch.argmax(pred_out, dim=1).squeeze().cpu().numpy() | |
| seg_vis = decode_seg_map(pred_mask) | |
| # Stitch the input and the prediction side-by-side | |
| combined_vis = np.hstack((disp_img, seg_vis)) | |
| return combined_vis | |
| def run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp): | |
| # Forward Pass | |
| with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): | |
| out1, out2 = model({'view1': view1_tensor, 'view2': view2_tensor}, {'H': 448, 'W': 448}) | |
| # --- 1. Correspondences Gallery Visualization --- | |
| try: | |
| disp1_bgr = cv2.cvtColor(view1_disp, cv2.COLOR_RGB2BGR) | |
| disp2_bgr = cv2.cvtColor(view2_disp, cv2.COLOR_RGB2BGR) | |
| data_in_vis = { | |
| 'view1': disp1_bgr, | |
| 'view2': disp2_bgr, | |
| 'pred1': out1, | |
| 'pred2': out2 | |
| } | |
| match_img_bgr = VisMast3r(data_in_vis, n_viz=50) | |
| match_img_rgb = cv2.cvtColor(match_img_bgr, cv2.COLOR_BGR2RGB) | |
| gallery_out = [match_img_rgb] | |
| except Exception as e: | |
| print(f"Visualization error: {e}") | |
| gallery_out = [] | |
| # --- 2. 3D Point Cloud Generation (.glb) --- | |
| try: | |
| gr.Info("Generating 3D point cloud.") | |
| # Extract points in a common reference frame | |
| pts3d_1 = out1['pts3d'].squeeze(0).float().cpu().numpy().reshape(-1, 3) | |
| pts3d_2 = out2['pts3d_in_other_view'].squeeze(0).float().cpu().numpy().reshape(-1, 3) | |
| # Flatten the display images to assign colors to each point | |
| colors_1 = view1_disp.reshape(-1, 3) | |
| colors_2 = view2_disp.reshape(-1, 3) | |
| # Basic filtering to remove invalid background depths (Z <= 0) | |
| valid_1 = np.isfinite(pts3d_1).all(axis=1) & (pts3d_1[:, 2] > 0) | |
| valid_2 = np.isfinite(pts3d_2).all(axis=1) & (pts3d_2[:, 2] > 0) | |
| pts_combined = np.vstack((pts3d_1[valid_1], pts3d_2[valid_2])) | |
| colors_combined = np.vstack((colors_1[valid_1], colors_2[valid_2])) | |
| gr.Info("Downsampling point cloud for web rendering.") | |
| max_points = 50000 # Safe limit for smooth web rendering | |
| if len(pts_combined) > max_points: | |
| # Randomly select indices without replacement | |
| indices = np.random.choice(len(pts_combined), max_points, replace=False) | |
| pts_combined = pts_combined[indices] | |
| colors_combined = colors_combined[indices] | |
| # ----------------------------------------------------------- | |
| # Build Trimesh point cloud | |
| pc = trimesh.PointCloud(pts_combined, colors=colors_combined) | |
| scene = trimesh.Scene([pc]) | |
| transform = np.array([ | |
| [1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| scene.apply_transform(transform) | |
| tmp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False) | |
| scene.export(tmp_file.name) | |
| glb_path = tmp_file.name | |
| except Exception as e: | |
| print(f"Point cloud export error: {e}") | |
| glb_path = None | |
| return glb_path, gallery_out | |
| # --------------------------------------------------------------------------- | |
| # Gradio Router & UI Callbacks | |
| # --------------------------------------------------------------------------- | |
| def process_inference(task, view1_file, view2_file): | |
| gr.Info("Loading model weights... This may take a few minutes if downloading for the first time.") | |
| model = get_model(task) | |
| if view1_file is None: | |
| raise gr.Error("Please upload View 1.") | |
| # Automatically parse the input into a tensor and an RGB display image | |
| view1_tensor, view1_disp = get_input_data(view1_file) | |
| depth_out = gr.update(visible=False) | |
| seg_out = gr.update(visible=False) | |
| pc_out = gr.update(visible=False) | |
| match_out = gr.update(visible=False) | |
| if task == "Depth": | |
| res = run_depth(model, view1_tensor, view1_disp) | |
| depth_out = gr.update(value=res, visible=True) | |
| elif task == "Segmentation": | |
| res = run_segmentation(model, view1_tensor, view1_disp) | |
| seg_out = gr.update(value=res, visible=True) | |
| elif task == "Point Cloud & Correspondences": | |
| if view2_file is None: | |
| raise gr.Error("Matching requires View 2.") | |
| view2_tensor, view2_disp = get_input_data(view2_file) | |
| glb_path, matches = run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp) | |
| pc_out = gr.update(value=glb_path, visible=True) | |
| match_out = gr.update(value=matches, visible=True) | |
| return depth_out, seg_out, pc_out, match_out | |
| def update_ui_visibility(task): | |
| show_v2 = (task == "Point Cloud & Correspondences") | |
| return ( | |
| gr.update(visible=show_v2), | |
| gr.update(visible=(task == "Depth")), | |
| gr.update(visible=(task == "Segmentation")), | |
| gr.update(visible=(task == "Point Cloud & Correspondences")), | |
| gr.update(visible=(task == "Point Cloud & Correspondences")) | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Gradio Interface | |
| # --------------------------------------------------------------------------- | |
| EXAMPLE_DATA = [ | |
| { | |
| "task": "Depth", | |
| "v1": "examples/sample_depth.npz", | |
| "v2": None | |
| }, | |
| { | |
| "task": "Segmentation", | |
| "v1": "examples/sample_seg.npz", | |
| "v2": None | |
| }, | |
| { | |
| "task": "Point Cloud & Correspondences", | |
| "v1": "examples/matching_v1.jpg", | |
| "v2": "examples/matching_v2.npz" | |
| } | |
| ] | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| } | |
| .match-gallery button, .match-gallery .image-container { | |
| aspect-ratio: 2 / 1 !important; | |
| } | |
| button.primary {background-color: #0E2841 !important; border-color: #0E2841 !important;} | |
| .loading-note {font-size: 0.85em; color: #666; text-align: center; margin-top: 5px;} | |
| .logo-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| margin-bottom: -15px; | |
| } | |
| .logo-img img { | |
| max-height: 120px !important; | |
| width: auto !important; | |
| object-fit: contain !important; | |
| background-color: transparent !important; | |
| } | |
| .example-gallery button, .example-gallery .image-container, | |
| .match-gallery button, .match-gallery .image-container { | |
| aspect-ratio: 2 / 1 !important; | |
| border-radius: 8px !important; | |
| transition: transform 0.2s ease; | |
| cursor: pointer; | |
| } | |
| .example-gallery button:hover, .example-gallery .image-container:hover { | |
| aspect-ratio: 2 / 1 !important; | |
| transform: scale(1.02); | |
| border: 2px solid #0E2841 !important; | |
| } | |
| .example-gallery, .example-gallery * { | |
| scrollbar-width: none !important; /* For Firefox */ | |
| -ms-overflow-style: none !important; /* For IE and Edge */ | |
| } | |
| .example-gallery ::-webkit-scrollbar { | |
| display: none !important; /* For Chrome, Safari, Opera */ | |
| } | |
| .example-gallery .grid-container { | |
| overflow-x: hidden !important; /* Prevents container scrolling */ | |
| } | |
| """ | |
| with gr.Blocks(title="REALM Inference Demo") as demo: | |
| with gr.Row(elem_classes="logo-container"): | |
| gr.Image( | |
| value="logo.png", | |
| show_label=False, | |
| interactive=False, | |
| container=False, | |
| elem_classes="logo-img" | |
| ) | |
| gr.Markdown("<h1 style='text-align: center; color: #FFFFF; margin-top: 0;'>An RGB and Event Aligned Latent Manifold for Cross-Modal Perception</h1>") | |
| gr.Markdown( | |
| """ | |
| <p style='text-align: center; font-size: 1.1em; margin-top: 0;'> | |
| This interactive demo showcases the capabilities of the REALM model for depth estimation, semantic segmentation, and 3D correspondence matching using RGB and event data: | |
| <a href='https://papers.starslab.ca/realm' target='_blank'> | |
| 📄 papers.starslab.ca/realm | |
| </a> | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_dropdown = gr.Dropdown( | |
| choices=["Depth", "Segmentation", "Point Cloud & Correspondences"], | |
| value="Depth", | |
| label="Select Inference Task", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| view1_input = gr.File(label="View 1 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"]) | |
| view2_input = gr.File(label="View 2 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"], visible=False) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| gr.Markdown("*Note: The depth estimation and semantic segmentation models are specifically evaluated on street-driving scenarios from the MVSEC and DSEC environments. In contrast, the dense feature matching and 3D correspondence pipeline is fully generalizable and works across any arbitrary scene.*", elem_classes="loading-note") | |
| gr.Markdown("*Note: The interface may pause during the first run of each task while heavy model weights are downloaded from Hugging Face.*", elem_classes="loading-note") | |
| with gr.Column(scale=1): | |
| depth_output = gr.Image(label="Predicted Depth", visible=True) | |
| seg_output = gr.Image(label="Segmentation Mask", visible=False) | |
| pc_output = gr.Model3D(label="3D Point Cloud", visible=False) | |
| match_gallery = gr.Gallery(label="Correspondences", columns=1, visible=False) | |
| gr.Markdown("### Try it out with Sample Data") | |
| example_gallery = gr.Gallery( | |
| value=[ | |
| # Format: (Thumbnail Image Path, Display Caption) | |
| ("examples/thumbnail_depth.jpg", "Depth (Events)"), | |
| ("examples/thumbnail_seg.jpg", "Segmentation (Events)"), | |
| ("examples/thumbnail_match.jpg", "Matching (Events + RGB)") | |
| ], | |
| allow_preview=False, # Prevents opening the image in fullscreen | |
| columns=3, | |
| show_label=False, | |
| elem_classes="example-gallery" | |
| ) | |
| # 1. Event Handler to load the files based on what was clicked | |
| def load_example_data(evt: gr.SelectData): | |
| idx = evt.index | |
| data = EXAMPLE_DATA[idx] | |
| return data["task"], data["v1"], data["v2"] | |
| # 2. Chain the events together | |
| example_gallery.select( | |
| fn=load_example_data, | |
| inputs=None, | |
| outputs=[task_dropdown, view1_input, view2_input] | |
| ).then( | |
| # Update the UI visibility (hide/show View 2) | |
| fn=update_ui_visibility, | |
| inputs=[task_dropdown], | |
| outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] | |
| ).then( | |
| # Run the model | |
| fn=process_inference, | |
| inputs=[task_dropdown, view1_input, view2_input], | |
| outputs=[depth_output, seg_output, pc_output, match_gallery] | |
| ) | |
| # --- Standard Event Listeners (Keep these exactly as they are) --- | |
| task_dropdown.change( | |
| fn=update_ui_visibility, | |
| inputs=[task_dropdown], | |
| outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] | |
| ) | |
| # --- Event Listeners --- | |
| task_dropdown.change( | |
| fn=update_ui_visibility, | |
| inputs=[task_dropdown], | |
| outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] | |
| ) | |
| run_btn.click( | |
| fn=process_inference, | |
| inputs=[task_dropdown, view1_input, view2_input], | |
| outputs=[depth_output, seg_output, pc_output, match_gallery] | |
| ) | |
| # --- Credits & Dataset Acknowledgements Footer --- | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 40px; padding: 20px; border-top: 1px solid #e5e7eb; color: #6b7280; font-size: 1.0em;"> | |
| <p>This interactive demonstration utilizes the <strong><a href="https://github.com/naver/dune">DUNE</a></strong> encoder for image processing and the <strong><a href="https://github.com/naver/mast3r">MAST3R</a></strong> decoder for dense feature matching and 3D correspondence.</p> | |
| <p>The sample data provided in the examples are derived from the following open-source datasets:</p> | |
| <div style="display: flex; justify-content: center; gap: 20px; margin-top: 10px;"> | |
| <a href="https://star-datasets.github.io/vector/" target="_blank" style="color: #FFFFF; text-decoration: none; font-weight: 500;">VECtor Benchmark</a> | | |
| <a href="https://daniilidis-group.github.io/mvsec/" target="_blank" style="color: #FFFFF; text-decoration: none; font-weight: 500;">MVSEC</a> | | |
| <a href="https://github.com/uzh-rpg/DSEC" target="_blank" style="color: #FFFFF; text-decoration: none; font-weight: 500;">DSEC</a> | |
| </div> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| css=css, | |
| ssr_mode=False) |