REALM-demo / app.py
viciopoli's picture
Update app.py
ff50e16 verified
import gradio as gr
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import tempfile
import trimesh
from realm.model_factory import REALM_creator
from realm.utils.vis import VisMast3r, voxel_to_rgb_image
# ---------------------------------------------------------------------------
# Global State & Initialization
# ---------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_MODEL_CACHE = {}
def get_model(task_name: str):
"""Fetch the model from cache, or build it if it hasn't been loaded yet."""
config_map = {
"Depth": "depth",
"Segmentation": "segmentation",
"Point Cloud & Correspondences": "mast3r"
}
config_name = config_map.get(task_name)
if not config_name:
raise ValueError(f"Unknown task: {task_name}")
if config_name not in _MODEL_CACHE:
print(f"Loading REALM model for '{config_name}'...")
config_path = Path("configs") / f"{config_name}.yaml"
model = REALM_creator(config_path).to(device)
model.eval()
_MODEL_CACHE[config_name] = model
return _MODEL_CACHE[config_name]
# ---------------------------------------------------------------------------
# Preprocessing & Visualization Helpers
# ---------------------------------------------------------------------------
def image_to_normalized_tensor(img_rgb: np.ndarray) -> torch.Tensor:
"""Convert an RGB numpy array [0, 255] to a normalized PyTorch tensor."""
img_tensor = torch.from_numpy(img_rgb).float().permute(2, 0, 1) / 255.0
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return (img_tensor - mean) / std
def parse_event_file(file_path: str) -> torch.Tensor:
"""Load a .npz or .npy event voxel and resize it to [1, C, 448, 448]."""
if file_path.endswith('.npz'):
with np.load(file_path) as data:
voxel_np = data[data.files[0]]
else:
voxel_np = np.load(file_path)
ev_tensor = torch.from_numpy(voxel_np).float()
if ev_tensor.ndim == 3:
ev_tensor = ev_tensor.unsqueeze(0)
ev_tensor = F.interpolate(ev_tensor, size=(448, 448), mode="bilinear", align_corners=False)
return ev_tensor
def get_input_data(file_obj):
"""
Parses the upload and returns a tuple: (Model Tensor, Display Image [H, W, 3] uint8)
"""
file_path = file_obj.name
ext = Path(file_path).suffix.lower()
if ext in ['.npz', '.npy']:
ev_tensor = parse_event_file(file_path)
# Convert the event voxel grid into an RGB display image for visualization
disp_img = voxel_to_rgb_image(ev_tensor.squeeze(0), 1.0)
if isinstance(disp_img, torch.Tensor):
disp_img = disp_img.cpu().numpy()
if disp_img.max() <= 1.0:
disp_img = disp_img * 255.0
disp_img = np.clip(disp_img, 0, 255).astype(np.uint8)
return ev_tensor.to(device), disp_img
else:
img_bgr = cv2.imread(file_path)
if img_bgr is None:
raise gr.Error(f"Could not read the uploaded image: {Path(file_path).name}")
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, (448, 448))
tensor = image_to_normalized_tensor(img_resized).unsqueeze(0).to(device)
return tensor, img_resized
def vis_depth(depth_map: np.ndarray, cmap_name: str = "magma") -> np.ndarray:
"""Visualise a depth map with logarithmic scaling."""
valid = (depth_map > 1e-6) & np.isfinite(depth_map)
if valid.sum() == 0:
return np.zeros((*depth_map.shape, 3), dtype=np.uint8)
log_depth = np.zeros_like(depth_map)
log_depth[valid] = np.log(depth_map[valid])
log_min = np.percentile(log_depth[valid], 2)
log_max = np.percentile(log_depth[valid], 95)
log_clipped = np.clip(log_depth, log_min, log_max)
norm = np.zeros_like(depth_map)
if log_max > log_min:
norm[valid] = (log_clipped[valid] - log_min) / (log_max - log_min)
else:
norm[valid] = 0.5
colored = plt.get_cmap(cmap_name)(norm)[:, :, :3]
colored[~valid] = 0.0
return (colored * 255).astype(np.uint8)
def build_seg_color_map() -> np.ndarray:
"""Hardcoded 11-class Cityscapes color map."""
color_map = np.zeros((256, 3), dtype=np.uint8)
mapping = {
0: (70, 130, 180), 1: (70, 70, 70), 2: (190, 153, 153),
3: (220, 20, 60), 4: (153, 153, 153), 5: (128, 64, 128),
6: (244, 35, 232), 7: (107, 142, 35), 8: (0, 0, 142),
9: (102, 102, 156), 10: (250, 170, 30)
}
for train_id, rgb in mapping.items():
color_map[train_id] = rgb
return color_map
SEG_COLOR_MAP = build_seg_color_map()
def decode_seg_map(label_mask: np.ndarray) -> np.ndarray:
"""Map predicted label IDs to RGB colors. Void (255) is rendered black."""
safe_mask = label_mask.copy()
safe_mask[safe_mask == 255] = 0
rgb = SEG_COLOR_MAP[safe_mask].copy()
rgb[label_mask == 255] = 0
return rgb.astype(np.uint8)
# ---------------------------------------------------------------------------
# Task-Specific Inference
# ---------------------------------------------------------------------------
@torch.inference_mode()
def run_depth(model, inp_tensor, disp_img):
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"):
pred_padded = model(inp_tensor, {"upsample": True, "H": 448, "W": 448})
pred_np = pred_padded.squeeze().float().cpu().numpy()
depth_vis = vis_depth(pred_np, "magma")
# Stitch the input and the prediction side-by-side
combined_vis = np.hstack((disp_img, depth_vis))
return combined_vis
@torch.inference_mode()
def run_segmentation(model, inp_tensor, disp_img):
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"):
pred_out = model(inp_tensor, {"upsample": True, "H": 448, "W": 448})
pred_mask = torch.argmax(pred_out, dim=1).squeeze().cpu().numpy()
seg_vis = decode_seg_map(pred_mask)
# Stitch the input and the prediction side-by-side
combined_vis = np.hstack((disp_img, seg_vis))
return combined_vis
@torch.inference_mode()
def run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp):
# Forward Pass
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"):
out1, out2 = model({'view1': view1_tensor, 'view2': view2_tensor}, {'H': 448, 'W': 448})
# --- 1. Correspondences Gallery Visualization ---
try:
disp1_bgr = cv2.cvtColor(view1_disp, cv2.COLOR_RGB2BGR)
disp2_bgr = cv2.cvtColor(view2_disp, cv2.COLOR_RGB2BGR)
data_in_vis = {
'view1': disp1_bgr,
'view2': disp2_bgr,
'pred1': out1,
'pred2': out2
}
match_img_bgr = VisMast3r(data_in_vis, n_viz=50)
match_img_rgb = cv2.cvtColor(match_img_bgr, cv2.COLOR_BGR2RGB)
gallery_out = [match_img_rgb]
except Exception as e:
print(f"Visualization error: {e}")
gallery_out = []
# --- 2. 3D Point Cloud Generation (.glb) ---
try:
gr.Info("Generating 3D point cloud.")
# Extract points in a common reference frame
pts3d_1 = out1['pts3d'].squeeze(0).float().cpu().numpy().reshape(-1, 3)
pts3d_2 = out2['pts3d_in_other_view'].squeeze(0).float().cpu().numpy().reshape(-1, 3)
# Flatten the display images to assign colors to each point
colors_1 = view1_disp.reshape(-1, 3)
colors_2 = view2_disp.reshape(-1, 3)
# Basic filtering to remove invalid background depths (Z <= 0)
valid_1 = np.isfinite(pts3d_1).all(axis=1) & (pts3d_1[:, 2] > 0)
valid_2 = np.isfinite(pts3d_2).all(axis=1) & (pts3d_2[:, 2] > 0)
pts_combined = np.vstack((pts3d_1[valid_1], pts3d_2[valid_2]))
colors_combined = np.vstack((colors_1[valid_1], colors_2[valid_2]))
gr.Info("Downsampling point cloud for web rendering.")
max_points = 50000 # Safe limit for smooth web rendering
if len(pts_combined) > max_points:
# Randomly select indices without replacement
indices = np.random.choice(len(pts_combined), max_points, replace=False)
pts_combined = pts_combined[indices]
colors_combined = colors_combined[indices]
# -----------------------------------------------------------
# Build Trimesh point cloud
pc = trimesh.PointCloud(pts_combined, colors=colors_combined)
scene = trimesh.Scene([pc])
transform = np.array([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]
])
scene.apply_transform(transform)
tmp_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
scene.export(tmp_file.name)
glb_path = tmp_file.name
except Exception as e:
print(f"Point cloud export error: {e}")
glb_path = None
return glb_path, gallery_out
# ---------------------------------------------------------------------------
# Gradio Router & UI Callbacks
# ---------------------------------------------------------------------------
def process_inference(task, view1_file, view2_file):
gr.Info("Loading model weights... This may take a few minutes if downloading for the first time.")
model = get_model(task)
if view1_file is None:
raise gr.Error("Please upload View 1.")
# Automatically parse the input into a tensor and an RGB display image
view1_tensor, view1_disp = get_input_data(view1_file)
depth_out = gr.update(visible=False)
seg_out = gr.update(visible=False)
pc_out = gr.update(visible=False)
match_out = gr.update(visible=False)
if task == "Depth":
res = run_depth(model, view1_tensor, view1_disp)
depth_out = gr.update(value=res, visible=True)
elif task == "Segmentation":
res = run_segmentation(model, view1_tensor, view1_disp)
seg_out = gr.update(value=res, visible=True)
elif task == "Point Cloud & Correspondences":
if view2_file is None:
raise gr.Error("Matching requires View 2.")
view2_tensor, view2_disp = get_input_data(view2_file)
glb_path, matches = run_matching(model, view1_tensor, view1_disp, view2_tensor, view2_disp)
pc_out = gr.update(value=glb_path, visible=True)
match_out = gr.update(value=matches, visible=True)
return depth_out, seg_out, pc_out, match_out
def update_ui_visibility(task):
show_v2 = (task == "Point Cloud & Correspondences")
return (
gr.update(visible=show_v2),
gr.update(visible=(task == "Depth")),
gr.update(visible=(task == "Segmentation")),
gr.update(visible=(task == "Point Cloud & Correspondences")),
gr.update(visible=(task == "Point Cloud & Correspondences"))
)
# ---------------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------------
EXAMPLE_DATA = [
{
"task": "Depth",
"v1": "examples/sample_depth.npz",
"v2": None
},
{
"task": "Segmentation",
"v1": "examples/sample_seg.npz",
"v2": None
},
{
"task": "Point Cloud & Correspondences",
"v1": "examples/matching_v1.jpg",
"v2": "examples/matching_v2.npz"
}
]
css = """
.gradio-container {
max-width: 1200px !important;
margin: 0 auto !important;
}
.match-gallery button, .match-gallery .image-container {
aspect-ratio: 2 / 1 !important;
}
button.primary {background-color: #0E2841 !important; border-color: #0E2841 !important;}
.loading-note {font-size: 0.85em; color: #666; text-align: center; margin-top: 5px;}
.logo-container {
display: flex;
justify-content: center;
align-items: center;
margin-bottom: -15px;
}
.logo-img img {
max-height: 120px !important;
width: auto !important;
object-fit: contain !important;
background-color: transparent !important;
}
.example-gallery button, .example-gallery .image-container,
.match-gallery button, .match-gallery .image-container {
aspect-ratio: 2 / 1 !important;
border-radius: 8px !important;
transition: transform 0.2s ease;
cursor: pointer;
}
.example-gallery button:hover, .example-gallery .image-container:hover {
aspect-ratio: 2 / 1 !important;
transform: scale(1.02);
border: 2px solid #0E2841 !important;
}
.example-gallery, .example-gallery * {
scrollbar-width: none !important; /* For Firefox */
-ms-overflow-style: none !important; /* For IE and Edge */
}
.example-gallery ::-webkit-scrollbar {
display: none !important; /* For Chrome, Safari, Opera */
}
.example-gallery .grid-container {
overflow-x: hidden !important; /* Prevents container scrolling */
}
"""
with gr.Blocks(title="REALM Inference Demo") as demo:
with gr.Row(elem_classes="logo-container"):
gr.Image(
value="logo.png",
show_label=False,
interactive=False,
container=False,
elem_classes="logo-img"
)
gr.Markdown("<h1 style='text-align: center; color: #FFFFF; margin-top: 0;'>An RGB and Event Aligned Latent Manifold for Cross-Modal Perception</h1>")
gr.Markdown(
"""
<p style='text-align: center; font-size: 1.1em; margin-top: 0;'>
This interactive demo showcases the capabilities of the REALM model for depth estimation, semantic segmentation, and 3D correspondence matching using RGB and event data:
<a href='https://papers.starslab.ca/realm' target='_blank'>
📄 papers.starslab.ca/realm
</a>
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
task_dropdown = gr.Dropdown(
choices=["Depth", "Segmentation", "Point Cloud & Correspondences"],
value="Depth",
label="Select Inference Task",
interactive=True
)
with gr.Row():
view1_input = gr.File(label="View 1 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"])
view2_input = gr.File(label="View 2 (Image, .npz, or .npy)", file_types=["image", ".npz", ".npy"], visible=False)
run_btn = gr.Button("Run Inference", variant="primary")
gr.Markdown("*Note: The depth estimation and semantic segmentation models are specifically evaluated on street-driving scenarios from the MVSEC and DSEC environments. In contrast, the dense feature matching and 3D correspondence pipeline is fully generalizable and works across any arbitrary scene.*", elem_classes="loading-note")
gr.Markdown("*Note: The interface may pause during the first run of each task while heavy model weights are downloaded from Hugging Face.*", elem_classes="loading-note")
with gr.Column(scale=1):
depth_output = gr.Image(label="Predicted Depth", visible=True)
seg_output = gr.Image(label="Segmentation Mask", visible=False)
pc_output = gr.Model3D(label="3D Point Cloud", visible=False)
match_gallery = gr.Gallery(label="Correspondences", columns=1, visible=False)
gr.Markdown("### Try it out with Sample Data")
example_gallery = gr.Gallery(
value=[
# Format: (Thumbnail Image Path, Display Caption)
("examples/thumbnail_depth.jpg", "Depth (Events)"),
("examples/thumbnail_seg.jpg", "Segmentation (Events)"),
("examples/thumbnail_match.jpg", "Matching (Events + RGB)")
],
allow_preview=False, # Prevents opening the image in fullscreen
columns=3,
show_label=False,
elem_classes="example-gallery"
)
# 1. Event Handler to load the files based on what was clicked
def load_example_data(evt: gr.SelectData):
idx = evt.index
data = EXAMPLE_DATA[idx]
return data["task"], data["v1"], data["v2"]
# 2. Chain the events together
example_gallery.select(
fn=load_example_data,
inputs=None,
outputs=[task_dropdown, view1_input, view2_input]
).then(
# Update the UI visibility (hide/show View 2)
fn=update_ui_visibility,
inputs=[task_dropdown],
outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery]
).then(
# Run the model
fn=process_inference,
inputs=[task_dropdown, view1_input, view2_input],
outputs=[depth_output, seg_output, pc_output, match_gallery]
)
# --- Standard Event Listeners (Keep these exactly as they are) ---
task_dropdown.change(
fn=update_ui_visibility,
inputs=[task_dropdown],
outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery]
)
# --- Event Listeners ---
task_dropdown.change(
fn=update_ui_visibility,
inputs=[task_dropdown],
outputs=[view2_input, depth_output, seg_output, pc_output, match_gallery]
)
run_btn.click(
fn=process_inference,
inputs=[task_dropdown, view1_input, view2_input],
outputs=[depth_output, seg_output, pc_output, match_gallery]
)
# --- Credits & Dataset Acknowledgements Footer ---
gr.HTML("""
<div style="text-align: center; margin-top: 40px; padding: 20px; border-top: 1px solid #e5e7eb; color: #6b7280; font-size: 1.0em;">
<p>This interactive demonstration utilizes the <strong><a href="https://github.com/naver/dune">DUNE</a></strong> encoder for image processing and the <strong><a href="https://github.com/naver/mast3r">MAST3R</a></strong> decoder for dense feature matching and 3D correspondence.</p>
<p>The sample data provided in the examples are derived from the following open-source datasets:</p>
<div style="display: flex; justify-content: center; gap: 20px; margin-top: 10px;">
<a href="https://star-datasets.github.io/vector/" target="_blank" style="color: #FFFFF; text-decoration: none; font-weight: 500;">VECtor Benchmark</a> |
<a href="https://daniilidis-group.github.io/mvsec/" target="_blank" style="color: #FFFFF; text-decoration: none; font-weight: 500;">MVSEC</a> |
<a href="https://github.com/uzh-rpg/DSEC" target="_blank" style="color: #FFFFF; text-decoration: none; font-weight: 500;">DSEC</a>
</div>
</div>
""")
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
css=css,
ssr_mode=False)