import gradio as gr import torch import torch.nn.functional as F import cv2 import numpy as np from pathlib import Path import matplotlib.pyplot as plt import tempfile import trimesh from realm.model_factory import REALM_creator from realm.utils.vis import VisMast3r, voxel_to_rgb_image # --------------------------------------------------------------------------- # Global State & Initialization # --------------------------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _MODEL_CACHE = {} def get_model(task_name: str): """Fetch the model from cache, or build it if it hasn't been loaded yet.""" config_map = { "Depth": "depth", "Segmentation": "segmentation", "Point Cloud & Correspondences": "mast3r" } config_name = config_map.get(task_name) if not config_name: raise ValueError(f"Unknown task: {task_name}") if config_name not in _MODEL_CACHE: print(f"Loading REALM model for '{config_name}'...") config_path = Path("configs") / f"{config_name}.yaml" model = REALM_creator(config_path).to(device) model.eval() _MODEL_CACHE[config_name] = model return _MODEL_CACHE[config_name] # --------------------------------------------------------------------------- # Preprocessing & Visualization Helpers # --------------------------------------------------------------------------- def image_to_normalized_tensor(img_rgb: np.ndarray) -> torch.Tensor: """Convert an RGB numpy array [0, 255] to a normalized PyTorch tensor.""" img_tensor = torch.from_numpy(img_rgb).float().permute(2, 0, 1) / 255.0 mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) return (img_tensor - mean) / std def parse_event_file(file_path: str) -> torch.Tensor: """Load a .npz or .npy event voxel and resize it to [1, C, 448, 448].""" if file_path.endswith('.npz'): with np.load(file_path) as data: voxel_np = data[data.files[0]] else: voxel_np = np.load(file_path) ev_tensor = torch.from_numpy(voxel_np).float() if ev_tensor.ndim == 3: ev_tensor = ev_tensor.unsqueeze(0) ev_tensor = F.interpolate(ev_tensor, size=(448, 448), mode="bilinear", align_corners=False) return ev_tensor def get_input_data(file_obj): """ Parses the upload and returns a tuple: (Model Tensor, Display Image [H, W, 3] uint8) """ file_path = file_obj.name ext = Path(file_path).suffix.lower() if ext in ['.npz', '.npy']: ev_tensor = parse_event_file(file_path) # Convert the event voxel grid into an RGB display image for visualization disp_img = voxel_to_rgb_image(ev_tensor.squeeze(0), 1.0) if isinstance(disp_img, torch.Tensor): disp_img = disp_img.cpu().numpy() if disp_img.max() <= 1.0: disp_img = disp_img * 255.0 disp_img = np.clip(disp_img, 0, 255).astype(np.uint8) return ev_tensor.to(device), disp_img else: img_bgr = cv2.imread(file_path) if img_bgr is None: raise gr.Error(f"Could not read the uploaded image: {Path(file_path).name}") img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img_resized = cv2.resize(img_rgb, (448, 448)) tensor = image_to_normalized_tensor(img_resized).unsqueeze(0).to(device) return tensor, img_resized def vis_depth(depth_map: np.ndarray, cmap_name: str = "magma") -> np.ndarray: """Visualise a depth map with logarithmic scaling.""" valid = (depth_map > 1e-6) & np.isfinite(depth_map) if valid.sum() == 0: return np.zeros((*depth_map.shape, 3), dtype=np.uint8) log_depth = np.zeros_like(depth_map) log_depth[valid] = np.log(depth_map[valid]) log_min = np.percentile(log_depth[valid], 2) log_max = np.percentile(log_depth[valid], 95) log_clipped = np.clip(log_depth, log_min, log_max) norm = np.zeros_like(depth_map) if log_max > log_min: norm[valid] = (log_clipped[valid] - log_min) / (log_max - log_min) else: norm[valid] = 0.5 colored = plt.get_cmap(cmap_name)(norm)[:, :, :3] colored[~valid] = 0.0 return (colored * 255).astype(np.uint8) def build_seg_color_map() -> np.ndarray: """Hardcoded 11-class Cityscapes color map.""" color_map = np.zeros((256, 3), dtype=np.uint8) mapping = { 0: (70, 130, 180), 1: (70, 70, 70), 2: (190, 153, 153), 3: (220, 20, 60), 4: (153, 153, 153), 5: (128, 64, 128), 6: (244, 35, 232), 7: (107, 142, 35), 8: (0, 0, 142), 9: (102, 102, 156), 10: (250, 170, 30) } for train_id, rgb in mapping.items(): color_map[train_id] = rgb return color_map SEG_COLOR_MAP = build_seg_color_map() def decode_seg_map(label_mask: np.ndarray) -> np.ndarray: """Map predicted label IDs to RGB colors. Void (255) is rendered black.""" safe_mask = label_mask.copy() safe_mask[safe_mask == 255] = 0 rgb = SEG_COLOR_MAP[safe_mask].copy() rgb[label_mask == 255] = 0 return rgb.astype(np.uint8) # --------------------------------------------------------------------------- # Task-Specific Inference # --------------------------------------------------------------------------- @torch.inference_mode() def run_depth(model, inp_tensor, disp_img): with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): pred_padded = model(inp_tensor, {"upsample": True, "H": 448, "W": 448}) pred_np = pred_padded.squeeze().float().cpu().numpy() depth_vis = vis_depth(pred_np, "magma") # Stitch the input and the prediction side-by-side combined_vis = np.hstack((disp_img, depth_vis)) return combined_vis @torch.inference_mode() def run_segmentation(model, inp_tensor, disp_img): with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): pred_out = model(inp_tensor, {"upsample": True, "H": 448, "W": 448}) pred_mask = torch.argmax(pred_out, dim=1).squeeze().cpu().numpy() seg_vis = decode_seg_map(pred_mask) # Stitch the input and the prediction side-by-side combined_vis = np.hstack((disp_img, seg_vis)) return combined_vis @torch.inference_mode() def run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp): # Forward Pass with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): out1, out2 = model({'view1': view1_tensor, 'view2': view2_tensor}, {'H': 448, 'W': 448}) # --- 1. Correspondences Gallery Visualization --- try: disp1_bgr = cv2.cvtColor(view1_disp, cv2.COLOR_RGB2BGR) disp2_bgr = cv2.cvtColor(view2_disp, cv2.COLOR_RGB2BGR) data_in_vis = { 'view1': disp1_bgr, 'view2': disp2_bgr, 'pred1': out1, 'pred2': out2 } match_img_bgr = VisMast3r(data_in_vis, n_viz=50) match_img_rgb = cv2.cvtColor(match_img_bgr, cv2.COLOR_BGR2RGB) gallery_out = [match_img_rgb] except Exception as e: print(f"Visualization error: {e}") gallery_out = [] # --- 2. 3D Point Cloud Generation (.glb) --- try: gr.Info("Generating 3D point cloud.") # Extract points in a common reference frame pts3d_1 = out1['pts3d'].squeeze(0).float().cpu().numpy().reshape(-1, 3) pts3d_2 = out2['pts3d_in_other_view'].squeeze(0).float().cpu().numpy().reshape(-1, 3) # Flatten the display images to assign colors to each point colors_1 = view1_disp.reshape(-1, 3) colors_2 = view2_disp.reshape(-1, 3) # Basic filtering to remove invalid background depths (Z <= 0) valid_1 = np.isfinite(pts3d_1).all(axis=1) & (pts3d_1[:, 2] > 0) valid_2 = np.isfinite(pts3d_2).all(axis=1) & (pts3d_2[:, 2] > 0) pts_combined = np.vstack((pts3d_1[valid_1], pts3d_2[valid_2])) colors_combined = np.vstack((colors_1[valid_1], colors_2[valid_2])) gr.Info("Downsampling point cloud for web rendering.") max_points = 50000 # Safe limit for smooth web rendering if len(pts_combined) > max_points: # Randomly select indices without replacement indices = np.random.choice(len(pts_combined), max_points, replace=False) pts_combined = pts_combined[indices] colors_combined = colors_combined[indices] # ----------------------------------------------------------- # Build Trimesh point cloud pc = trimesh.PointCloud(pts_combined, colors=colors_combined) scene = trimesh.Scene([pc]) transform = np.array([ [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1] ]) scene.apply_transform(transform) tmp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False) scene.export(tmp_file.name) glb_path = tmp_file.name except Exception as e: print(f"Point cloud export error: {e}") glb_path = None return glb_path, gallery_out # --------------------------------------------------------------------------- # Gradio Router & UI Callbacks # --------------------------------------------------------------------------- def process_inference(task, view1_file, view2_file): gr.Info("Loading model weights... This may take a few minutes if downloading for the first time.") model = get_model(task) if view1_file is None: raise gr.Error("Please upload View 1.") # Automatically parse the input into a tensor and an RGB display image view1_tensor, view1_disp = get_input_data(view1_file) depth_out = gr.update(visible=False) seg_out = gr.update(visible=False) pc_out = gr.update(visible=False) match_out = gr.update(visible=False) if task == "Depth": res = run_depth(model, view1_tensor, view1_disp) depth_out = gr.update(value=res, visible=True) elif task == "Segmentation": res = run_segmentation(model, view1_tensor, view1_disp) seg_out = gr.update(value=res, visible=True) elif task == "Point Cloud & Correspondences": if view2_file is None: raise gr.Error("Matching requires View 2.") view2_tensor, view2_disp = get_input_data(view2_file) glb_path, matches = run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp) pc_out = gr.update(value=glb_path, visible=True) match_out = gr.update(value=matches, visible=True) return depth_out, seg_out, pc_out, match_out def update_ui_visibility(task): show_v2 = (task == "Point Cloud & Correspondences") return ( gr.update(visible=show_v2), gr.update(visible=(task == "Depth")), gr.update(visible=(task == "Segmentation")), gr.update(visible=(task == "Point Cloud & Correspondences")), gr.update(visible=(task == "Point Cloud & Correspondences")) ) # --------------------------------------------------------------------------- # Gradio Interface # --------------------------------------------------------------------------- EXAMPLE_DATA = [ { "task": "Depth", "v1": "examples/sample_depth.npz", "v2": None }, { "task": "Segmentation", "v1": "examples/sample_seg.npz", "v2": None }, { "task": "Point Cloud & Correspondences", "v1": "examples/matching_v1.jpg", "v2": "examples/matching_v2.npz" } ] css = """ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; } .match-gallery button, .match-gallery .image-container { aspect-ratio: 2 / 1 !important; } button.primary {background-color: #0E2841 !important; border-color: #0E2841 !important;} .loading-note {font-size: 0.85em; color: #666; text-align: center; margin-top: 5px;} .logo-container { display: flex; justify-content: center; align-items: center; margin-bottom: -15px; } .logo-img img { max-height: 120px !important; width: auto !important; object-fit: contain !important; background-color: transparent !important; } .example-gallery button, .example-gallery .image-container, .match-gallery button, .match-gallery .image-container { aspect-ratio: 2 / 1 !important; border-radius: 8px !important; transition: transform 0.2s ease; cursor: pointer; } .example-gallery button:hover, .example-gallery .image-container:hover { aspect-ratio: 2 / 1 !important; transform: scale(1.02); border: 2px solid #0E2841 !important; } .example-gallery, .example-gallery * { scrollbar-width: none !important; /* For Firefox */ -ms-overflow-style: none !important; /* For IE and Edge */ } .example-gallery ::-webkit-scrollbar { display: none !important; /* For Chrome, Safari, Opera */ } .example-gallery .grid-container { overflow-x: hidden !important; /* Prevents container scrolling */ } """ with gr.Blocks(title="REALM Inference Demo") as demo: with gr.Row(elem_classes="logo-container"): gr.Image( value="logo.png", show_label=False, interactive=False, container=False, elem_classes="logo-img" ) gr.Markdown("

