"""Gradio Web Demo for WildDet3D (5-mode). Supports 5 prompt modes: - Text: Enter text like "car.person.traffic light" (one-to-many) - Visual: Click box on image, text="visual" (one-to-many) - Visual+Label: Click box + category label (one-to-many) - Geometry: Click box on image, text="geometric" (one-to-one) - Geometry+Label: Click box + category label (one-to-one) - Point: Click on image to select point Requirements: pip install gradio>=5.0.0 Usage: python demo/huggingface/app.py Then open http://localhost:7860 in browser. """ import os import sys from pathlib import Path # Add paths: support both local dev and HuggingFace Space. # Local dev: demo/huggingface/app.py -> repo root = ../../ # HF Space: wilddet3d/ is bundled in the same directory as app.py _this_dir = Path(__file__).resolve().parent if (_this_dir / "wilddet3d").exists(): # HuggingFace Space: everything bundled next to app.py sys.path.insert(0, str(_this_dir)) else: # Local dev: repo root is two levels up repo_root = _this_dir.parent.parent sys.path.insert(0, str(repo_root)) import spaces import gradio as gr import numpy as np import torch import cv2 from PIL import Image import json from wilddet3d.inference import build_model, WildDet3DPredictor from wilddet3d.preprocessing import preprocess from wilddet3d.vis.visualize import draw_3d_boxes def cross_category_nms( boxes2d, boxes3d, scores, scores_2d, scores_3d, class_ids, iou_threshold=0.8, ): """Cross-category NMS: suppress overlapping boxes across categories. For boxes with 2D IoU > threshold, keep the one with higher combined score regardless of category. Args: boxes2d: (N, 4) tensor, pixel xyxy. boxes3d: (N, 10) tensor. scores: (N,) combined scores (for ranking). scores_2d: (N,) 2D scores. scores_3d: (N,) 3D scores. class_ids: (N,) class indices. iou_threshold: IoU threshold for suppression. Returns: Filtered tensors (boxes2d, boxes3d, scores, scores_2d, scores_3d, class_ids). """ if len(boxes2d) <= 1: return boxes2d, boxes3d, scores, scores_2d, scores_3d, class_ids # Sort by combined score descending order = scores.argsort(descending=True) boxes2d = boxes2d[order] boxes3d = boxes3d[order] scores = scores[order] scores_2d = scores_2d[order] scores_3d = scores_3d[order] class_ids = class_ids[order] # Compute pairwise IoU x1 = torch.max(boxes2d[:, None, 0], boxes2d[None, :, 0]) y1 = torch.max(boxes2d[:, None, 1], boxes2d[None, :, 1]) x2 = torch.min(boxes2d[:, None, 2], boxes2d[None, :, 2]) y2 = torch.min(boxes2d[:, None, 3], boxes2d[None, :, 3]) inter = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) area = ( (boxes2d[:, 2] - boxes2d[:, 0]) * (boxes2d[:, 3] - boxes2d[:, 1]) ) union = area[:, None] + area[None, :] - inter iou = inter / (union + 1e-6) n = len(boxes2d) suppressed = set() keep = [] for i in range(n): if i in suppressed: continue keep.append(i) for j in range(i + 1, n): if j in suppressed: continue if iou[i, j] >= iou_threshold: suppressed.add(j) keep = torch.tensor(keep, dtype=torch.long, device=boxes2d.device) return ( boxes2d[keep], boxes3d[keep], scores[keep], scores_2d[keep], scores_3d[keep], class_ids[keep], ) # ---- BEV Renderer (inline JS from bev-renderer.js) ---- BEV_RENDERER_JS = r""" var BEV_EDGES=[[0,1],[1,2],[2,3],[3,0],[4,5],[5,6],[6,7],[7,4],[0,4],[1,5],[2,6],[3,7]]; var BEV_FACES=[[0,1,2,3],[4,5,6,7],[0,1,5,4],[2,3,7,6],[0,3,7,4],[1,2,6,5]]; class BEVRenderer{ constructor(canvasId){this.canvas=canvasId?document.getElementById(canvasId):null;this.ctx=this.canvas?this.canvas.getContext('2d'):null;this.bgColor='#f8f8f8';} render(boxes,colors,elevDeg){ if(elevDeg===undefined)elevDeg=35; this._resizeCanvas();var ctx=this.ctx;var w=this.canvas.width;var h=this.canvas.height; ctx.fillStyle=this.bgColor;ctx.fillRect(0,0,w,h); var validBoxes=[];var labels=[];var boxColors=[]; for(var i=0;im)m=d;}return Math.max(m*2,2.0);} _findGroundY(allCorners){var minY=Infinity;for(var i=0;iw-2)px=w-2-pillW;if(py<2)py=2; ctx.globalAlpha=0.8;ctx.fillStyle=color;var r=3;ctx.beginPath();ctx.moveTo(px+r,py);ctx.lineTo(px+pillW-r,py);ctx.arcTo(px+pillW,py,px+pillW,py+r,r);ctx.lineTo(px+pillW,py+pillH-r);ctx.arcTo(px+pillW,py+pillH,px+pillW-r,py+pillH,r);ctx.lineTo(px+r,py+pillH);ctx.arcTo(px,py+pillH,px,py+pillH-r,r);ctx.lineTo(px,py+r);ctx.arcTo(px,py,px+r,py,r);ctx.closePath();ctx.fill(); ctx.globalAlpha=1.0;ctx.fillStyle='#fff';ctx.textAlign='left';ctx.textBaseline='top';ctx.fillText(label,px+padH,py+padV);ctx.restore();} } """ # Color palette for BEV boxes (per category) BEV_COLORS = [ "#e74c3c", "#3b82f6", "#22c55e", "#f59e0b", "#a855f7", "#06b6d4", "#ec4899", "#f97316", ] def make_bev_html(boxes3d_np, class_ids_np, class_names, elev_deg=35): """Generate HTML with embedded BEV canvas renderer. Args: boxes3d_np: (N, 10) numpy array of 3D boxes. class_ids_np: (N,) numpy array of class indices. class_names: list of class name strings. elev_deg: initial elevation angle in degrees. Returns: HTML string with embedded canvas + JS. """ boxes_json = [] colors_json = [] for i in range(len(boxes3d_np)): cid = int(class_ids_np[i]) name = class_names[cid] if cid < len(class_names) else "object" b = boxes3d_np[i] label = name boxes_json.append({ "box3d": b.tolist(), "category": label, }) colors_json.append(BEV_COLORS[cid % len(BEV_COLORS)]) boxes_data = json.dumps(boxes_json) colors_data = json.dumps(colors_json) # Build a self-contained HTML page for the iframe. # Gradio gr.HTML uses innerHTML which does NOT execute """ # Escape for srcdoc attribute (double-quote safe) escaped = inner_html.replace("&", "&").replace('"', """) return ( f'' ) def draw_points_on_image(image, points): """Draw points on image. Args: image: numpy array (H, W, 3) points: list of (x, y, label) tuples Returns: Image with points drawn """ img = image.copy() if img.dtype != np.uint8: img = np.clip(img, 0, 255).astype(np.uint8) h, w = img.shape[:2] radius = max(4, int(min(h, w) * 0.012)) for x, y, label in points: # Positive = green, Negative = red fill = (0, 255, 0) if label == 1 else (255, 0, 0) cv2.circle(img, (int(x), int(y)), radius, fill, -1) cv2.circle(img, (int(x), int(y)), radius, (255, 255, 255), 2) return img def draw_box_on_image(image, box, thickness=3): """Draw box on image. Args: image: numpy array (H, W, 3) box: [x1, y1, x2, y2] coordinates thickness: line thickness Returns: Image with box drawn """ img = image.copy() if img.dtype != np.uint8: img = np.clip(img, 0, 255).astype(np.uint8) x1, y1, x2, y2 = [int(v) for v in box] color = (255, 50, 50) # bright red in RGB cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) return img # HuggingFace Model repo for checkpoints HF_MODEL_REPO = "allenai/WildDet3D" HF_CKPT_NAME = "wilddet3d_alldata_all_prompt_v1.0.pt" # Local checkpoint paths (tried in order) LOCAL_CHECKPOINTS = [ "ckpt/wilddet3d.pt", # release repo layout ] # Default demo image path DEFAULT_IMAGE_PATH = "assets/demo/rgb.png" # Global model (loaded once) _cached_model = None def _resolve_checkpoint(): """Resolve checkpoint: local if exists, else download from HF Hub.""" for path in LOCAL_CHECKPOINTS: if os.path.exists(path): return path from huggingface_hub import hf_hub_download hf_token = os.environ.get("HF_TOKEN") print(f"Downloading checkpoint from {HF_MODEL_REPO}...") ckpt = hf_hub_download( repo_id=HF_MODEL_REPO, filename=HF_CKPT_NAME, token=hf_token ) return ckpt def get_model(): """Load model once and cache it.""" global _cached_model if _cached_model is None: ckpt_path = _resolve_checkpoint() print(f"Loading WildDet3D model from {ckpt_path}...") _cached_model = build_model( checkpoint=ckpt_path, score_threshold=0.0, canonical_rotation=True, skip_pretrained=True, ) print("Model loaded!") return _cached_model def load_default_image(): """Load the default demo image.""" if os.path.exists(DEFAULT_IMAGE_PATH): return np.array(Image.open(DEFAULT_IMAGE_PATH).convert("RGB")) return None def load_default_intrinsics(): """Return placeholder intrinsics values.""" return 0, 0, 0, 0 def format_intrinsics(K): """Format intrinsics tensor for display.""" if K is None: return "Not available" if isinstance(K, torch.Tensor): K = K.cpu().numpy() if K.ndim == 3: K = K[0] return ( f"fx={K[0, 0]:.2f}, fy={K[1, 1]:.2f}, " f"cx={K[0, 2]:.2f}, cy={K[1, 2]:.2f}" ) def scale_intrinsics_to_original(K, input_hw, original_hw): """Scale intrinsics from model input resolution to original.""" if K is None: return None if isinstance(K, torch.Tensor): K = K.clone() else: K = K.copy() input_h, input_w = input_hw orig_h, orig_w = original_hw scale_x = orig_w / input_w scale_y = orig_h / input_h if K.ndim == 3: K[:, 0, 0] *= scale_x K[:, 1, 1] *= scale_y K[:, 0, 2] *= scale_x K[:, 1, 2] *= scale_y else: K[0, 0] *= scale_x K[1, 1] *= scale_y K[0, 2] *= scale_x K[1, 2] *= scale_y return K def transform_coords_to_input_space(x, y, original_hw, input_hw, padding): """Transform coords from original image space to preprocessed input. Args: x, y: Coordinates in original image space original_hw: (H, W) of original image input_hw: (H, W) of preprocessed image (e.g., 1008x1008) padding: (pad_left, pad_right, pad_top, pad_bottom) Returns: (new_x, new_y) in preprocessed input space """ orig_h, orig_w = original_hw pad_left, pad_right, pad_top, pad_bottom = padding content_w = input_hw[1] - pad_left - pad_right content_h = input_hw[0] - pad_top - pad_bottom scale_x = content_w / orig_w scale_y = content_h / orig_h new_x = x * scale_x + pad_left new_y = y * scale_y + pad_top return new_x, new_y def on_image_select( evt: gr.SelectData, image, original_image, state, prompt_mode, point_label, ): """Handle click on image and visualize the click.""" if image is None: return state, "Please upload an image first", None x, y = evt.index[0], evt.index[1] label = 1 if "Positive" in point_label else 0 new_state = { "points": list(state.get("points", [])), "box": list(state.get("box", [])), } vis_image = ( original_image.copy() if original_image is not None else image.copy() ) if prompt_mode == "Point": new_state["points"].append((x, y, label)) new_state["box"] = [] label_str = "+" if label == 1 else "-" info = ( f"Points: {len(new_state['points'])} total. " f"Last: ({x}, {y}) [{label_str}]" ) vis_image = draw_points_on_image(vis_image, new_state["points"]) elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"): new_state["points"] = [] box_clicks = list(new_state.get("box", [])) box_clicks.append((x, y)) if len(box_clicks) == 1: new_state["box"] = box_clicks info = ( f"[{prompt_mode}] Corner 1: ({x}, {y}) " f"- click again for corner 2" ) vis_image = draw_points_on_image(vis_image, [(x, y, 1)]) elif len(box_clicks) >= 2: x1, y1 = box_clicks[0] x2, y2 = box_clicks[1] box = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] new_state["box"] = [(box[0], box[1]), (box[2], box[3])] info = ( f"[{prompt_mode}] Box: " f"({box[0]}, {box[1]}) -> ({box[2]}, {box[3]})" ) vis_image = draw_box_on_image(vis_image, box) else: info = f"Box clicks: {box_clicks}" else: info = "Text mode - just enter text and click Run" return new_state, info, vis_image def clear_clicks(state, original_image): """Reset click state and restore original image.""" new_state = {"points": [], "box": []} return ( new_state, "Cleared - ready for new clicks", original_image.copy() if original_image is not None else None, ) @spaces.GPU def run_wilddet3d( image, original_image, state, prompt_mode, text_prompt, use_label, label_text, score_thres, use_actual_K, fx, fy, cx, cy, ): """Run WildDet3D with selected prompt mode.""" if image is None: return None, "Please upload an image first", None, "" # Use original (clean) image for model input in point/box modes, # since `image` may have drawn points/boxes on it. if original_image is not None and prompt_mode != "Text": image = original_image # Convert RGBA to RGB if needed if image.ndim == 3 and image.shape[2] == 4: image = image[:, :, :3] device = "cuda" if torch.cuda.is_available() else "cpu" detector = get_model() # Build intrinsics matrix if use_actual_K: if fx <= 0 or fy <= 0: return ( None, "Please enter valid intrinsics (fx, fy must be > 0)", None, None, ) intrinsics = np.array([ [fx, 0, cx], [0, fy, cy], [0, 0, 1] ], dtype=np.float32) else: intrinsics = None # preprocess uses default placeholder # Preprocess image data = preprocess(image.astype(np.float32), intrinsics) # Build prompt_text for box/point modes if prompt_mode == "Box-to-Multi-Object": prefix = "visual" elif prompt_mode == "Box-to-Single-Object": prefix = "geometric" else: prefix = "geometric" # Point mode default if prompt_mode != "Text": if use_label and label_text and label_text.strip(): geo_prompt_text = f"{prefix}: {label_text.strip()}" else: geo_prompt_text = prefix # Initialize prompt info for visualization prompt_points = None prompt_box = None # Run based on prompt mode if prompt_mode == "Text": input_texts = [ t.strip() for t in text_prompt.split(".") if t.strip() ] if not input_texts: input_texts = ["object"] results = detector( images=data["images"].to(device), intrinsics=data["intrinsics"].to(device)[None], input_hw=[data["input_hw"]], original_hw=[data["original_hw"]], padding=[data["padding"]], input_texts=input_texts, return_predicted_intrinsics=True, ) ( boxes, boxes3d, scores, scores_2d, scores_3d, class_ids, depth_maps, predicted_K, confidence_maps, ) = results class_id_mapping = {i: t for i, t in enumerate(input_texts)} elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"): box_coords = state.get("box", []) if len(box_coords) < 2: return ( None, "Please click twice on the image to define a box", None, None, ) x1_orig, y1_orig = box_coords[0] x2_orig, y2_orig = box_coords[1] x1, y1 = transform_coords_to_input_space( x1_orig, y1_orig, data["original_hw"], data["input_hw"], data["padding"], ) x2, y2 = transform_coords_to_input_space( x2_orig, y2_orig, data["original_hw"], data["input_hw"], data["padding"], ) box_xyxy = [float(x1), float(y1), float(x2), float(y2)] prompt_box = [x1_orig, y1_orig, x2_orig, y2_orig] results = detector( images=data["images"].to(device), intrinsics=data["intrinsics"].to(device)[None], input_hw=[data["input_hw"]], original_hw=[data["original_hw"]], padding=[data["padding"]], input_boxes=[box_xyxy], prompt_text=geo_prompt_text, return_predicted_intrinsics=True, ) ( boxes, boxes3d, scores, scores_2d, scores_3d, class_ids, depth_maps, predicted_K, confidence_maps, ) = results class_id_mapping = {0: geo_prompt_text} elif prompt_mode == "Point": points = state.get("points", []) if not points: return ( None, "Please click on the image to select a point", None, None, ) transformed_points = [] for x_orig, y_orig, lbl in points: x, y = transform_coords_to_input_space( x_orig, y_orig, data["original_hw"], data["input_hw"], data["padding"], ) transformed_points.append((x, y, lbl)) prompt_points = points results = detector( images=data["images"].to(device), intrinsics=data["intrinsics"].to(device)[None], input_hw=[data["input_hw"]], original_hw=[data["original_hw"]], padding=[data["padding"]], input_points=[transformed_points], prompt_text=geo_prompt_text, return_predicted_intrinsics=True, ) ( boxes, boxes3d, scores, scores_2d, scores_3d, class_ids, depth_maps, predicted_K, confidence_maps, ) = results class_id_mapping = {0: geo_prompt_text} else: return None, f"Unknown prompt mode: {prompt_mode}", None, None # Cross-category NMS (suppress duplicates across categories) if len(boxes[0]) > 1: ( boxes[0], boxes3d[0], scores[0], scores_2d[0], scores_3d[0], class_ids[0], ) = cross_category_nms( boxes[0], boxes3d[0], scores[0], scores_2d[0], scores_3d[0], class_ids[0], iou_threshold=0.8, ) # For one-to-one modes (Point, Box-to-Single-Object), keep only # the highest confidence detection. if prompt_mode in ("Point", "Box-to-Single-Object") and len(boxes[0]) > 1: best = scores[0].argmax() boxes[0] = boxes[0][best:best+1] boxes3d[0] = boxes3d[0][best:best+1] scores[0] = scores[0][best:best+1] scores_2d[0] = scores_2d[0][best:best+1] scores_3d[0] = scores_3d[0][best:best+1] class_ids[0] = class_ids[0][best:best+1] # Scale predicted intrinsics to original resolution predicted_K_scaled = scale_intrinsics_to_original( predicted_K, input_hw=data["input_hw"], original_hw=data["original_hw"], ) # Format intrinsics info orig_h, orig_w = data["original_hw"] intrinsics_info = f"Image: {orig_w}x{orig_h}\n" if use_actual_K: intrinsics_info += ( f"Intrinsics: fx={fx:.2f}, fy={fy:.2f}, " f"cx={cx:.2f}, cy={cy:.2f}" ) else: intrinsics_info += ( f"Intrinsics: default (focal={max(orig_h, orig_w)})" ) # 2D visualization img_2d = visualize_results( data, boxes3d, scores, scores_2d, scores_3d, class_ids, class_id_mapping, score_thres, ) # Depth map visualization (with confidence mask if available) depth_vis_img = None if depth_maps is not None and len(depth_maps) > 0: depth_np_raw = depth_maps[0].cpu().numpy() d = depth_np_raw.squeeze() pad_l, pad_r, pad_t, pad_b = data["padding"] h_end = d.shape[0] - pad_b if pad_b > 0 else d.shape[0] w_end = d.shape[1] - pad_r if pad_r > 0 else d.shape[1] d_crop = d[pad_t:h_end, pad_l:w_end] # Get confidence mask if available conf_mask = None if confidence_maps is not None and len(confidence_maps) > 0: c = confidence_maps[0].cpu().numpy().squeeze() c_crop = c[pad_t:h_end, pad_l:w_end] conf_mask = c_crop > 0.5 d_valid = d_crop[d_crop > 0.01] if len(d_valid) > 0: d_min, d_max = d_valid.min(), d_valid.max() d_norm = np.clip( (d_crop - d_min) / (d_max - d_min + 1e-6), 0, 1 ) d_norm = (1.0 - d_norm) * 255 d_norm = d_norm.astype(np.uint8) depth_vis_img = cv2.applyColorMap(d_norm, cv2.COLORMAP_TURBO) depth_vis_img = cv2.cvtColor(depth_vis_img, cv2.COLOR_BGR2RGB) # Apply confidence mask: low-confidence regions -> gray if conf_mask is not None: gray_bg = np.full_like(depth_vis_img, 200) depth_vis_img = np.where( conf_mask[:, :, None], depth_vis_img, gray_bg ) depth_vis_img = Image.fromarray(depth_vis_img) # BEV visualization bev_html = "" mask_bev = scores_2d[0] >= score_thres if mask_bev.sum() > 0: bev_boxes = boxes3d[0][mask_bev].cpu().numpy() bev_cids = class_ids[0][mask_bev].cpu().numpy() bev_names = [ class_id_mapping.get(i, str(i)) for i in range(max(len(class_id_mapping), 1)) ] bev_html = make_bev_html(bev_boxes, bev_cids, bev_names) return img_2d, intrinsics_info, depth_vis_img, bev_html def visualize_results( data, boxes3d, scores, scores_2d, scores_3d, class_ids, class_id_mapping, score_thres, ): """Visualize 3D detection results using wilddet3d.vis.draw_3d_boxes.""" filtered_boxes3d = [] filtered_scores_2d = [] filtered_scores_3d = [] filtered_class_ids = [] for i in range(len(boxes3d)): mask = scores_2d[i] >= score_thres filtered_boxes3d.append(boxes3d[i][mask]) if scores_2d is not None: filtered_scores_2d.append(scores_2d[i][mask]) else: filtered_scores_2d.append(torch.zeros_like(scores[i][mask])) if scores_3d is not None: filtered_scores_3d.append(scores_3d[i][mask]) else: filtered_scores_3d.append(torch.zeros_like(scores[i][mask])) filtered_class_ids.append(class_ids[i][mask]) # Get original image (clean, no prompt annotations) # original_images is (1, 3, H, W) float32 [0, 255] after ToTensor # -> squeeze batch, permute to (H, W, 3) for visualization original_img = ( data["original_images"] .cpu().squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) ) # Use wilddet3d's draw_3d_boxes for visualization K = data["original_intrinsics"].cpu().numpy() if K.ndim == 3: K = K[0] class_names = [ class_id_mapping.get(i, str(i)) for i in range(max(len(class_id_mapping), 1)) ] # Draw 3D boxes with 2D/3D score labels if len(filtered_boxes3d) > 0 and len(filtered_boxes3d[0]) > 0: pil_img = draw_3d_boxes( image=original_img, boxes3d=filtered_boxes3d[0], intrinsics=K, scores_2d=filtered_scores_2d[0], scores_3d=filtered_scores_3d[0], class_ids=filtered_class_ids[0], class_names=class_names, n_colors=max(len(class_id_mapping), 1), ) else: pil_img = Image.fromarray(original_img) return pil_img # Load default values default_fx, default_fy, default_cx, default_cy = load_default_intrinsics() default_image = load_default_image() # Build Gradio interface with gr.Blocks( title="WildDet3D: 3D Detection", css=""" .column-form { border: none !important; } .gradio-container { max-width: 100% !important; padding: 0 !important; } .contain { max-width: 100% !important; } * { outline: none !important; } .gr-group, .gr-box, .gr-panel, [class*="column"], [class*="Column"] { border-color: transparent !important; box-shadow: none !important; } """, ) as demo: # ---- Terms of Use gate ---- with gr.Column(visible=True) as terms_page: gr.Markdown( "# WildDet3D: Scaling Promptable 3D Detection in the Wild\n" "### Allen Institute for AI (Ai2)" ) gr.Markdown(""" ### WildDet3D Terms of Use By using WildDet3D, you agree: - to Ai2's Terms of Use and Responsible Use Guidelines; - you will not submit or upload personal, sensitive, confidential, or proprietary information to WildDet3D; and - none of your uploaded content or inputs to WildDet3D will violate Hugging Face's Code of Conduct or Content Policy. **If you do not agree with any of these statements, please do not access or use WildDet3D.** """) agree_btn = gr.Button( "Agree & Use WildDet3D", variant="primary", size="lg", ) # ---- Main app (hidden until agreed) ---- with gr.Column(visible=False) as main_app: gr.Markdown( "# WildDet3D: Scaling Promptable 3D Detection in the Wild\n" "### Allen Institute for AI (Ai2)" ) gr.Markdown(""" **How to use:** - **Text**: Enter object names (e.g., "car.person.traffic light"), click Run - **Box-to-Multi-Object**: Draw box -> detect ALL similar objects (one-to-many) - **Box-to-Single-Object**: Draw box -> detect ONLY the boxed object (one-to-one) - **Point**: Click on object, click Run """) # State for click coordinates and original image click_state = gr.State({"points": [], "box": []}) original_image_state = gr.State( default_image.copy() if default_image is not None else None ) with gr.Row(): # Left column: Input with gr.Column(scale=1): input_image = gr.Image( label="Input Image (click for Box/Point mode)", type="numpy", value=default_image, interactive=True, sources=["upload", "clipboard"], ) # Prompt settings prompt_mode = gr.Radio( choices=[ "Text", "Box-to-Multi-Object", "Box-to-Single-Object", "Point", ], value="Box-to-Single-Object", label="Prompt Mode", ) text_prompt = gr.Textbox( label="Text Prompt (separate categories with '.')", value="person.chair.monitor.pen", placeholder="e.g. chair.table.monitor", visible=False, ) # Box mode caption box_caption = gr.Markdown( "Click the top-left corner, then the bottom-right corner to draw a box.", visible=True, ) # Point mode caption point_caption = gr.Markdown( "Click on the image to add points. Use positive to include, negative to exclude.", visible=False, ) # Point mode controls point_label = gr.Radio( choices=["Positive (include)", "Negative (exclude)"], value="Positive (include)", label="Point Label", visible=False, ) # Hidden states (kept for function signatures) use_label = gr.Checkbox(value=False, visible=False) label_text = gr.Textbox(value="", visible=False) click_info = gr.Textbox(value="", visible=False) with gr.Row(): clear_btn = gr.Button( "Clear Clicks", visible=True ) run_btn = gr.Button("Run Detection", variant="primary") # Intrinsics settings use_actual_K = gr.Checkbox( label="Use Actual Intrinsics (uncheck to use default)", value=False, ) with gr.Row(visible=False) as intrinsics_row: fx = gr.Number(label="fx", value=0) fy = gr.Number(label="fy", value=0) cx = gr.Number(label="cx", value=0) cy = gr.Number(label="cy", value=0) score_thres = gr.Slider( minimum=0, maximum=1, value=0.3, step=0.05, label="Score Threshold", ) # Right column: Output with gr.Column(scale=1): output_image = gr.Image( label="3D Detection Results", type="pil" ) bev_output = gr.HTML( show_label=False, value="
" "BEV view will appear after detection
", ) gr.Markdown( "*Predictions filtered by per-category NMS and " "cross-category NMS. " "For object shape and location, this is raw model " "output without alignment with point cloud " "or the model's own predicted depth.*", ) depth_image = gr.Image(label="Model Predicted Depth Map", type="pil") intrinsics_info = gr.Textbox( label="Intrinsics Info", interactive=False ) # Toggle visibility based on prompt mode def on_mode_change(mode, orig_img): is_text = mode == "Text" is_point = mode == "Point" is_box = mode in ("Box-to-Multi-Object", "Box-to-Single-Object") new_state = {"points": [], "box": []} restored_img = orig_img.copy() if orig_img is not None else None return ( gr.update(visible=is_text), gr.update(visible=is_box), gr.update(visible=is_point), gr.update(visible=is_point), gr.update(visible=not is_text), new_state, restored_img, ) prompt_mode.change( on_mode_change, inputs=[prompt_mode, original_image_state], outputs=[ text_prompt, box_caption, point_caption, point_label, clear_btn, click_state, input_image, ], ) # Toggle intrinsics input visibility def on_intrinsics_toggle(use_actual): return gr.update(visible=use_actual) use_actual_K.change( on_intrinsics_toggle, inputs=[use_actual_K], outputs=[intrinsics_row], ) # Connect events input_image.select( on_image_select, inputs=[ input_image, original_image_state, click_state, prompt_mode, point_label, ], outputs=[click_state, click_info, input_image], ) clear_btn.click( clear_clicks, inputs=[click_state, original_image_state], outputs=[click_state, click_info, input_image], ) # When new image is uploaded, save it as original def on_image_upload(image): if image is None: return None, {"points": [], "box": []}, "Upload an image" return ( image.copy(), {"points": [], "box": []}, "Image loaded - select mode and click", ) input_image.upload( on_image_upload, inputs=[input_image], outputs=[original_image_state, click_state, click_info], ) run_btn.click( run_wilddet3d, inputs=[ input_image, original_image_state, click_state, prompt_mode, text_prompt, use_label, label_text, score_thres, use_actual_K, fx, fy, cx, cy, ], outputs=[output_image, intrinsics_info, depth_image, bev_output], ) # ---- Terms agree handler ---- def on_agree(): return gr.update(visible=False), gr.update(visible=True) agree_btn.click( on_agree, inputs=[], outputs=[terms_page, main_app], ) if __name__ == "__main__": print("=" * 60) print("WildDet3D Web Demo") print("=" * 60) print() print("Starting server...") port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) demo.launch(share=False, server_name="0.0.0.0", server_port=port)