import gradio as gr import torch import torch.nn.functional as F import cv2 import numpy as np from pathlib import Path import matplotlib.pyplot as plt import tempfile import trimesh from realm.model_factory import REALM_creator from realm.utils.vis import VisMast3r, voxel_to_rgb_image # --------------------------------------------------------------------------- # Global State & Initialization # --------------------------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _MODEL_CACHE = {} def get_model(task_name: str): """Fetch the model from cache, or build it if it hasn't been loaded yet.""" config_map = { "Depth": "depth", "Segmentation": "segmentation", "Point Cloud & Correspondences": "mast3r" } config_name = config_map.get(task_name) if not config_name: raise ValueError(f"Unknown task: {task_name}") if config_name not in _MODEL_CACHE: print(f"Loading REALM model for '{config_name}'...") config_path = Path("configs") / f"{config_name}.yaml" model = REALM_creator(config_path).to(device) model.eval() _MODEL_CACHE[config_name] = model return _MODEL_CACHE[config_name] # --------------------------------------------------------------------------- # Preprocessing & Visualization Helpers # --------------------------------------------------------------------------- def image_to_normalized_tensor(img_rgb: np.ndarray) -> torch.Tensor: """Convert an RGB numpy array [0, 255] to a normalized PyTorch tensor.""" img_tensor = torch.from_numpy(img_rgb).float().permute(2, 0, 1) / 255.0 mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) return (img_tensor - mean) / std def parse_event_file(file_path: str) -> torch.Tensor: """Load a .npz or .npy event voxel and resize it to [1, C, 448, 448].""" if file_path.endswith('.npz'): with np.load(file_path) as data: voxel_np = data[data.files[0]] else: voxel_np = np.load(file_path) ev_tensor = torch.from_numpy(voxel_np).float() if ev_tensor.ndim == 3: ev_tensor = ev_tensor.unsqueeze(0) ev_tensor = F.interpolate(ev_tensor, size=(448, 448), mode="bilinear", align_corners=False) return ev_tensor def get_input_data(file_obj): """ Parses the upload and returns a tuple: (Model Tensor, Display Image [H, W, 3] uint8) """ file_path = file_obj.name ext = Path(file_path).suffix.lower() if ext in ['.npz', '.npy']: ev_tensor = parse_event_file(file_path) # Convert the event voxel grid into an RGB display image for visualization disp_img = voxel_to_rgb_image(ev_tensor.squeeze(0), 1.0) if isinstance(disp_img, torch.Tensor): disp_img = disp_img.cpu().numpy() if disp_img.max() <= 1.0: disp_img = disp_img * 255.0 disp_img = np.clip(disp_img, 0, 255).astype(np.uint8) return ev_tensor.to(device), disp_img else: img_bgr = cv2.imread(file_path) if img_bgr is None: raise gr.Error(f"Could not read the uploaded image: {Path(file_path).name}") img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img_resized = cv2.resize(img_rgb, (448, 448)) tensor = image_to_normalized_tensor(img_resized).unsqueeze(0).to(device) return tensor, img_resized def vis_depth(depth_map: np.ndarray, cmap_name: str = "magma") -> np.ndarray: """Visualise a depth map with logarithmic scaling.""" valid = (depth_map > 1e-6) & np.isfinite(depth_map) if valid.sum() == 0: return np.zeros((*depth_map.shape, 3), dtype=np.uint8) log_depth = np.zeros_like(depth_map) log_depth[valid] = np.log(depth_map[valid]) log_min = np.percentile(log_depth[valid], 2) log_max = np.percentile(log_depth[valid], 95) log_clipped = np.clip(log_depth, log_min, log_max) norm = np.zeros_like(depth_map) if log_max > log_min: norm[valid] = (log_clipped[valid] - log_min) / (log_max - log_min) else: norm[valid] = 0.5 colored = plt.get_cmap(cmap_name)(norm)[:, :, :3] colored[~valid] = 0.0 return (colored * 255).astype(np.uint8) def build_seg_color_map() -> np.ndarray: """Hardcoded 11-class Cityscapes color map.""" color_map = np.zeros((256, 3), dtype=np.uint8) mapping = { 0: (70, 130, 180), 1: (70, 70, 70), 2: (190, 153, 153), 3: (220, 20, 60), 4: (153, 153, 153), 5: (128, 64, 128), 6: (244, 35, 232), 7: (107, 142, 35), 8: (0, 0, 142), 9: (102, 102, 156), 10: (250, 170, 30) } for train_id, rgb in mapping.items(): color_map[train_id] = rgb return color_map SEG_COLOR_MAP = build_seg_color_map() def decode_seg_map(label_mask: np.ndarray) -> np.ndarray: """Map predicted label IDs to RGB colors. Void (255) is rendered black.""" safe_mask = label_mask.copy() safe_mask[safe_mask == 255] = 0 rgb = SEG_COLOR_MAP[safe_mask].copy() rgb[label_mask == 255] = 0 return rgb.astype(np.uint8) # --------------------------------------------------------------------------- # Task-Specific Inference # --------------------------------------------------------------------------- @torch.inference_mode() def run_depth(model, inp_tensor, disp_img): with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): pred_padded = model(inp_tensor, {"upsample": True, "H": 448, "W": 448}) pred_np = pred_padded.squeeze().float().cpu().numpy() depth_vis = vis_depth(pred_np, "magma") # Stitch the input and the prediction side-by-side combined_vis = np.hstack((disp_img, depth_vis)) return combined_vis @torch.inference_mode() def run_segmentation(model, inp_tensor, disp_img): with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): pred_out = model(inp_tensor, {"upsample": True, "H": 448, "W": 448}) pred_mask = torch.argmax(pred_out, dim=1).squeeze().cpu().numpy() seg_vis = decode_seg_map(pred_mask) # Stitch the input and the prediction side-by-side combined_vis = np.hstack((disp_img, seg_vis)) return combined_vis @torch.inference_mode() def run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp): # Forward Pass with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"): out1, out2 = model({'view1': view1_tensor, 'view2': view2_tensor}, {'H': 448, 'W': 448}) # --- 1. Correspondences Gallery Visualization --- try: disp1_bgr = cv2.cvtColor(view1_disp, cv2.COLOR_RGB2BGR) disp2_bgr = cv2.cvtColor(view2_disp, cv2.COLOR_RGB2BGR) data_in_vis = { 'view1': disp1_bgr, 'view2': disp2_bgr, 'pred1': out1, 'pred2': out2 } match_img_bgr = VisMast3r(data_in_vis, n_viz=50) match_img_rgb = cv2.cvtColor(match_img_bgr, cv2.COLOR_BGR2RGB) gallery_out = [match_img_rgb] except Exception as e: print(f"Visualization error: {e}") gallery_out = [] # --- 2. 3D Point Cloud Generation (.glb) --- try: gr.Info("Generating 3D point cloud.") # Extract points in a common reference frame pts3d_1 = out1['pts3d'].squeeze(0).float().cpu().numpy().reshape(-1, 3) pts3d_2 = out2['pts3d_in_other_view'].squeeze(0).float().cpu().numpy().reshape(-1, 3) # Flatten the display images to assign colors to each point colors_1 = view1_disp.reshape(-1, 3) colors_2 = view2_disp.reshape(-1, 3) # Basic filtering to remove invalid background depths (Z <= 0) valid_1 = np.isfinite(pts3d_1).all(axis=1) & (pts3d_1[:, 2] > 0) valid_2 = np.isfinite(pts3d_2).all(axis=1) & (pts3d_2[:, 2] > 0) pts_combined = np.vstack((pts3d_1[valid_1], pts3d_2[valid_2])) colors_combined = np.vstack((colors_1[valid_1], colors_2[valid_2])) gr.Info("Downsampling point cloud for web rendering.") max_points = 50000 # Safe limit for smooth web rendering if len(pts_combined) > max_points: # Randomly select indices without replacement indices = np.random.choice(len(pts_combined), max_points, replace=False) pts_combined = pts_combined[indices] colors_combined = colors_combined[indices] # ----------------------------------------------------------- # Build Trimesh point cloud pc = trimesh.PointCloud(pts_combined, colors=colors_combined) scene = trimesh.Scene([pc]) transform = np.array([ [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1] ]) scene.apply_transform(transform) tmp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False) scene.export(tmp_file.name) glb_path = tmp_file.name except Exception as e: print(f"Point cloud export error: {e}") glb_path = None return glb_path, gallery_out # --------------------------------------------------------------------------- # Gradio Router & UI Callbacks # --------------------------------------------------------------------------- def process_inference(task, view1_file, view2_file): gr.Info("Loading model weights... This may take a few minutes if downloading for the first time.") model = get_model(task) if view1_file is None: raise gr.Error("Please upload View 1.") # Automatically parse the input into a tensor and an RGB display image view1_tensor, view1_disp = get_input_data(view1_file) depth_out = gr.update(visible=False) seg_out = gr.update(visible=False) pc_out = gr.update(visible=False) match_out = gr.update(visible=False) if task == "Depth": res = run_depth(model, view1_tensor, view1_disp) depth_out = gr.update(value=res, visible=True) elif task == "Segmentation": res = run_segmentation(model, view1_tensor, view1_disp) seg_out = gr.update(value=res, visible=True) elif task == "Point Cloud & Correspondences": if view2_file is None: raise gr.Error("Matching requires View 2.") view2_tensor, view2_disp = get_input_data(view2_file) glb_path, matches = run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp) pc_out = gr.update(value=glb_path, visible=True) match_out = gr.update(value=matches, visible=True) return depth_out, seg_out, pc_out, match_out def update_ui_visibility(task): show_v2 = (task == "Point Cloud & Correspondences") return ( gr.update(visible=show_v2), gr.update(visible=(task == "Depth")), gr.update(visible=(task == "Segmentation")), gr.update(visible=(task == "Point Cloud & Correspondences")), gr.update(visible=(task == "Point Cloud & Correspondences")) ) # --------------------------------------------------------------------------- # Gradio Interface # --------------------------------------------------------------------------- EXAMPLE_DATA = [ { "task": "Depth", "v1": "examples/sample_depth.npz", "v2": None }, { "task": "Segmentation", "v1": "examples/sample_seg.npz", "v2": None }, { "task": "Point Cloud & Correspondences", "v1": "examples/matching_v1.jpg", "v2": "examples/matching_v2.npz" } ] css = """ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; } .match-gallery button, .match-gallery .image-container { aspect-ratio: 2 / 1 !important; } button.primary {background-color: #0E2841 !important; border-color: #0E2841 !important;} .loading-note {font-size: 0.85em; color: #666; text-align: center; margin-top: 5px;} .logo-container { display: flex; justify-content: center; align-items: center; margin-bottom: -15px; } .logo-img img { max-height: 120px !important; width: auto !important; object-fit: contain !important; background-color: transparent !important; } .example-gallery button, .example-gallery .image-container, .match-gallery button, .match-gallery .image-container { aspect-ratio: 2 / 1 !important; border-radius: 8px !important; transition: transform 0.2s ease; cursor: pointer; } .example-gallery button:hover, .example-gallery .image-container:hover { aspect-ratio: 2 / 1 !important; transform: scale(1.02); border: 2px solid #0E2841 !important; } .example-gallery, .example-gallery * { scrollbar-width: none !important; /* For Firefox */ -ms-overflow-style: none !important; /* For IE and Edge */ } .example-gallery ::-webkit-scrollbar { display: none !important; /* For Chrome, Safari, Opera */ } .example-gallery .grid-container { overflow-x: hidden !important; /* Prevents container scrolling */ } """ with gr.Blocks(title="REALM Inference Demo") as demo: with gr.Row(elem_classes="logo-container"): gr.Image( value="logo.png", show_label=False, interactive=False, container=False, elem_classes="logo-img" ) gr.Markdown("
This interactive demo showcases the capabilities of the REALM model for depth estimation, semantic segmentation, and 3D correspondence matching using RGB and event data: 📄 papers.starslab.ca/realm
""" ) with gr.Row(): with gr.Column(scale=1): task_dropdown = gr.Dropdown( choices=["Depth", "Segmentation", "Point Cloud & Correspondences"], value="Depth", label="Select Inference Task", interactive=True ) with gr.Row(): view1_input = gr.File(label="View 1 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"]) view2_input = gr.File(label="View 2 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"], visible=False) run_btn = gr.Button("Run Inference", variant="primary") gr.Markdown("*Note: The depth estimation and semantic segmentation models are specifically evaluated on street-driving scenarios from the MVSEC and DSEC environments. In contrast, the dense feature matching and 3D correspondence pipeline is fully generalizable and works across any arbitrary scene.*", elem_classes="loading-note") gr.Markdown("*Note: The interface may pause during the first run of each task while heavy model weights are downloaded from Hugging Face.*", elem_classes="loading-note") with gr.Column(scale=1): depth_output = gr.Image(label="Predicted Depth", visible=True) seg_output = gr.Image(label="Segmentation Mask", visible=False) pc_output = gr.Model3D(label="3D Point Cloud", visible=False) match_gallery = gr.Gallery(label="Correspondences", columns=1, visible=False) gr.Markdown("### Try it out with Sample Data") example_gallery = gr.Gallery( value=[ # Format: (Thumbnail Image Path, Display Caption) ("examples/thumbnail_depth.jpg", "Depth (Events)"), ("examples/thumbnail_seg.jpg", "Segmentation (Events)"), ("examples/thumbnail_match.jpg", "Matching (Events + RGB)") ], allow_preview=False, # Prevents opening the image in fullscreen columns=3, show_label=False, elem_classes="example-gallery" ) # 1. Event Handler to load the files based on what was clicked def load_example_data(evt: gr.SelectData): idx = evt.index data = EXAMPLE_DATA[idx] return data["task"], data["v1"], data["v2"] # 2. Chain the events together example_gallery.select( fn=load_example_data, inputs=None, outputs=[task_dropdown, view1_input, view2_input] ).then( # Update the UI visibility (hide/show View 2) fn=update_ui_visibility, inputs=[task_dropdown], outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] ).then( # Run the model fn=process_inference, inputs=[task_dropdown, view1_input, view2_input], outputs=[depth_output, seg_output, pc_output, match_gallery] ) # --- Standard Event Listeners (Keep these exactly as they are) --- task_dropdown.change( fn=update_ui_visibility, inputs=[task_dropdown], outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] ) # --- Event Listeners --- task_dropdown.change( fn=update_ui_visibility, inputs=[task_dropdown], outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery] ) run_btn.click( fn=process_inference, inputs=[task_dropdown, view1_input, view2_input], outputs=[depth_output, seg_output, pc_output, match_gallery] ) # --- Credits & Dataset Acknowledgements Footer --- gr.HTML(""" """) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, css=css, ssr_mode=False)