An RGB and Event Aligned Latent Manifold for Cross-Modal Perception

") gr.Markdown( """

This interactive demo showcases the capabilities of the REALM model for depth estimation, semantic segmentation, and 3D correspondence matching using RGB and event data: 📄 papers.starslab.ca/realm

""" ) with gr.Row(): with gr.Column(scale=1): task_dropdown = gr.Dropdown( choices=["Depth", "Segmentation", "Point Cloud & Correspondences"], value="Depth", label="Select Inference Task", interactive=True ) with gr.Row(): view1_input = gr.File(label="View 1 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"]) view2_input = gr.File(label="View 2 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"], visible=False) run_btn = gr.Button("Run Inference", variant="primary") gr.Markdown("*Note: The depth estimation and semantic segmentation models are specifically evaluated on street-driving scenarios from the MVSEC and DSEC environments. In contrast, the dense feature matching and 3D correspondence pipeline is fully generalizable and works across any arbitrary scene.*", elem_classes="loading-note") gr.Markdown("*Note: The interface may pause during the first run of each task while heavy model weights are downloaded from Hugging Face.*", elem_classes="loading-note") with gr.Column(scale=1): depth_output = gr.Image(label="Predicted Depth", visible=True) seg_output = gr.Image(label="Segmentation Mask", visible=False) pc_output = gr.Model3D(label="3D Point Cloud", visible=False) match_gallery = gr.Gallery(label="Correspondences", columns=1, visible=False) gr.Markdown("### Try it out with Sample Data") example_gallery = gr.Gallery( value=[ # Format: (Thumbnail Image Path, Display Caption) ("examples/thumbnail_depth.jpg", "Depth (Events)"), ("examples/thumbnail_seg.jpg", "Segmentation (Events)"), ("examples/thumbnail_match.jpg", "Matching (Events + RGB)") ], allow_preview=False, # Prevents opening the image in fullscreen columns=3, show_label=False, elem_classes="example-gallery" ) # 1. Event Handler to load the files based on what was clicked def load_example_data(evt: gr.SelectData): idx = evt.index data = EXAMPLE_DATA[idx] return data["task"], data["v1"], data["v2"] # 2. Chain the events together example_gallery.select( fn=load_example_data, inputs=None, outputs=[task_dropdown, view1_input, view2_input] ).then( # Update the UI visibility (hide/show View 2) fn=update_ui_visibility, inputs=[task_dropdown], outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] ).then( # Run the model fn=process_inference, inputs=[task_dropdown, view1_input, view2_input], outputs=[depth_output, seg_output, pc_output, match_gallery] ) # --- Standard Event Listeners (Keep these exactly as they are) --- task_dropdown.change( fn=update_ui_visibility, inputs=[task_dropdown], outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] ) # --- Event Listeners --- task_dropdown.change( fn=update_ui_visibility, inputs=[task_dropdown], outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] ) run_btn.click( fn=process_inference, inputs=[task_dropdown, view1_input, view2_input], outputs=[depth_output, seg_output, pc_output, match_gallery] ) # --- Credits & Dataset Acknowledgements Footer --- gr.HTML("""

This interactive demonstration utilizes the DUNE encoder for image processing and the MAST3R decoder for dense feature matching and 3D correspondence.

The sample data provided in the examples are derived from the following open-source datasets:

VECtor Benchmark | MVSEC | DSEC
""") if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, css=css, ssr_mode=False)