Spaces:
Running on Zero
Running on Zero
| """ | |
| VOID VLM-Mask-Reasoner β Quadmask Generation Demo | |
| Generates 4-level semantic masks for interaction-aware video inpainting. | |
| Pipeline from https://github.com/Netflix/void-model: | |
| Stage 1: SAM2 segmentation β black mask (transformers Sam2Model) | |
| Stage 2: Gemini VLM scene analysis β affected objects JSON (repo code) | |
| Stage 3: SAM3 text-prompted segmentation β grey mask (transformers Sam3Model) | |
| Stage 4: Combine black + grey β quadmask (0/63/127/255) (repo code) | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import tempfile | |
| import shutil | |
| import subprocess | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| import imageio | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import hf_hub_download | |
| import openai | |
| # ββ Add repo modules to path βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "VLM-MASK-REASONER")) | |
| # ββ Repo imports: Stage 2 (VLM) and Stage 4 (combine) ββββββββββββββββββββββββ | |
| from stage2_vlm_analysis import ( | |
| process_video as vlm_process_video, | |
| calculate_square_grid, | |
| ) | |
| from stage4_combine_masks import process_video as combine_process_video | |
| # Stage 3 helpers (grid logic, mask combination β not the SegmentationModel) | |
| from stage3a_generate_grey_masks_v2 import ( | |
| calculate_square_grid as calc_grid_3a, | |
| gridify_masks, | |
| filter_masks_by_proximity, | |
| segment_object_all_frames as _repo_segment_all_frames, | |
| process_video_grey_masks, | |
| ) | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SAM2_MODEL_ID = "facebook/sam2.1-hiera-large" | |
| SAM3_MODEL_ID = "jetjodh/sam3" | |
| DEFAULT_VLM_MODEL = "gemini-3-flash-preview" | |
| MAX_FRAMES = 197 | |
| FPS_DEFAULT = 12 | |
| FRAME_STRIDE = 4 # Process every Nth frame for SAM2 tracking | |
| # ββ Load transformers SAM2 (video model with propagation support) βββββββββββββ | |
| print("Loading SAM2 video model (transformers)...") | |
| from transformers import Sam2VideoModel, Sam2VideoProcessor | |
| from transformers.models.sam2_video.modeling_sam2_video import Sam2VideoInferenceSession | |
| sam2_model = Sam2VideoModel.from_pretrained(SAM2_MODEL_ID).to("cuda") | |
| sam2_processor = Sam2VideoProcessor.from_pretrained(SAM2_MODEL_ID) | |
| print("SAM2 video model ready.") | |
| # ββ Load transformers SAM3 βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading SAM3 model (transformers)...") | |
| from transformers import Sam3Model, Sam3Processor | |
| sam3_model = Sam3Model.from_pretrained(SAM3_MODEL_ID).to("cuda") | |
| sam3_processor = Sam3Processor.from_pretrained(SAM3_MODEL_ID) | |
| print("SAM3 ready.") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STAGE 1: SAM2 VIDEO SEGMENTATION (transformers Sam2VideoModel) | |
| # Uses proper video propagation with memory β matches repo's propagate_in_video | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def stage1_segment_video(frames: list, points: list, **kwargs) -> list: | |
| """Segment primary object across all video frames using SAM2 video propagation. | |
| Matches repo: point prompts + bounding box on frame 0, propagate through video. | |
| Returns list of uint8 masks (0=object, 255=background).""" | |
| total = len(frames) | |
| h, w = frames[0].shape[:2] | |
| # Preprocess all frames | |
| pil_frames = [Image.fromarray(f) for f in frames] | |
| inputs = sam2_processor(images=pil_frames, return_tensors="pt").to(sam2_model.device) | |
| # Create inference session with all frames | |
| session = Sam2VideoInferenceSession( | |
| video=inputs["pixel_values"], | |
| video_height=h, | |
| video_width=w, | |
| inference_device=sam2_model.device, | |
| inference_state_device=sam2_model.device, | |
| dtype=torch.float32, | |
| ) | |
| # Add point prompts + bounding box on frame 0 via processor | |
| # (handles normalization, object registration, and obj_with_new_inputs) | |
| pts = np.array(points, dtype=np.float32) | |
| x_min, x_max = pts[:, 0].min(), pts[:, 0].max() | |
| y_min, y_max = pts[:, 1].min(), pts[:, 1].max() | |
| x_margin = max((x_max - x_min) * 0.1, 10) | |
| y_margin = max((y_max - y_min) * 0.1, 10) | |
| box = [ | |
| max(0, x_min - x_margin), | |
| max(0, y_min - y_margin), | |
| min(w, x_max + x_margin), | |
| min(h, y_max + y_margin), | |
| ] | |
| sam2_processor.process_new_points_or_boxes_for_video_frame( | |
| inference_session=session, | |
| frame_idx=0, | |
| obj_ids=[1], | |
| input_points=[[[[float(p[0]), float(p[1])] for p in points]]], | |
| input_labels=[[[1] * len(points)]], | |
| input_boxes=[[[float(box[0]), float(box[1]), float(box[2]), float(box[3])]]], | |
| ) | |
| # Run forward on the prompted frame first (populates cond_frame_outputs) | |
| with torch.no_grad(): | |
| sam2_model(session, frame_idx=0) | |
| # Propagate through all frames (matches repo's propagate_in_video) | |
| video_segments = {} | |
| original_sizes = [[h, w]] | |
| with torch.no_grad(): | |
| for output in sam2_model.propagate_in_video_iterator(session): | |
| frame_idx = output.frame_idx | |
| # pred_masks shape varies β get the raw logits and resize to original | |
| mask_logits = output.pred_masks[0].cpu().float() # first object | |
| # Ensure 4D for interpolation: (1, 1, H_model, W_model) | |
| while mask_logits.dim() < 4: | |
| mask_logits = mask_logits.unsqueeze(0) | |
| mask_resized = torch.nn.functional.interpolate( | |
| mask_logits, size=(h, w), mode="bilinear", align_corners=False | |
| ) | |
| mask = (mask_resized.squeeze() > 0.0).numpy() | |
| video_segments[frame_idx] = mask | |
| # Convert to uint8 masks (0=object, 255=background) | |
| all_masks = [] | |
| for idx in range(total): | |
| if idx in video_segments: | |
| mask_bool = video_segments[idx] | |
| else: | |
| nearest = min(video_segments.keys(), key=lambda k: abs(k - idx)) | |
| mask_bool = video_segments[nearest] | |
| mask_uint8 = np.where(mask_bool, 0, 255).astype(np.uint8) | |
| all_masks.append(mask_uint8) | |
| return all_masks | |
| def write_mask_video(masks: list, fps: float, output_path: str): | |
| """Write list of uint8 grayscale masks to lossless MP4.""" | |
| h, w = masks[0].shape[:2] | |
| temp_avi = str(Path(output_path).with_suffix('.avi')) | |
| fourcc = cv2.VideoWriter_fourcc(*'FFV1') | |
| out = cv2.VideoWriter(temp_avi, fourcc, fps, (w, h), isColor=False) | |
| for mask in masks: | |
| out.write(mask) | |
| out.release() | |
| cmd = [ | |
| 'ffmpeg', '-y', '-i', temp_avi, | |
| '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast', | |
| '-pix_fmt', 'yuv444p', str(output_path), | |
| ] | |
| subprocess.run(cmd, capture_output=True) | |
| if os.path.exists(temp_avi): | |
| os.unlink(temp_avi) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STAGE 3: SAM3 TEXT-PROMPTED SEGMENTATION (transformers) | |
| # β Drop-in replacement for repo's SegmentationModel.segment() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TransformersSam3Segmenter: | |
| """Matches the interface of the repo's SegmentationModel for stage3a.""" | |
| model_type = "sam3" | |
| def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray: | |
| """Segment object by text prompt. Returns boolean mask.""" | |
| h, w = image_pil.height, image_pil.width | |
| union = np.zeros((h, w), dtype=bool) | |
| try: | |
| inputs = sam3_processor( | |
| images=image_pil, text=prompt, return_tensors="pt" | |
| ).to(sam3_model.device) | |
| with torch.no_grad(): | |
| outputs = sam3_model(**inputs) | |
| results = sam3_processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=0.3, | |
| mask_threshold=0.5, | |
| target_sizes=inputs.get("original_sizes").tolist(), | |
| )[0] | |
| masks = results.get("masks") | |
| if masks is not None and len(masks) > 0: | |
| if torch.is_tensor(masks): | |
| masks = masks.cpu().numpy() | |
| if masks.ndim == 2: | |
| union = masks.astype(bool) | |
| elif masks.ndim == 3: | |
| union = masks.any(axis=0).astype(bool) | |
| elif masks.ndim == 4: | |
| union = masks.any(axis=(0, 1)).astype(bool) | |
| except Exception as e: | |
| print(f" Warning: SAM3 segmentation failed for '{prompt}': {e}") | |
| return union | |
| seg_model = TransformersSam3Segmenter() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_frames(video_path: str, max_frames: int = MAX_FRAMES): | |
| """Extract frames from video. Returns (frames_rgb_list, fps).""" | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT | |
| frames = [] | |
| while len(frames) < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| return frames, fps | |
| def draw_points_on_image(image: np.ndarray, points: list, radius: int = 6) -> np.ndarray: | |
| pil_img = Image.fromarray(image.copy()) | |
| draw = ImageDraw.Draw(pil_img) | |
| for i, (x, y) in enumerate(points): | |
| r = radius | |
| draw.ellipse([x - r, y - r, x + r, y + r], fill="red", outline="white", width=2) | |
| draw.text((x + r + 2, y - r), str(i + 1), fill="white") | |
| return np.array(pil_img) | |
| def frames_to_video(frames: list, fps: float) -> str: | |
| tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| tmp_path = tmp.name | |
| tmp.close() | |
| writer = imageio.get_writer(tmp_path, fps=fps, codec='libx264', | |
| output_params=['-crf', '18', '-pix_fmt', 'yuv420p']) | |
| for frame in frames: | |
| writer.append_data(frame) | |
| writer.close() | |
| return tmp_path | |
| def create_quadmask_visualization(video_path: str, quadmask_path: str) -> str: | |
| cap_vid = cv2.VideoCapture(video_path) | |
| cap_qm = cv2.VideoCapture(quadmask_path) | |
| fps = cap_vid.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT | |
| vis_frames = [] | |
| while True: | |
| ret_v, frame = cap_vid.read() | |
| ret_q, qm_frame = cap_qm.read() | |
| if not ret_v or not ret_q: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| qm = cv2.cvtColor(qm_frame, cv2.COLOR_BGR2GRAY) if len(qm_frame.shape) == 3 else qm_frame | |
| qm = np.where(qm <= 31, 0, qm) | |
| qm = np.where((qm > 31) & (qm <= 95), 63, qm) | |
| qm = np.where((qm > 95) & (qm <= 191), 127, qm) | |
| qm = np.where(qm > 191, 255, qm) | |
| overlay = frame_rgb.copy() | |
| overlay[qm == 0] = [255, 50, 50] | |
| overlay[qm == 63] = [255, 200, 0] | |
| overlay[qm == 127] = [50, 255, 50] | |
| result = cv2.addWeighted(frame_rgb, 0.5, overlay, 0.5, 0) | |
| result[qm == 255] = frame_rgb[qm == 255] | |
| vis_frames.append(result) | |
| cap_vid.release() | |
| cap_qm.release() | |
| return frames_to_video(vis_frames, fps) if vis_frames else None | |
| def create_quadmask_color_video(quadmask_path: str) -> str: | |
| cap = cv2.VideoCapture(quadmask_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT | |
| color_frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| qm = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame | |
| qm = np.where(qm <= 31, 0, qm) | |
| qm = np.where((qm > 31) & (qm <= 95), 63, qm) | |
| qm = np.where((qm > 95) & (qm <= 191), 127, qm) | |
| qm = np.where(qm > 191, 255, qm) | |
| h, w = qm.shape | |
| color = np.full((h, w, 3), 255, dtype=np.uint8) | |
| color[qm == 0] = [0, 0, 0] | |
| color[qm == 63] = [80, 80, 80] | |
| color[qm == 127] = [160, 160, 160] | |
| color_frames.append(color) | |
| cap.release() | |
| return frames_to_video(color_frames, fps) if color_frames else None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN PIPELINE | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_pipeline(video_path: str, points_json: str, instruction: str, | |
| progress=gr.Progress(track_tqdm=False)): | |
| """Run the full VLM-Mask-Reasoner pipeline.""" | |
| if not video_path: | |
| raise gr.Error("Please upload a video.") | |
| if not points_json or points_json == "[]": | |
| raise gr.Error("Please click on the image to select at least one point on the primary object.") | |
| if not instruction.strip(): | |
| raise gr.Error("Please enter an edit instruction.") | |
| points = json.loads(points_json) | |
| if len(points) == 0: | |
| raise gr.Error("Please select at least one point on the primary object.") | |
| api_key = os.environ.get("GEMINI_API_KEY", "") | |
| # Create temp output directory | |
| output_dir = Path(tempfile.mkdtemp(prefix="void_quadmask_")) | |
| input_video_path = output_dir / "input_video.mp4" | |
| shutil.copy2(video_path, input_video_path) | |
| # ββ Stage 1: SAM2 Segmentation ββββββββββββββββββββββββββββββββββββββββββ | |
| progress(0.05, desc="Stage 1: SAM2 segmentation...") | |
| frames, fps = extract_frames(str(input_video_path)) | |
| if len(frames) < 2: | |
| raise gr.Error("Video must have at least 2 frames.") | |
| black_masks = stage1_segment_video(frames, points, stride=FRAME_STRIDE) | |
| black_mask_path = output_dir / "black_mask.mp4" | |
| write_mask_video(black_masks, fps, str(black_mask_path)) | |
| # Save first frame for VLM analysis | |
| first_frame_path = output_dir / "first_frame.jpg" | |
| cv2.imwrite(str(first_frame_path), cv2.cvtColor(frames[0], cv2.COLOR_RGB2BGR)) | |
| # Save segmentation metadata (Stage 2 expects this) | |
| seg_info = { | |
| "total_frames": len(frames), | |
| "frame_width": frames[0].shape[1], | |
| "frame_height": frames[0].shape[0], | |
| "fps": fps, | |
| "video_path": str(input_video_path), | |
| "instruction": instruction, | |
| "primary_points_by_frame": {"0": points}, | |
| "first_appears_frame": 0, | |
| } | |
| with open(output_dir / "segmentation_info.json", 'w') as f: | |
| json.dump(seg_info, f, indent=2) | |
| progress(0.3, desc="Stage 1 complete.") | |
| # ββ Stage 2: VLM Analysis (repo code) βββββββββββββββββββββββββββββββββββ | |
| analysis = None | |
| if api_key: | |
| progress(0.35, desc="Stage 2: VLM analysis (calling Gemini)...") | |
| try: | |
| video_info = { | |
| "video_path": str(input_video_path), | |
| "instruction": instruction, | |
| "output_dir": str(output_dir), | |
| "multi_frame_grids": True, | |
| } | |
| client = openai.OpenAI( | |
| api_key=api_key, | |
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", | |
| ) | |
| analysis = vlm_process_video(video_info, client, DEFAULT_VLM_MODEL) | |
| progress(0.55, desc="Stage 2 complete.") | |
| except Exception as e: | |
| gr.Warning(f"VLM analysis failed: {e}. Generating binary mask only.") | |
| analysis = None | |
| else: | |
| gr.Warning("No GEMINI_API_KEY set. Generating binary mask only (no VLM analysis).") | |
| # ββ Stage 3: Grey Mask Generation (repo logic + transformers SAM3) ββββββ | |
| grey_mask_path = output_dir / "grey_mask.mp4" | |
| vlm_analysis_path = output_dir / "vlm_analysis.json" | |
| if analysis and vlm_analysis_path.exists(): | |
| progress(0.6, desc="Stage 3: Generating grey masks (SAM3 segmentation)...") | |
| try: | |
| video_info_3 = { | |
| "video_path": str(input_video_path), | |
| "output_dir": str(output_dir), | |
| "min_grid": 8, | |
| } | |
| # Uses the repo's process_video_grey_masks with our TransformersSam3Segmenter | |
| process_video_grey_masks(video_info_3, seg_model) | |
| progress(0.8, desc="Stage 3 complete.") | |
| except Exception as e: | |
| gr.Warning(f"Stage 3 failed: {e}. Generating binary mask only.") | |
| # ββ Stage 4: Combine into Quadmask (repo code) βββββββββββββββββββββββββ | |
| quadmask_path = output_dir / "quadmask_0.mp4" | |
| if grey_mask_path.exists(): | |
| progress(0.85, desc="Stage 4: Combining into quadmask...") | |
| combine_process_video(black_mask_path, grey_mask_path, quadmask_path) | |
| else: | |
| shutil.copy2(black_mask_path, quadmask_path) | |
| progress(0.9, desc="Creating visualizations...") | |
| # ββ Visualization outputs βββββββββββββββββββββββββββββββββββββββββββββββ | |
| overlay_path = create_quadmask_visualization(str(input_video_path), str(quadmask_path)) | |
| color_path = create_quadmask_color_video(str(quadmask_path)) | |
| analysis_text = "" | |
| if vlm_analysis_path.exists(): | |
| with open(vlm_analysis_path) as f: | |
| analysis_text = f.read() | |
| else: | |
| analysis_text = "No VLM analysis available." | |
| progress(1.0, desc="Done!") | |
| return str(quadmask_path), overlay_path, color_path, analysis_text | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_video_upload(video_path): | |
| if not video_path: | |
| return None, None, "[]", gr.update(interactive=False) | |
| frames, _ = extract_frames(video_path, max_frames=1) | |
| if not frames: | |
| return None, None, "[]", gr.update(interactive=False) | |
| return frames[0], frames[0], "[]", gr.update(interactive=True) | |
| def on_frame_select(clean_frame, points_json, evt: gr.SelectData): | |
| if clean_frame is None: | |
| return None, points_json | |
| points = json.loads(points_json) if points_json else [] | |
| x, y = evt.index | |
| points.append([int(x), int(y)]) | |
| annotated = draw_points_on_image(clean_frame, points) | |
| return annotated, json.dumps(points) | |
| def on_clear_points(clean_frame): | |
| if clean_frame is not None: | |
| return clean_frame, "[]" | |
| return None, "[]" | |
| DESCRIPTION = """ | |
| # VOID VLM-Mask-Reasoner β Quadmask Generation | |
| Generate **4-level semantic masks** (quadmasks) for interaction-aware video inpainting with [VOID](https://github.com/Netflix/void-model). | |
| **Pipeline:** Click points on object β SAM2 segments it β Gemini VLM reasons about interactions β SAM3 segments affected objects β Quadmask generated | |
| Use the generated quadmask with the [VOID inpainting demo](https://huggingface.co/spaces/sam-motamed/VOID). | |
| """ | |
| QUADMASK_EXPLAINER = """ | |
| ### Quadmask format | |
| | Pixel Value | Color | Meaning | | |
| |-------------|-------|---------| | |
| | **0** (black) | Red overlay | Primary object to remove | | |
| | **63** (dark grey) | Yellow overlay | Overlap of primary + affected zone | | |
| | **127** (mid grey) | Green overlay | Affected region (shadows, reflections, physics) | | |
| | **255** (white) | Original | Background β keep as-is | | |
| """ | |
| with gr.Blocks(title="VOID VLM-Mask-Reasoner", theme=gr.themes.Default()) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| points_state = gr.Textbox(value="[]", visible=False, label="points_json_api", elem_id="points_json_api") | |
| clean_frame_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_input = gr.Video(label="Upload Video", sources=["upload"]) | |
| frame_display = gr.Image( | |
| label="Click to select primary object points (click multiple spots on the object)", | |
| interactive=True, type="numpy", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Points", size="sm") | |
| points_display = gr.Textbox(label="Selected Points", value="[]", | |
| interactive=False, max_lines=2) | |
| instruction_input = gr.Textbox( | |
| label="Edit instruction β describe what to remove", | |
| placeholder="e.g., remove the person", lines=1, | |
| ) | |
| generate_btn = gr.Button("Generate Quadmask", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_quadmask_file = gr.File(label="Download lossless quadmask_0.mp4 (use this with VOID)") | |
| with gr.Tabs(): | |
| with gr.TabItem("Quadmask Overlay"): | |
| output_overlay = gr.Video(label="Quadmask overlay on original video") | |
| with gr.TabItem("Raw Quadmask"): | |
| output_color = gr.Video(label="Color-coded quadmask") | |
| with gr.TabItem("VLM Analysis"): | |
| output_analysis = gr.Code(label="VLM Analysis JSON", language="json") | |
| video_input.change( | |
| fn=on_video_upload, inputs=[video_input], | |
| outputs=[frame_display, clean_frame_state, points_state, generate_btn], | |
| ) | |
| points_state.change(lambda p: p, inputs=points_state, outputs=points_display) | |
| frame_display.select( | |
| fn=on_frame_select, inputs=[clean_frame_state, points_state], | |
| outputs=[frame_display, points_state], | |
| ) | |
| clear_btn.click( | |
| fn=on_clear_points, inputs=[clean_frame_state], | |
| outputs=[frame_display, points_state], | |
| ) | |
| generate_btn.click( | |
| fn=run_pipeline, inputs=[video_input, points_state, instruction_input], | |
| outputs=[output_quadmask_file, output_overlay, output_color, output_analysis], | |
| ) | |
| gr.Markdown(QUADMASK_EXPLAINER) | |
| if __name__ == "__main__": | |
| demo.launch() | |