import contextlib import colorsys import html import json import os import shutil import tempfile import threading import time import zipfile from pathlib import Path from typing import Any import cv2 import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image, ImageDraw from scipy.spatial.transform import Rotation try: import spaces except ImportError: class _SpacesShim: @staticmethod def GPU(*decorator_args, **decorator_kwargs): if decorator_args and callable(decorator_args[0]) and len(decorator_args) == 1 and not decorator_kwargs: return decorator_args[0] def _wrap(func): return func return _wrap spaces = _SpacesShim() from lingbot_map.models.gct_stream import GCTStream from lingbot_map.utils.geometry import closed_form_inverse_se3_general from lingbot_map.utils.load_fn import load_and_preprocess_images from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri from lingbot_map.vis.glb_export import predictions_to_glb ROOT = Path(__file__).resolve().parent OUTPUT_ROOT = ROOT / "app_output" OUTPUT_ROOT.mkdir(exist_ok=True) HF_MODEL_REPO = "robbyant/lingbot-map" MODEL_FILENAME = "lingbot-map.pt" IMAGE_SIZE = 518 PATCH_SIZE = 14 DEFAULT_FPS = 8 DEFAULT_MAX_FRAMES = 24 MAX_FRAMES_HARD_LIMIT = 24 DEFAULT_SCALE_FRAMES = 4 DEFAULT_KEYFRAME_INTERVAL = 2 DEFAULT_CONF_PERCENTILE = 50.0 DEFAULT_CAMERA_ITERATIONS = 1 MAX_VISER_POINTS = 25_000 IS_SPACE_RUNTIME = bool(os.getenv("SPACE_ID")) SKIP_EAGER_MODEL_LOAD = os.getenv("LINGBOT_SPACE_SKIP_MODEL_LOAD") == "1" MODEL_CACHE: dict[str, Any] = {} MODEL_CACHE_LOCK = threading.Lock() STARTUP_NOTES: list[str] = [] def _resolve_path(file_obj: Any) -> str: if file_obj is None: return "" if isinstance(file_obj, str): return file_obj return getattr(file_obj, "name", "") def _cleanup_old_runs(keep_last: int = 8) -> None: run_dirs = sorted([p for p in OUTPUT_ROOT.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime) for stale_dir in run_dirs[:-keep_last]: shutil.rmtree(stale_dir, ignore_errors=True) def _pick_runtime_device() -> torch.device: try: torch.empty(1, device="cuda") return torch.device("cuda") except Exception: return torch.device("cpu") def _load_model_bundle() -> dict[str, Any]: with MODEL_CACHE_LOCK: cached = MODEL_CACHE.get("default") if cached is not None: return cached device = _pick_runtime_device() weight_name = MODEL_FILENAME weight_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=weight_name) model = GCTStream( img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, enable_3d_rope=True, max_frame_num=1024, kv_cache_sliding_window=64, kv_cache_scale_frames=8, kv_cache_cross_frame_special=True, kv_cache_include_scale_frames=True, use_sdpa=True, camera_num_iterations=DEFAULT_CAMERA_ITERATIONS, ) checkpoint = torch.load(weight_path, map_location="cpu", weights_only=False) state_dict = checkpoint.get("model", checkpoint) missing, unexpected = model.load_state_dict(state_dict, strict=False) model = model.to(device).eval() inference_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 if device.type == "cuda" and getattr(model, "aggregator", None) is not None: model.aggregator = model.aggregator.to(dtype=inference_dtype) bundle = { "model": model, "device": device, "dtype": inference_dtype, "weight_name": weight_name, "weight_path": str(weight_path), "missing_keys": len(missing), "unexpected_keys": len(unexpected), } MODEL_CACHE["default"] = bundle return bundle def _eager_load_default_model() -> None: if not IS_SPACE_RUNTIME or SKIP_EAGER_MODEL_LOAD: return try: bundle = _load_model_bundle() STARTUP_NOTES.append( f"Startup preload complete on `{bundle['device']}` with `{bundle['weight_name']}`." ) except Exception as exc: STARTUP_NOTES.append(f"Startup preload failed: {exc}") def _extract_video_frames(video_file: str, frames_dir: Path, fps: int, max_frames: int) -> tuple[list[str], dict[str, Any]]: cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise ValueError("Could not open the uploaded video.") source_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) interval = max(1, round(source_fps / max(fps, 1))) saved_paths = [] frame_idx = 0 while len(saved_paths) < max_frames: ok, frame = cap.read() if not ok: break if frame_idx % interval == 0: output_path = frames_dir / f"{len(saved_paths):06d}.jpg" cv2.imwrite(str(output_path), frame) saved_paths.append(str(output_path)) frame_idx += 1 cap.release() return saved_paths, { "source_fps": round(source_fps, 2), "sample_interval": interval, "original_frame_count": total_frames, } def _prepare_inputs(video_file: Any, fps: int, max_frames: int) -> tuple[torch.Tensor, list[str], Path, dict[str, Any]]: _cleanup_old_runs() work_dir = Path(tempfile.mkdtemp(prefix="lingbot-map-", dir=OUTPUT_ROOT)) input_dir = work_dir / "inputs" input_dir.mkdir(parents=True, exist_ok=True) input_summary = {"input_mode": "video"} video_path = _resolve_path(video_file) if not video_path: raise ValueError("Upload one short video.") image_paths, video_summary = _extract_video_frames(video_path, input_dir, fps=fps, max_frames=max_frames) input_summary.update(video_summary) if len(image_paths) < 2: raise ValueError("Provide at least 2 frames. The Space is tuned for short multi-frame reconstructions.") images = load_and_preprocess_images( image_paths, mode="crop", image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, ) return images, image_paths, work_dir, input_summary def _squeeze_single_batch(key: str, value: torch.Tensor) -> torch.Tensor: batched_dims = { "pose_enc": 3, "depth": 5, "depth_conf": 4, "world_points": 5, "world_points_conf": 4, "extrinsic": 4, "intrinsic": 4, "images": 5, } expected_ndim = batched_dims.get(key) if expected_ndim is None or value.ndim != expected_ndim or value.shape[0] != 1: return value return value[0] def _postprocess_predictions(predictions: dict[str, Any], images: torch.Tensor) -> tuple[dict[str, Any], torch.Tensor]: extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype) extrinsic_4x4[..., :3, :4] = extrinsic extrinsic_4x4[..., 3, 3] = 1.0 extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4) predictions["extrinsic"] = extrinsic_4x4[..., :3, :4] predictions["intrinsic"] = intrinsic predictions.pop("pose_enc_list", None) predictions.pop("images", None) for key, value in list(predictions.items()): if isinstance(value, torch.Tensor): predictions[key] = _squeeze_single_batch(key, value.detach().to("cpu")) images_cpu = images.detach().to("cpu") if torch.cuda.is_available(): torch.cuda.synchronize() return predictions, images_cpu def _prepare_for_visualization(predictions: dict[str, Any], images: torch.Tensor) -> dict[str, Any]: vis_predictions = {} for key, value in predictions.items(): if isinstance(value, torch.Tensor): vis_predictions[key] = _squeeze_single_batch(key, value).detach().cpu().numpy() else: vis_predictions[key] = value vis_predictions["images"] = _squeeze_single_batch("images", images).detach().cpu().numpy() return vis_predictions def _estimate_gpu_duration(images: torch.Tensor, num_scale_frames: int, keyframe_interval: int) -> int: frame_count = int(getattr(images, "shape", [DEFAULT_MAX_FRAMES])[0]) del num_scale_frames, keyframe_interval return min(180, max(60, 24 + frame_count * 4)) @spaces.GPU(duration=_estimate_gpu_duration) def _run_inference(images: torch.Tensor, num_scale_frames: int, keyframe_interval: int) -> tuple[dict[str, Any], torch.Tensor, dict[str, Any]]: bundle = _load_model_bundle() model = bundle["model"] device = bundle["device"] dtype = bundle["dtype"] if device.type == "cuda": torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() images = images.to(device) output_device = torch.device("cpu") autocast_context = ( torch.amp.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() ) started_at = time.time() with torch.no_grad(): with autocast_context: predictions = model.inference_streaming( images, num_scale_frames=num_scale_frames, keyframe_interval=keyframe_interval, output_device=output_device, ) inference_seconds = time.time() - started_at images_for_post = predictions["images"] del images if device.type == "cuda": torch.cuda.empty_cache() predictions, images_cpu = _postprocess_predictions(predictions, images_for_post) return predictions, images_cpu, { "runtime_seconds": round(inference_seconds, 2), "device": str(device), "dtype": str(dtype), "weight_name": bundle["weight_name"], "weight_path": bundle["weight_path"], "missing_keys": bundle["missing_keys"], "unexpected_keys": bundle["unexpected_keys"], "peak_memory_gb": round(torch.cuda.max_memory_allocated() / 1e9, 2) if device.type == "cuda" else None, } def _make_preview_strip(images: torch.Tensor, output_path: Path) -> str: frames = _squeeze_single_batch("images", images.detach().cpu()) count = frames.shape[0] indices = sorted({int(round(i)) for i in np.linspace(0, count - 1, num=min(4, count))}) tiles = [] for idx in indices: rgb = (frames[idx].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) tile = Image.fromarray(rgb).resize((320, 220)) tiles.append(tile) banner = Image.new("RGB", (320 * len(tiles), 260), color=(245, 240, 228)) draw = ImageDraw.Draw(banner) draw.text((18, 14), f"LingBot-Map preview | {count} frames", fill=(31, 41, 55)) draw.text((18, 38), "ZeroGPU demo export", fill=(87, 96, 110)) x_offset = 0 for tile in tiles: banner.paste(tile, (x_offset, 72)) x_offset += tile.width banner.save(output_path) return str(output_path) def _save_predictions_npz(predictions: dict[str, Any], output_path: Path) -> str: arrays = {} for key, value in predictions.items(): if isinstance(value, torch.Tensor): arrays[key] = value.detach().cpu().numpy() np.savez_compressed(output_path, **arrays) return str(output_path) def _empty_viser_preview(message: str) -> str: return ( "
" f"
{html.escape(message)}
" "
" ) def _count_confident_points(vis_predictions: dict[str, Any], conf_percentile: float) -> tuple[int, float]: conf = vis_predictions.get("world_points_conf") if conf is None: return 0, 0.0 conf_flat = conf.reshape(-1) threshold = np.percentile(conf_flat, conf_percentile) if conf_percentile > 0 else 0.0 kept = int(((conf_flat >= threshold) & (conf_flat > 1e-5)).sum()) return kept, float(threshold) def _prepare_viser_point_cloud( vis_predictions: dict[str, Any], conf_percentile: float, max_points: int = MAX_VISER_POINTS, ) -> tuple[np.ndarray, np.ndarray, float]: world_points = vis_predictions.get("world_points") conf = vis_predictions.get("world_points_conf") if world_points is None: world_points = vis_predictions.get("world_points_from_depth") conf = vis_predictions.get("depth_conf") if world_points is None: raise ValueError("Missing world point predictions.") images = vis_predictions["images"] if images.ndim == 4 and images.shape[1] == 3: images = np.transpose(images, (0, 2, 3, 1)) points = np.asarray(world_points).reshape(-1, 3) colors = (np.asarray(images).reshape(-1, 3) * 255).clip(0, 255).astype(np.uint8) if conf is None: conf_flat = np.ones(points.shape[0], dtype=np.float32) threshold = 0.0 else: conf_flat = np.asarray(conf).reshape(-1) threshold = np.percentile(conf_flat, conf_percentile) if conf_percentile > 0 else 0.0 mask = (conf_flat >= threshold) & (conf_flat > 1e-5) points = points[mask] colors = colors[mask] if points.shape[0] == 0: return points.astype(np.float32), colors, float(threshold) if points.shape[0] > max_points: keep_indices = np.linspace(0, points.shape[0] - 1, num=max_points, dtype=np.int64) points = points[keep_indices] colors = colors[keep_indices] return points.astype(np.float32), colors, float(threshold) def _add_viser_cameras( server: Any, vis_predictions: dict[str, Any], scene_extent: float, ) -> list[np.ndarray]: extrinsics = vis_predictions.get("extrinsic") intrinsics = vis_predictions.get("intrinsic") images = vis_predictions.get("images") if extrinsics is None or intrinsics is None or images is None: return [] extrinsics = np.asarray(extrinsics) intrinsics = np.asarray(intrinsics) images = np.asarray(images) if images.ndim == 4 and images.shape[1] == 3: _, _, image_height, image_width = images.shape else: _, image_height, image_width, _ = images.shape camera_positions: list[np.ndarray] = [] frustum_scale = max(scene_extent * 0.05, 0.05) for idx, world_to_camera_3x4 in enumerate(extrinsics): world_to_camera = np.eye(4, dtype=np.float32) world_to_camera[:3, :4] = world_to_camera_3x4 camera_to_world = np.linalg.inv(world_to_camera) camera_positions.append(camera_to_world[:3, 3].copy()) intrinsic = intrinsics[idx] fy = float(max(intrinsic[1, 1], 1e-6)) fov = float(np.clip(2 * np.arctan2(image_height / 2.0, fy), 0.1, np.pi - 0.1)) aspect = float(max(image_width / max(image_height, 1), 1e-3)) quat_xyzw = Rotation.from_matrix(camera_to_world[:3, :3]).as_quat() wxyz = ( float(quat_xyzw[3]), float(quat_xyzw[0]), float(quat_xyzw[1]), float(quat_xyzw[2]), ) color = tuple( int(channel * 255) for channel in colorsys.hsv_to_rgb(idx / max(len(extrinsics), 1), 0.65, 1.0) ) server.scene.add_camera_frustum( f"/cameras/camera_{idx:02d}", fov=fov, aspect=aspect, scale=frustum_scale, color=color, wxyz=wxyz, position=tuple(float(x) for x in camera_to_world[:3, 3]), variant="wireframe", ) return camera_positions def _build_viser_preview( vis_predictions: dict[str, Any], output_path: Path, conf_percentile: float, ) -> tuple[str, str | None, int]: try: import viser except ModuleNotFoundError: return ( _empty_viser_preview("Static Viser preview is unavailable because `viser` is not installed."), None, 0, ) server = None try: points, colors, _ = _prepare_viser_point_cloud(vis_predictions, conf_percentile) if points.shape[0] == 0: return _empty_viser_preview("No confident points were available for the static Viser preview."), None, 0 server = viser.ViserServer(port=0, verbose=False) server.scene.set_up_direction("+z") if hasattr(server.scene, "world_axes"): server.scene.world_axes.visible = False lower = np.percentile(points, 5, axis=0) upper = np.percentile(points, 95, axis=0) scene_extent = float(np.linalg.norm(upper - lower)) scene_extent = max(scene_extent, 1e-3) scene_center = points.mean(axis=0) server.scene.add_point_cloud( "/reconstruction", points=points, colors=colors, point_size=max(scene_extent * 0.0025, 0.003), ) camera_positions = _add_viser_cameras(server, vis_predictions, scene_extent) if camera_positions: camera_center = np.mean(np.asarray(camera_positions), axis=0) scene_center = (scene_center + camera_center) / 2.0 server.initial_camera.look_at = tuple(float(x) for x in scene_center) server.initial_camera.position = tuple( float(x) for x in scene_center + np.array([scene_extent, scene_extent, max(scene_extent * 0.65, 0.25)]) ) server.initial_camera.up = (0.0, 0.0, 1.0) html_doc = server.scene.as_html(dark_mode=True) output_path.write_text(html_doc, encoding="utf-8") iframe_html = ( "" ) return iframe_html, str(output_path), int(points.shape[0]) except Exception as exc: return ( _empty_viser_preview(f"Static Viser preview could not be created for this run: {exc}"), None, 0, ) finally: if server is not None and hasattr(server, "stop"): with contextlib.suppress(Exception): server.stop() def _zip_outputs(work_dir: Path, paths: list[Path], output_name: str) -> str: zip_path = work_dir / output_name with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zip_file: for path in paths: if path.exists(): zip_file.write(path, arcname=path.name) return str(zip_path) def _export_outputs( work_dir: Path, image_paths: list[str], predictions: dict[str, Any], images_cpu: torch.Tensor, input_summary: dict[str, Any], runtime_summary: dict[str, Any], num_scale_frames: int, keyframe_interval: int, conf_percentile: float, ) -> tuple[str, str, str, dict[str, Any]]: vis_predictions = _prepare_for_visualization(predictions, images_cpu) glb_path = work_dir / "lingbot-map-reconstruction.glb" scene = predictions_to_glb( vis_predictions, conf_thres=conf_percentile, show_cam=True, target_dir=str(work_dir), mask_sky=False, ) scene.export(glb_path) viser_preview_html, viser_preview_path, viser_points = _build_viser_preview( vis_predictions, work_dir / "viser-preview.html", conf_percentile=conf_percentile, ) preview_path = Path(_make_preview_strip(images_cpu, work_dir / "preview.png")) npz_path = Path(_save_predictions_npz(predictions, work_dir / "predictions.npz")) points_kept, conf_threshold = _count_confident_points(vis_predictions, conf_percentile) summary = { "model_variant": "Default", "model_filename": MODEL_FILENAME, "frames_used": len(image_paths), "num_scale_frames": num_scale_frames, "keyframe_interval": keyframe_interval, "confidence_percentile": conf_percentile, "confidence_threshold": round(conf_threshold, 4), "points_kept_for_glb": points_kept, "points_used_for_viser_preview": viser_points, "input_summary": input_summary, "runtime_summary": runtime_summary, } summary_path = work_dir / "summary.json" summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") artifact_path = _zip_outputs( work_dir, [glb_path, preview_path, npz_path, summary_path, Path(viser_preview_path) if viser_preview_path else work_dir / "__missing__"], output_name="lingbot-map-results.zip", ) return str(glb_path), viser_preview_html, artifact_path, summary def _format_status(summary: dict[str, Any]) -> str: runtime = summary["runtime_summary"] input_summary = summary["input_summary"] lines = [ "## Run Complete", f"- Model: `{summary['model_filename']}`", f"- Frames used: `{summary['frames_used']}`", f"- Input mode: `{input_summary['input_mode']}`", f"- Runtime: `{runtime['runtime_seconds']}s` on `{runtime['device']}`", f"- GLB confidence percentile: `{summary['confidence_percentile']}`", f"- Points kept for GLB: `{summary['points_kept_for_glb']}`", f"- Points used for static Viser preview: `{summary['points_used_for_viser_preview']}`", ] if runtime.get("peak_memory_gb") is not None: lines.append(f"- Peak GPU memory: `{runtime['peak_memory_gb']} GB`") if input_summary.get("sample_interval"): lines.append(f"- Video sample interval: `every {input_summary['sample_interval']} frame(s)`") return "\n".join(lines) def reconstruct_scene( video_file: Any, fps: int, max_frames: int, num_scale_frames: int, keyframe_interval: int, conf_percentile: float, ): max_frames = max(2, min(int(max_frames), MAX_FRAMES_HARD_LIMIT)) num_scale_frames = max(1, int(num_scale_frames)) keyframe_interval = max(1, int(keyframe_interval)) conf_percentile = float(conf_percentile) images, image_paths, work_dir, input_summary = _prepare_inputs(video_file=video_file, fps=int(fps), max_frames=max_frames) num_scale_frames = min(num_scale_frames, int(images.shape[0])) predictions, images_cpu, runtime_summary = _run_inference( images, num_scale_frames=num_scale_frames, keyframe_interval=keyframe_interval, ) glb_path, viser_preview_html, artifact_path, summary = _export_outputs( work_dir=work_dir, image_paths=image_paths, predictions=predictions, images_cpu=images_cpu, input_summary=input_summary, runtime_summary=runtime_summary, num_scale_frames=num_scale_frames, keyframe_interval=keyframe_interval, conf_percentile=conf_percentile, ) preview_path = str(work_dir / "preview.png") status = _format_status(summary) return viser_preview_html, glb_path, preview_path, artifact_path, summary, status def _build_startup_markdown() -> str: if not STARTUP_NOTES: return "" return "\n".join([f"- {note}" for note in STARTUP_NOTES]) css = """ #container { max-width: 1200px; margin: 0 auto; } .teaser img { max-height: 260px !important; object-fit: cover !important; border-radius: 8px !important; } .viser-frame { width: 100%; height: 380px; border: 1px solid #d7dce5; border-radius: 12px; background: #0f1720; } .viser-empty { min-height: 380px; border: 1px dashed #c9d1dd; border-radius: 12px; display: flex; align-items: center; justify-content: center; padding: 24px; text-align: center; background: linear-gradient(180deg, #f8fafc 0%, #eef2f7 100%); color: #334155; } footer {display: none !important;} """ _eager_load_default_model() with gr.Blocks(title="LingBot 3D") as demo: with gr.Column(elem_id="container"): gr.Markdown("# LingBot 3D") gr.Markdown( "Upload a short video clip and get back a navigable 3D scene. " "Powered by the LingBot-Map checkpoint, exported as a GLB plus a downloadable results bundle." ) with gr.Row(): with gr.Column(): video_file = gr.Video( label="Input video", sources=["upload"], format="mp4", height=380, ) with gr.Column(): gr.Markdown("### Static Viser Preview") viser_preview = gr.HTML( value=_empty_viser_preview("Run a reconstruction to load the static Viser preview."), ) with gr.Accordion("Fallback GLB preview", open=False): model_preview = gr.Model3D( label="GLB preview", display_mode="point_cloud", clear_color=[1.0, 1.0, 1.0, 1.0], height=380, ) run_button = gr.Button("Build 3D Scene", variant="primary") status_markdown = gr.Markdown() with gr.Accordion("Sampling & reconstruction settings", open=False): with gr.Row(): fps = gr.Slider(minimum=1, maximum=12, step=1, value=DEFAULT_FPS, label="Sampling FPS") max_frames = gr.Slider(minimum=2, maximum=MAX_FRAMES_HARD_LIMIT, step=1, value=DEFAULT_MAX_FRAMES, label="Max frames") with gr.Row(): num_scale_frames = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_SCALE_FRAMES, label="Scale frames") keyframe_interval = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_KEYFRAME_INTERVAL, label="Keyframe interval") conf_percentile = gr.Slider( minimum=0, maximum=90, step=5, value=DEFAULT_CONF_PERCENTILE, label="GLB confidence percentile", info="Higher = fewer, more confident points", ) with gr.Row(): preview_image = gr.Image(label="Frame preview", interactive=False, height=200) artifact_file = gr.File(label="Download results bundle") summary_json = gr.JSON(visible=False) run_button.click( fn=reconstruct_scene, inputs=[ video_file, fps, max_frames, num_scale_frames, keyframe_interval, conf_percentile, ], outputs=[ viser_preview, model_preview, preview_image, artifact_file, summary_json, status_markdown, ], show_progress="full", ) demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch(css=css)