multimodalart HF Staff Claude Opus 4.6 (1M context) commited on
Commit
535cf33
Β·
1 Parent(s): 0edb2a4

Add VOID VLM-Mask-Reasoner quadmask generation demo

Browse files

4-stage pipeline matching Netflix VOID repo:
- Stage 1: SAM2 video segmentation (transformers Sam2VideoModel)
- Stage 2: Gemini VLM scene analysis (repo code directly)
- Stage 3: SAM3 text-prompted segmentation (transformers Sam3Model)
- Stage 4: Lossless quadmask combination (repo code directly)

Gradio UI with point-click selection, progress tracking,
overlay visualization, and lossless quadmask download.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: VOID Quadmask Reasoner
3
- emoji: πŸš€
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.11.0
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: VOID Quadmask Reasoner
3
+ emoji: 🎭
4
+ colorFrom: gray
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 6.11.0
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
+ license: mit
12
+ short_description: 'VLM-Mask-Reasoner: Generate quadmasks for VOID video inpainting'
13
  ---
 
 
VLM-MASK-REASONER/README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VLM Mask Reasoner β€” Mask Generation Pipeline
2
+
3
+ Generates quadmasks for video inpainting by combining user-guided SAM2 segmentation with VLM (Gemini) scene reasoning. The output `quadmask_0.mp4` encodes four semantic layers that the inpainting model uses to understand what to remove and what to preserve.
4
+
5
+ ---
6
+
7
+ ## Quadmask Values
8
+
9
+ | Value | Meaning |
10
+ |-------|---------|
11
+ | `0` | Primary object (to be removed) |
12
+ | `63` | Overlap of primary and affected regions |
13
+ | `127` | Affected objects (shadows, reflections, held items) |
14
+ | `255` | Background β€” keep as-is |
15
+
16
+ ---
17
+
18
+ ## Step 1 β€” Select Points via GUI
19
+
20
+ Launch the point selector GUI:
21
+
22
+ ```bash
23
+ python point_selector_gui.py
24
+ ```
25
+
26
+ Use this GUI to place sparse click points on the object(s) you want removed.
27
+
28
+ **A few things to know:**
29
+
30
+ - Points can be placed on **any frame**, not just the first. If the object you want to remove only appears later in the video, navigate to that frame and click there.
31
+ - You can place points across **multiple frames** β€” useful when there are multiple distinct objects to remove, or when an object's position shifts significantly over time.
32
+ - The GUI saves your selections to a `config_points.json` file. Keep track of where this is saved β€” you'll pass it to the pipeline next.
33
+
34
+ ---
35
+
36
+ ## Step 2 β€” Run the Pipeline
37
+
38
+ Once you have your `config_points.json`, run all stages with a single command:
39
+
40
+ ```bash
41
+ bash run_pipeline.sh <config_points.json>
42
+ ```
43
+
44
+ Optional flags:
45
+
46
+ ```bash
47
+ bash run_pipeline.sh <config_points.json> \
48
+ --sam2-checkpoint ../sam2_hiera_large.pt \
49
+ --device cuda
50
+ ```
51
+
52
+ This runs four stages automatically:
53
+
54
+ 1. **Stage 1 β€” SAM2 Segmentation:** Propagates your point clicks into a per-frame black mask for the primary object.
55
+ 2. **Stage 2 β€” VLM Analysis (Gemini):** Analyzes the scene to identify affected objects β€” things like shadows, reflections, or items the primary object is interacting with.
56
+ 3. **Stage 3 β€” Grey Mask Generation:** Produces a grey mask track for the affected objects identified in Stage 2.
57
+ 4. **Stage 4 β€” Combine into Quadmask:** Merges the black and grey masks into the final `quadmask_0.mp4`.
58
+
59
+ The output `quadmask_0.mp4` is written into each video's `output_dir` as specified in the config.
60
+
61
+ > **Note on grey values in frame 1:** The inpainting model was trained with grey-valued regions (`127`) starting from frame 1 onward β€” not on the very first frame. We find this convention improves inference quality, so the pipeline automatically clears any grey pixels from frame 0 of the final quadmask before saving.
62
+
63
+ ---
64
+
65
+ ## Step 3 (Optional) β€” Manual Mask Correction
66
+
67
+ If the generated quadmask needs refinement, you can correct it interactively:
68
+
69
+ ```bash
70
+ python edit_quadmask.py
71
+ ```
72
+
73
+ Point the GUI to the folder containing `quadmask_0.mp4`. You can paint over regions frame-by-frame to fix any mask errors before running inference. The corrected mask is saved back to `quadmask_0.mp4` in the same folder.
74
+
75
+ ---
76
+
77
+ ## Installation & Dependencies
78
+
79
+ ### 1. Python dependencies
80
+
81
+ Install the main requirements from the repo root:
82
+
83
+ ```bash
84
+ pip install -r requirements.txt
85
+ ```
86
+
87
+ ### 2. SAM2
88
+
89
+ SAM2 must be installed separately (it is not on PyPI):
90
+
91
+ ```bash
92
+ pip install git+https://github.com/facebookresearch/segment-anything-2.git
93
+ ```
94
+
95
+ Then download the SAM2 checkpoint. The pipeline defaults to `sam2_hiera_large.pt` one level above this directory:
96
+
97
+ ```bash
98
+ # from the repo root (or wherever you want to store it)
99
+ wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
100
+ ```
101
+
102
+ If you place the checkpoint elsewhere, pass it explicitly:
103
+
104
+ ```bash
105
+ bash run_pipeline.sh config_points.json --sam2-checkpoint /path/to/sam2_hiera_large.pt
106
+ ```
107
+
108
+ > SAM2 requires **Python β‰₯ 3.10** and **PyTorch β‰₯ 2.3.1** with CUDA. See the [SAM2 repo](https://github.com/facebookresearch/segment-anything-2) for full system requirements.
109
+
110
+ ### 3. Gemini API key
111
+
112
+ Stage 2 uses the Gemini VLM. Export your API key before running the pipeline:
113
+
114
+ ```bash
115
+ export GEMINI_API_KEY="your_key_here"
116
+ ```
VLM-MASK-REASONER/__pycache__/stage1_sam2_segmentation.cpython-312.pyc ADDED
Binary file (19.3 kB). View file
 
VLM-MASK-REASONER/__pycache__/stage2_vlm_analysis.cpython-312.pyc ADDED
Binary file (42.6 kB). View file
 
VLM-MASK-REASONER/__pycache__/stage3a_generate_grey_masks_v2.cpython-312.pyc ADDED
Binary file (26 kB). View file
 
VLM-MASK-REASONER/__pycache__/stage4_combine_masks.cpython-312.pyc ADDED
Binary file (9.67 kB). View file
 
VLM-MASK-REASONER/edit_quadmask.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Mask Editor GUI - Edit gridified video masks with grid toggling and brush tools
4
+ """
5
+ import cv2
6
+ import numpy as np
7
+ import tkinter as tk
8
+ from tkinter import ttk, filedialog, messagebox
9
+ from PIL import Image, ImageTk
10
+ import subprocess
11
+ from pathlib import Path
12
+ import copy
13
+ import time
14
+
15
+ class MaskEditorGUI:
16
+ def __init__(self, root):
17
+ self.root = root
18
+ self.root.title("Mask Editor")
19
+
20
+ # Video data
21
+ self.rgb_frames = []
22
+ self.mask_frames = []
23
+ self.current_frame = 0
24
+ self.grid_rows = 0
25
+ self.grid_cols = 0
26
+ self.min_grid = 8
27
+
28
+ # Edit state
29
+ self.undo_stack = []
30
+ self.redo_stack = []
31
+ self.current_tool = "grid" # "grid" or "brush"
32
+ self.brush_size = 20
33
+ self.brush_mode = "add" # "add" or "erase"
34
+
35
+ # Display state
36
+ self.display_scale = 1.0
37
+ self.rgb_photo = None
38
+ self.mask_photo = None
39
+ self.dragging = False
40
+ self.last_brush_pos = None
41
+ self.last_update_time = 0
42
+ self.update_interval = 0.2 # Update every 200ms during dragging (5 FPS - less choppy)
43
+ self.cached_rgb_frame = None # Cache current RGB frame
44
+ self.cached_frame_idx = -1 # Track which frame is cached
45
+ self.pending_update = False # Track if update is needed after drag
46
+ self.brush_repeat_id = None # Timer for continuous brush application
47
+
48
+ # Paths
49
+ self.folder_path = None
50
+ self.mask_path = None
51
+ self.rgb_path = None
52
+
53
+ self.setup_ui()
54
+
55
+ def setup_ui(self):
56
+ """Setup the GUI layout"""
57
+ # Menu bar
58
+ menubar = tk.Menu(self.root)
59
+ self.root.config(menu=menubar)
60
+
61
+ file_menu = tk.Menu(menubar, tearoff=0)
62
+ menubar.add_cascade(label="File", menu=file_menu)
63
+ file_menu.add_command(label="Open Folder", command=self.load_folder)
64
+ file_menu.add_command(label="Save Mask", command=self.save_mask)
65
+ file_menu.add_separator()
66
+ file_menu.add_command(label="Exit", command=self.root.quit)
67
+
68
+ edit_menu = tk.Menu(menubar, tearoff=0)
69
+ menubar.add_cascade(label="Edit", menu=edit_menu)
70
+ edit_menu.add_command(label="Undo", command=self.undo, accelerator="Ctrl+Z")
71
+ edit_menu.add_command(label="Redo", command=self.redo, accelerator="Ctrl+Y")
72
+
73
+ # Keyboard shortcuts
74
+ self.root.bind("<Control-z>", lambda e: self.undo())
75
+ self.root.bind("<Control-y>", lambda e: self.redo())
76
+ self.root.bind("<Left>", lambda e: self.prev_frame())
77
+ self.root.bind("<Right>", lambda e: self.next_frame())
78
+
79
+ # Top toolbar
80
+ toolbar = ttk.Frame(self.root)
81
+ toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
82
+
83
+ ttk.Label(toolbar, text="Folder:").pack(side=tk.LEFT)
84
+ self.folder_label = ttk.Label(toolbar, text="None", foreground="gray")
85
+ self.folder_label.pack(side=tk.LEFT, padx=5)
86
+
87
+ ttk.Button(toolbar, text="Open Folder", command=self.load_folder).pack(side=tk.LEFT, padx=5)
88
+ ttk.Button(toolbar, text="Save Mask", command=self.save_mask).pack(side=tk.LEFT, padx=5)
89
+
90
+ # Main content area
91
+ content = ttk.Frame(self.root)
92
+ content.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
93
+
94
+ # Left panel - Original video
95
+ left_panel = ttk.LabelFrame(content, text="Original Video")
96
+ left_panel.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
97
+
98
+ self.rgb_canvas = tk.Canvas(left_panel, width=640, height=480, bg='black')
99
+ self.rgb_canvas.pack(fill=tk.BOTH, expand=True)
100
+
101
+ # Right panel - Mask
102
+ right_panel = ttk.LabelFrame(content, text="Mask (Editable)")
103
+ right_panel.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
104
+
105
+ self.mask_canvas = tk.Canvas(right_panel, width=640, height=480, bg='black')
106
+ self.mask_canvas.pack(fill=tk.BOTH, expand=True)
107
+ self.mask_canvas.bind("<Button-1>", self.on_mask_click)
108
+ self.mask_canvas.bind("<B1-Motion>", self.on_mask_drag)
109
+ self.mask_canvas.bind("<ButtonRelease-1>", self.on_mask_release)
110
+
111
+ # Bottom controls
112
+ controls = ttk.Frame(self.root)
113
+ controls.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=5)
114
+
115
+ # Frame navigation
116
+ nav_frame = ttk.LabelFrame(controls, text="Frame Navigation")
117
+ nav_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
118
+
119
+ ttk.Button(nav_frame, text="<<", command=self.first_frame, width=5).pack(side=tk.LEFT, padx=2)
120
+ ttk.Button(nav_frame, text="<", command=self.prev_frame, width=5).pack(side=tk.LEFT, padx=2)
121
+
122
+ self.frame_label = ttk.Label(nav_frame, text="Frame: 0 / 0")
123
+ self.frame_label.pack(side=tk.LEFT, padx=10)
124
+
125
+ ttk.Button(nav_frame, text=">", command=self.next_frame, width=5).pack(side=tk.LEFT, padx=2)
126
+ ttk.Button(nav_frame, text=">>", command=self.last_frame, width=5).pack(side=tk.LEFT, padx=2)
127
+
128
+ self.frame_slider = ttk.Scale(nav_frame, from_=0, to=100, orient=tk.HORIZONTAL,
129
+ command=self.on_slider_change)
130
+ self.frame_slider.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=10)
131
+
132
+ # Tool selection
133
+ tool_frame = ttk.LabelFrame(controls, text="Tools")
134
+ tool_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
135
+
136
+ self.tool_var = tk.StringVar(value="grid")
137
+ ttk.Radiobutton(tool_frame, text="Grid Toggle", variable=self.tool_var,
138
+ value="grid", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
139
+ ttk.Radiobutton(tool_frame, text="Grid Black Toggle", variable=self.tool_var,
140
+ value="grid_black", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
141
+ ttk.Radiobutton(tool_frame, text="Brush (Add Black)", variable=self.tool_var,
142
+ value="brush_add", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
143
+ ttk.Radiobutton(tool_frame, text="Brush (Erase Black)", variable=self.tool_var,
144
+ value="brush_erase", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
145
+
146
+ ttk.Label(tool_frame, text="Brush Size:").pack(side=tk.LEFT, padx=10)
147
+ self.brush_slider = ttk.Scale(tool_frame, from_=5, to=100, orient=tk.HORIZONTAL,
148
+ command=self.on_brush_size_change)
149
+ self.brush_slider.set(20)
150
+ self.brush_slider.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5)
151
+ self.brush_size_label = ttk.Label(tool_frame, text="20")
152
+ self.brush_size_label.pack(side=tk.LEFT, padx=5)
153
+
154
+ # Copy from previous frame
155
+ copy_frame = ttk.LabelFrame(controls, text="Copy from Previous Frame")
156
+ copy_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
157
+
158
+ ttk.Button(copy_frame, text="Copy Black Mask",
159
+ command=self.copy_black_from_previous).pack(side=tk.LEFT, padx=5)
160
+ ttk.Button(copy_frame, text="Copy Grey Mask",
161
+ command=self.copy_grey_from_previous).pack(side=tk.LEFT, padx=5)
162
+
163
+ # Info panel
164
+ info_frame = ttk.Frame(controls)
165
+ info_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
166
+
167
+ self.info_label = ttk.Label(info_frame, text="Load a folder to begin", foreground="blue")
168
+ self.info_label.pack(side=tk.LEFT, padx=5)
169
+
170
+ self.grid_info_label = ttk.Label(info_frame, text="Grid: N/A")
171
+ self.grid_info_label.pack(side=tk.RIGHT, padx=5)
172
+
173
+ def calculate_square_grid(self, width, height, min_grid=8):
174
+ """Calculate grid dimensions to make square cells"""
175
+ aspect_ratio = width / height
176
+
177
+ if width >= height:
178
+ grid_rows = min_grid
179
+ grid_cols = max(min_grid, round(min_grid * aspect_ratio))
180
+ else:
181
+ grid_cols = min_grid
182
+ grid_rows = max(min_grid, round(min_grid / aspect_ratio))
183
+
184
+ return grid_rows, grid_cols
185
+
186
+ def load_folder(self):
187
+ """Load a folder containing rgb_full.mp4/input_video.mp4 and quadmask_0.mp4"""
188
+ folder = filedialog.askdirectory(title="Select Folder")
189
+ if not folder:
190
+ return
191
+
192
+ folder_path = Path(folder)
193
+
194
+ # Find RGB video
195
+ rgb_path = None
196
+ for name in ["rgb_full.mp4", "input_video.mp4"]:
197
+ candidate = folder_path / name
198
+ if candidate.exists():
199
+ rgb_path = candidate
200
+ break
201
+
202
+ mask_path = folder_path / "quadmask_0.mp4"
203
+
204
+ if not rgb_path or not mask_path.exists():
205
+ messagebox.showerror("Error", "Folder must contain quadmask_0.mp4 and rgb_full.mp4 or input_video.mp4")
206
+ return
207
+
208
+ self.folder_path = folder_path
209
+ self.rgb_path = rgb_path
210
+ self.mask_path = mask_path
211
+
212
+ # Load videos
213
+ self.load_videos()
214
+
215
+ def load_videos(self):
216
+ """Load RGB and mask videos into memory"""
217
+ self.info_label.config(text="Loading videos...")
218
+ self.root.update()
219
+
220
+ # Load RGB frames
221
+ self.rgb_frames = self.read_video_frames(self.rgb_path)
222
+
223
+ # Load mask frames
224
+ self.mask_frames = self.read_video_frames(self.mask_path)
225
+
226
+ if len(self.rgb_frames) != len(self.mask_frames):
227
+ messagebox.showwarning("Warning",
228
+ f"Frame count mismatch: RGB={len(self.rgb_frames)}, Mask={len(self.mask_frames)}")
229
+
230
+ if len(self.mask_frames) == 0:
231
+ messagebox.showerror("Error", "No frames loaded")
232
+ return
233
+
234
+ # Calculate grid dimensions
235
+ height, width = self.mask_frames[0].shape[:2]
236
+ self.grid_rows, self.grid_cols = self.calculate_square_grid(width, height, self.min_grid)
237
+
238
+ # Calculate display scale
239
+ max_width = 600
240
+ max_height = 450
241
+ scale_w = max_width / width
242
+ scale_h = max_height / height
243
+ self.display_scale = min(scale_w, scale_h, 1.0)
244
+
245
+ # Update UI
246
+ self.folder_label.config(text=self.folder_path.name, foreground="black")
247
+ self.grid_info_label.config(text=f"Grid: {self.grid_rows}x{self.grid_cols}")
248
+ self.frame_slider.config(to=len(self.mask_frames)-1)
249
+ self.current_frame = 0
250
+ self.undo_stack = []
251
+ self.redo_stack = []
252
+
253
+ self.update_display()
254
+ self.info_label.config(text=f"Loaded {len(self.mask_frames)} frames", foreground="green")
255
+
256
+ def read_video_frames(self, video_path):
257
+ """Read all frames from a video"""
258
+ cap = cv2.VideoCapture(str(video_path))
259
+ frames = []
260
+ while True:
261
+ ret, frame = cap.read()
262
+ if not ret:
263
+ break
264
+ # Convert to grayscale if needed
265
+ if len(frame.shape) == 3:
266
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
267
+ frames.append(frame)
268
+ cap.release()
269
+ return frames
270
+
271
+ def write_video_frames(self, frames, output_path, fps=12):
272
+ """Write frames to a video file using lossless H.264"""
273
+ if not frames:
274
+ return
275
+
276
+ height, width = frames[0].shape[:2]
277
+
278
+ # Write temp AVI first
279
+ temp_avi = output_path.with_suffix('.avi')
280
+ fourcc = cv2.VideoWriter_fourcc(*'FFV1')
281
+ out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (width, height), isColor=False)
282
+
283
+ for frame in frames:
284
+ if len(frame.shape) == 3:
285
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
286
+ out.write(frame)
287
+
288
+ out.release()
289
+
290
+ # Convert to LOSSLESS H.264 (qp=0)
291
+ cmd = [
292
+ 'ffmpeg', '-y', '-i', str(temp_avi),
293
+ '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
294
+ '-pix_fmt', 'yuv444p', '-r', '12',
295
+ str(output_path)
296
+ ]
297
+ subprocess.run(cmd, capture_output=True)
298
+ temp_avi.unlink()
299
+
300
+ def update_display(self, fast_mode=False):
301
+ """Update both canvas displays (or just mask in fast mode)"""
302
+ if not self.mask_frames:
303
+ return
304
+
305
+ # Cache RGB frame if needed (only in full mode)
306
+ if not fast_mode and self.cached_frame_idx != self.current_frame:
307
+ if self.current_frame < len(self.rgb_frames):
308
+ rgb_frame = self.rgb_frames[self.current_frame]
309
+ self.cached_rgb_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_GRAY2RGB) if len(rgb_frame.shape) == 2 else rgb_frame.copy()
310
+ self.cached_frame_idx = self.current_frame
311
+ else:
312
+ self.cached_rgb_frame = None
313
+
314
+ if not fast_mode:
315
+ # Update frame label
316
+ self.frame_label.config(text=f"Frame: {self.current_frame + 1} / {len(self.mask_frames)}")
317
+ self.frame_slider.set(self.current_frame)
318
+
319
+ # Display RGB frame
320
+ if self.cached_rgb_frame is not None:
321
+ rgb_display = cv2.resize(self.cached_rgb_frame, None, fx=self.display_scale, fy=self.display_scale)
322
+ rgb_image = Image.fromarray(rgb_display)
323
+ self.rgb_photo = ImageTk.PhotoImage(rgb_image)
324
+ self.rgb_canvas.delete("all")
325
+ self.rgb_canvas.create_image(0, 0, anchor=tk.NW, image=self.rgb_photo)
326
+
327
+ # Display mask frame with grid overlay
328
+ mask_frame = self.mask_frames[self.current_frame]
329
+ # Use simple mode (no RGB blending) during fast updates for speed
330
+ mask_display = self.create_mask_visualization(mask_frame, self.cached_rgb_frame, simple_mode=fast_mode)
331
+ mask_display = cv2.resize(mask_display, None, fx=self.display_scale, fy=self.display_scale)
332
+ mask_image = Image.fromarray(mask_display)
333
+ self.mask_photo = ImageTk.PhotoImage(mask_image)
334
+ self.mask_canvas.delete("all")
335
+ self.mask_canvas.create_image(0, 0, anchor=tk.NW, image=self.mask_photo)
336
+
337
+ # Draw grid overlay (skip in fast mode for performance)
338
+ if not fast_mode:
339
+ self.draw_grid_overlay()
340
+
341
+ def create_mask_visualization(self, mask_frame, rgb_frame=None, simple_mode=False):
342
+ """Create RGB visualization of mask with color coding and RGB background"""
343
+ height, width = mask_frame.shape
344
+ vis = np.zeros((height, width, 3), dtype=np.uint8)
345
+
346
+ if simple_mode:
347
+ # Fast simple mode - no blending, just solid colors
348
+ vis[mask_frame == 255] = [150, 150, 150] # Background - gray
349
+ vis[mask_frame == 127] = [0, 200, 0] # Gridified - green
350
+ vis[mask_frame == 63] = [200, 200, 0] # Overlap - yellow
351
+ vis[mask_frame == 0] = [200, 0, 0] # Black - red
352
+ return vis
353
+
354
+ # If RGB frame is provided, use it as background
355
+ if rgb_frame is not None:
356
+ if len(rgb_frame.shape) == 2:
357
+ # Convert grayscale to RGB
358
+ rgb_background = cv2.cvtColor(rgb_frame, cv2.COLOR_GRAY2RGB)
359
+ else:
360
+ rgb_background = rgb_frame.copy()
361
+
362
+ # Use it at 50% opacity as base
363
+ vis = (rgb_background * 0.5).astype(np.uint8)
364
+
365
+ # Color coding with transparency to show background:
366
+ # 0 (black) -> Red tint (to indicate removal area)
367
+ # 63 (overlap) -> Yellow tint
368
+ # 127 (gridified) -> Green tint
369
+ # 255 (background) -> Keep RGB background visible
370
+
371
+ # Background areas - show RGB at 60% brightness
372
+ bg_mask = mask_frame == 255
373
+ if rgb_frame is not None:
374
+ vis[bg_mask] = (rgb_background[bg_mask] * 0.6).astype(np.uint8)
375
+ else:
376
+ vis[bg_mask] = [150, 150, 150]
377
+
378
+ # Green overlay for gridified areas - blend 40% background + 60% green tint
379
+ green_mask = mask_frame == 127
380
+ if rgb_frame is not None:
381
+ vis[green_mask] = np.clip(rgb_background[green_mask] * 0.4 + np.array([0, 180, 0]) * 0.6, 0, 255).astype(np.uint8)
382
+ else:
383
+ vis[green_mask] = [0, 200, 0]
384
+
385
+ # Yellow overlay for overlap areas - blend 40% background + 60% yellow tint
386
+ yellow_mask = mask_frame == 63
387
+ if rgb_frame is not None:
388
+ vis[yellow_mask] = np.clip(rgb_background[yellow_mask] * 0.4 + np.array([180, 180, 0]) * 0.6, 0, 255).astype(np.uint8)
389
+ else:
390
+ vis[yellow_mask] = [200, 200, 0]
391
+
392
+ # Red tint for black areas (removal) - blend 30% background + 70% red tint
393
+ black_mask = mask_frame == 0
394
+ if rgb_frame is not None:
395
+ vis[black_mask] = np.clip(rgb_background[black_mask] * 0.3 + np.array([200, 0, 0]) * 0.7, 0, 255).astype(np.uint8)
396
+ else:
397
+ vis[black_mask] = [200, 0, 0]
398
+
399
+ return vis
400
+
401
+ def draw_grid_overlay(self):
402
+ """Draw grid lines on mask canvas"""
403
+ if not self.mask_frames:
404
+ return
405
+
406
+ height, width = self.mask_frames[0].shape
407
+
408
+ scaled_width = int(width * self.display_scale)
409
+ scaled_height = int(height * self.display_scale)
410
+
411
+ cell_width = scaled_width / self.grid_cols
412
+ cell_height = scaled_height / self.grid_rows
413
+
414
+ # Draw vertical lines
415
+ for col in range(self.grid_cols + 1):
416
+ x = int(col * cell_width)
417
+ self.mask_canvas.create_line(x, 0, x, scaled_height, fill='red', width=1, tags='grid')
418
+
419
+ # Draw horizontal lines
420
+ for row in range(self.grid_rows + 1):
421
+ y = int(row * cell_height)
422
+ self.mask_canvas.create_line(0, y, scaled_width, y, fill='red', width=1, tags='grid')
423
+
424
+ def get_grid_from_pos(self, x, y):
425
+ """Get grid row, col from canvas position"""
426
+ if not self.mask_frames:
427
+ return None, None
428
+
429
+ height, width = self.mask_frames[0].shape
430
+
431
+ # Convert to frame coordinates
432
+ frame_x = int(x / self.display_scale)
433
+ frame_y = int(y / self.display_scale)
434
+
435
+ if frame_x < 0 or frame_x >= width or frame_y < 0 or frame_y >= height:
436
+ return None, None
437
+
438
+ cell_width = width / self.grid_cols
439
+ cell_height = height / self.grid_rows
440
+
441
+ col = int(frame_x / cell_width)
442
+ row = int(frame_y / cell_height)
443
+
444
+ return row, col
445
+
446
+ def toggle_grid(self, row, col):
447
+ """Toggle a grid cell between 127 and 255, handling 63 overlaps"""
448
+ if row is None or col is None:
449
+ return
450
+
451
+ if row < 0 or row >= self.grid_rows or col < 0 or col >= self.grid_cols:
452
+ return
453
+
454
+ # Save state for undo
455
+ self.save_state()
456
+
457
+ mask = self.mask_frames[self.current_frame]
458
+ height, width = mask.shape
459
+
460
+ cell_width = width / self.grid_cols
461
+ cell_height = height / self.grid_rows
462
+
463
+ y1 = int(row * cell_height)
464
+ y2 = int((row + 1) * cell_height)
465
+ x1 = int(col * cell_width)
466
+ x2 = int((col + 1) * cell_width)
467
+
468
+ grid_region = mask[y1:y2, x1:x2]
469
+
470
+ # Check if grid has any 127 or 63 values
471
+ has_active = np.any((grid_region == 127) | (grid_region == 63))
472
+
473
+ if has_active:
474
+ # Turn OFF: 127->255, 63->0, keep 0 and 255 as is
475
+ mask[y1:y2, x1:x2] = np.where(grid_region == 127, 255,
476
+ np.where(grid_region == 63, 0, grid_region))
477
+ else:
478
+ # Turn ON: 255->127, 0->63, keep others as is
479
+ mask[y1:y2, x1:x2] = np.where(grid_region == 255, 127,
480
+ np.where(grid_region == 0, 63, grid_region))
481
+
482
+ self.update_display()
483
+
484
+ def toggle_grid_black(self, row, col):
485
+ """Toggle black mask in a grid cell"""
486
+ if row is None or col is None:
487
+ return
488
+
489
+ if row < 0 or row >= self.grid_rows or col < 0 or col >= self.grid_cols:
490
+ return
491
+
492
+ # Save state for undo
493
+ self.save_state()
494
+
495
+ mask = self.mask_frames[self.current_frame]
496
+ height, width = mask.shape
497
+
498
+ cell_width = width / self.grid_cols
499
+ cell_height = height / self.grid_rows
500
+
501
+ y1 = int(row * cell_height)
502
+ y2 = int((row + 1) * cell_height)
503
+ x1 = int(col * cell_width)
504
+ x2 = int((col + 1) * cell_width)
505
+
506
+ grid_region = mask[y1:y2, x1:x2]
507
+
508
+ # Check if grid has any black (0 or 63 values)
509
+ has_black = np.any((grid_region == 0) | (grid_region == 63))
510
+
511
+ if has_black:
512
+ # Turn OFF black: 0->255, 63->127, keep 127 and 255 as is
513
+ mask[y1:y2, x1:x2] = np.where(grid_region == 0, 255,
514
+ np.where(grid_region == 63, 127, grid_region))
515
+ else:
516
+ # Turn ON black: 255->0, 127->63, keep others as is
517
+ mask[y1:y2, x1:x2] = np.where(grid_region == 255, 0,
518
+ np.where(grid_region == 127, 63, grid_region))
519
+
520
+ self.update_display()
521
+
522
+ def apply_brush(self, x, y, mode="add"):
523
+ """Apply brush to add/erase black mask (vectorized for speed)"""
524
+ if not self.mask_frames:
525
+ return
526
+
527
+ mask = self.mask_frames[self.current_frame]
528
+ height, width = mask.shape
529
+
530
+ # Convert to frame coordinates
531
+ frame_x = int(x / self.display_scale)
532
+ frame_y = int(y / self.display_scale)
533
+
534
+ if frame_x < 0 or frame_x >= width or frame_y < 0 or frame_y >= height:
535
+ return
536
+
537
+ # Create circular brush using vectorized operations
538
+ radius = int(self.brush_size / 2)
539
+
540
+ y1 = max(0, frame_y - radius)
541
+ y2 = min(height, frame_y + radius + 1)
542
+ x1 = max(0, frame_x - radius)
543
+ x2 = min(width, frame_x + radius + 1)
544
+
545
+ # Get the region
546
+ region = mask[y1:y2, x1:x2]
547
+
548
+ # Create coordinate grids for distance calculation
549
+ yy, xx = np.ogrid[y1:y2, x1:x2]
550
+ dist = np.sqrt((xx - frame_x)**2 + (yy - frame_y)**2)
551
+ brush_mask = dist <= radius
552
+
553
+ if mode == "add":
554
+ # Add black: 255->0, 127->63
555
+ region[brush_mask & (region == 255)] = 0
556
+ region[brush_mask & (region == 127)] = 63
557
+ else: # erase
558
+ # Erase black: 0->255, 63->127
559
+ region[brush_mask & (region == 0)] = 255
560
+ region[brush_mask & (region == 63)] = 127
561
+
562
+ def on_mask_click(self, event):
563
+ """Handle click on mask canvas"""
564
+ if not self.mask_frames:
565
+ return
566
+
567
+ tool = self.tool_var.get()
568
+
569
+ if tool == "grid":
570
+ row, col = self.get_grid_from_pos(event.x, event.y)
571
+ self.toggle_grid(row, col)
572
+ elif tool == "grid_black":
573
+ row, col = self.get_grid_from_pos(event.x, event.y)
574
+ self.toggle_grid_black(row, col)
575
+ elif tool in ["brush_add", "brush_erase"]:
576
+ self.save_state()
577
+ mode = "add" if tool == "brush_add" else "erase"
578
+ self.apply_brush(event.x, event.y, mode)
579
+ self.dragging = True
580
+ self.last_brush_pos = (event.x, event.y)
581
+ self.last_update_time = time.time()
582
+ self.update_display(fast_mode=True)
583
+ # Start continuous brush application
584
+ self.schedule_brush_repeat()
585
+
586
+ def on_mask_drag(self, event):
587
+ """Handle drag on mask canvas with throttled updates"""
588
+ if not self.dragging:
589
+ return
590
+
591
+ tool = self.tool_var.get()
592
+ if tool in ["brush_add", "brush_erase"]:
593
+ # Update brush position when moving
594
+ self.last_brush_pos = (event.x, event.y)
595
+ mode = "add" if tool == "brush_add" else "erase"
596
+ self.apply_brush(event.x, event.y, mode)
597
+
598
+ # Only update display if enough time has passed (fast mode - no grid)
599
+ current_time = time.time()
600
+ if current_time - self.last_update_time >= self.update_interval:
601
+ self.update_display(fast_mode=True)
602
+ self.last_update_time = current_time
603
+
604
+ def on_mask_release(self, event):
605
+ """Handle release on mask canvas"""
606
+ self.dragging = False
607
+ self.last_brush_pos = None
608
+ # Cancel continuous brush application
609
+ if self.brush_repeat_id:
610
+ self.root.after_cancel(self.brush_repeat_id)
611
+ self.brush_repeat_id = None
612
+ # Final full update when releasing to show the complete result with blending
613
+ self.update_display(fast_mode=False)
614
+
615
+ def schedule_brush_repeat(self):
616
+ """Schedule continuous brush application while mouse is held down"""
617
+ if self.dragging and self.last_brush_pos:
618
+ tool = self.tool_var.get()
619
+ if tool in ["brush_add", "brush_erase"]:
620
+ mode = "add" if tool == "brush_add" else "erase"
621
+ x, y = self.last_brush_pos
622
+ self.apply_brush(x, y, mode)
623
+
624
+ # Update display if enough time has passed
625
+ current_time = time.time()
626
+ if current_time - self.last_update_time >= self.update_interval:
627
+ self.update_display(fast_mode=True)
628
+ self.last_update_time = current_time
629
+
630
+ # Schedule next application (every 30ms for smooth continuous painting)
631
+ self.brush_repeat_id = self.root.after(30, self.schedule_brush_repeat)
632
+
633
+ def copy_black_from_previous(self):
634
+ """Copy ONLY black component from previous frame, preserving grey in current frame"""
635
+ if not self.mask_frames:
636
+ messagebox.showwarning("Warning", "No mask loaded")
637
+ return
638
+
639
+ if self.current_frame == 0:
640
+ messagebox.showwarning("Warning", "Cannot copy from previous frame - already at first frame")
641
+ return
642
+
643
+ # Save state for undo
644
+ self.save_state()
645
+
646
+ prev_mask = self.mask_frames[self.current_frame - 1]
647
+ curr_mask = self.mask_frames[self.current_frame]
648
+
649
+ # Copy ONLY the black component from previous frame
650
+ # Where prev has black (0 or 63): add black to curr
651
+ # Where prev doesn't have black (127 or 255): remove black from curr
652
+
653
+ has_black_in_prev = (prev_mask == 0) | (prev_mask == 63)
654
+ no_black_in_prev = (prev_mask == 127) | (prev_mask == 255)
655
+
656
+ # Remove black where prev doesn't have it (preserve grey)
657
+ curr_mask[no_black_in_prev & (curr_mask == 0)] = 255 # 0 β†’ 255
658
+ curr_mask[no_black_in_prev & (curr_mask == 63)] = 127 # 63 β†’ 127 (keep grey)
659
+
660
+ # Add black where prev has it (preserve grey)
661
+ curr_mask[has_black_in_prev & (curr_mask == 255)] = 0 # 255 β†’ 0
662
+ curr_mask[has_black_in_prev & (curr_mask == 127)] = 63 # 127 β†’ 63 (keep grey, add black)
663
+
664
+ self.update_display()
665
+ self.info_label.config(text="Copied black mask from previous frame", foreground="green")
666
+
667
+ def copy_grey_from_previous(self):
668
+ """Copy ONLY grey component from previous frame, preserving black in current frame"""
669
+ if not self.mask_frames:
670
+ messagebox.showwarning("Warning", "No mask loaded")
671
+ return
672
+
673
+ if self.current_frame == 0:
674
+ messagebox.showwarning("Warning", "Cannot copy from previous frame - already at first frame")
675
+ return
676
+
677
+ # Save state for undo
678
+ self.save_state()
679
+
680
+ prev_mask = self.mask_frames[self.current_frame - 1]
681
+ curr_mask = self.mask_frames[self.current_frame]
682
+
683
+ # Copy ONLY the grey component from previous frame
684
+ # Where prev has grey (127 or 63): add grey to curr
685
+ # Where prev doesn't have grey (0 or 255): remove grey from curr
686
+
687
+ has_grey_in_prev = (prev_mask == 127) | (prev_mask == 63)
688
+ no_grey_in_prev = (prev_mask == 0) | (prev_mask == 255)
689
+
690
+ # Remove grey where prev doesn't have it (preserve black)
691
+ curr_mask[no_grey_in_prev & (curr_mask == 127)] = 255 # 127 β†’ 255
692
+ curr_mask[no_grey_in_prev & (curr_mask == 63)] = 0 # 63 β†’ 0 (keep black)
693
+
694
+ # Add grey where prev has it (preserve black)
695
+ curr_mask[has_grey_in_prev & (curr_mask == 255)] = 127 # 255 β†’ 127
696
+ curr_mask[has_grey_in_prev & (curr_mask == 0)] = 63 # 0 β†’ 63 (keep black, add grey)
697
+
698
+ self.update_display()
699
+ self.info_label.config(text="Copied grey mask from previous frame", foreground="green")
700
+
701
+ def save_state(self):
702
+ """Save current state for undo"""
703
+ if not self.mask_frames:
704
+ return
705
+
706
+ # Save deep copy of current frame
707
+ state = {
708
+ 'frame': self.current_frame,
709
+ 'mask': self.mask_frames[self.current_frame].copy()
710
+ }
711
+ self.undo_stack.append(state)
712
+ self.redo_stack.clear()
713
+
714
+ # Limit undo stack size
715
+ if len(self.undo_stack) > 50:
716
+ self.undo_stack.pop(0)
717
+
718
+ def undo(self):
719
+ """Undo last edit"""
720
+ if not self.undo_stack:
721
+ return
722
+
723
+ # Save current state to redo
724
+ redo_state = {
725
+ 'frame': self.current_frame,
726
+ 'mask': self.mask_frames[self.current_frame].copy()
727
+ }
728
+ self.redo_stack.append(redo_state)
729
+
730
+ # Restore previous state
731
+ state = self.undo_stack.pop()
732
+ self.current_frame = state['frame']
733
+ self.mask_frames[self.current_frame] = state['mask']
734
+
735
+ self.update_display()
736
+
737
+ def redo(self):
738
+ """Redo last undone edit"""
739
+ if not self.redo_stack:
740
+ return
741
+
742
+ # Save current state to undo
743
+ undo_state = {
744
+ 'frame': self.current_frame,
745
+ 'mask': self.mask_frames[self.current_frame].copy()
746
+ }
747
+ self.undo_stack.append(undo_state)
748
+
749
+ # Restore redo state
750
+ state = self.redo_stack.pop()
751
+ self.current_frame = state['frame']
752
+ self.mask_frames[self.current_frame] = state['mask']
753
+
754
+ self.update_display()
755
+
756
+ def save_mask(self):
757
+ """Save edited mask back to quadmask_0.mp4"""
758
+ if not self.mask_frames or not self.mask_path:
759
+ messagebox.showwarning("Warning", "No mask loaded")
760
+ return
761
+
762
+ # Confirm save
763
+ result = messagebox.askyesno("Confirm Save",
764
+ f"Save mask to {self.mask_path.name}?\nThis will overwrite the existing file.")
765
+ if not result:
766
+ return
767
+
768
+ self.info_label.config(text="Saving mask...", foreground="blue")
769
+ self.root.update()
770
+
771
+ # Write video
772
+ self.write_video_frames(self.mask_frames, self.mask_path)
773
+
774
+ self.info_label.config(text="Mask saved successfully!", foreground="green")
775
+ messagebox.showinfo("Success", f"Mask saved to {self.mask_path.name}!")
776
+
777
+ def first_frame(self):
778
+ """Go to first frame"""
779
+ self.current_frame = 0
780
+ self.update_display()
781
+
782
+ def last_frame(self):
783
+ """Go to last frame"""
784
+ if self.mask_frames:
785
+ self.current_frame = len(self.mask_frames) - 1
786
+ self.update_display()
787
+
788
+ def prev_frame(self):
789
+ """Go to previous frame"""
790
+ if self.current_frame > 0:
791
+ self.current_frame -= 1
792
+ self.update_display()
793
+
794
+ def next_frame(self):
795
+ """Go to next frame"""
796
+ if self.mask_frames and self.current_frame < len(self.mask_frames) - 1:
797
+ self.current_frame += 1
798
+ self.update_display()
799
+
800
+ def on_slider_change(self, value):
801
+ """Handle slider change"""
802
+ if not self.mask_frames:
803
+ return
804
+
805
+ new_frame = int(float(value))
806
+ if new_frame != self.current_frame:
807
+ self.current_frame = new_frame
808
+ self.update_display()
809
+
810
+ def on_tool_change(self):
811
+ """Handle tool selection change"""
812
+ tool = self.tool_var.get()
813
+ if tool == "grid":
814
+ self.info_label.config(text="Grid Toggle: Click grids to toggle 127↔255", foreground="blue")
815
+ elif tool == "grid_black":
816
+ self.info_label.config(text="Grid Black Toggle: Click grids to toggle black mask (0/63)", foreground="blue")
817
+ elif tool == "brush_add":
818
+ self.info_label.config(text="Brush (Add): Paint black mask areas", foreground="blue")
819
+ else: # brush_erase
820
+ self.info_label.config(text="Brush (Erase): Erase black mask areas", foreground="blue")
821
+
822
+ def on_brush_size_change(self, value):
823
+ """Handle brush size change"""
824
+ self.brush_size = int(float(value))
825
+ self.brush_size_label.config(text=str(self.brush_size))
826
+
827
+ if __name__ == "__main__":
828
+ root = tk.Tk()
829
+ root.geometry("1400x800")
830
+ app = MaskEditorGUI(root)
831
+ root.mainloop()
VLM-MASK-REASONER/point_selector_gui.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Point Selector GUI - Multi-Frame Support
4
+
5
+ NEW: Support adding points across multiple frames for complex cases
6
+ Example: Car appears at frame 0, hand carrying it appears at frame 30
7
+ β†’ Add points on car at frame 0, points on hand at frame 30
8
+ β†’ Both get segmented together as "primary object to remove"
9
+
10
+ Usage:
11
+ python point_selector_gui_multiframe.py --config pexel_test_config.json
12
+ """
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import tkinter as tk
17
+ from tkinter import ttk, filedialog, messagebox
18
+ from PIL import Image, ImageTk
19
+ import json
20
+ import argparse
21
+ from pathlib import Path
22
+ from typing import List, Dict, Tuple
23
+
24
+
25
+ class PointSelectorGUI:
26
+ def __init__(self, root, config_path=None):
27
+ self.root = root
28
+ self.root.title("Point Selector - Multi-Frame Support")
29
+
30
+ # Data
31
+ self.config_path = config_path
32
+ self.config_data = None
33
+ self.current_video_idx = 0
34
+ self.current_frame_idx = 0
35
+ self.video_captures = []
36
+ self.total_frames_list = []
37
+
38
+ # NEW: Points organized by frame
39
+ self.points_by_frame = {} # {frame_idx: [(x, y), ...]}
40
+ self.all_points_by_frame = [] # List of dicts for all videos
41
+
42
+ # Display
43
+ self.display_scale = 1.0
44
+ self.photo = None
45
+ self.point_radius = 8
46
+
47
+ self.setup_ui()
48
+
49
+ if config_path:
50
+ self.load_config_direct(config_path)
51
+
52
+ def setup_ui(self):
53
+ """Setup the GUI layout"""
54
+ # Menu bar
55
+ menubar = tk.Menu(self.root)
56
+ self.root.config(menu=menubar)
57
+
58
+ file_menu = tk.Menu(menubar, tearoff=0)
59
+ menubar.add_cascade(label="File", menu=file_menu)
60
+ file_menu.add_command(label="Load Config", command=self.load_config)
61
+ file_menu.add_command(label="Save Points", command=self.save_points)
62
+ file_menu.add_separator()
63
+ file_menu.add_command(label="Exit", command=self.root.quit)
64
+
65
+ # Top toolbar
66
+ toolbar = ttk.Frame(self.root)
67
+ toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
68
+
69
+ ttk.Label(toolbar, text="Config:").pack(side=tk.LEFT)
70
+ self.config_label = ttk.Label(toolbar, text="None", foreground="gray")
71
+ self.config_label.pack(side=tk.LEFT, padx=5)
72
+
73
+ ttk.Button(toolbar, text="Load Config", command=self.load_config).pack(side=tk.LEFT, padx=5)
74
+ ttk.Button(toolbar, text="Save All Points", command=self.save_points).pack(side=tk.LEFT, padx=5)
75
+
76
+ # Video info
77
+ info_frame = ttk.Frame(self.root)
78
+ info_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
79
+
80
+ self.video_label = ttk.Label(info_frame, text="Video: None", font=("Arial", 10, "bold"))
81
+ self.video_label.pack(side=tk.LEFT, padx=5)
82
+
83
+ self.instruction_label = ttk.Label(info_frame, text="", foreground="blue")
84
+ self.instruction_label.pack(side=tk.LEFT, padx=10)
85
+
86
+ # Frame navigation controls - COMPACT (Ctrl+←/β†’ shortcuts)
87
+ frame_nav = ttk.LabelFrame(self.root, text="Frame Navigation")
88
+ frame_nav.pack(side=tk.TOP, fill=tk.X, padx=5, pady=2)
89
+
90
+ btn_frame = ttk.Frame(frame_nav)
91
+ btn_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=2)
92
+
93
+ # Compact buttons
94
+ ttk.Button(btn_frame, text="<<", command=self.first_frame, width=3).pack(side=tk.LEFT, padx=1)
95
+ ttk.Button(btn_frame, text="<10", command=lambda: self.prev_frame(10), width=3).pack(side=tk.LEFT, padx=1)
96
+ ttk.Button(btn_frame, text="<", command=lambda: self.prev_frame(1), width=3).pack(side=tk.LEFT, padx=1)
97
+
98
+ self.frame_label = ttk.Label(btn_frame, text="F: 0/0", font=("Arial", 9))
99
+ self.frame_label.pack(side=tk.LEFT, padx=8)
100
+
101
+ ttk.Button(btn_frame, text=">", command=lambda: self.next_frame(1), width=3).pack(side=tk.LEFT, padx=1)
102
+ ttk.Button(btn_frame, text="10>", command=lambda: self.next_frame(10), width=3).pack(side=tk.LEFT, padx=1)
103
+ ttk.Button(btn_frame, text=">>", command=self.last_frame, width=3).pack(side=tk.LEFT, padx=1)
104
+
105
+ # Slider inline
106
+ self.frame_slider = ttk.Scale(btn_frame, from_=0, to=100, orient=tk.HORIZONTAL, command=self.on_slider_change, length=250)
107
+ self.frame_slider.pack(side=tk.LEFT, padx=5)
108
+
109
+ # Frames with points inline
110
+ ttk.Label(btn_frame, text="Points:", font=("Arial", 8)).pack(side=tk.LEFT, padx=3)
111
+ self.frames_with_points_label = ttk.Label(btn_frame, text="None", foreground="blue", font=("Arial", 8))
112
+ self.frames_with_points_label.pack(side=tk.LEFT)
113
+
114
+ # Main canvas - SMALLER to fit everything
115
+ canvas_frame = ttk.LabelFrame(self.root, text="Click to add points")
116
+ canvas_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=2)
117
+
118
+ self.canvas = tk.Canvas(canvas_frame, width=800, height=450, bg='black', cursor="crosshair")
119
+ self.canvas.pack(fill=tk.BOTH, expand=True)
120
+ self.canvas.bind("<Button-1>", self.on_canvas_click)
121
+
122
+ # Bottom controls - COMPACT
123
+ controls = ttk.Frame(self.root)
124
+ controls.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=2)
125
+
126
+ # Point info - compact
127
+ point_info = ttk.Frame(controls)
128
+ point_info.pack(side=tk.TOP, fill=tk.X, pady=2)
129
+
130
+ self.point_count_label = ttk.Label(point_info, text="Pts: 0", font=("Arial", 9))
131
+ self.point_count_label.pack(side=tk.LEFT, padx=5)
132
+
133
+ ttk.Button(point_info, text="Clear Frame", command=self.clear_current_frame, width=10).pack(side=tk.LEFT, padx=2)
134
+ ttk.Button(point_info, text="Clear ALL", command=self.clear_all_frames, width=9).pack(side=tk.LEFT, padx=2)
135
+ ttk.Button(point_info, text="Undo", command=self.undo_last_point, width=6).pack(side=tk.LEFT, padx=2)
136
+
137
+ # Video navigation - compact
138
+ nav_frame = ttk.Frame(controls)
139
+ nav_frame.pack(side=tk.TOP, fill=tk.X, pady=2)
140
+
141
+ ttk.Button(nav_frame, text="<< First", command=self.first_video, width=8).pack(side=tk.LEFT, padx=2)
142
+ ttk.Button(nav_frame, text="< Prev", command=self.prev_video, width=8).pack(side=tk.LEFT, padx=2)
143
+
144
+ self.nav_label = ttk.Label(nav_frame, text="Video: 0/0", font=("Arial", 10, "bold"))
145
+ self.nav_label.pack(side=tk.LEFT, padx=15)
146
+
147
+ ttk.Button(nav_frame, text="Save & Next >", command=self.save_and_next, width=12).pack(side=tk.LEFT, padx=2)
148
+ ttk.Button(nav_frame, text="Last >>", command=self.last_video, width=8).pack(side=tk.LEFT, padx=2)
149
+
150
+ # Status - compact
151
+ self.status_label = ttk.Label(controls, text="Load config", foreground="blue", font=("Arial", 8))
152
+ self.status_label.pack(side=tk.TOP, pady=2)
153
+
154
+ # Keyboard shortcuts
155
+ self.root.bind("<space>", lambda e: self.save_and_next())
156
+ self.root.bind("<Left>", lambda e: self.prev_video())
157
+ self.root.bind("<Right>", lambda e: self.save_and_next())
158
+ self.root.bind("<Control-z>", lambda e: self.undo_last_point())
159
+ self.root.bind("<Control-Left>", lambda e: self.prev_frame(1))
160
+ self.root.bind("<Control-Right>", lambda e: self.next_frame(1))
161
+ self.root.bind("<Control-Shift-Left>", lambda e: self.prev_frame(10))
162
+ self.root.bind("<Control-Shift-Right>", lambda e: self.next_frame(10))
163
+
164
+ def load_config_direct(self, config_path):
165
+ """Load config from path (for command line usage)"""
166
+ self.config_path = Path(config_path)
167
+
168
+ try:
169
+ with open(self.config_path, 'r') as f:
170
+ self.config_data = json.load(f)
171
+ except Exception as e:
172
+ messagebox.showerror("Error", f"Failed to load config: {e}")
173
+ return
174
+
175
+ self.process_config()
176
+
177
+ def load_config(self):
178
+ """Load JSON config file via dialog"""
179
+ filepath = filedialog.askopenfilename(
180
+ title="Select Config JSON",
181
+ filetypes=[("JSON files", "*.json"), ("All files", "*.*")]
182
+ )
183
+
184
+ if not filepath:
185
+ return
186
+
187
+ self.config_path = Path(filepath)
188
+
189
+ try:
190
+ with open(self.config_path, 'r') as f:
191
+ self.config_data = json.load(f)
192
+ except Exception as e:
193
+ messagebox.showerror("Error", f"Failed to load config: {e}")
194
+ return
195
+
196
+ self.process_config()
197
+
198
+ def process_config(self):
199
+ """Process loaded config"""
200
+ # Validate config
201
+ if isinstance(self.config_data, list):
202
+ videos = self.config_data
203
+ elif isinstance(self.config_data, dict) and "videos" in self.config_data:
204
+ videos = self.config_data["videos"]
205
+ else:
206
+ messagebox.showerror("Error", "Config must be a list or have 'videos' key")
207
+ return
208
+
209
+ if not isinstance(videos, list) or len(videos) == 0:
210
+ messagebox.showerror("Error", "No videos in config")
211
+ return
212
+
213
+ self.videos = videos
214
+
215
+ # Open video captures
216
+ self.status_label.config(text="Opening video files...", foreground="blue")
217
+ self.root.update()
218
+
219
+ self.open_videos()
220
+
221
+ # Initialize storage - now dict per video
222
+ self.all_points_by_frame = [{} for _ in range(len(self.videos))]
223
+
224
+ # Load existing points if available
225
+ self.load_existing_points()
226
+
227
+ # Update UI
228
+ self.config_label.config(text=self.config_path.name, foreground="black")
229
+ self.current_video_idx = 0
230
+ self.current_frame_idx = 0
231
+ self.display_current_video()
232
+
233
+ self.status_label.config(
234
+ text=f"Loaded {len(self.videos)} videos. Navigate frames and click points. Can add points on multiple frames!",
235
+ foreground="green"
236
+ )
237
+
238
+ def open_videos(self):
239
+ """Open all videos for frame navigation"""
240
+ self.video_captures = []
241
+ self.total_frames_list = []
242
+
243
+ for i, video_info in enumerate(self.videos):
244
+ video_path = video_info.get("video_path", "")
245
+
246
+ if not video_path:
247
+ self.video_captures.append(None)
248
+ self.total_frames_list.append(0)
249
+ continue
250
+
251
+ video_path = Path(video_path)
252
+ if not video_path.is_absolute():
253
+ video_path = self.config_path.parent / video_path
254
+
255
+ if not video_path.exists():
256
+ messagebox.showwarning("Warning", f"Video not found: {video_path}")
257
+ self.video_captures.append(None)
258
+ self.total_frames_list.append(0)
259
+ continue
260
+
261
+ cap = cv2.VideoCapture(str(video_path))
262
+ if not cap.isOpened():
263
+ messagebox.showwarning("Warning", f"Failed to open video: {video_path}")
264
+ self.video_captures.append(None)
265
+ self.total_frames_list.append(0)
266
+ continue
267
+
268
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
269
+ self.video_captures.append(cap)
270
+ self.total_frames_list.append(total_frames)
271
+
272
+ self.status_label.config(text=f"Opened video {i+1}/{len(self.videos)}", foreground="blue")
273
+ self.root.update()
274
+
275
+ def load_existing_points(self):
276
+ """Load existing points from output file if it exists"""
277
+ output_path = self.config_path.parent / f"{self.config_path.stem}_points.json"
278
+
279
+ if not output_path.exists():
280
+ return
281
+
282
+ try:
283
+ with open(output_path, 'r') as f:
284
+ existing_data = json.load(f)
285
+
286
+ if isinstance(existing_data, list):
287
+ existing_videos = existing_data
288
+ elif isinstance(existing_data, dict) and "videos" in existing_data:
289
+ existing_videos = existing_data["videos"]
290
+ else:
291
+ return
292
+
293
+ for i, video_data in enumerate(existing_videos):
294
+ if i < len(self.all_points_by_frame):
295
+ # Load multi-frame format
296
+ points_by_frame = video_data.get("primary_points_by_frame", {})
297
+ # Convert string keys to int
298
+ self.all_points_by_frame[i] = {int(k): v for k, v in points_by_frame.items()}
299
+
300
+ self.status_label.config(text="Loaded existing points", foreground="green")
301
+ except Exception as e:
302
+ print(f"Warning: Could not load existing points: {e}")
303
+
304
+ def get_current_frame(self):
305
+ """Get frame at current_frame_idx from current video"""
306
+ if self.current_video_idx >= len(self.video_captures):
307
+ return None
308
+
309
+ cap = self.video_captures[self.current_video_idx]
310
+ if cap is None:
311
+ return None
312
+
313
+ cap.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame_idx)
314
+ ret, frame = cap.read()
315
+
316
+ if not ret:
317
+ return None
318
+
319
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
320
+
321
+ def display_current_video(self):
322
+ """Display current video frame"""
323
+ if not self.video_captures:
324
+ return
325
+
326
+ video_info = self.videos[self.current_video_idx]
327
+ video_path = video_info.get("video_path", "")
328
+
329
+ # Update labels
330
+ self.video_label.config(text=f"Video: {Path(video_path).name}")
331
+ instruction = video_info.get("instruction", "")
332
+ if instruction:
333
+ self.instruction_label.config(text=f"Instruction: {instruction}")
334
+
335
+ self.nav_label.config(text=f"Video: {self.current_video_idx + 1}/{len(self.videos)}")
336
+
337
+ # Load points for this video
338
+ self.points_by_frame = self.all_points_by_frame[self.current_video_idx].copy()
339
+
340
+ # Update frame controls
341
+ total_frames = self.total_frames_list[self.current_video_idx]
342
+ self.frame_slider.config(to=max(1, total_frames - 1))
343
+ self.frame_slider.set(self.current_frame_idx)
344
+ self.frame_label.config(text=f"F: {self.current_frame_idx}/{total_frames - 1}")
345
+
346
+ # Update frames with points display
347
+ self.update_frames_display()
348
+
349
+ self.display_frame()
350
+
351
+ def update_frames_display(self):
352
+ """Update display showing which frames have points"""
353
+ if not self.points_by_frame:
354
+ self.frames_with_points_label.config(text="None", foreground="gray")
355
+ else:
356
+ frames = sorted(self.points_by_frame.keys())
357
+ frames_str = ", ".join(f"F{f}" for f in frames)
358
+ total_points = sum(len(pts) for pts in self.points_by_frame.values())
359
+ self.frames_with_points_label.config(
360
+ text=f"{frames_str} ({total_points} total points)",
361
+ foreground="green"
362
+ )
363
+
364
+ def display_frame(self):
365
+ """Display current frame with points"""
366
+ frame = self.get_current_frame()
367
+ if frame is None:
368
+ return
369
+
370
+ # Draw points for CURRENT frame
371
+ vis = frame.copy()
372
+ current_points = self.points_by_frame.get(self.current_frame_idx, [])
373
+
374
+ for i, (x, y) in enumerate(current_points):
375
+ cv2.circle(vis, (x, y), self.point_radius, (255, 0, 0), -1)
376
+ cv2.circle(vis, (x, y), self.point_radius + 2, (255, 255, 255), 2)
377
+ cv2.putText(vis, str(i + 1), (x + 12, y + 12), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
378
+
379
+ # Show indicator if other frames have points
380
+ if len(self.points_by_frame) > 0:
381
+ other_frames = [f for f in self.points_by_frame.keys() if f != self.current_frame_idx]
382
+ if other_frames:
383
+ text = f"Other frames with points: {', '.join(map(str, sorted(other_frames)))}"
384
+ cv2.putText(vis, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
385
+
386
+ # Scale for display
387
+ h, w = vis.shape[:2]
388
+ max_width, max_height = 800, 450
389
+ scale_w = max_width / w
390
+ scale_h = max_height / h
391
+ self.display_scale = min(scale_w, scale_h, 1.0)
392
+
393
+ new_w = int(w * self.display_scale)
394
+ new_h = int(h * self.display_scale)
395
+ vis_resized = cv2.resize(vis, (new_w, new_h))
396
+
397
+ # Convert to PIL and display
398
+ pil_img = Image.fromarray(vis_resized)
399
+ self.photo = ImageTk.PhotoImage(pil_img)
400
+ self.canvas.delete("all")
401
+ self.canvas.create_image(0, 0, anchor=tk.NW, image=self.photo)
402
+
403
+ self.point_count_label.config(text=f"Pts on F{self.current_frame_idx}: {len(current_points)}")
404
+
405
+ def on_canvas_click(self, event):
406
+ """Handle click on canvas - add point to CURRENT frame"""
407
+ # Convert to frame coordinates
408
+ x = int(event.x / self.display_scale)
409
+ y = int(event.y / self.display_scale)
410
+
411
+ # Add to current frame
412
+ if self.current_frame_idx not in self.points_by_frame:
413
+ self.points_by_frame[self.current_frame_idx] = []
414
+
415
+ self.points_by_frame[self.current_frame_idx].append((x, y))
416
+ self.update_frames_display()
417
+ self.display_frame()
418
+
419
+ def clear_current_frame(self):
420
+ """Clear points for current frame only"""
421
+ if self.current_frame_idx in self.points_by_frame:
422
+ del self.points_by_frame[self.current_frame_idx]
423
+ self.update_frames_display()
424
+ self.display_frame()
425
+
426
+ def clear_all_frames(self):
427
+ """Clear all points for current video"""
428
+ result = messagebox.askyesno("Clear All", "Clear points from ALL frames?")
429
+ if result:
430
+ self.points_by_frame = {}
431
+ self.update_frames_display()
432
+ self.display_frame()
433
+
434
+ def undo_last_point(self):
435
+ """Remove last point from current frame"""
436
+ if self.current_frame_idx in self.points_by_frame and self.points_by_frame[self.current_frame_idx]:
437
+ self.points_by_frame[self.current_frame_idx].pop()
438
+ if not self.points_by_frame[self.current_frame_idx]:
439
+ del self.points_by_frame[self.current_frame_idx]
440
+ self.update_frames_display()
441
+ self.display_frame()
442
+
443
+ # Frame navigation methods
444
+ def first_frame(self):
445
+ """Jump to first frame"""
446
+ self.current_frame_idx = 0
447
+ self.frame_slider.set(self.current_frame_idx)
448
+ self.update_frame_display()
449
+
450
+ def last_frame(self):
451
+ """Jump to last frame"""
452
+ total_frames = self.total_frames_list[self.current_video_idx]
453
+ self.current_frame_idx = max(0, total_frames - 1)
454
+ self.frame_slider.set(self.current_frame_idx)
455
+ self.update_frame_display()
456
+
457
+ def prev_frame(self, step=1):
458
+ """Go to previous frame"""
459
+ self.current_frame_idx = max(0, self.current_frame_idx - step)
460
+ self.frame_slider.set(self.current_frame_idx)
461
+ self.update_frame_display()
462
+
463
+ def next_frame(self, step=1):
464
+ """Go to next frame"""
465
+ total_frames = self.total_frames_list[self.current_video_idx]
466
+ self.current_frame_idx = min(total_frames - 1, self.current_frame_idx + step)
467
+ self.frame_slider.set(self.current_frame_idx)
468
+ self.update_frame_display()
469
+
470
+ def on_slider_change(self, value):
471
+ """Handle slider change"""
472
+ self.current_frame_idx = int(float(value))
473
+ self.update_frame_display()
474
+
475
+ def update_frame_display(self):
476
+ """Update frame label and display"""
477
+ total_frames = self.total_frames_list[self.current_video_idx]
478
+ self.frame_label.config(text=f"F: {self.current_frame_idx}/{total_frames - 1}")
479
+ self.display_frame()
480
+
481
+ # Video navigation
482
+ def first_video(self):
483
+ """Jump to first video"""
484
+ self.save_current_points()
485
+ self.current_video_idx = 0
486
+ self.current_frame_idx = 0
487
+ self.display_current_video()
488
+
489
+ def last_video(self):
490
+ """Jump to last video"""
491
+ self.save_current_points()
492
+ self.current_video_idx = len(self.videos) - 1
493
+ self.current_frame_idx = 0
494
+ self.display_current_video()
495
+
496
+ def prev_video(self):
497
+ """Go to previous video"""
498
+ if self.current_video_idx > 0:
499
+ self.save_current_points()
500
+ self.current_video_idx -= 1
501
+ self.current_frame_idx = 0
502
+ self.display_current_video()
503
+
504
+ def save_and_next(self):
505
+ """Save current points and move to next video"""
506
+ if len(self.points_by_frame) == 0:
507
+ result = messagebox.askyesno("No Points", "No points selected for any frame. Continue to next video?")
508
+ if not result:
509
+ return
510
+
511
+ self.save_current_points()
512
+
513
+ if self.current_video_idx < len(self.videos) - 1:
514
+ self.current_video_idx += 1
515
+ self.current_frame_idx = 0
516
+ self.display_current_video()
517
+ else:
518
+ messagebox.showinfo("Complete", "All videos processed!")
519
+
520
+ def save_current_points(self):
521
+ """Save current video's points to storage"""
522
+ self.all_points_by_frame[self.current_video_idx] = self.points_by_frame.copy()
523
+
524
+ def save_points(self):
525
+ """Save all points to JSON file"""
526
+ if not self.config_path:
527
+ messagebox.showerror("Error", "No config loaded")
528
+ return
529
+
530
+ # Save current video first
531
+ self.save_current_points()
532
+
533
+ # Build output
534
+ output_videos = []
535
+ for i, video_info in enumerate(self.videos):
536
+ video_data = video_info.copy()
537
+
538
+ points_by_frame = self.all_points_by_frame[i]
539
+
540
+ # Convert to serializable format (int keys β†’ string keys for JSON)
541
+ video_data["primary_points_by_frame"] = {
542
+ str(frame_idx): points for frame_idx, points in points_by_frame.items()
543
+ }
544
+
545
+ # Also save list of frames for easy access
546
+ video_data["primary_frames"] = sorted(points_by_frame.keys())
547
+
548
+ # Backwards compatibility: if only one frame, save as before
549
+ if len(points_by_frame) == 1:
550
+ frame_idx = list(points_by_frame.keys())[0]
551
+ video_data["first_appears_frame"] = frame_idx
552
+ video_data["primary_points"] = points_by_frame[frame_idx]
553
+ elif len(points_by_frame) > 1:
554
+ # Multiple frames - use first frame as "first_appears_frame"
555
+ video_data["first_appears_frame"] = min(points_by_frame.keys())
556
+ # Flatten all points for backwards compat (not ideal but helps)
557
+ all_points = []
558
+ for frame_idx in sorted(points_by_frame.keys()):
559
+ all_points.extend(points_by_frame[frame_idx])
560
+ video_data["primary_points"] = all_points
561
+
562
+ output_videos.append(video_data)
563
+
564
+ # Match input format
565
+ if isinstance(self.config_data, list):
566
+ output_data = output_videos
567
+ else:
568
+ output_data = {"videos": output_videos}
569
+
570
+ # Save
571
+ output_path = self.config_path.parent / f"{self.config_path.stem}_points.json"
572
+
573
+ try:
574
+ with open(output_path, 'w') as f:
575
+ json.dump(output_data, f, indent=2)
576
+
577
+ self.status_label.config(text=f"Saved to {output_path.name}", foreground="green")
578
+ messagebox.showinfo("Success", f"Points saved to:\n{output_path}")
579
+ except Exception as e:
580
+ messagebox.showerror("Error", f"Failed to save: {e}")
581
+
582
+ def __del__(self):
583
+ """Clean up video captures"""
584
+ for cap in self.video_captures:
585
+ if cap is not None:
586
+ cap.release()
587
+
588
+
589
+ def main():
590
+ parser = argparse.ArgumentParser(description="Point Selector GUI - Multi-Frame Support")
591
+ parser.add_argument("--config", help="Config JSON file to load")
592
+ args = parser.parse_args()
593
+
594
+ root = tk.Tk()
595
+ root.geometry("900x750") # Compact height to fit on screen
596
+ gui = PointSelectorGUI(root, config_path=args.config)
597
+ root.mainloop()
598
+
599
+
600
+ if __name__ == "__main__":
601
+ main()
VLM-MASK-REASONER/run_pipeline.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # run_pipeline.sh
3
+ # Runs stages 1-4 given a points config JSON (output of point_selector_gui.py)
4
+ #
5
+ # Usage:
6
+ # bash run_pipeline.sh <config_points.json> [--sam2-checkpoint PATH] [--device cuda]
7
+ #
8
+ # Example:
9
+ # bash run_pipeline.sh my_config_points.json
10
+ # bash run_pipeline.sh my_config_points.json --sam2-checkpoint ../sam2_hiera_large.pt
11
+
12
+ set -e
13
+
14
+ # ── Arguments ──────────────────────────────────────────────────────────────────
15
+ CONFIG="$1"
16
+ if [ -z "$CONFIG" ]; then
17
+ echo "Usage: bash run_pipeline.sh <config_points.json> [--sam2-checkpoint PATH] [--device cuda]"
18
+ exit 1
19
+ fi
20
+
21
+ SAM2_CHECKPOINT="../sam2_hiera_large.pt"
22
+ DEVICE="cuda"
23
+
24
+ # Parse optional flags
25
+ shift
26
+ while [[ $# -gt 0 ]]; do
27
+ case "$1" in
28
+ --sam2-checkpoint) SAM2_CHECKPOINT="$2"; shift 2 ;;
29
+ --device) DEVICE="$2"; shift 2 ;;
30
+ *) echo "Unknown argument: $1"; exit 1 ;;
31
+ esac
32
+ done
33
+
34
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
35
+
36
+ echo "=========================================="
37
+ echo "Void Mask Generation Pipeline"
38
+ echo "=========================================="
39
+ echo "Config: $CONFIG"
40
+ echo "SAM2 checkpoint: $SAM2_CHECKPOINT"
41
+ echo "Device: $DEVICE"
42
+ echo "=========================================="
43
+
44
+ # ── Stage 1: SAM2 Segmentation ─────────────────────────────────────────────────
45
+ echo ""
46
+ echo "[1/4] SAM2 segmentation..."
47
+ python "$SCRIPT_DIR/stage1_sam2_segmentation.py" \
48
+ --config "$CONFIG" \
49
+ --sam2-checkpoint "$SAM2_CHECKPOINT" \
50
+ --device "$DEVICE"
51
+
52
+ # ── Stage 2: VLM Analysis ──────────────────────────────────────────────────────
53
+ echo ""
54
+ echo "[2/4] VLM analysis (Gemini)..."
55
+ python "$SCRIPT_DIR/stage2_vlm_analysis.py" \
56
+ --config "$CONFIG"
57
+
58
+ # ── Stage 3a: Generate Grey Masks ─────────────────────────────────────────────
59
+ echo ""
60
+ echo "[3/4] Generating grey masks..."
61
+ python "$SCRIPT_DIR/stage3a_generate_grey_masks_v2.py" \
62
+ --config "$CONFIG"
63
+
64
+ # ── Stage 4: Combine into Quadmask ────────────────────────────────────────────
65
+ echo ""
66
+ echo "[4/4] Combining masks into quadmask_0.mp4..."
67
+ python "$SCRIPT_DIR/stage4_combine_masks.py" \
68
+ --config "$CONFIG"
69
+
70
+ echo ""
71
+ echo "=========================================="
72
+ echo "Pipeline complete!"
73
+ echo "Output: quadmask_0.mp4 in each video's output_dir"
74
+ echo "=========================================="
VLM-MASK-REASONER/stage1_sam2_segmentation.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 1: SAM2 Point-Prompted Segmentation
4
+
5
+ Takes user-selected points and generates pixel-perfect masks of primary objects
6
+ using SAM2 video tracking.
7
+
8
+ Input: <config>_points.json (with primary_points)
9
+ Output: For each video:
10
+ - black_mask.mp4: Primary object mask (0=object, 255=background)
11
+ - first_frame.jpg: First frame for VLM analysis
12
+ - segmentation_info.json: Metadata
13
+
14
+ Usage:
15
+ python stage1_sam2_segmentation.py --config more_dyn_2_config_points.json
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import cv2
23
+ import numpy as np
24
+ import torch
25
+ import tempfile
26
+ import shutil
27
+ from pathlib import Path
28
+ from typing import Dict, List, Tuple
29
+ import subprocess
30
+
31
+ # Check SAM2 availability
32
+ try:
33
+ from sam2.build_sam import build_sam2_video_predictor
34
+ SAM2_AVAILABLE = True
35
+ except ImportError:
36
+ SAM2_AVAILABLE = False
37
+ print("⚠️ SAM2 not installed. Install with:")
38
+ print(" pip install git+https://github.com/facebookresearch/segment-anything-2.git")
39
+ sys.exit(1)
40
+
41
+
42
+ class SAM2PointSegmenter:
43
+ """SAM2 video segmentation with point prompts"""
44
+
45
+ def __init__(self, checkpoint_path: str, model_cfg: str = "sam2_hiera_l.yaml", device: str = "cuda"):
46
+ print(f" Loading SAM2 video predictor...")
47
+ self.device = device
48
+ self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=device)
49
+ print(f" βœ“ SAM2 loaded on {device}")
50
+
51
+ def segment_video(self, video_path: str, points: List[List[int]] = None,
52
+ output_mask_path: str = None, temp_dir: str = None,
53
+ first_appears_frame: int = 0,
54
+ points_by_frame: Dict[int, List[List[int]]] = None) -> Dict:
55
+ """
56
+ Segment video using point prompts (single or multi-frame).
57
+
58
+ Args:
59
+ video_path: Path to input video
60
+ points: List of [x, y] points on object (single frame, legacy)
61
+ output_mask_path: Path to save mask video
62
+ temp_dir: Directory for temporary frames
63
+ first_appears_frame: Frame index where points were selected (for single frame)
64
+ points_by_frame: Dict mapping frame_idx β†’ [[x, y], ...] (multi-frame support)
65
+
66
+ Returns:
67
+ Dict with segmentation metadata
68
+ """
69
+ # Handle both old and new formats
70
+ if points_by_frame is not None:
71
+ # Multi-frame format
72
+ if not points_by_frame or len(points_by_frame) == 0:
73
+ raise ValueError("No points provided")
74
+ elif points is not None:
75
+ # Single frame format (backwards compat)
76
+ if not points or len(points) == 0:
77
+ raise ValueError("No points provided")
78
+ points_by_frame = {first_appears_frame: points}
79
+ else:
80
+ raise ValueError("Must provide either points or points_by_frame")
81
+
82
+ # Create temp directory for frames
83
+ if temp_dir is None:
84
+ temp_dir = tempfile.mkdtemp(prefix="sam2_frames_")
85
+ cleanup = True
86
+ else:
87
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
88
+ cleanup = False
89
+
90
+ print(f" Extracting frames to: {temp_dir}")
91
+ frame_files = self._extract_frames(video_path, temp_dir)
92
+
93
+ if len(frame_files) == 0:
94
+ raise RuntimeError(f"No frames extracted from {video_path}")
95
+
96
+ # Get video properties
97
+ cap = cv2.VideoCapture(video_path)
98
+ fps = cap.get(cv2.CAP_PROP_FPS)
99
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
100
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
101
+ total_frames = len(frame_files)
102
+ cap.release()
103
+
104
+ # Count total points across all frames
105
+ total_points = sum(len(pts) for pts in points_by_frame.values())
106
+ print(f" Video: {frame_width}x{frame_height}, {total_frames} frames @ {fps} fps")
107
+ print(f" Using {total_points} points across {len(points_by_frame)} frame(s) for segmentation")
108
+
109
+ # Initialize SAM2
110
+ print(f" Initializing SAM2...")
111
+ inference_state = self.predictor.init_state(video_path=temp_dir)
112
+
113
+ # Add points for each frame (all with obj_id=1 to merge into single mask)
114
+ for frame_idx in sorted(points_by_frame.keys()):
115
+ frame_points = points_by_frame[frame_idx]
116
+
117
+ # Convert points to numpy array
118
+ points_np = np.array(frame_points, dtype=np.float32)
119
+ labels_np = np.ones(len(frame_points), dtype=np.int32) # All positive
120
+
121
+ # Calculate bounding box from points (with 10% margin for hair/clothes)
122
+ x_coords = points_np[:, 0]
123
+ y_coords = points_np[:, 1]
124
+
125
+ x_min, x_max = x_coords.min(), x_coords.max()
126
+ y_min, y_max = y_coords.min(), y_coords.max()
127
+
128
+ # Add 10% margin
129
+ x_margin = (x_max - x_min) * 0.1
130
+ y_margin = (y_max - y_min) * 0.1
131
+
132
+ box = np.array([
133
+ max(0, x_min - x_margin),
134
+ max(0, y_min - y_margin),
135
+ min(frame_width, x_max + x_margin),
136
+ min(frame_height, y_max + y_margin)
137
+ ], dtype=np.float32)
138
+
139
+ print(f" Adding {len(frame_points)} points + box to frame {frame_idx}")
140
+ print(f" Points: {frame_points[:3]}..." if len(frame_points) > 3 else f" Points: {frame_points}")
141
+ print(f" Box: [{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}]")
142
+
143
+ # Add points + box to this frame (all use obj_id=1 to merge)
144
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
145
+ inference_state=inference_state,
146
+ frame_idx=frame_idx,
147
+ obj_id=1,
148
+ points=points_np,
149
+ labels=labels_np,
150
+ box=box,
151
+ )
152
+
153
+ print(f" Propagating through video...")
154
+
155
+ # Propagate through video
156
+ video_segments = {}
157
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
158
+ # Get mask for object ID 1
159
+ mask_logits = out_mask_logits[out_obj_ids.index(1)]
160
+ mask = (mask_logits > 0.0).cpu().numpy().squeeze()
161
+ video_segments[out_frame_idx] = mask
162
+
163
+ print(f" βœ“ Segmented {len(video_segments)} frames")
164
+
165
+ # Write mask video
166
+ print(f" Writing mask video...")
167
+ self._write_mask_video(video_segments, output_mask_path, fps, frame_width, frame_height)
168
+
169
+ # Cleanup
170
+ if cleanup:
171
+ shutil.rmtree(temp_dir)
172
+
173
+ # Build metadata
174
+ metadata = {
175
+ "total_frames": total_frames,
176
+ "frame_width": frame_width,
177
+ "frame_height": frame_height,
178
+ "fps": fps,
179
+ }
180
+
181
+ # Add points info based on format
182
+ if points_by_frame:
183
+ total_points = sum(len(pts) for pts in points_by_frame.values())
184
+ metadata["num_points"] = total_points
185
+ metadata["points_by_frame"] = {str(k): v for k, v in points_by_frame.items()}
186
+ else:
187
+ metadata["num_points"] = len(points) if points else 0
188
+ metadata["points"] = points
189
+
190
+ return metadata
191
+
192
+ def _extract_frames(self, video_path: str, output_dir: str) -> List[str]:
193
+ """Extract video frames as JPG files"""
194
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
195
+
196
+ cap = cv2.VideoCapture(video_path)
197
+ frame_idx = 0
198
+ frame_files = []
199
+
200
+ while True:
201
+ ret, frame = cap.read()
202
+ if not ret:
203
+ break
204
+
205
+ # SAM2 expects frames named as frame_000000.jpg, frame_000001.jpg, etc.
206
+ frame_filename = f"{frame_idx:06d}.jpg"
207
+ frame_path = os.path.join(output_dir, frame_filename)
208
+ cv2.imwrite(frame_path, frame)
209
+ frame_files.append(frame_path)
210
+ frame_idx += 1
211
+
212
+ if frame_idx % 20 == 0:
213
+ print(f" Extracted {frame_idx} frames...", end='\r')
214
+
215
+ cap.release()
216
+ print(f" Extracted {frame_idx} frames")
217
+
218
+ return frame_files
219
+
220
+ def _write_mask_video(self, masks: Dict[int, np.ndarray], output_path: str,
221
+ fps: float, width: int, height: int):
222
+ """Write masks to video file"""
223
+ # Write temp AVI first
224
+ temp_avi = Path(output_path).with_suffix('.avi')
225
+ fourcc = cv2.VideoWriter_fourcc(*'FFV1')
226
+ out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (width, height), isColor=False)
227
+
228
+ for frame_idx in sorted(masks.keys()):
229
+ mask = masks[frame_idx]
230
+ # Convert boolean mask to 0/255
231
+ mask_uint8 = np.where(mask, 0, 255).astype(np.uint8)
232
+ out.write(mask_uint8)
233
+
234
+ out.release()
235
+
236
+ # Convert to lossless MP4
237
+ cmd = [
238
+ 'ffmpeg', '-y', '-i', str(temp_avi),
239
+ '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
240
+ '-pix_fmt', 'yuv444p',
241
+ str(output_path)
242
+ ]
243
+ subprocess.run(cmd, capture_output=True)
244
+ temp_avi.unlink()
245
+
246
+ print(f" βœ“ Saved mask video: {output_path}")
247
+
248
+
249
+ def process_config(config_path: str, sam2_checkpoint: str, device: str = "cuda"):
250
+ """Process all videos in config"""
251
+ config_path = Path(config_path)
252
+
253
+ # Load config
254
+ with open(config_path, 'r') as f:
255
+ config_data = json.load(f)
256
+
257
+ # Handle both formats
258
+ if isinstance(config_data, list):
259
+ videos = config_data
260
+ elif isinstance(config_data, dict) and "videos" in config_data:
261
+ videos = config_data["videos"]
262
+ else:
263
+ raise ValueError("Config must be a list or have 'videos' key")
264
+
265
+ print(f"\n{'='*70}")
266
+ print(f"Stage 1: SAM2 Point-Prompted Segmentation")
267
+ print(f"{'='*70}")
268
+ print(f"Config: {config_path.name}")
269
+ print(f"Videos: {len(videos)}")
270
+ print(f"Device: {device}")
271
+ print(f"{'='*70}\n")
272
+
273
+ # Initialize SAM2
274
+ segmenter = SAM2PointSegmenter(sam2_checkpoint, device=device)
275
+
276
+ # Process each video
277
+ for i, video_info in enumerate(videos):
278
+ video_path = video_info.get("video_path", "")
279
+ instruction = video_info.get("instruction", "")
280
+ output_dir = video_info.get("output_dir", "")
281
+
282
+ # Read points - support both single-frame and multi-frame formats
283
+ points_by_frame_raw = video_info.get("primary_points_by_frame", None)
284
+ points = video_info.get("primary_points", [])
285
+ first_appears_frame = video_info.get("first_appears_frame", 0)
286
+
287
+ # Convert points_by_frame from string keys to int keys
288
+ points_by_frame = None
289
+ if points_by_frame_raw:
290
+ points_by_frame = {int(k): v for k, v in points_by_frame_raw.items()}
291
+
292
+ if not video_path:
293
+ print(f"\n⚠️ Video {i+1}: No video_path, skipping")
294
+ continue
295
+
296
+ if not points and not points_by_frame:
297
+ print(f"\n⚠️ Video {i+1}: No primary_points selected, skipping")
298
+ continue
299
+
300
+ video_path = Path(video_path)
301
+ if not video_path.exists():
302
+ print(f"\n⚠️ Video {i+1}: File not found: {video_path}, skipping")
303
+ continue
304
+
305
+ print(f"\n{'─'*70}")
306
+ print(f"Video {i+1}/{len(videos)}: {video_path.name}")
307
+ print(f"{'─'*70}")
308
+ print(f"Instruction: {instruction}")
309
+
310
+ if points_by_frame:
311
+ total_points = sum(len(pts) for pts in points_by_frame.values())
312
+ print(f"Points: {total_points} across {len(points_by_frame)} frame(s)")
313
+ print(f"Frames: {sorted(points_by_frame.keys())}")
314
+ else:
315
+ print(f"Points: {len(points)}")
316
+ print(f"First appears frame: {first_appears_frame}")
317
+
318
+ # Setup output directory
319
+ if output_dir:
320
+ output_dir = Path(output_dir)
321
+ else:
322
+ # Create unique output directory per video
323
+ video_name = video_path.stem # Get video name without extension
324
+ output_dir = video_path.parent / f"{video_name}_masks_output"
325
+
326
+ output_dir.mkdir(parents=True, exist_ok=True)
327
+ print(f"Output: {output_dir}")
328
+
329
+ try:
330
+ # Segment video - use multi-frame or single-frame format
331
+ black_mask_path = output_dir / "black_mask.mp4"
332
+ if points_by_frame:
333
+ metadata = segmenter.segment_video(
334
+ str(video_path),
335
+ output_mask_path=str(black_mask_path),
336
+ points_by_frame=points_by_frame
337
+ )
338
+ # Use first frame for VLM analysis
339
+ first_frame_for_vlm = min(points_by_frame.keys())
340
+ else:
341
+ metadata = segmenter.segment_video(
342
+ str(video_path),
343
+ points=points,
344
+ output_mask_path=str(black_mask_path),
345
+ first_appears_frame=first_appears_frame
346
+ )
347
+ first_frame_for_vlm = first_appears_frame
348
+
349
+ # Save frame where object appears for VLM analysis
350
+ cap = cv2.VideoCapture(str(video_path))
351
+ cap.set(cv2.CAP_PROP_POS_FRAMES, first_frame_for_vlm)
352
+ ret, first_frame = cap.read()
353
+ cap.release()
354
+
355
+ if ret:
356
+ first_frame_path = output_dir / "first_frame.jpg"
357
+ cv2.imwrite(str(first_frame_path), first_frame)
358
+ print(f" βœ“ Saved first frame (frame {first_frame_for_vlm}): {first_frame_path.name}")
359
+
360
+ # Copy input video
361
+ input_copy_path = output_dir / "input_video.mp4"
362
+ if not input_copy_path.exists():
363
+ shutil.copy2(video_path, input_copy_path)
364
+ print(f" βœ“ Copied input video")
365
+
366
+ # Save metadata
367
+ metadata["video_path"] = str(video_path)
368
+ metadata["instruction"] = instruction
369
+ if points_by_frame:
370
+ metadata["primary_points_by_frame"] = {str(k): v for k, v in points_by_frame.items()}
371
+ metadata["primary_frames"] = sorted(points_by_frame.keys())
372
+ metadata["first_appears_frame"] = min(points_by_frame.keys())
373
+ else:
374
+ metadata["primary_points"] = points
375
+ metadata["first_appears_frame"] = first_appears_frame
376
+
377
+ metadata_path = output_dir / "segmentation_info.json"
378
+ with open(metadata_path, 'w') as f:
379
+ json.dump(metadata, f, indent=2)
380
+ print(f" βœ“ Saved metadata: {metadata_path.name}")
381
+
382
+ print(f"\nβœ… Video {i+1} complete!")
383
+
384
+ except Exception as e:
385
+ print(f"\n❌ Error processing video {i+1}: {e}")
386
+ import traceback
387
+ traceback.print_exc()
388
+ continue
389
+
390
+ print(f"\n{'='*70}")
391
+ print(f"Stage 1 Complete!")
392
+ print(f"{'='*70}\n")
393
+
394
+
395
+ def main():
396
+ parser = argparse.ArgumentParser(description="Stage 1: SAM2 Point-Prompted Segmentation")
397
+ parser.add_argument("--config", required=True, help="Config JSON with primary_points")
398
+ parser.add_argument("--sam2-checkpoint", default="../sam2_hiera_large.pt",
399
+ help="Path to SAM2 checkpoint")
400
+ parser.add_argument("--device", default="cuda", help="Device (cuda/cpu)")
401
+ args = parser.parse_args()
402
+
403
+ if not SAM2_AVAILABLE:
404
+ print("❌ SAM2 not available")
405
+ sys.exit(1)
406
+
407
+ # Check checkpoint exists
408
+ checkpoint_path = Path(args.sam2_checkpoint)
409
+ if not checkpoint_path.exists():
410
+ print(f"❌ Checkpoint not found: {checkpoint_path}")
411
+ print(f" Download with:")
412
+ print(f" wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt")
413
+ sys.exit(1)
414
+
415
+ process_config(args.config, str(checkpoint_path), args.device)
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()
VLM-MASK-REASONER/stage2_vlm_analysis.py ADDED
@@ -0,0 +1,1022 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 2: VLM Analysis - Identify Affected Objects & Physics
4
+
5
+ Analyzes videos with primary masks to identify:
6
+ - Integral belongings (to add to black mask)
7
+ - Affected objects (shadows, reflections, held items)
8
+ - Physics behavior (will_move, needs_trajectory)
9
+
10
+ Input: Config from Stage 1 (with output_dir containing black_mask.mp4, first_frame.jpg)
11
+ Output: For each video:
12
+ - vlm_analysis.json: Identified objects and physics reasoning
13
+
14
+ Usage:
15
+ python stage2_vlm_analysis.py --config more_dyn_2_config_points_absolute.json
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import cv2
23
+ import numpy as np
24
+ import base64
25
+ from pathlib import Path
26
+ from typing import Dict, List
27
+ from PIL import Image, ImageDraw
28
+
29
+ import openai
30
+
31
+ DEFAULT_MODEL = "gemini-3-pro-preview"
32
+
33
+
34
+ def image_to_data_url(image_path: str) -> str:
35
+ """Convert image file to base64 data URL"""
36
+ with open(image_path, 'rb') as f:
37
+ img_data = base64.b64encode(f.read()).decode('utf-8')
38
+
39
+ # Detect format
40
+ ext = Path(image_path).suffix.lower()
41
+ if ext == '.png':
42
+ mime = 'image/png'
43
+ elif ext in ['.jpg', '.jpeg']:
44
+ mime = 'image/jpeg'
45
+ else:
46
+ mime = 'image/jpeg'
47
+
48
+ return f"data:{mime};base64,{img_data}"
49
+
50
+
51
+ def video_to_data_url(video_path: str) -> str:
52
+ """Convert video file to base64 data URL"""
53
+ with open(video_path, 'rb') as f:
54
+ video_data = base64.b64encode(f.read()).decode('utf-8')
55
+ return f"data:video/mp4;base64,{video_data}"
56
+
57
+
58
+ def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> tuple:
59
+ """Calculate grid dimensions matching stage3a logic"""
60
+ aspect_ratio = width / height
61
+ if width >= height:
62
+ grid_rows = min_grid
63
+ grid_cols = max(min_grid, round(min_grid * aspect_ratio))
64
+ else:
65
+ grid_cols = min_grid
66
+ grid_rows = max(min_grid, round(min_grid / aspect_ratio))
67
+ return grid_rows, grid_cols
68
+
69
+
70
+ def create_first_frame_with_mask_overlay(first_frame_path: str, black_mask_path: str,
71
+ output_path: str, frame_idx: int = 0) -> str:
72
+ """Create visualization of first frame with red overlay on primary object
73
+
74
+ Args:
75
+ first_frame_path: Path to first_frame.jpg
76
+ black_mask_path: Path to black_mask.mp4
77
+ output_path: Where to save overlay
78
+ frame_idx: Which frame to extract from black_mask.mp4 (default: 0)
79
+ """
80
+ # Load first frame
81
+ frame = cv2.imread(first_frame_path)
82
+ if frame is None:
83
+ raise ValueError(f"Failed to load first frame: {first_frame_path}")
84
+
85
+ # Load black mask video and get the specified frame
86
+ cap = cv2.VideoCapture(black_mask_path)
87
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
88
+ ret, mask_frame = cap.read()
89
+ cap.release()
90
+
91
+ if not ret:
92
+ raise ValueError(f"Failed to load black mask frame {frame_idx}: {black_mask_path}")
93
+
94
+ # Convert mask to binary (0 = object, 255 = background)
95
+ if len(mask_frame.shape) == 3:
96
+ mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
97
+
98
+ object_mask = (mask_frame == 0)
99
+
100
+ # Create red overlay on object
101
+ overlay = frame.copy()
102
+ overlay[object_mask] = [0, 0, 255] # Red in BGR
103
+
104
+ # Blend: 60% original + 40% red overlay
105
+ result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
106
+
107
+ # Save
108
+ cv2.imwrite(output_path, result)
109
+ return output_path
110
+
111
+
112
+ def create_gridded_frame_overlay(first_frame_path: str, black_mask_path: str,
113
+ output_path: str, min_grid: int = 8) -> tuple:
114
+ """Create first frame with BOTH red mask overlay AND grid lines
115
+
116
+ Returns: (output_path, grid_rows, grid_cols)
117
+ """
118
+ # Load first frame
119
+ frame = cv2.imread(first_frame_path)
120
+ if frame is None:
121
+ raise ValueError(f"Failed to load first frame: {first_frame_path}")
122
+
123
+ h, w = frame.shape[:2]
124
+
125
+ # Load black mask
126
+ cap = cv2.VideoCapture(black_mask_path)
127
+ ret, mask_frame = cap.read()
128
+ cap.release()
129
+
130
+ if not ret:
131
+ raise ValueError(f"Failed to load black mask: {black_mask_path}")
132
+
133
+ if len(mask_frame.shape) == 3:
134
+ mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
135
+
136
+ object_mask = (mask_frame == 0)
137
+
138
+ # Create red overlay
139
+ overlay = frame.copy()
140
+ overlay[object_mask] = [0, 0, 255]
141
+ result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
142
+
143
+ # Calculate grid
144
+ grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
145
+
146
+ # Draw grid lines
147
+ cell_width = w / grid_cols
148
+ cell_height = h / grid_rows
149
+
150
+ # Vertical lines
151
+ for col in range(1, grid_cols):
152
+ x = int(col * cell_width)
153
+ cv2.line(result, (x, 0), (x, h), (255, 255, 0), 1) # Yellow lines
154
+
155
+ # Horizontal lines
156
+ for row in range(1, grid_rows):
157
+ y = int(row * cell_height)
158
+ cv2.line(result, (0, y), (w, y), (255, 255, 0), 1)
159
+
160
+ # Add grid labels
161
+ font = cv2.FONT_HERSHEY_SIMPLEX
162
+ font_scale = 0.3
163
+ thickness = 1
164
+
165
+ # Label columns at top
166
+ for col in range(grid_cols):
167
+ x = int((col + 0.5) * cell_width)
168
+ cv2.putText(result, str(col), (x-5, 15), font, font_scale, (255, 255, 0), thickness)
169
+
170
+ # Label rows on left
171
+ for row in range(grid_rows):
172
+ y = int((row + 0.5) * cell_height)
173
+ cv2.putText(result, str(row), (5, y+5), font, font_scale, (255, 255, 0), thickness)
174
+
175
+ cv2.imwrite(output_path, result)
176
+ return output_path, grid_rows, grid_cols
177
+
178
+
179
+ def create_multi_frame_grid_samples(video_path: str, output_dir: Path,
180
+ min_grid: int = 8,
181
+ sample_points: list = [0.0, 0.11, 0.22, 0.33, 0.44, 0.56, 0.67, 0.78, 0.89, 1.0]) -> tuple:
182
+ """
183
+ Create gridded frame samples at multiple time points in video.
184
+ Helps VLM see objects that appear mid-video with grid reference.
185
+
186
+ Args:
187
+ video_path: Path to video
188
+ output_dir: Where to save samples
189
+ min_grid: Minimum grid size
190
+ sample_points: List of normalized positions [0.0-1.0] to sample
191
+
192
+ Returns: (sample_paths, grid_rows, grid_cols)
193
+ """
194
+ cap = cv2.VideoCapture(str(video_path))
195
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
196
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
197
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
198
+
199
+ # Calculate grid (same for all frames)
200
+ grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
201
+ cell_width = w / grid_cols
202
+ cell_height = h / grid_rows
203
+
204
+ sample_paths = []
205
+
206
+ for i, t in enumerate(sample_points):
207
+ frame_idx = int(t * (total_frames - 1))
208
+ frame_idx = max(0, min(frame_idx, total_frames - 1))
209
+
210
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
211
+ ret, frame = cap.read()
212
+ if not ret:
213
+ continue
214
+
215
+ # Draw grid
216
+ result = frame.copy()
217
+
218
+ # Vertical lines
219
+ for col in range(1, grid_cols):
220
+ x = int(col * cell_width)
221
+ cv2.line(result, (x, 0), (x, h), (255, 255, 0), 2)
222
+
223
+ # Horizontal lines
224
+ for row in range(1, grid_rows):
225
+ y = int(row * cell_height)
226
+ cv2.line(result, (0, y), (w, y), (255, 255, 0), 2)
227
+
228
+ # Add grid labels
229
+ font = cv2.FONT_HERSHEY_SIMPLEX
230
+ font_scale = 0.4
231
+ thickness = 1
232
+
233
+ # Label columns
234
+ for col in range(grid_cols):
235
+ x = int((col + 0.5) * cell_width)
236
+ cv2.putText(result, str(col), (x-8, 20), font, font_scale, (255, 255, 0), thickness)
237
+
238
+ # Label rows
239
+ for row in range(grid_rows):
240
+ y = int((row + 0.5) * cell_height)
241
+ cv2.putText(result, str(row), (10, y+8), font, font_scale, (255, 255, 0), thickness)
242
+
243
+ # Add frame number and percentage
244
+ label = f"Frame {frame_idx} ({int(t*100)}%)"
245
+ cv2.putText(result, label, (10, h-10), font, 0.5, (255, 255, 0), 2)
246
+
247
+ # Save
248
+ output_path = output_dir / f"grid_sample_frame_{frame_idx:04d}.jpg"
249
+ cv2.imwrite(str(output_path), result)
250
+ sample_paths.append(output_path)
251
+
252
+ cap.release()
253
+ return sample_paths, grid_rows, grid_cols
254
+
255
+
256
+ def make_vlm_analysis_prompt(instruction: str, grid_rows: int, grid_cols: int,
257
+ has_multi_frame_grids: bool = False) -> str:
258
+ """Create VLM prompt for analyzing video with primary mask"""
259
+
260
+ grid_context = ""
261
+ if has_multi_frame_grids:
262
+ grid_context = f"""
263
+ 1. **Multiple Grid Reference Frames**: Sampled frames at 0%, 11%, 22%, 33%, 44%, 56%, 67%, 78%, 89%, 100% of video
264
+ - Each frame shows YELLOW GRID with {grid_rows} rows Γ— {grid_cols} columns
265
+ - Grid cells labeled (row, col) starting from (0, 0) at top-left
266
+ - Frame number shown at bottom
267
+ - Use these to locate objects that appear MID-VIDEO and track object positions across time
268
+ 2. **First Frame with RED mask**: Shows what will be REMOVED (primary object)
269
+ 3. **Full Video**: Complete action and interactions"""
270
+ else:
271
+ grid_context = f"""
272
+ 1. **First Frame with Grid**: PRIMARY OBJECT highlighted in RED + GRID OVERLAY
273
+ - The red overlay shows what will be REMOVED (already masked)
274
+ - Yellow grid with {grid_rows} rows Γ— {grid_cols} columns
275
+ - Grid cells are labeled (row, col) starting from (0, 0) at top-left
276
+ 2. **Full Video**: Complete scene and action"""
277
+
278
+ return f"""
279
+ You are an expert video analyst specializing in physics and object interactions.
280
+
281
+ ═══════════════════════════════════════════════════════════════════
282
+ CONTEXT
283
+ ═══════════════════════════════════════════════════════════════════
284
+
285
+ You will see MULTIPLE inputs:
286
+ {grid_context}
287
+
288
+ Edit instruction: "{instruction}"
289
+
290
+ IMPORTANT: Some objects may NOT appear in first frame. They may enter later.
291
+ Watch the ENTIRE video and note when each object first appears.
292
+
293
+ ═══════════════════════════════════════════════════════════════════
294
+ YOUR TASK
295
+ ═══════════════════════════════════════════════════════════════════
296
+
297
+ Analyze what would happen if the PRIMARY OBJECT (shown in red) is removed.
298
+ Watch the ENTIRE video to see all interactions and movements.
299
+
300
+ STEP 1: IDENTIFY INTEGRAL BELONGINGS (0-3 items)
301
+ ─────────────────────────────────────────────────
302
+ Items that should be ADDED to the primary removal mask (removed WITH primary object):
303
+
304
+ βœ“ INCLUDE:
305
+ β€’ Distinct wearable items: hat, backpack, jacket (if separate/visible)
306
+ β€’ Vehicles/equipment being ridden: bike, skateboard, surfboard, scooter
307
+ β€’ Large carried items that are part of the subject
308
+
309
+ βœ— DO NOT INCLUDE:
310
+ β€’ Generic clothing (shirt, pants, shoes) - already captured with person
311
+ β€’ Held items that could be set down: guitar, cup, phone, tools
312
+ β€’ Objects they're interacting with but not wearing/riding
313
+
314
+ Examples:
315
+ β€’ Person on bike β†’ integral: "bike"
316
+ β€’ Person with guitar β†’ integral: none (guitar is affected, not integral)
317
+ β€’ Surfer β†’ integral: "surfboard"
318
+ β€’ Boxer β†’ integral: "boxing gloves" (wearable equipment)
319
+
320
+ STEP 2: IDENTIFY AFFECTED OBJECTS (0-5 objects)
321
+ ────────────────────────────────────────────────
322
+ Objects/effects that are SEPARATE from primary but affected by its removal.
323
+
324
+ CRITICAL: Do NOT include integral belongings from Step 1.
325
+
326
+ Two categories:
327
+
328
+ A) VISUAL ARTIFACTS (disappear when primary removed):
329
+ β€’ shadow, reflection, wake, ripples, splash, footprints
330
+ β€’ These vanish completely - no physics needed
331
+
332
+ **CRITICAL FOR VISUAL ARTIFACTS:**
333
+ You MUST provide GRID LOCALIZATIONS across the reference frames.
334
+ Keyword segmentation fails to isolate specific shadows/reflections.
335
+
336
+ For each visual artifact:
337
+ - Look at each grid reference frame you were shown
338
+ - Identify which grid cells the artifact occupies in EACH frame
339
+ - List all grid cells (row, col) that contain any part of it
340
+ - Be thorough - include ALL touched cells (over-mask is better than under-mask)
341
+
342
+ Format:
343
+ {{
344
+ "noun": "shadow",
345
+ "category": "visual_artifact",
346
+ "grid_localizations": [
347
+ {{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, ...]}},
348
+ {{"frame": 5, "grid_regions": [{{"row": 6, "col": 4}}, ...]}},
349
+ // ... for each reference frame shown
350
+ ]
351
+ }}
352
+
353
+ B) PHYSICAL OBJECTS (may move, fall, or stay):
354
+
355
+ CRITICAL - Understand the difference:
356
+
357
+ **SUPPORTING vs ACTING ON:**
358
+ β€’ SUPPORTING = holding UP against gravity β†’ object WILL FALL when removed
359
+ Examples: holding guitar, carrying cup, person sitting on chair
360
+ β†’ will_move: TRUE
361
+
362
+ β€’ ACTING ON = touching/manipulating but object rests on stable surface β†’ object STAYS
363
+ Examples: hand crushing can (can on table), hand opening can (can on counter),
364
+ hand pushing object (object on floor)
365
+ β†’ will_move: FALSE
366
+
367
+ **Key Questions:**
368
+ 1. Is the primary object HOLDING THIS UP against gravity?
369
+ - YES β†’ will_move: true, needs_trajectory: true
370
+ - NO β†’ Check next question
371
+
372
+ 2. Is this object RESTING ON a stable surface (table, floor, counter)?
373
+ - YES β†’ will_move: false (stays on surface when primary removed)
374
+ - NO β†’ will_move: true
375
+
376
+ 3. Is the primary object DOING an action TO this object?
377
+ - Opening can, crushing can, pushing button, turning knob
378
+ - When primary removed β†’ action STOPS, object stays in current state
379
+ - will_move: false
380
+
381
+ **SPECIAL CASE - Object Currently Moving But Should Have Stayed:**
382
+ If primary object CAUSES another object to move (hitting, kicking, throwing):
383
+ - The object is currently moving in the video
384
+ - But WITHOUT primary, it would have stayed at its original position
385
+ - You MUST provide:
386
+ β€’ "currently_moving": true
387
+ β€’ "should_have_stayed": true
388
+ β€’ "original_position_grid": {{"row": R, "col": C}} - Where it started
389
+
390
+ Examples:
391
+ - Golf club hits ball β†’ Ball at tee, then flies (mark original tee position)
392
+ - Person kicks soccer ball β†’ Ball on ground, then rolls (mark original ground position)
393
+ - Hand throws object β†’ Object held, then flies (mark original held position)
394
+
395
+ Format:
396
+ {{
397
+ "noun": "golf ball",
398
+ "category": "physical",
399
+ "currently_moving": true,
400
+ "should_have_stayed": true,
401
+ "original_position_grid": {{"row": 6, "col": 7}},
402
+ "why": "ball was stationary until club hit it"
403
+ }}
404
+
405
+ For each physical object, determine:
406
+ - **will_move**: true ONLY if object will fall/move when support removed
407
+ - **first_appears_frame**: frame number object first appears (0 if from start)
408
+ - **why**: Brief explanation of relationship to primary object
409
+
410
+ IF will_move=TRUE, also provide GRID-BASED TRAJECTORY:
411
+ - **object_size_grids**: {{"rows": R, "cols": C}} - How many grid cells object occupies
412
+ IMPORTANT: Add 1 extra cell padding for safety (better to over-mask than under-mask)
413
+ Example: Object looks 2Γ—1 β†’ report as 3Γ—2
414
+
415
+ - **trajectory_path**: List of keyframe positions as grid coordinates
416
+ Format: [{{"frame": N, "grid_row": R, "grid_col": C}}, ...]
417
+ - IMPORTANT: First keyframe should be at first_appears_frame (not frame 0 if object appears later!)
418
+ - Provide 3-5 keyframes spanning from first appearance to end
419
+ - (grid_row, grid_col) is the CENTER position of object at that frame
420
+ - Use the yellow grid reference frames to determine positions
421
+ - For objects appearing mid-video: use the grid samples to locate them
422
+ - Example: Object appears at frame 15, falls to bottom
423
+ [{{"frame": 15, "grid_row": 3, "grid_col": 5}}, ← First appearance
424
+ {{"frame": 25, "grid_row": 6, "grid_col": 5}}, ← Mid-fall
425
+ {{"frame": 35, "grid_row": 9, "grid_col": 5}}] ← On ground
426
+
427
+ βœ“ Objects held/carried at ANY point in video
428
+ βœ“ Objects the primary supports or interacts with
429
+ βœ“ Visual effects visible at any time
430
+
431
+ βœ— Background objects never touched
432
+ βœ— Other people/animals with no contact
433
+ βœ— Integral belongings (already in Step 1)
434
+
435
+ STEP 3: SCENE DESCRIPTION
436
+ ──────────────────────────
437
+ Describe scene WITHOUT the primary object (1-2 sentences).
438
+ Focus on what remains and any dynamic changes (falling objects, etc).
439
+
440
+ ═══════════════════════════════════════════════════════════════════
441
+ OUTPUT FORMAT (STRICT JSON ONLY)
442
+ ═══════════════════════════════════════════════════════════════════
443
+
444
+ EXAMPLES TO LEARN FROM:
445
+
446
+ Example 1: Person holding guitar
447
+ {{
448
+ "affected_objects": [
449
+ {{
450
+ "noun": "guitar",
451
+ "will_move": true,
452
+ "why": "person is SUPPORTING guitar against gravity by holding it",
453
+ "object_size_grids": {{"rows": 3, "cols": 2}},
454
+ "trajectory_path": [
455
+ {{"frame": 0, "grid_row": 4, "grid_col": 5}},
456
+ {{"frame": 15, "grid_row": 6, "grid_col": 5}},
457
+ {{"frame": 30, "grid_row": 8, "grid_col": 6}}
458
+ ]
459
+ }}
460
+ ]
461
+ }}
462
+
463
+ Example 2: Hand crushing can on table
464
+ {{
465
+ "affected_objects": [
466
+ {{
467
+ "noun": "can",
468
+ "will_move": false,
469
+ "why": "can RESTS ON TABLE - hand is just acting on it. When hand removed, can stays on table (uncrushed)"
470
+ }}
471
+ ]
472
+ }}
473
+
474
+ Example 3: Hands opening can on counter
475
+ {{
476
+ "affected_objects": [
477
+ {{
478
+ "noun": "can",
479
+ "will_move": false,
480
+ "why": "can RESTS ON COUNTER - hands are doing opening action. When hands removed, can stays closed on counter"
481
+ }}
482
+ ]
483
+ }}
484
+
485
+ Example 4: Person sitting on chair
486
+ {{
487
+ "affected_objects": [
488
+ {{
489
+ "noun": "chair",
490
+ "will_move": false,
491
+ "why": "chair RESTS ON FLOOR - person sitting on it doesn't make it fall. Chair stays on floor when person removed"
492
+ }}
493
+ ]
494
+ }}
495
+
496
+ Example 5: Person throws ball (ball appears at frame 12)
497
+ {{
498
+ "affected_objects": [
499
+ {{
500
+ "noun": "ball",
501
+ "category": "physical",
502
+ "will_move": true,
503
+ "first_appears_frame": 12,
504
+ "why": "ball is SUPPORTED by person's hand, then thrown",
505
+ "object_size_grids": {{"rows": 2, "cols": 2}},
506
+ "trajectory_path": [
507
+ {{"frame": 12, "grid_row": 4, "grid_col": 3}},
508
+ {{"frame": 20, "grid_row": 2, "grid_col": 6}},
509
+ {{"frame": 28, "grid_row": 5, "grid_col": 8}}
510
+ ]
511
+ }}
512
+ ]
513
+ }}
514
+
515
+ Example 6: Person with shadow (shadow needs grid localization)
516
+ {{
517
+ "affected_objects": [
518
+ {{
519
+ "noun": "shadow",
520
+ "category": "visual_artifact",
521
+ "why": "cast by person on the floor",
522
+ "will_move": false,
523
+ "first_appears_frame": 0,
524
+ "movement_description": "Disappears entirely as visual artifact",
525
+ "grid_localizations": [
526
+ {{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}},
527
+ {{"frame": 12, "grid_regions": [{{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}, {{"row": 7, "col": 4}}]}},
528
+ {{"frame": 23, "grid_regions": [{{"row": 5, "col": 4}}, {{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}]}},
529
+ {{"frame": 35, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}]}},
530
+ {{"frame": 47, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}}
531
+ ]
532
+ }}
533
+ ]
534
+ }}
535
+
536
+ Example 7: Golf club hits ball (Case 4 - currently moving but should stay)
537
+ {{
538
+ "affected_objects": [
539
+ {{
540
+ "noun": "golf ball",
541
+ "category": "physical",
542
+ "currently_moving": true,
543
+ "should_have_stayed": true,
544
+ "original_position_grid": {{"row": 6, "col": 7}},
545
+ "first_appears_frame": 0,
546
+ "why": "ball was stationary on tee until club hit it. Without club, ball would remain at original position."
547
+ }}
548
+ ]
549
+ }}
550
+
551
+ YOUR OUTPUT FORMAT:
552
+ {{
553
+ "edit_instruction": "{instruction}",
554
+ "integral_belongings": [
555
+ {{
556
+ "noun": "bike",
557
+ "why": "person is riding the bike throughout the video"
558
+ }}
559
+ ],
560
+ "affected_objects": [
561
+ {{
562
+ "noun": "guitar",
563
+ "category": "physical",
564
+ "why": "person is SUPPORTING guitar against gravity by holding it",
565
+ "will_move": true,
566
+ "first_appears_frame": 0,
567
+ "movement_description": "Will fall from held position to the ground",
568
+ "object_size_grids": {{"rows": 3, "cols": 2}},
569
+ "trajectory_path": [
570
+ {{"frame": 0, "grid_row": 3, "grid_col": 6}},
571
+ {{"frame": 20, "grid_row": 6, "grid_col": 6}},
572
+ {{"frame": 40, "grid_row": 9, "grid_col": 7}}
573
+ ]
574
+ }},
575
+ {{
576
+ "noun": "shadow",
577
+ "category": "visual_artifact",
578
+ "why": "cast by person on floor",
579
+ "will_move": false,
580
+ "first_appears_frame": 0,
581
+ "movement_description": "Disappears entirely as visual artifact"
582
+ }}
583
+ ],
584
+ "scene_description": "An acoustic guitar falling to the ground in an empty room. Natural window lighting.",
585
+ "confidence": 0.85
586
+ }}
587
+
588
+ CRITICAL REMINDERS:
589
+ β€’ Watch ENTIRE video before answering
590
+ β€’ SUPPORTING vs ACTING ON:
591
+ - Primary HOLDS UP object against gravity β†’ will_move=TRUE (provide grid trajectory)
592
+ - Primary ACTS ON object (crushing, opening) but object on stable surface β†’ will_move=FALSE
593
+ - Object RESTS ON stable surface (table, floor) β†’ will_move=FALSE
594
+ β€’ For visual artifacts (shadow, reflection): will_move=false (no trajectory needed)
595
+ β€’ For held objects (guitar, cup): will_move=true (MUST provide object_size_grids + trajectory_path)
596
+ β€’ For objects on surfaces being acted on (can being crushed, can being opened): will_move=false
597
+ β€’ Grid trajectory: Add +1 cell padding to object size (over-mask is better than under-mask)
598
+ β€’ Grid trajectory: Use the yellow grid overlay to determine (row, col) positions
599
+ β€’ Be conservative - when in doubt, DON'T include
600
+ β€’ Output MUST be valid JSON only
601
+
602
+ GRID INFO: {grid_rows} rows Γ— {grid_cols} columns
603
+ EDIT INSTRUCTION: {instruction}
604
+ """.strip()
605
+
606
+
607
+ def call_vlm_with_images_and_video(client, model: str, image_data_urls: list,
608
+ video_data_url: str, prompt: str) -> str:
609
+ """Call VLM with multiple images and video"""
610
+ content = []
611
+
612
+ # Add all images first
613
+ for img_url in image_data_urls:
614
+ content.append({"type": "image_url", "image_url": {"url": img_url}})
615
+
616
+ # Add video
617
+ content.append({"type": "image_url", "image_url": {"url": video_data_url}})
618
+
619
+ # Add prompt
620
+ content.append({"type": "text", "text": prompt})
621
+
622
+ resp = client.chat.completions.create(
623
+ model=model,
624
+ messages=[
625
+ {
626
+ "role": "system",
627
+ "content": "You are an expert video analyst with deep understanding of physics and object interactions. Always output valid JSON only."
628
+ },
629
+ {
630
+ "role": "user",
631
+ "content": content
632
+ },
633
+ ],
634
+ )
635
+ return resp.choices[0].message.content
636
+
637
+
638
+ def parse_vlm_response(raw: str) -> Dict:
639
+ """Parse VLM JSON response"""
640
+ # Strip markdown code blocks
641
+ cleaned = raw.strip()
642
+ if cleaned.startswith("```"):
643
+ lines = cleaned.split('\n')
644
+ if lines[0].startswith("```"):
645
+ lines = lines[1:]
646
+ if lines and lines[-1].strip() == "```":
647
+ lines = lines[:-1]
648
+ cleaned = '\n'.join(lines)
649
+
650
+ try:
651
+ parsed = json.loads(cleaned)
652
+ except json.JSONDecodeError:
653
+ # Try to find JSON in response
654
+ start = cleaned.find("{")
655
+ end = cleaned.rfind("}")
656
+ if start != -1 and end != -1 and end > start:
657
+ parsed = json.loads(cleaned[start:end+1])
658
+ else:
659
+ raise ValueError("Failed to parse VLM response as JSON")
660
+
661
+ # Validate structure
662
+ result = {
663
+ "edit_instruction": parsed.get("edit_instruction", ""),
664
+ "integral_belongings": [],
665
+ "affected_objects": [],
666
+ "scene_description": parsed.get("scene_description", ""),
667
+ "confidence": float(parsed.get("confidence", 0.0))
668
+ }
669
+
670
+ # Parse integral belongings
671
+ for item in parsed.get("integral_belongings", [])[:3]:
672
+ obj = {
673
+ "noun": str(item.get("noun", "")).strip().lower(),
674
+ "why": str(item.get("why", "")).strip()[:200]
675
+ }
676
+ if obj["noun"]:
677
+ result["integral_belongings"].append(obj)
678
+
679
+ # Parse affected objects
680
+ for item in parsed.get("affected_objects", [])[:5]:
681
+ obj = {
682
+ "noun": str(item.get("noun", "")).strip().lower(),
683
+ "category": str(item.get("category", "physical")).strip().lower(),
684
+ "why": str(item.get("why", "")).strip()[:200],
685
+ "will_move": bool(item.get("will_move", False)),
686
+ "first_appears_frame": int(item.get("first_appears_frame", 0)),
687
+ "movement_description": str(item.get("movement_description", "")).strip()[:300]
688
+ }
689
+
690
+ # Parse Case 4: currently moving but should have stayed
691
+ if "currently_moving" in item:
692
+ obj["currently_moving"] = bool(item.get("currently_moving", False))
693
+ if "should_have_stayed" in item:
694
+ obj["should_have_stayed"] = bool(item.get("should_have_stayed", False))
695
+ if "original_position_grid" in item:
696
+ orig_grid = item.get("original_position_grid", {})
697
+ obj["original_position_grid"] = {
698
+ "row": int(orig_grid.get("row", 0)),
699
+ "col": int(orig_grid.get("col", 0))
700
+ }
701
+
702
+ # Parse grid localizations for visual artifacts
703
+ if "grid_localizations" in item:
704
+ grid_locs = []
705
+ for loc in item.get("grid_localizations", []):
706
+ frame_loc = {
707
+ "frame": int(loc.get("frame", 0)),
708
+ "grid_regions": []
709
+ }
710
+ for region in loc.get("grid_regions", []):
711
+ frame_loc["grid_regions"].append({
712
+ "row": int(region.get("row", 0)),
713
+ "col": int(region.get("col", 0))
714
+ })
715
+ if frame_loc["grid_regions"]: # Only add if has regions
716
+ grid_locs.append(frame_loc)
717
+ if grid_locs:
718
+ obj["grid_localizations"] = grid_locs
719
+
720
+ # Parse grid trajectory if will_move=true
721
+ if obj["will_move"] and "object_size_grids" in item and "trajectory_path" in item:
722
+ size_grids = item.get("object_size_grids", {})
723
+ obj["object_size_grids"] = {
724
+ "rows": int(size_grids.get("rows", 2)),
725
+ "cols": int(size_grids.get("cols", 2))
726
+ }
727
+
728
+ trajectory = []
729
+ for point in item.get("trajectory_path", []):
730
+ trajectory.append({
731
+ "frame": int(point.get("frame", 0)),
732
+ "grid_row": int(point.get("grid_row", 0)),
733
+ "grid_col": int(point.get("grid_col", 0))
734
+ })
735
+
736
+ if trajectory: # Only add if we have valid trajectory points
737
+ obj["trajectory_path"] = trajectory
738
+
739
+ if obj["noun"]:
740
+ result["affected_objects"].append(obj)
741
+
742
+ return result
743
+
744
+
745
+ def process_video(video_info: Dict, client, model: str):
746
+ """Process a single video with VLM analysis"""
747
+ video_path = video_info.get("video_path", "")
748
+ instruction = video_info.get("instruction", "")
749
+ output_dir = video_info.get("output_dir", "")
750
+
751
+ if not output_dir:
752
+ print(f" ⚠️ No output_dir specified, skipping")
753
+ return None
754
+
755
+ output_dir = Path(output_dir)
756
+ if not output_dir.exists():
757
+ print(f" ⚠️ Output directory not found: {output_dir}")
758
+ print(f" Run Stage 1 first to create black masks")
759
+ return None
760
+
761
+ # Check required files from Stage 1
762
+ black_mask_path = output_dir / "black_mask.mp4"
763
+ first_frame_path = output_dir / "first_frame.jpg"
764
+ input_video_path = output_dir / "input_video.mp4"
765
+ segmentation_info_path = output_dir / "segmentation_info.json"
766
+
767
+ if not black_mask_path.exists():
768
+ print(f" ⚠️ black_mask.mp4 not found in {output_dir}")
769
+ print(f" Run Stage 1 first")
770
+ return None
771
+
772
+ if not first_frame_path.exists():
773
+ print(f" ⚠️ first_frame.jpg not found in {output_dir}")
774
+ return None
775
+
776
+ if not input_video_path.exists():
777
+ # Try original video path
778
+ if Path(video_path).exists():
779
+ input_video_path = Path(video_path)
780
+ else:
781
+ print(f" ⚠️ Video not found: {video_path}")
782
+ return None
783
+
784
+ # Read segmentation metadata to get correct frame index
785
+ frame_idx = 0 # Default
786
+ if segmentation_info_path.exists():
787
+ try:
788
+ with open(segmentation_info_path, 'r') as f:
789
+ seg_info = json.load(f)
790
+ frame_idx = seg_info.get("first_appears_frame", 0)
791
+ print(f" Using frame {frame_idx} from segmentation metadata")
792
+ except Exception as e:
793
+ print(f" Warning: Could not read segmentation_info.json: {e}")
794
+ print(f" Using frame 0 as fallback")
795
+
796
+ # Get min_grid for grid calculation
797
+ min_grid = video_info.get('min_grid', 8)
798
+ use_multi_frame_grids = video_info.get('multi_frame_grids', True) # Default: use multi-frame
799
+ max_video_size_mb = video_info.get('max_video_size_for_multiframe', 25) # Default: 25MB limit
800
+
801
+ # Check video size and auto-disable multi-frame for large videos
802
+ if use_multi_frame_grids:
803
+ video_size_mb = input_video_path.stat().st_size / (1024 * 1024)
804
+ if video_size_mb > max_video_size_mb:
805
+ print(f" ⚠️ Video size ({video_size_mb:.1f} MB) exceeds {max_video_size_mb} MB")
806
+ print(f" Auto-disabling multi-frame grids to avoid API errors")
807
+ use_multi_frame_grids = False
808
+
809
+ print(f" Creating frame overlays and grids...")
810
+ overlay_path = output_dir / "first_frame_with_mask.jpg"
811
+ gridded_path = output_dir / "first_frame_with_grid.jpg"
812
+
813
+ # Create regular overlay (for backwards compatibility)
814
+ create_first_frame_with_mask_overlay(
815
+ str(first_frame_path),
816
+ str(black_mask_path),
817
+ str(overlay_path),
818
+ frame_idx=frame_idx
819
+ )
820
+
821
+ image_data_urls = []
822
+
823
+ if use_multi_frame_grids:
824
+ # Create multi-frame grid samples for objects appearing mid-video
825
+ print(f" Creating multi-frame grid samples (0%, 25%, 50%, 75%, 100%)...")
826
+ sample_paths, grid_rows, grid_cols = create_multi_frame_grid_samples(
827
+ str(input_video_path),
828
+ output_dir,
829
+ min_grid=min_grid
830
+ )
831
+
832
+ # Encode all grid samples
833
+ for sample_path in sample_paths:
834
+ image_data_urls.append(image_to_data_url(str(sample_path)))
835
+
836
+ # Also add the first frame with mask overlay
837
+ _, _, _ = create_gridded_frame_overlay(
838
+ str(first_frame_path),
839
+ str(black_mask_path),
840
+ str(gridded_path),
841
+ min_grid=min_grid
842
+ )
843
+ image_data_urls.append(image_to_data_url(str(gridded_path)))
844
+
845
+ print(f" Grid: {grid_rows}x{grid_cols}, {len(sample_paths)} sample frames + masked frame")
846
+
847
+ else:
848
+ # Single gridded first frame (old approach)
849
+ _, grid_rows, grid_cols = create_gridded_frame_overlay(
850
+ str(first_frame_path),
851
+ str(black_mask_path),
852
+ str(gridded_path),
853
+ min_grid=min_grid
854
+ )
855
+ image_data_urls.append(image_to_data_url(str(gridded_path)))
856
+ print(f" Grid: {grid_rows}x{grid_cols} (single frame)")
857
+
858
+ print(f" Encoding video for VLM...")
859
+
860
+ # Check video size
861
+ video_size_mb = input_video_path.stat().st_size / (1024 * 1024)
862
+ print(f" Video size: {video_size_mb:.1f} MB")
863
+
864
+ if video_size_mb > 20:
865
+ print(f" ⚠️ Warning: Large video may cause API errors")
866
+ if use_multi_frame_grids:
867
+ print(f" Consider setting multi_frame_grids=false for large videos")
868
+
869
+ video_data_url = video_to_data_url(str(input_video_path))
870
+
871
+ print(f" Calling {model}...")
872
+ prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
873
+ has_multi_frame_grids=use_multi_frame_grids)
874
+
875
+ try:
876
+ try:
877
+ raw_response = call_vlm_with_images_and_video(
878
+ client, model, image_data_urls, video_data_url, prompt
879
+ )
880
+ except Exception as e:
881
+ # If multi-frame fails (likely payload size issue), fall back to single frame
882
+ if use_multi_frame_grids and "400" in str(e):
883
+ print(f" ⚠️ Multi-frame request failed (payload too large?)")
884
+ print(f" Falling back to single-frame grid mode...")
885
+
886
+ # Retry with just the gridded first frame
887
+ image_data_urls = [image_to_data_url(str(gridded_path))]
888
+ prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
889
+ has_multi_frame_grids=False)
890
+
891
+ try:
892
+ raw_response = call_vlm_with_images_and_video(
893
+ client, model, image_data_urls, video_data_url, prompt
894
+ )
895
+ print(f" βœ“ Single-frame fallback succeeded")
896
+ except Exception as e2:
897
+ raise e2 # Re-raise if fallback also fails
898
+ else:
899
+ raise # Re-raise if not a 400 or not multi-frame mode
900
+
901
+ # Parse and save results (runs whether first call succeeded or fallback succeeded)
902
+ print(f" Parsing VLM response...")
903
+ analysis = parse_vlm_response(raw_response)
904
+
905
+ # Save results
906
+ output_path = output_dir / "vlm_analysis.json"
907
+ with open(output_path, 'w') as f:
908
+ json.dump(analysis, f, indent=2)
909
+
910
+ print(f" βœ“ Saved VLM analysis: {output_path.name}")
911
+
912
+ # Print summary
913
+ print(f"\n Summary:")
914
+ print(f" - Integral belongings: {len(analysis['integral_belongings'])}")
915
+ for obj in analysis['integral_belongings']:
916
+ print(f" β€’ {obj['noun']}: {obj['why']}")
917
+
918
+ print(f" - Affected objects: {len(analysis['affected_objects'])}")
919
+ for obj in analysis['affected_objects']:
920
+ move_str = "WILL MOVE" if obj['will_move'] else "STAYS/DISAPPEARS"
921
+ traj_str = ""
922
+ if obj.get('will_move') and 'trajectory_path' in obj:
923
+ num_points = len(obj['trajectory_path'])
924
+ size = obj.get('object_size_grids', {})
925
+ traj_str = f" (trajectory: {num_points} keyframes, size: {size.get('rows')}Γ—{size.get('cols')} grids)"
926
+ print(f" β€’ {obj['noun']}: {move_str}{traj_str}")
927
+
928
+ return analysis
929
+
930
+ except Exception as e:
931
+ print(f" ❌ VLM analysis failed: {e}")
932
+ import traceback
933
+ traceback.print_exc()
934
+ return None
935
+
936
+
937
+ def process_config(config_path: str, model: str = DEFAULT_MODEL):
938
+ """Process all videos in config"""
939
+ config_path = Path(config_path)
940
+
941
+ # Load config
942
+ with open(config_path, 'r') as f:
943
+ config_data = json.load(f)
944
+
945
+ # Handle both formats
946
+ if isinstance(config_data, list):
947
+ videos = config_data
948
+ elif isinstance(config_data, dict) and "videos" in config_data:
949
+ videos = config_data["videos"]
950
+ else:
951
+ raise ValueError("Config must be a list or have 'videos' key")
952
+
953
+ print(f"\n{'='*70}")
954
+ print(f"Stage 2: VLM Analysis - Identify Affected Objects")
955
+ print(f"{'='*70}")
956
+ print(f"Config: {config_path.name}")
957
+ print(f"Videos: {len(videos)}")
958
+ print(f"Model: {model}")
959
+ print(f"{'='*70}\n")
960
+
961
+ # Initialize VLM client
962
+ api_key = os.environ.get("GEMINI_API_KEY")
963
+ if not api_key:
964
+ raise RuntimeError("GEMINI_API_KEY environment variable not set")
965
+ client = openai.OpenAI(
966
+ api_key=api_key,
967
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
968
+ )
969
+
970
+ # Process each video
971
+ results = []
972
+ for i, video_info in enumerate(videos):
973
+ video_path = video_info.get("video_path", "")
974
+ instruction = video_info.get("instruction", "")
975
+
976
+ print(f"\n{'─'*70}")
977
+ print(f"Video {i+1}/{len(videos)}: {Path(video_path).name}")
978
+ print(f"{'─'*70}")
979
+ print(f"Instruction: {instruction}")
980
+
981
+ try:
982
+ analysis = process_video(video_info, client, model)
983
+ results.append({
984
+ "video": video_path,
985
+ "success": analysis is not None,
986
+ "analysis": analysis
987
+ })
988
+
989
+ if analysis:
990
+ print(f"\nβœ… Video {i+1} complete!")
991
+ else:
992
+ print(f"\n⚠️ Video {i+1} skipped")
993
+
994
+ except Exception as e:
995
+ print(f"\n❌ Error processing video {i+1}: {e}")
996
+ results.append({
997
+ "video": video_path,
998
+ "success": False,
999
+ "error": str(e)
1000
+ })
1001
+ continue
1002
+
1003
+ # Summary
1004
+ print(f"\n{'='*70}")
1005
+ print(f"Stage 2 Complete!")
1006
+ print(f"{'='*70}")
1007
+ successful = sum(1 for r in results if r["success"])
1008
+ print(f"Successful: {successful}/{len(videos)}")
1009
+ print(f"{'='*70}\n")
1010
+
1011
+
1012
+ def main():
1013
+ parser = argparse.ArgumentParser(description="Stage 2: VLM Analysis")
1014
+ parser.add_argument("--config", required=True, help="Config JSON from Stage 1")
1015
+ parser.add_argument("--model", default=DEFAULT_MODEL, help="VLM model name")
1016
+ args = parser.parse_args()
1017
+
1018
+ process_config(args.config, args.model)
1019
+
1020
+
1021
+ if __name__ == "__main__":
1022
+ main()
VLM-MASK-REASONER/stage2_vlm_analysis_cf.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 2: VLM Analysis - Identify Affected Objects & Physics
4
+ (Cloudflare AI Gateway variant)
5
+
6
+ Identical to stage2_vlm_analysis.py but routes through the internal CF AI Gateway
7
+ instead of calling the Gemini API directly. Video is sent as sampled frames rather
8
+ than a raw video data URL (not supported by the OpenAI-compat endpoint).
9
+
10
+ Required environment variables:
11
+ CF_PROJECT_ID - Cloudflare AI Gateway project ID
12
+ CF_USER_ID - Cloudflare AI Gateway user ID
13
+ MODEL_ID - Model identifier to use (e.g. "gemini-3-pro-preview")
14
+
15
+ Usage:
16
+ python stage2_vlm_analysis_cf.py --config my_config_points.json
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import json
22
+ import argparse
23
+ import cv2
24
+ import numpy as np
25
+ import base64
26
+ from pathlib import Path
27
+ from typing import Dict, List
28
+ from PIL import Image, ImageDraw
29
+
30
+ import openai
31
+
32
+ DEFAULT_MODEL = "gemini-3-pro-preview"
33
+
34
+
35
+ def image_to_data_url(image_path: str) -> str:
36
+ """Convert image file to base64 data URL"""
37
+ with open(image_path, 'rb') as f:
38
+ img_data = base64.b64encode(f.read()).decode('utf-8')
39
+
40
+ # Detect format
41
+ ext = Path(image_path).suffix.lower()
42
+ if ext == '.png':
43
+ mime = 'image/png'
44
+ elif ext in ['.jpg', '.jpeg']:
45
+ mime = 'image/jpeg'
46
+ else:
47
+ mime = 'image/jpeg'
48
+
49
+ return f"data:{mime};base64,{img_data}"
50
+
51
+
52
+ def video_to_data_url(video_path: str) -> str:
53
+ """Convert video file to base64 data URL"""
54
+ with open(video_path, 'rb') as f:
55
+ video_data = base64.b64encode(f.read()).decode('utf-8')
56
+ return f"data:video/mp4;base64,{video_data}"
57
+
58
+
59
+ def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> tuple:
60
+ """Calculate grid dimensions matching stage3a logic"""
61
+ aspect_ratio = width / height
62
+ if width >= height:
63
+ grid_rows = min_grid
64
+ grid_cols = max(min_grid, round(min_grid * aspect_ratio))
65
+ else:
66
+ grid_cols = min_grid
67
+ grid_rows = max(min_grid, round(min_grid / aspect_ratio))
68
+ return grid_rows, grid_cols
69
+
70
+
71
+ def create_first_frame_with_mask_overlay(first_frame_path: str, black_mask_path: str,
72
+ output_path: str, frame_idx: int = 0) -> str:
73
+ """Create visualization of first frame with red overlay on primary object
74
+
75
+ Args:
76
+ first_frame_path: Path to first_frame.jpg
77
+ black_mask_path: Path to black_mask.mp4
78
+ output_path: Where to save overlay
79
+ frame_idx: Which frame to extract from black_mask.mp4 (default: 0)
80
+ """
81
+ # Load first frame
82
+ frame = cv2.imread(first_frame_path)
83
+ if frame is None:
84
+ raise ValueError(f"Failed to load first frame: {first_frame_path}")
85
+
86
+ # Load black mask video and get the specified frame
87
+ cap = cv2.VideoCapture(black_mask_path)
88
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
89
+ ret, mask_frame = cap.read()
90
+ cap.release()
91
+
92
+ if not ret:
93
+ raise ValueError(f"Failed to load black mask frame {frame_idx}: {black_mask_path}")
94
+
95
+ # Convert mask to binary (0 = object, 255 = background)
96
+ if len(mask_frame.shape) == 3:
97
+ mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
98
+
99
+ object_mask = (mask_frame == 0)
100
+
101
+ # Create red overlay on object
102
+ overlay = frame.copy()
103
+ overlay[object_mask] = [0, 0, 255] # Red in BGR
104
+
105
+ # Blend: 60% original + 40% red overlay
106
+ result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
107
+
108
+ # Save
109
+ cv2.imwrite(output_path, result)
110
+ return output_path
111
+
112
+
113
+ def create_gridded_frame_overlay(first_frame_path: str, black_mask_path: str,
114
+ output_path: str, min_grid: int = 8) -> tuple:
115
+ """Create first frame with BOTH red mask overlay AND grid lines
116
+
117
+ Returns: (output_path, grid_rows, grid_cols)
118
+ """
119
+ # Load first frame
120
+ frame = cv2.imread(first_frame_path)
121
+ if frame is None:
122
+ raise ValueError(f"Failed to load first frame: {first_frame_path}")
123
+
124
+ h, w = frame.shape[:2]
125
+
126
+ # Load black mask
127
+ cap = cv2.VideoCapture(black_mask_path)
128
+ ret, mask_frame = cap.read()
129
+ cap.release()
130
+
131
+ if not ret:
132
+ raise ValueError(f"Failed to load black mask: {black_mask_path}")
133
+
134
+ if len(mask_frame.shape) == 3:
135
+ mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
136
+
137
+ object_mask = (mask_frame == 0)
138
+
139
+ # Create red overlay
140
+ overlay = frame.copy()
141
+ overlay[object_mask] = [0, 0, 255]
142
+ result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
143
+
144
+ # Calculate grid
145
+ grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
146
+
147
+ # Draw grid lines
148
+ cell_width = w / grid_cols
149
+ cell_height = h / grid_rows
150
+
151
+ # Vertical lines
152
+ for col in range(1, grid_cols):
153
+ x = int(col * cell_width)
154
+ cv2.line(result, (x, 0), (x, h), (255, 255, 0), 1) # Yellow lines
155
+
156
+ # Horizontal lines
157
+ for row in range(1, grid_rows):
158
+ y = int(row * cell_height)
159
+ cv2.line(result, (0, y), (w, y), (255, 255, 0), 1)
160
+
161
+ # Add grid labels
162
+ font = cv2.FONT_HERSHEY_SIMPLEX
163
+ font_scale = 0.3
164
+ thickness = 1
165
+
166
+ # Label columns at top
167
+ for col in range(grid_cols):
168
+ x = int((col + 0.5) * cell_width)
169
+ cv2.putText(result, str(col), (x-5, 15), font, font_scale, (255, 255, 0), thickness)
170
+
171
+ # Label rows on left
172
+ for row in range(grid_rows):
173
+ y = int((row + 0.5) * cell_height)
174
+ cv2.putText(result, str(row), (5, y+5), font, font_scale, (255, 255, 0), thickness)
175
+
176
+ cv2.imwrite(output_path, result)
177
+ return output_path, grid_rows, grid_cols
178
+
179
+
180
+ def create_multi_frame_grid_samples(video_path: str, output_dir: Path,
181
+ min_grid: int = 8,
182
+ sample_points: list = [0.0, 0.11, 0.22, 0.33, 0.44, 0.56, 0.67, 0.78, 0.89, 1.0]) -> tuple:
183
+ """
184
+ Create gridded frame samples at multiple time points in video.
185
+ Helps VLM see objects that appear mid-video with grid reference.
186
+
187
+ Args:
188
+ video_path: Path to video
189
+ output_dir: Where to save samples
190
+ min_grid: Minimum grid size
191
+ sample_points: List of normalized positions [0.0-1.0] to sample
192
+
193
+ Returns: (sample_paths, grid_rows, grid_cols)
194
+ """
195
+ cap = cv2.VideoCapture(str(video_path))
196
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
197
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
198
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
199
+
200
+ # Calculate grid (same for all frames)
201
+ grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
202
+ cell_width = w / grid_cols
203
+ cell_height = h / grid_rows
204
+
205
+ sample_paths = []
206
+
207
+ for i, t in enumerate(sample_points):
208
+ frame_idx = int(t * (total_frames - 1))
209
+ frame_idx = max(0, min(frame_idx, total_frames - 1))
210
+
211
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
212
+ ret, frame = cap.read()
213
+ if not ret:
214
+ continue
215
+
216
+ # Draw grid
217
+ result = frame.copy()
218
+
219
+ # Vertical lines
220
+ for col in range(1, grid_cols):
221
+ x = int(col * cell_width)
222
+ cv2.line(result, (x, 0), (x, h), (255, 255, 0), 2)
223
+
224
+ # Horizontal lines
225
+ for row in range(1, grid_rows):
226
+ y = int(row * cell_height)
227
+ cv2.line(result, (0, y), (w, y), (255, 255, 0), 2)
228
+
229
+ # Add grid labels
230
+ font = cv2.FONT_HERSHEY_SIMPLEX
231
+ font_scale = 0.4
232
+ thickness = 1
233
+
234
+ # Label columns
235
+ for col in range(grid_cols):
236
+ x = int((col + 0.5) * cell_width)
237
+ cv2.putText(result, str(col), (x-8, 20), font, font_scale, (255, 255, 0), thickness)
238
+
239
+ # Label rows
240
+ for row in range(grid_rows):
241
+ y = int((row + 0.5) * cell_height)
242
+ cv2.putText(result, str(row), (10, y+8), font, font_scale, (255, 255, 0), thickness)
243
+
244
+ # Add frame number and percentage
245
+ label = f"Frame {frame_idx} ({int(t*100)}%)"
246
+ cv2.putText(result, label, (10, h-10), font, 0.5, (255, 255, 0), 2)
247
+
248
+ # Save
249
+ output_path = output_dir / f"grid_sample_frame_{frame_idx:04d}.jpg"
250
+ cv2.imwrite(str(output_path), result)
251
+ sample_paths.append(output_path)
252
+
253
+ cap.release()
254
+ return sample_paths, grid_rows, grid_cols
255
+
256
+
257
+ def make_vlm_analysis_prompt(instruction: str, grid_rows: int, grid_cols: int,
258
+ has_multi_frame_grids: bool = False) -> str:
259
+ """Create VLM prompt for analyzing video with primary mask"""
260
+
261
+ grid_context = ""
262
+ if has_multi_frame_grids:
263
+ grid_context = f"""
264
+ 1. **Multiple Grid Reference Frames**: Sampled frames at 0%, 11%, 22%, 33%, 44%, 56%, 67%, 78%, 89%, 100% of video
265
+ - Each frame shows YELLOW GRID with {grid_rows} rows Γ— {grid_cols} columns
266
+ - Grid cells labeled (row, col) starting from (0, 0) at top-left
267
+ - Frame number shown at bottom
268
+ - Use these to locate objects that appear MID-VIDEO and track object positions across time
269
+ 2. **First Frame with RED mask**: Shows what will be REMOVED (primary object)
270
+ 3. **Full Video**: Complete action and interactions"""
271
+ else:
272
+ grid_context = f"""
273
+ 1. **First Frame with Grid**: PRIMARY OBJECT highlighted in RED + GRID OVERLAY
274
+ - The red overlay shows what will be REMOVED (already masked)
275
+ - Yellow grid with {grid_rows} rows Γ— {grid_cols} columns
276
+ - Grid cells are labeled (row, col) starting from (0, 0) at top-left
277
+ 2. **Full Video**: Complete scene and action"""
278
+
279
+ return f"""
280
+ You are an expert video analyst specializing in physics and object interactions.
281
+
282
+ ═══════════════════════════════════════════════════════════════════
283
+ CONTEXT
284
+ ═══════════════════════════════════════════════════════════════════
285
+
286
+ You will see MULTIPLE inputs:
287
+ {grid_context}
288
+
289
+ Edit instruction: "{instruction}"
290
+
291
+ IMPORTANT: Some objects may NOT appear in first frame. They may enter later.
292
+ Watch the ENTIRE video and note when each object first appears.
293
+
294
+ ═══════════════════════════════════════════════════════════════════
295
+ YOUR TASK
296
+ ═══════════════════════════════════════════════════════════════════
297
+
298
+ Analyze what would happen if the PRIMARY OBJECT (shown in red) is removed.
299
+ Watch the ENTIRE video to see all interactions and movements.
300
+
301
+ STEP 1: IDENTIFY INTEGRAL BELONGINGS (0-3 items)
302
+ ─────────────────────────────────────────────────
303
+ Items that should be ADDED to the primary removal mask (removed WITH primary object):
304
+
305
+ βœ“ INCLUDE:
306
+ β€’ Distinct wearable items: hat, backpack, jacket (if separate/visible)
307
+ β€’ Vehicles/equipment being ridden: bike, skateboard, surfboard, scooter
308
+ β€’ Large carried items that are part of the subject
309
+
310
+ βœ— DO NOT INCLUDE:
311
+ β€’ Generic clothing (shirt, pants, shoes) - already captured with person
312
+ β€’ Held items that could be set down: guitar, cup, phone, tools
313
+ β€’ Objects they're interacting with but not wearing/riding
314
+
315
+ Examples:
316
+ β€’ Person on bike β†’ integral: "bike"
317
+ β€’ Person with guitar β†’ integral: none (guitar is affected, not integral)
318
+ β€’ Surfer β†’ integral: "surfboard"
319
+ β€’ Boxer β†’ integral: "boxing gloves" (wearable equipment)
320
+
321
+ STEP 2: IDENTIFY AFFECTED OBJECTS (0-5 objects)
322
+ ────────────────────────────────────────────────
323
+ Objects/effects that are SEPARATE from primary but affected by its removal.
324
+
325
+ CRITICAL: Do NOT include integral belongings from Step 1.
326
+
327
+ Two categories:
328
+
329
+ A) VISUAL ARTIFACTS (disappear when primary removed):
330
+ β€’ shadow, reflection, wake, ripples, splash, footprints
331
+ β€’ These vanish completely - no physics needed
332
+
333
+ **CRITICAL FOR VISUAL ARTIFACTS:**
334
+ You MUST provide GRID LOCALIZATIONS across the reference frames.
335
+ Keyword segmentation fails to isolate specific shadows/reflections.
336
+
337
+ For each visual artifact:
338
+ - Look at each grid reference frame you were shown
339
+ - Identify which grid cells the artifact occupies in EACH frame
340
+ - List all grid cells (row, col) that contain any part of it
341
+ - Be thorough - include ALL touched cells (over-mask is better than under-mask)
342
+
343
+ Format:
344
+ {{
345
+ "noun": "shadow",
346
+ "category": "visual_artifact",
347
+ "grid_localizations": [
348
+ {{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, ...]}},
349
+ {{"frame": 5, "grid_regions": [{{"row": 6, "col": 4}}, ...]}},
350
+ // ... for each reference frame shown
351
+ ]
352
+ }}
353
+
354
+ B) PHYSICAL OBJECTS (may move, fall, or stay):
355
+
356
+ CRITICAL - Understand the difference:
357
+
358
+ **SUPPORTING vs ACTING ON:**
359
+ β€’ SUPPORTING = holding UP against gravity β†’ object WILL FALL when removed
360
+ Examples: holding guitar, carrying cup, person sitting on chair
361
+ β†’ will_move: TRUE
362
+
363
+ β€’ ACTING ON = touching/manipulating but object rests on stable surface β†’ object STAYS
364
+ Examples: hand crushing can (can on table), hand opening can (can on counter),
365
+ hand pushing object (object on floor)
366
+ β†’ will_move: FALSE
367
+
368
+ **Key Questions:**
369
+ 1. Is the primary object HOLDING THIS UP against gravity?
370
+ - YES β†’ will_move: true, needs_trajectory: true
371
+ - NO β†’ Check next question
372
+
373
+ 2. Is this object RESTING ON a stable surface (table, floor, counter)?
374
+ - YES β†’ will_move: false (stays on surface when primary removed)
375
+ - NO β†’ will_move: true
376
+
377
+ 3. Is the primary object DOING an action TO this object?
378
+ - Opening can, crushing can, pushing button, turning knob
379
+ - When primary removed β†’ action STOPS, object stays in current state
380
+ - will_move: false
381
+
382
+ **SPECIAL CASE - Object Currently Moving But Should Have Stayed:**
383
+ If primary object CAUSES another object to move (hitting, kicking, throwing):
384
+ - The object is currently moving in the video
385
+ - But WITHOUT primary, it would have stayed at its original position
386
+ - You MUST provide:
387
+ β€’ "currently_moving": true
388
+ β€’ "should_have_stayed": true
389
+ β€’ "original_position_grid": {{"row": R, "col": C}} - Where it started
390
+
391
+ Examples:
392
+ - Golf club hits ball β†’ Ball at tee, then flies (mark original tee position)
393
+ - Person kicks soccer ball β†’ Ball on ground, then rolls (mark original ground position)
394
+ - Hand throws object β†’ Object held, then flies (mark original held position)
395
+
396
+ Format:
397
+ {{
398
+ "noun": "golf ball",
399
+ "category": "physical",
400
+ "currently_moving": true,
401
+ "should_have_stayed": true,
402
+ "original_position_grid": {{"row": 6, "col": 7}},
403
+ "why": "ball was stationary until club hit it"
404
+ }}
405
+
406
+ For each physical object, determine:
407
+ - **will_move**: true ONLY if object will fall/move when support removed
408
+ - **first_appears_frame**: frame number object first appears (0 if from start)
409
+ - **why**: Brief explanation of relationship to primary object
410
+
411
+ IF will_move=TRUE, also provide GRID-BASED TRAJECTORY:
412
+ - **object_size_grids**: {{"rows": R, "cols": C}} - How many grid cells object occupies
413
+ IMPORTANT: Add 1 extra cell padding for safety (better to over-mask than under-mask)
414
+ Example: Object looks 2Γ—1 β†’ report as 3Γ—2
415
+
416
+ - **trajectory_path**: List of keyframe positions as grid coordinates
417
+ Format: [{{"frame": N, "grid_row": R, "grid_col": C}}, ...]
418
+ - IMPORTANT: First keyframe should be at first_appears_frame (not frame 0 if object appears later!)
419
+ - Provide 3-5 keyframes spanning from first appearance to end
420
+ - (grid_row, grid_col) is the CENTER position of object at that frame
421
+ - Use the yellow grid reference frames to determine positions
422
+ - For objects appearing mid-video: use the grid samples to locate them
423
+ - Example: Object appears at frame 15, falls to bottom
424
+ [{{"frame": 15, "grid_row": 3, "grid_col": 5}}, ← First appearance
425
+ {{"frame": 25, "grid_row": 6, "grid_col": 5}}, ← Mid-fall
426
+ {{"frame": 35, "grid_row": 9, "grid_col": 5}}] ← On ground
427
+
428
+ βœ“ Objects held/carried at ANY point in video
429
+ βœ“ Objects the primary supports or interacts with
430
+ βœ“ Visual effects visible at any time
431
+
432
+ βœ— Background objects never touched
433
+ βœ— Other people/animals with no contact
434
+ βœ— Integral belongings (already in Step 1)
435
+
436
+ STEP 3: SCENE DESCRIPTION
437
+ ──────────────────────────
438
+ Describe scene WITHOUT the primary object (1-2 sentences).
439
+ Focus on what remains and any dynamic changes (falling objects, etc).
440
+
441
+ ═══════════════════════════════════════════════════════════════════
442
+ OUTPUT FORMAT (STRICT JSON ONLY)
443
+ ═══════════════════════════════════════════════════════════════════
444
+
445
+ EXAMPLES TO LEARN FROM:
446
+
447
+ Example 1: Person holding guitar
448
+ {{
449
+ "affected_objects": [
450
+ {{
451
+ "noun": "guitar",
452
+ "will_move": true,
453
+ "why": "person is SUPPORTING guitar against gravity by holding it",
454
+ "object_size_grids": {{"rows": 3, "cols": 2}},
455
+ "trajectory_path": [
456
+ {{"frame": 0, "grid_row": 4, "grid_col": 5}},
457
+ {{"frame": 15, "grid_row": 6, "grid_col": 5}},
458
+ {{"frame": 30, "grid_row": 8, "grid_col": 6}}
459
+ ]
460
+ }}
461
+ ]
462
+ }}
463
+
464
+ Example 2: Hand crushing can on table
465
+ {{
466
+ "affected_objects": [
467
+ {{
468
+ "noun": "can",
469
+ "will_move": false,
470
+ "why": "can RESTS ON TABLE - hand is just acting on it. When hand removed, can stays on table (uncrushed)"
471
+ }}
472
+ ]
473
+ }}
474
+
475
+ Example 3: Hands opening can on counter
476
+ {{
477
+ "affected_objects": [
478
+ {{
479
+ "noun": "can",
480
+ "will_move": false,
481
+ "why": "can RESTS ON COUNTER - hands are doing opening action. When hands removed, can stays closed on counter"
482
+ }}
483
+ ]
484
+ }}
485
+
486
+ Example 4: Person sitting on chair
487
+ {{
488
+ "affected_objects": [
489
+ {{
490
+ "noun": "chair",
491
+ "will_move": false,
492
+ "why": "chair RESTS ON FLOOR - person sitting on it doesn't make it fall. Chair stays on floor when person removed"
493
+ }}
494
+ ]
495
+ }}
496
+
497
+ Example 5: Person throws ball (ball appears at frame 12)
498
+ {{
499
+ "affected_objects": [
500
+ {{
501
+ "noun": "ball",
502
+ "category": "physical",
503
+ "will_move": true,
504
+ "first_appears_frame": 12,
505
+ "why": "ball is SUPPORTED by person's hand, then thrown",
506
+ "object_size_grids": {{"rows": 2, "cols": 2}},
507
+ "trajectory_path": [
508
+ {{"frame": 12, "grid_row": 4, "grid_col": 3}},
509
+ {{"frame": 20, "grid_row": 2, "grid_col": 6}},
510
+ {{"frame": 28, "grid_row": 5, "grid_col": 8}}
511
+ ]
512
+ }}
513
+ ]
514
+ }}
515
+
516
+ Example 6: Person with shadow (shadow needs grid localization)
517
+ {{
518
+ "affected_objects": [
519
+ {{
520
+ "noun": "shadow",
521
+ "category": "visual_artifact",
522
+ "why": "cast by person on the floor",
523
+ "will_move": false,
524
+ "first_appears_frame": 0,
525
+ "movement_description": "Disappears entirely as visual artifact",
526
+ "grid_localizations": [
527
+ {{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}},
528
+ {{"frame": 12, "grid_regions": [{{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}, {{"row": 7, "col": 4}}]}},
529
+ {{"frame": 23, "grid_regions": [{{"row": 5, "col": 4}}, {{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}]}},
530
+ {{"frame": 35, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}]}},
531
+ {{"frame": 47, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}}
532
+ ]
533
+ }}
534
+ ]
535
+ }}
536
+
537
+ Example 7: Golf club hits ball (Case 4 - currently moving but should stay)
538
+ {{
539
+ "affected_objects": [
540
+ {{
541
+ "noun": "golf ball",
542
+ "category": "physical",
543
+ "currently_moving": true,
544
+ "should_have_stayed": true,
545
+ "original_position_grid": {{"row": 6, "col": 7}},
546
+ "first_appears_frame": 0,
547
+ "why": "ball was stationary on tee until club hit it. Without club, ball would remain at original position."
548
+ }}
549
+ ]
550
+ }}
551
+
552
+ YOUR OUTPUT FORMAT:
553
+ {{
554
+ "edit_instruction": "{instruction}",
555
+ "integral_belongings": [
556
+ {{
557
+ "noun": "bike",
558
+ "why": "person is riding the bike throughout the video"
559
+ }}
560
+ ],
561
+ "affected_objects": [
562
+ {{
563
+ "noun": "guitar",
564
+ "category": "physical",
565
+ "why": "person is SUPPORTING guitar against gravity by holding it",
566
+ "will_move": true,
567
+ "first_appears_frame": 0,
568
+ "movement_description": "Will fall from held position to the ground",
569
+ "object_size_grids": {{"rows": 3, "cols": 2}},
570
+ "trajectory_path": [
571
+ {{"frame": 0, "grid_row": 3, "grid_col": 6}},
572
+ {{"frame": 20, "grid_row": 6, "grid_col": 6}},
573
+ {{"frame": 40, "grid_row": 9, "grid_col": 7}}
574
+ ]
575
+ }},
576
+ {{
577
+ "noun": "shadow",
578
+ "category": "visual_artifact",
579
+ "why": "cast by person on floor",
580
+ "will_move": false,
581
+ "first_appears_frame": 0,
582
+ "movement_description": "Disappears entirely as visual artifact"
583
+ }}
584
+ ],
585
+ "scene_description": "An acoustic guitar falling to the ground in an empty room. Natural window lighting.",
586
+ "confidence": 0.85
587
+ }}
588
+
589
+ CRITICAL REMINDERS:
590
+ β€’ Watch ENTIRE video before answering
591
+ β€’ SUPPORTING vs ACTING ON:
592
+ - Primary HOLDS UP object against gravity β†’ will_move=TRUE (provide grid trajectory)
593
+ - Primary ACTS ON object (crushing, opening) but object on stable surface β†’ will_move=FALSE
594
+ - Object RESTS ON stable surface (table, floor) β†’ will_move=FALSE
595
+ β€’ For visual artifacts (shadow, reflection): will_move=false (no trajectory needed)
596
+ β€’ For held objects (guitar, cup): will_move=true (MUST provide object_size_grids + trajectory_path)
597
+ β€’ For objects on surfaces being acted on (can being crushed, can being opened): will_move=false
598
+ β€’ Grid trajectory: Add +1 cell padding to object size (over-mask is better than under-mask)
599
+ β€’ Grid trajectory: Use the yellow grid overlay to determine (row, col) positions
600
+ β€’ Be conservative - when in doubt, DON'T include
601
+ β€’ Output MUST be valid JSON only
602
+
603
+ GRID INFO: {grid_rows} rows Γ— {grid_cols} columns
604
+ EDIT INSTRUCTION: {instruction}
605
+ """.strip()
606
+
607
+
608
+ def call_vlm_with_images_and_video(client, model: str, image_data_urls: list,
609
+ video_data_url: str, prompt: str) -> str:
610
+ """Call VLM with sampled frame images.
611
+
612
+ The CF AI Gateway OpenAI-compat endpoint does not support video/mp4 base64
613
+ data URLs, so we rely solely on the sampled grid frames already in
614
+ image_data_urls. video_data_url is accepted for signature compatibility but
615
+ intentionally not sent.
616
+ """
617
+ content = []
618
+
619
+ # Add all sampled frame images
620
+ for img_url in image_data_urls:
621
+ content.append({"type": "image_url", "image_url": {"url": img_url}})
622
+
623
+ # Add prompt
624
+ content.append({"type": "text", "text": prompt})
625
+
626
+ resp = client.chat.completions.create(
627
+ model=model,
628
+ messages=[
629
+ {
630
+ "role": "system",
631
+ "content": "You are an expert video analyst with deep understanding of physics and object interactions. Always output valid JSON only."
632
+ },
633
+ {
634
+ "role": "user",
635
+ "content": content
636
+ },
637
+ ],
638
+ )
639
+ return resp.choices[0].message.content
640
+
641
+
642
+ def parse_vlm_response(raw: str) -> Dict:
643
+ """Parse VLM JSON response"""
644
+ # Strip markdown code blocks
645
+ cleaned = raw.strip()
646
+ if cleaned.startswith("```"):
647
+ lines = cleaned.split('\n')
648
+ if lines[0].startswith("```"):
649
+ lines = lines[1:]
650
+ if lines and lines[-1].strip() == "```":
651
+ lines = lines[:-1]
652
+ cleaned = '\n'.join(lines)
653
+
654
+ try:
655
+ parsed = json.loads(cleaned)
656
+ except json.JSONDecodeError:
657
+ # Try to find JSON in response
658
+ start = cleaned.find("{")
659
+ end = cleaned.rfind("}")
660
+ if start != -1 and end != -1 and end > start:
661
+ parsed = json.loads(cleaned[start:end+1])
662
+ else:
663
+ raise ValueError("Failed to parse VLM response as JSON")
664
+
665
+ # Validate structure
666
+ result = {
667
+ "edit_instruction": parsed.get("edit_instruction", ""),
668
+ "integral_belongings": [],
669
+ "affected_objects": [],
670
+ "scene_description": parsed.get("scene_description", ""),
671
+ "confidence": float(parsed.get("confidence", 0.0))
672
+ }
673
+
674
+ # Parse integral belongings
675
+ for item in parsed.get("integral_belongings", [])[:3]:
676
+ obj = {
677
+ "noun": str(item.get("noun", "")).strip().lower(),
678
+ "why": str(item.get("why", "")).strip()[:200]
679
+ }
680
+ if obj["noun"]:
681
+ result["integral_belongings"].append(obj)
682
+
683
+ # Parse affected objects
684
+ for item in parsed.get("affected_objects", [])[:5]:
685
+ obj = {
686
+ "noun": str(item.get("noun", "")).strip().lower(),
687
+ "category": str(item.get("category", "physical")).strip().lower(),
688
+ "why": str(item.get("why", "")).strip()[:200],
689
+ "will_move": bool(item.get("will_move", False)),
690
+ "first_appears_frame": int(item.get("first_appears_frame", 0)),
691
+ "movement_description": str(item.get("movement_description", "")).strip()[:300]
692
+ }
693
+
694
+ # Parse Case 4: currently moving but should have stayed
695
+ if "currently_moving" in item:
696
+ obj["currently_moving"] = bool(item.get("currently_moving", False))
697
+ if "should_have_stayed" in item:
698
+ obj["should_have_stayed"] = bool(item.get("should_have_stayed", False))
699
+ if "original_position_grid" in item:
700
+ orig_grid = item.get("original_position_grid", {})
701
+ obj["original_position_grid"] = {
702
+ "row": int(orig_grid.get("row", 0)),
703
+ "col": int(orig_grid.get("col", 0))
704
+ }
705
+
706
+ # Parse grid localizations for visual artifacts
707
+ if "grid_localizations" in item:
708
+ grid_locs = []
709
+ for loc in item.get("grid_localizations", []):
710
+ frame_loc = {
711
+ "frame": int(loc.get("frame", 0)),
712
+ "grid_regions": []
713
+ }
714
+ for region in loc.get("grid_regions", []):
715
+ frame_loc["grid_regions"].append({
716
+ "row": int(region.get("row", 0)),
717
+ "col": int(region.get("col", 0))
718
+ })
719
+ if frame_loc["grid_regions"]: # Only add if has regions
720
+ grid_locs.append(frame_loc)
721
+ if grid_locs:
722
+ obj["grid_localizations"] = grid_locs
723
+
724
+ # Parse grid trajectory if will_move=true
725
+ if obj["will_move"] and "object_size_grids" in item and "trajectory_path" in item:
726
+ size_grids = item.get("object_size_grids", {})
727
+ obj["object_size_grids"] = {
728
+ "rows": int(size_grids.get("rows", 2)),
729
+ "cols": int(size_grids.get("cols", 2))
730
+ }
731
+
732
+ trajectory = []
733
+ for point in item.get("trajectory_path", []):
734
+ trajectory.append({
735
+ "frame": int(point.get("frame", 0)),
736
+ "grid_row": int(point.get("grid_row", 0)),
737
+ "grid_col": int(point.get("grid_col", 0))
738
+ })
739
+
740
+ if trajectory: # Only add if we have valid trajectory points
741
+ obj["trajectory_path"] = trajectory
742
+
743
+ if obj["noun"]:
744
+ result["affected_objects"].append(obj)
745
+
746
+ return result
747
+
748
+
749
+ def process_video(video_info: Dict, client, model: str):
750
+ """Process a single video with VLM analysis"""
751
+ video_path = video_info.get("video_path", "")
752
+ instruction = video_info.get("instruction", "")
753
+ output_dir = video_info.get("output_dir", "")
754
+
755
+ if not output_dir:
756
+ print(f" ⚠️ No output_dir specified, skipping")
757
+ return None
758
+
759
+ output_dir = Path(output_dir)
760
+ if not output_dir.exists():
761
+ print(f" ⚠️ Output directory not found: {output_dir}")
762
+ print(f" Run Stage 1 first to create black masks")
763
+ return None
764
+
765
+ # Check required files from Stage 1
766
+ black_mask_path = output_dir / "black_mask.mp4"
767
+ first_frame_path = output_dir / "first_frame.jpg"
768
+ input_video_path = output_dir / "input_video.mp4"
769
+ segmentation_info_path = output_dir / "segmentation_info.json"
770
+
771
+ if not black_mask_path.exists():
772
+ print(f" ⚠️ black_mask.mp4 not found in {output_dir}")
773
+ print(f" Run Stage 1 first")
774
+ return None
775
+
776
+ if not first_frame_path.exists():
777
+ print(f" ⚠️ first_frame.jpg not found in {output_dir}")
778
+ return None
779
+
780
+ if not input_video_path.exists():
781
+ # Try original video path
782
+ if Path(video_path).exists():
783
+ input_video_path = Path(video_path)
784
+ else:
785
+ print(f" ⚠️ Video not found: {video_path}")
786
+ return None
787
+
788
+ # Read segmentation metadata to get correct frame index
789
+ frame_idx = 0 # Default
790
+ if segmentation_info_path.exists():
791
+ try:
792
+ with open(segmentation_info_path, 'r') as f:
793
+ seg_info = json.load(f)
794
+ frame_idx = seg_info.get("first_appears_frame", 0)
795
+ print(f" Using frame {frame_idx} from segmentation metadata")
796
+ except Exception as e:
797
+ print(f" Warning: Could not read segmentation_info.json: {e}")
798
+ print(f" Using frame 0 as fallback")
799
+
800
+ # Get min_grid for grid calculation
801
+ min_grid = video_info.get('min_grid', 8)
802
+ use_multi_frame_grids = video_info.get('multi_frame_grids', True) # Default: use multi-frame
803
+ max_video_size_mb = video_info.get('max_video_size_for_multiframe', 25) # Default: 25MB limit
804
+
805
+ # Check video size and auto-disable multi-frame for large videos
806
+ if use_multi_frame_grids:
807
+ video_size_mb = input_video_path.stat().st_size / (1024 * 1024)
808
+ if video_size_mb > max_video_size_mb:
809
+ print(f" ⚠️ Video size ({video_size_mb:.1f} MB) exceeds {max_video_size_mb} MB")
810
+ print(f" Auto-disabling multi-frame grids to avoid API errors")
811
+ use_multi_frame_grids = False
812
+
813
+ print(f" Creating frame overlays and grids...")
814
+ overlay_path = output_dir / "first_frame_with_mask.jpg"
815
+ gridded_path = output_dir / "first_frame_with_grid.jpg"
816
+
817
+ # Create regular overlay (for backwards compatibility)
818
+ create_first_frame_with_mask_overlay(
819
+ str(first_frame_path),
820
+ str(black_mask_path),
821
+ str(overlay_path),
822
+ frame_idx=frame_idx
823
+ )
824
+
825
+ image_data_urls = []
826
+
827
+ if use_multi_frame_grids:
828
+ # Create multi-frame grid samples for objects appearing mid-video
829
+ print(f" Creating multi-frame grid samples (0%, 25%, 50%, 75%, 100%)...")
830
+ sample_paths, grid_rows, grid_cols = create_multi_frame_grid_samples(
831
+ str(input_video_path),
832
+ output_dir,
833
+ min_grid=min_grid
834
+ )
835
+
836
+ # Encode all grid samples
837
+ for sample_path in sample_paths:
838
+ image_data_urls.append(image_to_data_url(str(sample_path)))
839
+
840
+ # Also add the first frame with mask overlay
841
+ _, _, _ = create_gridded_frame_overlay(
842
+ str(first_frame_path),
843
+ str(black_mask_path),
844
+ str(gridded_path),
845
+ min_grid=min_grid
846
+ )
847
+ image_data_urls.append(image_to_data_url(str(gridded_path)))
848
+
849
+ print(f" Grid: {grid_rows}x{grid_cols}, {len(sample_paths)} sample frames + masked frame")
850
+
851
+ else:
852
+ # Single gridded first frame (old approach)
853
+ _, grid_rows, grid_cols = create_gridded_frame_overlay(
854
+ str(first_frame_path),
855
+ str(black_mask_path),
856
+ str(gridded_path),
857
+ min_grid=min_grid
858
+ )
859
+ image_data_urls.append(image_to_data_url(str(gridded_path)))
860
+ print(f" Grid: {grid_rows}x{grid_cols} (single frame)")
861
+
862
+ # CF gateway does not support video/mp4 base64 β€” pass None; frames already
863
+ # captured in image_data_urls above.
864
+ video_data_url = None
865
+
866
+ print(f" Calling {model}...")
867
+ prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
868
+ has_multi_frame_grids=use_multi_frame_grids)
869
+
870
+ try:
871
+ try:
872
+ raw_response = call_vlm_with_images_and_video(
873
+ client, model, image_data_urls, video_data_url, prompt
874
+ )
875
+ except Exception as e:
876
+ # If multi-frame fails (likely payload size issue), fall back to single frame
877
+ if use_multi_frame_grids and "400" in str(e):
878
+ print(f" ⚠️ Multi-frame request failed (payload too large?)")
879
+ print(f" Falling back to single-frame grid mode...")
880
+
881
+ # Retry with just the gridded first frame
882
+ image_data_urls = [image_to_data_url(str(gridded_path))]
883
+ prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
884
+ has_multi_frame_grids=False)
885
+
886
+ try:
887
+ raw_response = call_vlm_with_images_and_video(
888
+ client, model, image_data_urls, video_data_url, prompt
889
+ )
890
+ print(f" βœ“ Single-frame fallback succeeded")
891
+ except Exception as e2:
892
+ raise e2 # Re-raise if fallback also fails
893
+ else:
894
+ raise # Re-raise if not a 400 or not multi-frame mode
895
+
896
+ # Parse and save results (runs whether first call succeeded or fallback succeeded)
897
+ print(f" Parsing VLM response...")
898
+ analysis = parse_vlm_response(raw_response)
899
+
900
+ # Save results
901
+ output_path = output_dir / "vlm_analysis.json"
902
+ with open(output_path, 'w') as f:
903
+ json.dump(analysis, f, indent=2)
904
+
905
+ print(f" βœ“ Saved VLM analysis: {output_path.name}")
906
+
907
+ # Print summary
908
+ print(f"\n Summary:")
909
+ print(f" - Integral belongings: {len(analysis['integral_belongings'])}")
910
+ for obj in analysis['integral_belongings']:
911
+ print(f" β€’ {obj['noun']}: {obj['why']}")
912
+
913
+ print(f" - Affected objects: {len(analysis['affected_objects'])}")
914
+ for obj in analysis['affected_objects']:
915
+ move_str = "WILL MOVE" if obj['will_move'] else "STAYS/DISAPPEARS"
916
+ traj_str = ""
917
+ if obj.get('will_move') and 'trajectory_path' in obj:
918
+ num_points = len(obj['trajectory_path'])
919
+ size = obj.get('object_size_grids', {})
920
+ traj_str = f" (trajectory: {num_points} keyframes, size: {size.get('rows')}Γ—{size.get('cols')} grids)"
921
+ print(f" β€’ {obj['noun']}: {move_str}{traj_str}")
922
+
923
+ return analysis
924
+
925
+ except Exception as e:
926
+ print(f" ❌ VLM analysis failed: {e}")
927
+ import traceback
928
+ traceback.print_exc()
929
+ return None
930
+
931
+
932
+ def process_config(config_path: str, model: str = DEFAULT_MODEL):
933
+ """Process all videos in config"""
934
+ config_path = Path(config_path)
935
+
936
+ # Load config
937
+ with open(config_path, 'r') as f:
938
+ config_data = json.load(f)
939
+
940
+ # Handle both formats
941
+ if isinstance(config_data, list):
942
+ videos = config_data
943
+ elif isinstance(config_data, dict) and "videos" in config_data:
944
+ videos = config_data["videos"]
945
+ else:
946
+ raise ValueError("Config must be a list or have 'videos' key")
947
+
948
+ print(f"\n{'='*70}")
949
+ print(f"Stage 2: VLM Analysis - Identify Affected Objects")
950
+ print(f"{'='*70}")
951
+ print(f"Config: {config_path.name}")
952
+ print(f"Videos: {len(videos)}")
953
+ print(f"Model: {model}")
954
+ print(f"{'='*70}\n")
955
+
956
+ # Initialize VLM client (CF AI Gateway)
957
+ cf_project_id = os.environ.get("CF_PROJECT_ID")
958
+ cf_user_id = os.environ.get("CF_USER_ID")
959
+ if not cf_project_id or not cf_user_id:
960
+ raise RuntimeError("CF_PROJECT_ID and CF_USER_ID environment variables must be set")
961
+
962
+ metadata = json.dumps({"project_id": cf_project_id, "user_id": cf_user_id})
963
+ client = openai.OpenAI(
964
+ api_key=os.environ.get("GEMINI_API_KEY", "placeholder"),
965
+ base_url="https://ai-gateway.plain-flower-4887.workers.dev/compat",
966
+ default_headers={"cf-aig-metadata": metadata},
967
+ )
968
+
969
+ # Model comes from MODEL_ID env var; fall back to --model arg
970
+ model = os.environ.get("MODEL_ID", model)
971
+
972
+ # Process each video
973
+ results = []
974
+ for i, video_info in enumerate(videos):
975
+ video_path = video_info.get("video_path", "")
976
+ instruction = video_info.get("instruction", "")
977
+
978
+ print(f"\n{'─'*70}")
979
+ print(f"Video {i+1}/{len(videos)}: {Path(video_path).name}")
980
+ print(f"{'─'*70}")
981
+ print(f"Instruction: {instruction}")
982
+
983
+ try:
984
+ analysis = process_video(video_info, client, model)
985
+ results.append({
986
+ "video": video_path,
987
+ "success": analysis is not None,
988
+ "analysis": analysis
989
+ })
990
+
991
+ if analysis:
992
+ print(f"\nβœ… Video {i+1} complete!")
993
+ else:
994
+ print(f"\n⚠️ Video {i+1} skipped")
995
+
996
+ except Exception as e:
997
+ print(f"\n❌ Error processing video {i+1}: {e}")
998
+ results.append({
999
+ "video": video_path,
1000
+ "success": False,
1001
+ "error": str(e)
1002
+ })
1003
+ continue
1004
+
1005
+ # Summary
1006
+ print(f"\n{'='*70}")
1007
+ print(f"Stage 2 Complete!")
1008
+ print(f"{'='*70}")
1009
+ successful = sum(1 for r in results if r["success"])
1010
+ print(f"Successful: {successful}/{len(videos)}")
1011
+ print(f"{'='*70}\n")
1012
+
1013
+
1014
+ def main():
1015
+ parser = argparse.ArgumentParser(description="Stage 2: VLM Analysis")
1016
+ parser.add_argument("--config", required=True, help="Config JSON from Stage 1")
1017
+ parser.add_argument("--model", default=DEFAULT_MODEL, help="VLM model name")
1018
+ args = parser.parse_args()
1019
+
1020
+ process_config(args.config, args.model)
1021
+
1022
+
1023
+ if __name__ == "__main__":
1024
+ main()
VLM-MASK-REASONER/stage3a_generate_grey_masks.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 3a: Generate Grey Masks - Combine VLM Logic + User Trajectories
4
+
5
+ Generates grey masks (127=affected regions) by combining:
6
+ 1. VLM-identified affected objects (segmented + gridified)
7
+ 2. User-drawn trajectories (from Stage 3b)
8
+ 3. Proximity filtering (only mask near primary object)
9
+
10
+ Input: - vlm_analysis.json (Stage 2)
11
+ - black_mask.mp4 (Stage 1)
12
+ - trajectories.json (Stage 3b, optional)
13
+ Output: - grey_mask.mp4 (127=affected, 255=background)
14
+
15
+ Usage:
16
+ python stage3a_generate_grey_masks.py --config more_dyn_2_config_points_absolute.json
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import json
22
+ import argparse
23
+ import cv2
24
+ import numpy as np
25
+ from pathlib import Path
26
+ from typing import Dict, List, Tuple
27
+ from PIL import Image
28
+ import subprocess
29
+
30
+ # Segmentation model
31
+ try:
32
+ from sam3.model_builder import build_sam3_image_model
33
+ from sam3.model.sam3_image_processor import Sam3Processor
34
+ SAM3_AVAILABLE = True
35
+ except ImportError:
36
+ SAM3_AVAILABLE = False
37
+
38
+ try:
39
+ from lang_sam import LangSAM
40
+ LANGSAM_AVAILABLE = True
41
+ except ImportError:
42
+ LANGSAM_AVAILABLE = False
43
+
44
+
45
+ class SegmentationModel:
46
+ """Wrapper for segmentation"""
47
+
48
+ def __init__(self, model_type: str = "sam3"):
49
+ self.model_type = model_type.lower()
50
+
51
+ if self.model_type == "sam3":
52
+ if not SAM3_AVAILABLE:
53
+ raise ImportError("SAM3 not available")
54
+ print(f" Loading SAM3...")
55
+ model = build_sam3_image_model()
56
+ self.processor = Sam3Processor(model)
57
+ self.model = model
58
+ elif self.model_type == "langsam":
59
+ if not LANGSAM_AVAILABLE:
60
+ raise ImportError("LangSAM not available")
61
+ print(f" Loading LangSAM...")
62
+ self.model = LangSAM()
63
+ else:
64
+ raise ValueError(f"Unknown model: {model_type}")
65
+
66
+ def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
67
+ """Segment object using text prompt"""
68
+ if self.model_type == "sam3":
69
+ return self._segment_sam3(image_pil, prompt)
70
+ else:
71
+ return self._segment_langsam(image_pil, prompt)
72
+
73
+ def _segment_sam3(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
74
+ import torch
75
+ h, w = image_pil.height, image_pil.width
76
+ union = np.zeros((h, w), dtype=bool)
77
+
78
+ try:
79
+ inference_state = self.processor.set_image(image_pil)
80
+ output = self.processor.set_text_prompt(state=inference_state, prompt=prompt)
81
+ masks = output.get("masks")
82
+
83
+ if masks is None or len(masks) == 0:
84
+ return union
85
+
86
+ if torch.is_tensor(masks):
87
+ masks = masks.cpu().numpy()
88
+
89
+ if masks.ndim == 2:
90
+ union = masks.astype(bool)
91
+ elif masks.ndim == 3:
92
+ union = masks.any(axis=0).astype(bool)
93
+ elif masks.ndim == 4:
94
+ union = masks.any(axis=(0, 1)).astype(bool)
95
+
96
+ except Exception as e:
97
+ print(f" Warning: SAM3 segmentation failed for '{prompt}': {e}")
98
+
99
+ return union
100
+
101
+ def _segment_langsam(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
102
+ h, w = image_pil.height, image_pil.width
103
+ union = np.zeros((h, w), dtype=bool)
104
+
105
+ try:
106
+ results = self.model.predict([image_pil], [prompt])
107
+ if not results:
108
+ return union
109
+
110
+ r0 = results[0]
111
+ if isinstance(r0, dict) and "masks" in r0:
112
+ masks = r0["masks"]
113
+ if masks.ndim == 4 and masks.shape[0] == 1:
114
+ masks = masks[0]
115
+ if masks.ndim == 3:
116
+ union = masks.any(axis=0).astype(bool)
117
+ elif masks.ndim == 2:
118
+ union = masks.astype(bool)
119
+
120
+ except Exception as e:
121
+ print(f" Warning: LangSAM segmentation failed for '{prompt}': {e}")
122
+
123
+ return union
124
+
125
+
126
+ def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> Tuple[int, int]:
127
+ """Calculate grid dimensions for square cells"""
128
+ aspect_ratio = width / height
129
+ if width >= height:
130
+ grid_rows = min_grid
131
+ grid_cols = max(min_grid, round(min_grid * aspect_ratio))
132
+ else:
133
+ grid_cols = min_grid
134
+ grid_rows = max(min_grid, round(min_grid / aspect_ratio))
135
+ return grid_rows, grid_cols
136
+
137
+
138
+ def gridify_mask(mask: np.ndarray, grid_rows: int, grid_cols: int) -> np.ndarray:
139
+ """Convert pixel mask to gridified mask"""
140
+ h, w = mask.shape
141
+ gridified = np.zeros((h, w), dtype=bool)
142
+
143
+ cell_width = w / grid_cols
144
+ cell_height = h / grid_rows
145
+
146
+ for row in range(grid_rows):
147
+ for col in range(grid_cols):
148
+ y1 = int(row * cell_height)
149
+ y2 = int((row + 1) * cell_height)
150
+ x1 = int(col * cell_width)
151
+ x2 = int((col + 1) * cell_width)
152
+
153
+ cell_region = mask[y1:y2, x1:x2]
154
+ if cell_region.any():
155
+ gridified[y1:y2, x1:x2] = True
156
+
157
+ return gridified
158
+
159
+
160
+ def grid_cells_to_mask(grid_cells: List[List[int]], grid_rows: int, grid_cols: int,
161
+ frame_width: int, frame_height: int) -> np.ndarray:
162
+ """Convert grid cells to mask"""
163
+ mask = np.zeros((frame_height, frame_width), dtype=bool)
164
+
165
+ cell_width = frame_width / grid_cols
166
+ cell_height = frame_height / grid_rows
167
+
168
+ for row, col in grid_cells:
169
+ y1 = int(row * cell_height)
170
+ y2 = int((row + 1) * cell_height)
171
+ x1 = int(col * cell_width)
172
+ x2 = int((col + 1) * cell_width)
173
+ mask[y1:y2, x1:x2] = True
174
+
175
+ return mask
176
+
177
+
178
+ def dilate_mask(mask: np.ndarray, kernel_size: int = 15) -> np.ndarray:
179
+ """Dilate mask to create proximity region"""
180
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
181
+ return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
182
+
183
+
184
+ def filter_by_proximity(mask: np.ndarray, primary_mask: np.ndarray, dilation: int = 15) -> np.ndarray:
185
+ """Filter mask to only include regions near primary mask"""
186
+ # Dilate primary mask to create proximity region
187
+ proximity_region = dilate_mask(primary_mask, dilation)
188
+
189
+ # Only keep mask where it overlaps with proximity region
190
+ filtered = mask & proximity_region
191
+
192
+ return filtered
193
+
194
+
195
+ def process_video_grey_masks(video_info: Dict, segmenter: SegmentationModel,
196
+ trajectory_data: Dict = None):
197
+ """Generate grey masks for a single video"""
198
+ video_path = video_info.get("video_path", "")
199
+ output_dir = Path(video_info.get("output_dir", ""))
200
+
201
+ if not output_dir.exists():
202
+ print(f" ⚠️ Output directory not found: {output_dir}")
203
+ return
204
+
205
+ # Load required files
206
+ vlm_analysis_path = output_dir / "vlm_analysis.json"
207
+ black_mask_path = output_dir / "black_mask.mp4"
208
+ input_video_path = output_dir / "input_video.mp4"
209
+
210
+ if not vlm_analysis_path.exists():
211
+ print(f" ⚠️ vlm_analysis.json not found")
212
+ return
213
+
214
+ if not black_mask_path.exists():
215
+ print(f" ⚠️ black_mask.mp4 not found")
216
+ return
217
+
218
+ if not input_video_path.exists():
219
+ input_video_path = Path(video_path)
220
+ if not input_video_path.exists():
221
+ print(f" ⚠️ Video not found")
222
+ return
223
+
224
+ # Load VLM analysis
225
+ with open(vlm_analysis_path, 'r') as f:
226
+ analysis = json.load(f)
227
+
228
+ # Get video properties
229
+ cap = cv2.VideoCapture(str(input_video_path))
230
+ fps = cap.get(cv2.CAP_PROP_FPS)
231
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
232
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
233
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
234
+ cap.release()
235
+
236
+ # Calculate grid
237
+ min_grid = video_info.get('min_grid', 8)
238
+ grid_rows, grid_cols = calculate_square_grid(frame_width, frame_height, min_grid)
239
+
240
+ print(f" Video: {frame_width}x{frame_height}, {total_frames} frames, grid: {grid_rows}x{grid_cols}")
241
+
242
+ # Load first frame
243
+ cap = cv2.VideoCapture(str(input_video_path))
244
+ ret, first_frame = cap.read()
245
+ cap.release()
246
+
247
+ if not ret:
248
+ print(f" ⚠️ Failed to read first frame")
249
+ return
250
+
251
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
252
+ first_frame_pil = Image.fromarray(first_frame_rgb)
253
+
254
+ # Load black mask (first frame for proximity filtering)
255
+ black_cap = cv2.VideoCapture(str(black_mask_path))
256
+ ret, black_mask_frame = black_cap.read()
257
+ black_cap.release()
258
+
259
+ if not ret:
260
+ print(f" ⚠️ Failed to read black mask")
261
+ return
262
+
263
+ if len(black_mask_frame.shape) == 3:
264
+ black_mask_frame = cv2.cvtColor(black_mask_frame, cv2.COLOR_BGR2GRAY)
265
+
266
+ primary_mask = (black_mask_frame == 0) # 0 = primary object
267
+
268
+ # Initialize grey mask
269
+ grey_mask_combined = np.zeros((frame_height, frame_width), dtype=bool)
270
+
271
+ # Process affected objects from VLM
272
+ affected_objects = analysis.get('affected_objects', [])
273
+
274
+ print(f" Processing {len(affected_objects)} affected object(s)...")
275
+
276
+ for obj in affected_objects:
277
+ noun = obj.get('noun', '')
278
+ category = obj.get('category', 'physical')
279
+ will_move = obj.get('will_move', False)
280
+ needs_trajectory = obj.get('needs_trajectory', False)
281
+
282
+ if not noun:
283
+ continue
284
+
285
+ print(f" β€’ {noun} ({category})")
286
+
287
+ # Check if we have trajectory data for this object
288
+ has_trajectory = False
289
+ if needs_trajectory and trajectory_data:
290
+ for traj in trajectory_data:
291
+ if traj.get('object_noun', '') == noun:
292
+ # Use trajectory grid cells
293
+ print(f" Using user-drawn trajectory ({len(traj['trajectory_grid_cells'])} cells)")
294
+ traj_mask = grid_cells_to_mask(
295
+ traj['trajectory_grid_cells'],
296
+ grid_rows, grid_cols,
297
+ frame_width, frame_height
298
+ )
299
+ grey_mask_combined |= traj_mask
300
+ has_trajectory = True
301
+ break
302
+
303
+ # If no trajectory or doesn't need one, segment normally
304
+ if not has_trajectory:
305
+ # Segment object
306
+ obj_mask = segmenter.segment(first_frame_pil, noun)
307
+
308
+ if obj_mask.any():
309
+ print(f" Segmented {obj_mask.sum()} pixels")
310
+
311
+ # Filter by proximity to primary mask
312
+ obj_mask_filtered = filter_by_proximity(obj_mask, primary_mask, dilation=50)
313
+
314
+ if obj_mask_filtered.any():
315
+ print(f" After proximity filter: {obj_mask_filtered.sum()} pixels")
316
+
317
+ # Gridify
318
+ obj_mask_gridified = gridify_mask(obj_mask_filtered, grid_rows, grid_cols)
319
+
320
+ # Add to combined grey mask
321
+ grey_mask_combined |= obj_mask_gridified
322
+
323
+ print(f" βœ“ Added to grey mask")
324
+ else:
325
+ print(f" ⚠️ No pixels near primary object, skipping")
326
+ else:
327
+ print(f" ⚠️ Segmentation failed")
328
+
329
+ # Generate grey mask video
330
+ print(f" Generating grey mask video...")
331
+
332
+ # For simplicity, use same mask for all frames
333
+ # (In future, could track objects through video)
334
+ grey_mask_uint8 = np.where(grey_mask_combined, 127, 255).astype(np.uint8)
335
+
336
+ # Write temp AVI
337
+ temp_avi = output_dir / "grey_mask_temp.avi"
338
+ fourcc = cv2.VideoWriter_fourcc(*'FFV1')
339
+ out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (frame_width, frame_height), isColor=False)
340
+
341
+ for _ in range(total_frames):
342
+ out.write(grey_mask_uint8)
343
+
344
+ out.release()
345
+
346
+ # Convert to MP4
347
+ grey_mask_mp4 = output_dir / "grey_mask.mp4"
348
+ cmd = [
349
+ 'ffmpeg', '-y', '-i', str(temp_avi),
350
+ '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
351
+ '-pix_fmt', 'yuv444p',
352
+ str(grey_mask_mp4)
353
+ ]
354
+ subprocess.run(cmd, capture_output=True)
355
+ temp_avi.unlink()
356
+
357
+ print(f" βœ“ Saved grey_mask.mp4")
358
+
359
+ # Save debug visualization
360
+ debug_vis = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
361
+ debug_vis[grey_mask_combined] = [0, 255, 0] # Green for affected regions
362
+ debug_vis[primary_mask] = [255, 0, 0] # Red for primary
363
+ debug_path = output_dir / "debug_grey_mask.jpg"
364
+ cv2.imwrite(str(debug_path), debug_vis)
365
+ print(f" βœ“ Saved debug visualization")
366
+
367
+
368
+ def main():
369
+ parser = argparse.ArgumentParser(description="Stage 3a: Generate Grey Masks")
370
+ parser.add_argument("--config", required=True, help="Config JSON")
371
+ parser.add_argument("--segmentation-model", default="sam3", choices=["langsam", "sam3"],
372
+ help="Segmentation model")
373
+ args = parser.parse_args()
374
+
375
+ config_path = Path(args.config)
376
+
377
+ # Load config
378
+ with open(config_path, 'r') as f:
379
+ config_data = json.load(f)
380
+
381
+ if isinstance(config_data, list):
382
+ videos = config_data
383
+ elif isinstance(config_data, dict) and "videos" in config_data:
384
+ videos = config_data["videos"]
385
+ else:
386
+ raise ValueError("Invalid config format")
387
+
388
+ # Load trajectory data if exists
389
+ trajectory_path = config_path.parent / f"{config_path.stem}_trajectories.json"
390
+ trajectory_data = None
391
+
392
+ if trajectory_path.exists():
393
+ print(f"Loading trajectory data from: {trajectory_path.name}")
394
+ with open(trajectory_path, 'r') as f:
395
+ trajectory_data = json.load(f)
396
+ print(f" Loaded {len(trajectory_data)} trajectory(s)")
397
+ else:
398
+ print(f"No trajectory data found (Stage 3b not run or no objects needed trajectories)")
399
+
400
+ print(f"\n{'='*70}")
401
+ print(f"Stage 3a: Generate Grey Masks")
402
+ print(f"{'='*70}")
403
+ print(f"Videos: {len(videos)}")
404
+ print(f"Segmentation: {args.segmentation_model.upper()}")
405
+ print(f"{'='*70}\n")
406
+
407
+ # Load segmentation model
408
+ segmenter = SegmentationModel(args.segmentation_model)
409
+
410
+ # Process each video
411
+ for i, video_info in enumerate(videos):
412
+ video_path = video_info.get('video_path', '')
413
+ print(f"\n{'─'*70}")
414
+ print(f"Video {i+1}/{len(videos)}: {Path(video_path).parent.name}")
415
+ print(f"{'─'*70}")
416
+
417
+ try:
418
+ process_video_grey_masks(video_info, segmenter, trajectory_data)
419
+ print(f"\nβœ… Video {i+1} complete!")
420
+
421
+ except Exception as e:
422
+ print(f"\n❌ Error processing video {i+1}: {e}")
423
+ import traceback
424
+ traceback.print_exc()
425
+ continue
426
+
427
+ print(f"\n{'='*70}")
428
+ print(f"βœ… Stage 3a Complete!")
429
+ print(f"{'='*70}")
430
+ print(f"Generated grey_mask.mp4 for all videos")
431
+ print(f"Next: Run Stage 4 to combine black + grey masks")
432
+ print(f"{'='*70}\n")
433
+
434
+
435
+ if __name__ == "__main__":
436
+ main()
VLM-MASK-REASONER/stage3a_generate_grey_masks_v2.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 3a: Generate Grey Masks (CORRECTED)
4
+
5
+ Correct pipeline:
6
+ 1. For EACH affected object (from VLM analysis):
7
+ a) IF user drew trajectory (Stage 3b):
8
+ - Segment object in first_appears_frame β†’ get SIZE
9
+ - Apply object SIZE along trajectory path across all frames
10
+ b) ELSE (no user trajectory):
11
+ - Segment object through ALL frames (captures any movement/changes)
12
+ - This handles:
13
+ * Static objects (can, chair)
14
+ * Objects that move during video (golf ball)
15
+ * Dynamic effects (paint strokes, shadows)
16
+ - Filter by proximity to primary object
17
+
18
+ 2. Accumulate all masks (one combined mask per frame)
19
+
20
+ 3. Gridify ALL accumulated masks:
21
+ - If ANY pixel in grid cell β†’ ENTIRE cell = 127
22
+
23
+ 4. Write grey_mask.mp4
24
+
25
+ Key insight: will_move / needs_trajectory are ONLY for Stage 3b (user input).
26
+ In Stage 3a, we segment ALL affected objects through ALL frames.
27
+
28
+ Input: - vlm_analysis.json (Stage 2)
29
+ - black_mask.mp4 (Stage 1)
30
+ - trajectories.json (Stage 3b, optional)
31
+ Output: - grey_mask.mp4 (127=affected, 255=background)
32
+
33
+ Usage:
34
+ python stage3a_generate_grey_masks_v2.py --config more_dyn_2_config_points_absolute.json
35
+ """
36
+
37
+ import os
38
+ import sys
39
+ import json
40
+ import argparse
41
+ import cv2
42
+ import numpy as np
43
+ from pathlib import Path
44
+ from typing import Dict, List, Tuple, Optional
45
+ from PIL import Image
46
+ import subprocess
47
+
48
+ # SAM2 for video tracking
49
+ try:
50
+ from sam2.build_sam import build_sam2_video_predictor
51
+ SAM2_AVAILABLE = True
52
+ except ImportError:
53
+ SAM2_AVAILABLE = False
54
+
55
+ # SAM3 for single-frame segmentation
56
+ try:
57
+ from sam3.model_builder import build_sam3_image_model
58
+ from sam3.model.sam3_image_processor import Sam3Processor
59
+ SAM3_AVAILABLE = True
60
+ except ImportError:
61
+ SAM3_AVAILABLE = False
62
+
63
+ # LangSAM
64
+ try:
65
+ from lang_sam import LangSAM
66
+ LANGSAM_AVAILABLE = True
67
+ except ImportError:
68
+ LANGSAM_AVAILABLE = False
69
+
70
+
71
+ class SegmentationModel:
72
+ """Wrapper for segmentation"""
73
+
74
+ def __init__(self, model_type: str = "sam3"):
75
+ self.model_type = model_type.lower()
76
+
77
+ if self.model_type == "sam3":
78
+ if not SAM3_AVAILABLE:
79
+ raise ImportError("SAM3 not available")
80
+ print(f" Loading SAM3...")
81
+ model = build_sam3_image_model()
82
+ self.processor = Sam3Processor(model)
83
+ self.model = model
84
+ elif self.model_type == "langsam":
85
+ if not LANGSAM_AVAILABLE:
86
+ raise ImportError("LangSAM not available")
87
+ print(f" Loading LangSAM...")
88
+ self.model = LangSAM()
89
+ else:
90
+ raise ValueError(f"Unknown model: {model_type}")
91
+
92
+ def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
93
+ """Segment object using text prompt - returns boolean mask"""
94
+ if self.model_type == "sam3":
95
+ return self._segment_sam3(image_pil, prompt)
96
+ else:
97
+ return self._segment_langsam(image_pil, prompt)
98
+
99
+ def _segment_sam3(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
100
+ import torch
101
+ h, w = image_pil.height, image_pil.width
102
+ union = np.zeros((h, w), dtype=bool)
103
+
104
+ try:
105
+ inference_state = self.processor.set_image(image_pil)
106
+ output = self.processor.set_text_prompt(state=inference_state, prompt=prompt)
107
+ masks = output.get("masks")
108
+
109
+ if masks is None or len(masks) == 0:
110
+ return union
111
+
112
+ if torch.is_tensor(masks):
113
+ masks = masks.cpu().numpy()
114
+
115
+ if masks.ndim == 2:
116
+ union = masks.astype(bool)
117
+ elif masks.ndim == 3:
118
+ union = masks.any(axis=0).astype(bool)
119
+ elif masks.ndim == 4:
120
+ union = masks.any(axis=(0, 1)).astype(bool)
121
+
122
+ except Exception as e:
123
+ print(f" Warning: SAM3 failed: {e}")
124
+
125
+ return union
126
+
127
+ def _segment_langsam(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
128
+ h, w = image_pil.height, image_pil.width
129
+ union = np.zeros((h, w), dtype=bool)
130
+
131
+ try:
132
+ results = self.model.predict([image_pil], [prompt])
133
+ if not results:
134
+ return union
135
+
136
+ r0 = results[0]
137
+ if isinstance(r0, dict) and "masks" in r0:
138
+ masks = r0["masks"]
139
+ if masks.ndim == 4 and masks.shape[0] == 1:
140
+ masks = masks[0]
141
+ if masks.ndim == 3:
142
+ union = masks.any(axis=0).astype(bool)
143
+ elif masks.ndim == 2:
144
+ union = masks.astype(bool)
145
+
146
+ except Exception as e:
147
+ print(f" Warning: LangSAM failed: {e}")
148
+
149
+ return union
150
+
151
+
152
+ def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> Tuple[int, int]:
153
+ """Calculate grid dimensions for square cells"""
154
+ aspect_ratio = width / height
155
+ if width >= height:
156
+ grid_rows = min_grid
157
+ grid_cols = max(min_grid, round(min_grid * aspect_ratio))
158
+ else:
159
+ grid_cols = min_grid
160
+ grid_rows = max(min_grid, round(min_grid / aspect_ratio))
161
+ return grid_rows, grid_cols
162
+
163
+
164
+ def gridify_masks(masks: List[np.ndarray], grid_rows: int, grid_cols: int) -> List[np.ndarray]:
165
+ """
166
+ Gridify masks: if ANY pixel in grid cell β†’ ENTIRE cell = True
167
+
168
+ Args:
169
+ masks: List of boolean masks (one per frame)
170
+ grid_rows, grid_cols: Grid dimensions
171
+
172
+ Returns:
173
+ List of gridified boolean masks
174
+ """
175
+ gridified_masks = []
176
+
177
+ for mask in masks:
178
+ h, w = mask.shape
179
+ gridified = np.zeros((h, w), dtype=bool)
180
+
181
+ cell_width = w / grid_cols
182
+ cell_height = h / grid_rows
183
+
184
+ for row in range(grid_rows):
185
+ for col in range(grid_cols):
186
+ y1 = int(row * cell_height)
187
+ y2 = int((row + 1) * cell_height)
188
+ x1 = int(col * cell_width)
189
+ x2 = int((col + 1) * cell_width)
190
+
191
+ cell_region = mask[y1:y2, x1:x2]
192
+ # If ANY pixel in cell β†’ ENTIRE cell
193
+ if cell_region.any():
194
+ gridified[y1:y2, x1:x2] = True
195
+
196
+ gridified_masks.append(gridified)
197
+
198
+ return gridified_masks
199
+
200
+
201
+ def get_object_size(mask: np.ndarray) -> Tuple[int, int]:
202
+ """Get bounding box size of object"""
203
+ rows = np.any(mask, axis=1)
204
+ cols = np.any(mask, axis=0)
205
+
206
+ if not rows.any() or not cols.any():
207
+ return 0, 0
208
+
209
+ y1, y2 = np.where(rows)[0][[0, -1]]
210
+ x1, x2 = np.where(cols)[0][[0, -1]]
211
+
212
+ width = x2 - x1 + 1
213
+ height = y2 - y1 + 1
214
+
215
+ return width, height
216
+
217
+
218
+ def apply_object_along_trajectory(obj_mask: np.ndarray, trajectory_points: List[Tuple[int, int]],
219
+ total_frames: int, frame_shape: Tuple[int, int]) -> List[np.ndarray]:
220
+ """
221
+ Apply object along trajectory path across frames.
222
+
223
+ Args:
224
+ obj_mask: Object mask from first_appears_frame
225
+ trajectory_points: List of (x, y) points defining path
226
+ total_frames: Total number of frames in video
227
+ frame_shape: (height, width)
228
+
229
+ Returns:
230
+ List of masks (one per frame) with object placed along trajectory
231
+ """
232
+ h, w = frame_shape
233
+ masks = [np.zeros((h, w), dtype=bool) for _ in range(total_frames)]
234
+
235
+ if len(trajectory_points) < 2:
236
+ return masks
237
+
238
+ # Get object size
239
+ obj_width, obj_height = get_object_size(obj_mask)
240
+
241
+ if obj_width == 0 or obj_height == 0:
242
+ return masks
243
+
244
+ # Interpolate trajectory across frames
245
+ num_traj_points = len(trajectory_points)
246
+
247
+ for frame_idx in range(total_frames):
248
+ # Map frame index to trajectory point
249
+ t = frame_idx / max(total_frames - 1, 1) # 0.0 to 1.0
250
+ traj_idx = int(t * (num_traj_points - 1))
251
+ traj_idx = min(traj_idx, num_traj_points - 1)
252
+
253
+ # Get position on trajectory
254
+ x_center, y_center = trajectory_points[traj_idx]
255
+
256
+ # Place object at this position
257
+ x1 = max(0, int(x_center - obj_width // 2))
258
+ y1 = max(0, int(y_center - obj_height // 2))
259
+ x2 = min(w, x1 + obj_width)
260
+ y2 = min(h, y1 + obj_height)
261
+
262
+ # Place object mask
263
+ masks[frame_idx][y1:y2, x1:x2] = True
264
+
265
+ return masks
266
+
267
+
268
+ def segment_object_all_frames(video_path: str, obj_noun: str, segmenter: SegmentationModel,
269
+ frame_stride: int = 1) -> List[np.ndarray]:
270
+ """
271
+ Segment object through all frames.
272
+
273
+ Args:
274
+ video_path: Path to video
275
+ obj_noun: Object to segment
276
+ segmenter: Segmentation model
277
+ frame_stride: Process every Nth frame (for speed)
278
+
279
+ Returns:
280
+ List of boolean masks (one per frame)
281
+ """
282
+ cap = cv2.VideoCapture(video_path)
283
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
284
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
285
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
286
+
287
+ masks = []
288
+ frame_idx = 0
289
+
290
+ while True:
291
+ ret, frame = cap.read()
292
+ if not ret:
293
+ break
294
+
295
+ if frame_idx % frame_stride == 0:
296
+ # Segment this frame
297
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
298
+ frame_pil = Image.fromarray(frame_rgb)
299
+ mask = segmenter.segment(frame_pil, obj_noun)
300
+ masks.append(mask)
301
+
302
+ if (frame_idx + 1) % 10 == 0:
303
+ print(f" Frame {frame_idx + 1}/{total_frames}...", end='\r')
304
+ else:
305
+ # Reuse previous mask
306
+ if masks:
307
+ masks.append(masks[-1])
308
+ else:
309
+ masks.append(np.zeros((frame_height, frame_width), dtype=bool))
310
+
311
+ frame_idx += 1
312
+
313
+ cap.release()
314
+ print(f" Segmented {total_frames} frames")
315
+
316
+ return masks
317
+
318
+
319
+ def dilate_mask(mask: np.ndarray, kernel_size: int = 15) -> np.ndarray:
320
+ """Dilate mask for proximity checking"""
321
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
322
+ return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
323
+
324
+
325
+ def filter_masks_by_proximity(masks: List[np.ndarray], primary_mask: np.ndarray,
326
+ dilation: int = 50) -> List[np.ndarray]:
327
+ """Filter masks to only include regions near primary mask"""
328
+ proximity_region = dilate_mask(primary_mask, dilation)
329
+
330
+ filtered = []
331
+ for mask in masks:
332
+ filtered_mask = mask & proximity_region
333
+ filtered.append(filtered_mask)
334
+
335
+ return filtered
336
+
337
+
338
+ def process_video_grey_masks(video_info: Dict, segmenter: SegmentationModel,
339
+ trajectory_data: List[Dict] = None):
340
+ """Generate grey masks for a single video"""
341
+ video_path = video_info.get("video_path", "")
342
+ output_dir = Path(video_info.get("output_dir", ""))
343
+
344
+ if not output_dir.exists():
345
+ print(f" ⚠️ Output directory not found")
346
+ return
347
+
348
+ # Load required files
349
+ vlm_analysis_path = output_dir / "vlm_analysis.json"
350
+ black_mask_path = output_dir / "black_mask.mp4"
351
+ input_video_path = output_dir / "input_video.mp4"
352
+
353
+ if not vlm_analysis_path.exists():
354
+ print(f" ⚠️ vlm_analysis.json not found")
355
+ return
356
+
357
+ if not black_mask_path.exists():
358
+ print(f" ⚠️ black_mask.mp4 not found")
359
+ return
360
+
361
+ if not input_video_path.exists():
362
+ input_video_path = Path(video_path)
363
+
364
+ # Load VLM analysis
365
+ with open(vlm_analysis_path, 'r') as f:
366
+ analysis = json.load(f)
367
+
368
+ # Get video properties
369
+ cap = cv2.VideoCapture(str(input_video_path))
370
+ fps = cap.get(cv2.CAP_PROP_FPS)
371
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
372
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
373
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
374
+ cap.release()
375
+
376
+ # Calculate grid
377
+ min_grid = video_info.get('min_grid', 8)
378
+ grid_rows, grid_cols = calculate_square_grid(frame_width, frame_height, min_grid)
379
+
380
+ print(f" Video: {frame_width}x{frame_height}, {total_frames} frames, grid: {grid_rows}x{grid_cols}")
381
+
382
+ # Load black mask (first frame for proximity filtering)
383
+ black_cap = cv2.VideoCapture(str(black_mask_path))
384
+ ret, black_mask_frame = black_cap.read()
385
+ black_cap.release()
386
+
387
+ if len(black_mask_frame.shape) == 3:
388
+ black_mask_frame = cv2.cvtColor(black_mask_frame, cv2.COLOR_BGR2GRAY)
389
+
390
+ primary_mask = (black_mask_frame == 0) # 0 = primary object
391
+
392
+ # Initialize accumulated masks (one per frame)
393
+ accumulated_masks = [np.zeros((frame_height, frame_width), dtype=bool) for _ in range(total_frames)]
394
+
395
+ # Process affected objects
396
+ affected_objects = analysis.get('affected_objects', [])
397
+ print(f" Processing {len(affected_objects)} affected object(s)...")
398
+
399
+ for obj in affected_objects:
400
+ noun = obj.get('noun', '')
401
+
402
+ if not noun:
403
+ continue
404
+
405
+ print(f" β€’ {noun}")
406
+
407
+ # Check if we have USER TRAJECTORY for this object
408
+ has_trajectory = False
409
+ if trajectory_data:
410
+ for traj in trajectory_data:
411
+ if traj.get('object_noun', '') == noun and not traj.get('skipped', False):
412
+ has_trajectory = True
413
+ traj_points = traj.get('trajectory_points', [])
414
+
415
+ print(f" Using user-drawn trajectory ({len(traj_points)} points)")
416
+
417
+ # Segment object in first_appears_frame to get SIZE
418
+ first_frame_idx = obj.get('first_appears_frame', 0)
419
+ cap = cv2.VideoCapture(str(input_video_path))
420
+ cap.set(cv2.CAP_PROP_POS_FRAMES, first_frame_idx)
421
+ ret, frame = cap.read()
422
+ cap.release()
423
+
424
+ if ret:
425
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
426
+ frame_pil = Image.fromarray(frame_rgb)
427
+ obj_mask = segmenter.segment(frame_pil, noun)
428
+
429
+ if obj_mask.any():
430
+ obj_width, obj_height = get_object_size(obj_mask)
431
+ print(f" Segmented object (size: {obj_width}x{obj_height} px)")
432
+
433
+ # Apply object SIZE along trajectory
434
+ traj_masks = apply_object_along_trajectory(
435
+ obj_mask, traj_points, total_frames, (frame_height, frame_width)
436
+ )
437
+
438
+ # Accumulate
439
+ for i in range(total_frames):
440
+ accumulated_masks[i] |= traj_masks[i]
441
+
442
+ print(f" βœ“ Applied object along trajectory across {total_frames} frames")
443
+ else:
444
+ print(f" ⚠️ Segmentation failed, using trajectory grid cells only")
445
+ # Fallback: just use trajectory grid cells
446
+ grid_cells = traj.get('trajectory_grid_cells', [])
447
+ for row, col in grid_cells:
448
+ y1 = int(row * frame_height / grid_rows)
449
+ y2 = int((row + 1) * frame_height / grid_rows)
450
+ x1 = int(col * frame_width / grid_cols)
451
+ x2 = int((col + 1) * frame_width / grid_cols)
452
+ for i in range(total_frames):
453
+ accumulated_masks[i][y1:y2, x1:x2] = True
454
+
455
+ break
456
+
457
+ # If NO user trajectory, segment through ALL frames
458
+ # This captures: static objects, objects that move during video, dynamic effects
459
+ if not has_trajectory:
460
+ print(f" Segmenting through ALL frames (captures any movement/changes)...")
461
+ obj_masks = segment_object_all_frames(str(input_video_path), noun, segmenter, frame_stride=5)
462
+
463
+ # Filter by proximity to primary mask
464
+ obj_masks_filtered = filter_masks_by_proximity(obj_masks, primary_mask, dilation=50)
465
+
466
+ # Accumulate
467
+ for i in range(len(obj_masks_filtered)):
468
+ if i < len(accumulated_masks):
469
+ accumulated_masks[i] |= obj_masks_filtered[i]
470
+
471
+ pixel_count = sum(mask.sum() for mask in obj_masks_filtered)
472
+ print(f" βœ“ Segmented across {len(obj_masks_filtered)} frames ({pixel_count} total pixels)")
473
+
474
+ # GRIDIFY all accumulated masks
475
+ print(f" Gridifying masks...")
476
+ gridified_masks = gridify_masks(accumulated_masks, grid_rows, grid_cols)
477
+
478
+ # Convert to uint8 (127 = grey, 255 = background)
479
+ grey_masks_uint8 = [np.where(mask, 127, 255).astype(np.uint8) for mask in gridified_masks]
480
+
481
+ # Write video
482
+ print(f" Writing grey_mask.mp4...")
483
+ temp_avi = output_dir / "grey_mask_temp.avi"
484
+ fourcc = cv2.VideoWriter_fourcc(*'FFV1')
485
+ out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (frame_width, frame_height), isColor=False)
486
+
487
+ for mask in grey_masks_uint8:
488
+ out.write(mask)
489
+
490
+ out.release()
491
+
492
+ # Convert to MP4
493
+ grey_mask_mp4 = output_dir / "grey_mask.mp4"
494
+ cmd = [
495
+ 'ffmpeg', '-y', '-i', str(temp_avi),
496
+ '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
497
+ '-pix_fmt', 'yuv444p',
498
+ str(grey_mask_mp4)
499
+ ]
500
+ subprocess.run(cmd, capture_output=True)
501
+ temp_avi.unlink()
502
+
503
+ print(f" βœ“ Saved grey_mask.mp4")
504
+
505
+ # Save debug visualization (first frame)
506
+ debug_vis = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
507
+ debug_vis[gridified_masks[0]] = [0, 255, 0] # Green
508
+ debug_vis[primary_mask] = [255, 0, 0] # Red
509
+ debug_path = output_dir / "debug_grey_mask.jpg"
510
+ cv2.imwrite(str(debug_path), debug_vis)
511
+
512
+
513
+ def main():
514
+ parser = argparse.ArgumentParser(description="Stage 3a: Generate Grey Masks (Corrected)")
515
+ parser.add_argument("--config", required=True, help="Config JSON")
516
+ parser.add_argument("--segmentation-model", default="sam3", choices=["langsam", "sam3"],
517
+ help="Segmentation model")
518
+ args = parser.parse_args()
519
+
520
+ config_path = Path(args.config)
521
+
522
+ # Load config
523
+ with open(config_path, 'r') as f:
524
+ config_data = json.load(f)
525
+
526
+ if isinstance(config_data, list):
527
+ videos = config_data
528
+ elif isinstance(config_data, dict) and "videos" in config_data:
529
+ videos = config_data["videos"]
530
+ else:
531
+ raise ValueError("Invalid config format")
532
+
533
+ # Load trajectory data
534
+ trajectory_path = config_path.parent / f"{config_path.stem}_trajectories.json"
535
+ trajectory_data = None
536
+
537
+ if trajectory_path.exists():
538
+ print(f"Loading trajectory data: {trajectory_path.name}")
539
+ with open(trajectory_path, 'r') as f:
540
+ trajectory_data = json.load(f)
541
+ print(f" Loaded {len(trajectory_data)} trajectory(s)")
542
+
543
+ print(f"\n{'='*70}")
544
+ print(f"Stage 3a: Generate Grey Masks (CORRECTED)")
545
+ print(f"{'='*70}")
546
+ print(f"Videos: {len(videos)}")
547
+ print(f"Segmentation: {args.segmentation_model.upper()}")
548
+ print(f"{'='*70}\n")
549
+
550
+ # Load segmentation model
551
+ segmenter = SegmentationModel(args.segmentation_model)
552
+
553
+ # Process each video
554
+ for i, video_info in enumerate(videos):
555
+ video_path = video_info.get('video_path', '')
556
+ print(f"\n{'─'*70}")
557
+ print(f"Video {i+1}/{len(videos)}: {Path(video_path).parent.name}")
558
+ print(f"{'─'*70}")
559
+
560
+ try:
561
+ process_video_grey_masks(video_info, segmenter, trajectory_data)
562
+ print(f"\nβœ… Video {i+1} complete!")
563
+
564
+ except Exception as e:
565
+ print(f"\n❌ Error: {e}")
566
+ import traceback
567
+ traceback.print_exc()
568
+ continue
569
+
570
+ print(f"\n{'='*70}")
571
+ print(f"βœ… Stage 3a Complete!")
572
+ print(f"{'='*70}\n")
573
+
574
+
575
+ if __name__ == "__main__":
576
+ main()
VLM-MASK-REASONER/stage3b_trajectory_gui.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 3b: Trajectory Drawing GUI (Simplified - No Segmentation)
4
+
5
+ For objects with needs_trajectory=true, user draws movement paths.
6
+
7
+ Input: Config with vlm_analysis.json in output_dir
8
+ Output: trajectory_data.json with user-drawn paths as grid cells
9
+
10
+ Usage:
11
+ python stage3b_trajectory_gui.py --config more_dyn_2_config_points_absolute.json
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import json
17
+ import argparse
18
+ import cv2
19
+ import numpy as np
20
+ import tkinter as tk
21
+ from tkinter import ttk, messagebox
22
+ from PIL import Image, ImageTk, ImageDraw, ImageFont
23
+ from pathlib import Path
24
+ from typing import Dict, List, Tuple
25
+
26
+
27
+ def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> Tuple[int, int]:
28
+ """Calculate grid dimensions for square cells"""
29
+ aspect_ratio = width / height
30
+ if width >= height:
31
+ grid_rows = min_grid
32
+ grid_cols = max(min_grid, round(min_grid * aspect_ratio))
33
+ else:
34
+ grid_cols = min_grid
35
+ grid_rows = max(min_grid, round(min_grid / aspect_ratio))
36
+ return grid_rows, grid_cols
37
+
38
+
39
+ def points_to_grid_cells(points: List[Tuple[int, int]], grid_rows: int, grid_cols: int,
40
+ frame_width: int, frame_height: int) -> List[List[int]]:
41
+ """Convert trajectory points to grid cells"""
42
+ cell_width = frame_width / grid_cols
43
+ cell_height = frame_height / grid_rows
44
+
45
+ grid_cells = set()
46
+ for x, y in points:
47
+ col = int(x / cell_width)
48
+ row = int(y / cell_height)
49
+ if 0 <= row < grid_rows and 0 <= col < grid_cols:
50
+ grid_cells.add((row, col))
51
+
52
+ # Sort by row, then col
53
+ return sorted([[r, c] for r, c in grid_cells])
54
+
55
+
56
+ class TrajectoryGUI:
57
+ def __init__(self, root, objects_data: List[Dict]):
58
+ self.root = root
59
+ self.root.title("Stage 3b: Trajectory Drawing")
60
+
61
+ self.objects_data = objects_data # List of {video_info, objects_needing_trajectory}
62
+ self.current_video_idx = 0
63
+ self.current_object_idx = 0
64
+
65
+ # Current state
66
+ self.frame = None
67
+ self.trajectory_points = []
68
+ self.drawing = False
69
+
70
+ # Display
71
+ self.display_scale = 1.0
72
+ self.photo = None
73
+
74
+ # Results storage
75
+ self.all_trajectories = [] # List of trajectories for all videos
76
+
77
+ self.setup_ui()
78
+ self.load_current_object()
79
+
80
+ def setup_ui(self):
81
+ """Setup GUI layout"""
82
+ # Top info
83
+ info_frame = ttk.Frame(self.root)
84
+ info_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
85
+
86
+ self.video_label = ttk.Label(info_frame, text="Video: ", font=("Arial", 10, "bold"))
87
+ self.video_label.pack(side=tk.LEFT, padx=5)
88
+
89
+ self.object_label = ttk.Label(info_frame, text="Object: ", foreground="blue")
90
+ self.object_label.pack(side=tk.LEFT, padx=10)
91
+
92
+ # Instructions
93
+ inst_frame = ttk.LabelFrame(self.root, text="Instructions")
94
+ inst_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
95
+
96
+ ttk.Label(inst_frame, text="1. See the frame where object is visible", foreground="blue").pack(anchor=tk.W, padx=5)
97
+ ttk.Label(inst_frame, text="2. Click and drag to draw trajectory path (RED line)", foreground="red").pack(anchor=tk.W, padx=5)
98
+ ttk.Label(inst_frame, text="3. Draw from object's current position to where it should end up", foreground="orange").pack(anchor=tk.W, padx=5)
99
+ ttk.Label(inst_frame, text="4. Click 'Clear' to restart, 'Save & Next' when done", foreground="green").pack(anchor=tk.W, padx=5)
100
+
101
+ # Canvas
102
+ canvas_frame = ttk.LabelFrame(self.root, text="Draw Trajectory Path")
103
+ canvas_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
104
+
105
+ self.canvas = tk.Canvas(canvas_frame, width=800, height=600, bg='black', cursor="crosshair")
106
+ self.canvas.pack(fill=tk.BOTH, expand=True)
107
+ self.canvas.bind("<Button-1>", self.on_canvas_click)
108
+ self.canvas.bind("<B1-Motion>", self.on_canvas_drag)
109
+ self.canvas.bind("<ButtonRelease-1>", self.on_canvas_release)
110
+
111
+ # Controls
112
+ controls = ttk.Frame(self.root)
113
+ controls.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=5)
114
+
115
+ self.status_label = ttk.Label(controls, text="Draw trajectory path for object", foreground="blue")
116
+ self.status_label.pack(side=tk.TOP, pady=5)
117
+
118
+ button_frame = ttk.Frame(controls)
119
+ button_frame.pack(side=tk.TOP)
120
+
121
+ ttk.Button(button_frame, text="Clear Trajectory", command=self.clear_trajectory).pack(side=tk.LEFT, padx=5)
122
+ ttk.Button(button_frame, text="Skip Object", command=self.skip_object).pack(side=tk.LEFT, padx=5)
123
+ ttk.Button(button_frame, text="Save & Next", command=self.save_and_next).pack(side=tk.LEFT, padx=5)
124
+
125
+ self.progress_label = ttk.Label(controls, text="", font=("Arial", 9))
126
+ self.progress_label.pack(side=tk.TOP, pady=5)
127
+
128
+ def load_current_object(self):
129
+ """Load current object for trajectory drawing"""
130
+ if self.current_video_idx >= len(self.objects_data):
131
+ # All done
132
+ self.finish()
133
+ return
134
+
135
+ data = self.objects_data[self.current_video_idx]
136
+ video_info = data['video_info']
137
+ objects_needing_traj = data['objects']
138
+
139
+ if self.current_object_idx >= len(objects_needing_traj):
140
+ # Done with this video, move to next
141
+ self.current_video_idx += 1
142
+ self.current_object_idx = 0
143
+ self.load_current_object()
144
+ return
145
+
146
+ obj = objects_needing_traj[self.current_object_idx]
147
+ video_path = video_info.get('video_path', '')
148
+ output_dir = Path(video_info.get('output_dir', ''))
149
+
150
+ # Update labels
151
+ self.video_label.config(text=f"Video: {Path(video_path).parent.name}/{Path(video_path).name}")
152
+ self.object_label.config(text=f"Object: {obj['noun']} (will fall/move)")
153
+
154
+ total_objects = sum(len(d['objects']) for d in self.objects_data)
155
+ current_obj_num = sum(len(self.objects_data[i]['objects']) for i in range(self.current_video_idx)) + self.current_object_idx + 1
156
+ self.progress_label.config(text=f"Object {current_obj_num}/{total_objects} across {len(self.objects_data)} video(s)")
157
+
158
+ # Extract frame
159
+ frame_idx = obj.get('first_appears_frame', 0)
160
+ input_video = output_dir / "input_video.mp4"
161
+
162
+ if not input_video.exists():
163
+ input_video = Path(video_path)
164
+
165
+ cap = cv2.VideoCapture(str(input_video))
166
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
167
+ ret, frame = cap.read()
168
+ cap.release()
169
+
170
+ if not ret:
171
+ messagebox.showerror("Error", f"Failed to read frame {frame_idx} from video")
172
+ self.skip_object()
173
+ return
174
+
175
+ self.frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
176
+
177
+ print(f"\n Loaded frame {frame_idx} for '{obj['noun']}'")
178
+
179
+ # Clear trajectory
180
+ self.trajectory_points = []
181
+
182
+ # Calculate grid for this video
183
+ h, w = self.frame.shape[:2]
184
+ min_grid = video_info.get('min_grid', 8)
185
+ self.grid_rows, self.grid_cols = calculate_square_grid(w, h, min_grid)
186
+
187
+ # Display
188
+ self.status_label.config(text="Draw trajectory path (click and drag)", foreground="blue")
189
+ self.display_frame()
190
+
191
+ def display_frame(self):
192
+ """Display frame with trajectory"""
193
+ if self.frame is None:
194
+ return
195
+
196
+ # Create visualization
197
+ vis = self.frame.copy()
198
+ h, w = vis.shape[:2]
199
+
200
+ # Draw trajectory
201
+ if len(self.trajectory_points) > 1:
202
+ for i in range(len(self.trajectory_points) - 1):
203
+ pt1 = self.trajectory_points[i]
204
+ pt2 = self.trajectory_points[i + 1]
205
+ cv2.line(vis, pt1, pt2, (255, 0, 0), 5) # Thicker line for visibility
206
+
207
+ # Draw start point (green) and end point (red)
208
+ if len(self.trajectory_points) > 0:
209
+ start_pt = self.trajectory_points[0]
210
+ end_pt = self.trajectory_points[-1]
211
+ cv2.circle(vis, start_pt, 8, (0, 255, 0), -1) # Green start
212
+ cv2.circle(vis, end_pt, 8, (255, 0, 0), -1) # Red end
213
+
214
+ # Scale for display
215
+ max_width, max_height = 800, 600
216
+ scale_w = max_width / w
217
+ scale_h = max_height / h
218
+ self.display_scale = min(scale_w, scale_h, 1.0)
219
+
220
+ new_w = int(w * self.display_scale)
221
+ new_h = int(h * self.display_scale)
222
+ vis_resized = cv2.resize(vis, (new_w, new_h))
223
+
224
+ # Convert to PIL and display
225
+ pil_img = Image.fromarray(vis_resized)
226
+ self.photo = ImageTk.PhotoImage(pil_img)
227
+ self.canvas.delete("all")
228
+ self.canvas.create_image(0, 0, anchor=tk.NW, image=self.photo)
229
+
230
+ def on_canvas_click(self, event):
231
+ """Start drawing trajectory"""
232
+ # Convert to frame coordinates
233
+ x = int(event.x / self.display_scale)
234
+ y = int(event.y / self.display_scale)
235
+
236
+ self.trajectory_points = [(x, y)]
237
+ self.drawing = True
238
+
239
+ def on_canvas_drag(self, event):
240
+ """Continue drawing trajectory"""
241
+ if not self.drawing:
242
+ return
243
+
244
+ x = int(event.x / self.display_scale)
245
+ y = int(event.y / self.display_scale)
246
+
247
+ # Add point if far enough from last point
248
+ if len(self.trajectory_points) > 0:
249
+ last_x, last_y = self.trajectory_points[-1]
250
+ dist = np.sqrt((x - last_x)**2 + (y - last_y)**2)
251
+ if dist > 5: # Minimum distance between points
252
+ self.trajectory_points.append((x, y))
253
+ self.display_frame()
254
+
255
+ def on_canvas_release(self, event):
256
+ """Finish drawing trajectory"""
257
+ self.drawing = False
258
+ if len(self.trajectory_points) > 0:
259
+ x = int(event.x / self.display_scale)
260
+ y = int(event.y / self.display_scale)
261
+ self.trajectory_points.append((x, y))
262
+ self.display_frame()
263
+
264
+ def clear_trajectory(self):
265
+ """Clear drawn trajectory"""
266
+ self.trajectory_points = []
267
+ self.display_frame()
268
+ self.status_label.config(text="Trajectory cleared. Draw again.", foreground="blue")
269
+
270
+ def skip_object(self):
271
+ """Skip current object without saving trajectory"""
272
+ result = messagebox.askyesno("Skip Object", "Skip this object without drawing trajectory?")
273
+ if not result:
274
+ return
275
+
276
+ # Save empty trajectory
277
+ data = self.objects_data[self.current_video_idx]
278
+ obj = data['objects'][self.current_object_idx]
279
+
280
+ self.all_trajectories.append({
281
+ 'video_path': data['video_info']['video_path'],
282
+ 'object_noun': obj['noun'],
283
+ 'trajectory_points': [],
284
+ 'trajectory_grid_cells': [],
285
+ 'skipped': True
286
+ })
287
+
288
+ self.current_object_idx += 1
289
+ self.load_current_object()
290
+
291
+ def save_and_next(self):
292
+ """Save trajectory and move to next object"""
293
+ if len(self.trajectory_points) < 2:
294
+ messagebox.showwarning("Warning", "Draw a trajectory path first (at least 2 points)")
295
+ return
296
+
297
+ # Convert to grid cells
298
+ data = self.objects_data[self.current_video_idx]
299
+ obj = data['objects'][self.current_object_idx]
300
+
301
+ grid_cells = points_to_grid_cells(
302
+ self.trajectory_points,
303
+ self.grid_rows,
304
+ self.grid_cols,
305
+ self.frame.shape[1],
306
+ self.frame.shape[0]
307
+ )
308
+
309
+ # Save
310
+ self.all_trajectories.append({
311
+ 'video_path': data['video_info']['video_path'],
312
+ 'object_noun': obj['noun'],
313
+ 'first_appears_frame': obj.get('first_appears_frame', 0),
314
+ 'trajectory_points': self.trajectory_points,
315
+ 'trajectory_grid_cells': grid_cells,
316
+ 'grid_rows': self.grid_rows,
317
+ 'grid_cols': self.grid_cols,
318
+ 'skipped': False
319
+ })
320
+
321
+ print(f" βœ“ Saved trajectory for '{obj['noun']}': {len(grid_cells)} grid cells")
322
+
323
+ self.current_object_idx += 1
324
+ self.load_current_object()
325
+
326
+ def finish(self):
327
+ """All objects done"""
328
+ self.status_label.config(text="All trajectories complete!", foreground="green")
329
+ messagebox.showinfo("Complete", "All trajectory drawings complete!\n\nSaving results...")
330
+ self.root.quit()
331
+
332
+
333
+ def find_objects_needing_trajectory(config_path: str) -> List[Dict]:
334
+ """Find all objects that need trajectory input"""
335
+ config_path = Path(config_path)
336
+
337
+ with open(config_path, 'r') as f:
338
+ config_data = json.load(f)
339
+
340
+ if isinstance(config_data, list):
341
+ videos = config_data
342
+ elif isinstance(config_data, dict) and "videos" in config_data:
343
+ videos = config_data["videos"]
344
+ else:
345
+ raise ValueError("Invalid config format")
346
+
347
+ objects_data = []
348
+
349
+ for video_info in videos:
350
+ output_dir = Path(video_info.get('output_dir', ''))
351
+ vlm_analysis_path = output_dir / "vlm_analysis.json"
352
+
353
+ if not vlm_analysis_path.exists():
354
+ print(f" Skipping {output_dir.parent.name}: no vlm_analysis.json")
355
+ continue
356
+
357
+ with open(vlm_analysis_path, 'r') as f:
358
+ analysis = json.load(f)
359
+
360
+ # Find objects with needs_trajectory=true
361
+ objects_needing_traj = [
362
+ obj for obj in analysis.get('affected_objects', [])
363
+ if obj.get('needs_trajectory', False)
364
+ ]
365
+
366
+ if objects_needing_traj:
367
+ objects_data.append({
368
+ 'video_info': video_info,
369
+ 'objects': objects_needing_traj,
370
+ 'output_dir': output_dir
371
+ })
372
+
373
+ return objects_data
374
+
375
+
376
+ def main():
377
+ parser = argparse.ArgumentParser(description="Stage 3b: Trajectory Drawing GUI")
378
+ parser.add_argument("--config", required=True, help="Config JSON")
379
+ args = parser.parse_args()
380
+
381
+ print(f"\n{'='*70}")
382
+ print(f"Stage 3b: Trajectory Drawing GUI")
383
+ print(f"{'='*70}\n")
384
+
385
+ # Find objects needing trajectories
386
+ print("Finding objects that need trajectory input...")
387
+ objects_data = find_objects_needing_trajectory(args.config)
388
+
389
+ if not objects_data:
390
+ print("\nβœ… No objects need trajectory input!")
391
+ print("All objects are either stationary or visual artifacts.")
392
+ print("Proceeding to Stage 3a for mask generation...")
393
+ return
394
+
395
+ total_objects = sum(len(d['objects']) for d in objects_data)
396
+ print(f"\nFound {total_objects} object(s) needing trajectories across {len(objects_data)} video(s):")
397
+ for d in objects_data:
398
+ video_name = Path(d['video_info']['video_path']).parent.name
399
+ print(f" β€’ {video_name}: {', '.join(obj['noun'] for obj in d['objects'])}")
400
+
401
+ # Launch GUI
402
+ print("\nLaunching trajectory drawing GUI...")
403
+ print("Instructions:")
404
+ print(" 1. See the frame where the object is visible")
405
+ print(" 2. Click and drag to draw trajectory path (RED line)")
406
+ print(" 3. Draw from object's current position to where it should end up")
407
+ print(" 4. Click 'Save & Next' when done with each object")
408
+ print("")
409
+
410
+ root = tk.Tk()
411
+ root.geometry("900x800")
412
+ gui = TrajectoryGUI(root, objects_data)
413
+ root.mainloop()
414
+
415
+ # Save trajectories
416
+ config_path = Path(args.config)
417
+ output_path = config_path.parent / f"{config_path.stem}_trajectories.json"
418
+
419
+ with open(output_path, 'w') as f:
420
+ json.dump(gui.all_trajectories, f, indent=2)
421
+
422
+ print(f"\n{'='*70}")
423
+ print(f"βœ… Stage 3b Complete!")
424
+ print(f"{'='*70}")
425
+ print(f"Saved trajectories to: {output_path}")
426
+ print(f"Total trajectories: {len(gui.all_trajectories)}")
427
+ print(f"\nNext: Run Stage 3a to generate grey masks (includes trajectories)")
428
+ print(f"{'='*70}\n")
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
VLM-MASK-REASONER/stage4_combine_masks.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 4: Combine Black and Grey Masks into Tri/Quad Mask
4
+
5
+ Combines the black mask (primary object) and grey masks (affected objects)
6
+ into a single tri-mask or quad-mask video.
7
+
8
+ Mask values:
9
+ - 0: Primary object (from black mask)
10
+ - 63: Overlap of primary and affected objects
11
+ - 127: Affected objects only (from grey masks)
12
+ - 255: Background (keep)
13
+ """
14
+
15
+ import json
16
+ import argparse
17
+ from pathlib import Path
18
+ import cv2
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+
22
+
23
+ def combine_masks(black_frame, grey_frame):
24
+ """
25
+ Combine black and grey mask frames.
26
+
27
+ Rules:
28
+ - black=0, grey=255 β†’ 0 (primary object only)
29
+ - black=255, grey=127 β†’ 127 (affected object only)
30
+ - black=0, grey=127 β†’ 63 (overlap)
31
+ - black=255, grey=255 β†’ 255 (background)
32
+
33
+ Args:
34
+ black_frame: Frame from black_mask.mp4 (0=object, 255=background)
35
+ grey_frame: Frame from grey_mask.mp4 (127=object, 255=background)
36
+
37
+ Returns:
38
+ Combined mask frame
39
+ """
40
+ # Initialize with background (255)
41
+ combined = np.full_like(black_frame, 255, dtype=np.uint8)
42
+
43
+ # Primary object only (black=0, grey=255)
44
+ primary_only = (black_frame == 0) & (grey_frame == 255)
45
+ combined[primary_only] = 0
46
+
47
+ # Affected object only (black=255, grey=127)
48
+ affected_only = (black_frame == 255) & (grey_frame == 127)
49
+ combined[affected_only] = 127
50
+
51
+ # Overlap (black=0, grey=127)
52
+ overlap = (black_frame == 0) & (grey_frame == 127)
53
+ combined[overlap] = 63
54
+
55
+ return combined
56
+
57
+
58
+ def process_video(black_mask_path: Path, grey_mask_path: Path, output_path: Path):
59
+ """Combine black and grey mask videos into trimask/quadmask"""
60
+ import subprocess
61
+
62
+ print(f" Loading black mask: {black_mask_path.name}")
63
+ black_cap = cv2.VideoCapture(str(black_mask_path))
64
+
65
+ print(f" Loading grey mask: {grey_mask_path.name}")
66
+ grey_cap = cv2.VideoCapture(str(grey_mask_path))
67
+
68
+ # Get video properties
69
+ fps = black_cap.get(cv2.CAP_PROP_FPS)
70
+ width = int(black_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
+ height = int(black_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
+ total_frames = int(black_cap.get(cv2.CAP_PROP_FRAME_COUNT))
73
+
74
+ # Check grey mask has same properties
75
+ grey_total_frames = int(grey_cap.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ if total_frames != grey_total_frames:
77
+ print(f" ⚠️ Warning: Frame count mismatch (black: {total_frames}, grey: {grey_total_frames})")
78
+ total_frames = min(total_frames, grey_total_frames)
79
+
80
+ print(f" Video: {width}x{height} @ {fps:.2f}fps, {total_frames} frames")
81
+ print(f" Combining masks...")
82
+
83
+ # Collect all frames first
84
+ combined_frames = []
85
+
86
+ # Process frames
87
+ for frame_idx in tqdm(range(total_frames), desc=" Combining"):
88
+ ret_black, black_frame = black_cap.read()
89
+ ret_grey, grey_frame = grey_cap.read()
90
+
91
+ if not ret_black or not ret_grey:
92
+ print(f" ⚠️ Warning: Could not read frame {frame_idx}")
93
+ break
94
+
95
+ # Convert to grayscale if needed
96
+ if len(black_frame.shape) == 3:
97
+ black_frame = cv2.cvtColor(black_frame, cv2.COLOR_BGR2GRAY)
98
+ if len(grey_frame.shape) == 3:
99
+ grey_frame = cv2.cvtColor(grey_frame, cv2.COLOR_BGR2GRAY)
100
+
101
+ # Combine
102
+ combined_frame = combine_masks(black_frame, grey_frame)
103
+ combined_frames.append(combined_frame)
104
+
105
+ # Cleanup
106
+ black_cap.release()
107
+ grey_cap.release()
108
+
109
+ # On the first frame, clamp near-grey values (100–135) to 255 (background).
110
+ # Video codecs can introduce slight luma drift around 127; this ensures no
111
+ # grey pixels survive into the final quadmask on frame 0.
112
+ if combined_frames:
113
+ f0 = combined_frames[0]
114
+ grey_pixels = (f0 > 100) & (f0 < 135)
115
+ f0[grey_pixels] = 255
116
+ combined_frames[0] = f0
117
+
118
+ # Write using LOSSLESS encoding to preserve exact mask values
119
+ print(f" Writing lossless video...")
120
+
121
+ # Write temp AVI with FFV1 codec (lossless)
122
+ temp_avi = output_path.with_suffix('.avi')
123
+ fourcc = cv2.VideoWriter_fourcc(*'FFV1')
124
+ out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (width, height), isColor=False)
125
+
126
+ for frame in combined_frames:
127
+ out.write(frame)
128
+ out.release()
129
+
130
+ # Convert to LOSSLESS H.264 (qp=0, yuv444p to preserve all luma values)
131
+ cmd = [
132
+ 'ffmpeg', '-y', '-i', str(temp_avi),
133
+ '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
134
+ '-pix_fmt', 'yuv444p',
135
+ '-r', str(fps),
136
+ str(output_path)
137
+ ]
138
+ result = subprocess.run(cmd, capture_output=True, text=True)
139
+ if result.returncode != 0:
140
+ print(f" ⚠️ Warning: ffmpeg conversion had issues")
141
+ print(result.stderr)
142
+
143
+ # Clean up temp file
144
+ temp_avi.unlink()
145
+
146
+ print(f" βœ“ Saved: {output_path.name}")
147
+
148
+
149
+ def process_config(config_path: str):
150
+ """Process all videos in config"""
151
+ config_path = Path(config_path)
152
+
153
+ # Load config
154
+ with open(config_path, 'r') as f:
155
+ config_data = json.load(f)
156
+
157
+ # Handle both formats
158
+ if isinstance(config_data, list):
159
+ videos = config_data
160
+ elif isinstance(config_data, dict) and "videos" in config_data:
161
+ videos = config_data["videos"]
162
+ else:
163
+ raise ValueError("Config must be a list or have 'videos' key")
164
+
165
+ print(f"\n{'='*70}")
166
+ print(f"Stage 4: Combine Masks into Tri/Quad Mask")
167
+ print(f"{'='*70}")
168
+ print(f"Config: {config_path.name}")
169
+ print(f"Videos: {len(videos)}")
170
+ print(f"{'='*70}\n")
171
+
172
+ # Process each video
173
+ success_count = 0
174
+ for i, video_info in enumerate(videos):
175
+ video_path = video_info.get("video_path", "")
176
+ output_dir = video_info.get("output_dir", "")
177
+
178
+ print(f"\n{'─'*70}")
179
+ print(f"Video {i+1}/{len(videos)}: {Path(video_path).name}")
180
+ print(f"{'─'*70}")
181
+
182
+ if not output_dir:
183
+ print(f" ⚠️ No output_dir specified, skipping")
184
+ continue
185
+
186
+ output_dir = Path(output_dir)
187
+ if not output_dir.exists():
188
+ print(f" ⚠️ Output directory not found: {output_dir}")
189
+ continue
190
+
191
+ # Check for required masks
192
+ black_mask_path = output_dir / "black_mask.mp4"
193
+ grey_mask_path = output_dir / "grey_mask.mp4"
194
+
195
+ if not black_mask_path.exists():
196
+ print(f" ⚠️ black_mask.mp4 not found, skipping")
197
+ continue
198
+
199
+ if not grey_mask_path.exists():
200
+ print(f" ⚠️ grey_mask.mp4 not found, skipping")
201
+ continue
202
+
203
+ # Output path
204
+ output_path = output_dir / "quadmask_0.mp4"
205
+
206
+ try:
207
+ process_video(black_mask_path, grey_mask_path, output_path)
208
+ success_count += 1
209
+ print(f"\nβœ… Video {i+1} complete!")
210
+ except Exception as e:
211
+ print(f"\n❌ Error processing video {i+1}: {e}")
212
+ import traceback
213
+ traceback.print_exc()
214
+ continue
215
+
216
+ # Summary
217
+ print(f"\n{'='*70}")
218
+ print(f"Stage 4 Complete!")
219
+ print(f"{'='*70}")
220
+ print(f"Successful: {success_count}/{len(videos)}")
221
+ print(f"Failed: {len(videos) - success_count}/{len(videos)}")
222
+ print(f"{'='*70}\n")
223
+
224
+
225
+ def main():
226
+ parser = argparse.ArgumentParser(
227
+ description="Stage 4: Combine black and grey masks into tri/quad mask"
228
+ )
229
+ parser.add_argument(
230
+ "--config",
231
+ type=str,
232
+ required=True,
233
+ help="Path to config JSON (with output_dir for each video)"
234
+ )
235
+
236
+ args = parser.parse_args()
237
+ process_config(args.config)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
VLM-MASK-REASONER/test_gemini_video.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick Gemini API smoke test β€” samples a few frames from a video and asks a
4
+ simple question. Use this to verify your API key works before running the
5
+ full pipeline.
6
+
7
+ Usage:
8
+ export GEMINI_API_KEY="your_aistudio_key"
9
+ python test_gemini_video.py --video path/to/video.mp4
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import base64
15
+ import argparse
16
+ import cv2
17
+ import numpy as np
18
+ from pathlib import Path
19
+
20
+ import openai
21
+
22
+
23
+ FREE_TIER_MODEL = "gemini-2.0-flash"
24
+ NUM_FRAMES = 4 # keep low for free tier rate limits
25
+
26
+
27
+ def sample_frames(video_path: str, n: int = NUM_FRAMES):
28
+ """Sample n evenly-spaced frames from the video, return as base64 data URLs."""
29
+ cap = cv2.VideoCapture(video_path)
30
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
31
+ indices = [int(i * (total - 1) / (n - 1)) for i in range(n)]
32
+
33
+ data_urls = []
34
+ for idx in indices:
35
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
36
+ ret, frame = cap.read()
37
+ if not ret:
38
+ continue
39
+ _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
40
+ b64 = base64.b64encode(buf).decode("utf-8")
41
+ data_urls.append(f"data:image/jpeg;base64,{b64}")
42
+
43
+ cap.release()
44
+ return data_urls
45
+
46
+
47
+ def main():
48
+ parser = argparse.ArgumentParser(description="Gemini API smoke test with video frames")
49
+ parser.add_argument("--video", required=True, help="Path to a video file")
50
+ parser.add_argument("--model", default=FREE_TIER_MODEL, help="Gemini model to use")
51
+ parser.add_argument("--frames", type=int, default=NUM_FRAMES, help="Number of frames to sample")
52
+ args = parser.parse_args()
53
+
54
+ api_key = os.environ.get("GEMINI_API_KEY")
55
+ if not api_key:
56
+ print("ERROR: GEMINI_API_KEY environment variable not set")
57
+ sys.exit(1)
58
+
59
+ video_path = Path(args.video)
60
+ if not video_path.exists():
61
+ print(f"ERROR: Video not found: {video_path}")
62
+ sys.exit(1)
63
+
64
+ print(f"Video: {video_path.name}")
65
+ print(f"Model: {args.model}")
66
+ print(f"Frames: {args.frames}")
67
+ print()
68
+
69
+ print(f"Sampling {args.frames} frames...")
70
+ data_urls = sample_frames(str(video_path), args.frames)
71
+ print(f"Got {len(data_urls)} frames. Sending to Gemini...")
72
+
73
+ client = openai.OpenAI(
74
+ api_key=api_key,
75
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
76
+ )
77
+
78
+ content = [
79
+ {"type": "image_url", "image_url": {"url": url}} for url in data_urls
80
+ ]
81
+ content.append({
82
+ "type": "text",
83
+ "text": "These are evenly-spaced frames from a short video. In one sentence, describe what is happening in the video."
84
+ })
85
+
86
+ response = client.chat.completions.create(
87
+ model=args.model,
88
+ messages=[{"role": "user", "content": content}],
89
+ )
90
+
91
+ print("\n--- Gemini response ---")
92
+ print(response.choices[0].message.content)
93
+ print("-----------------------")
94
+ print("\nβœ… API key works!")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
app.py CHANGED
@@ -1,18 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  @spaces.GPU
5
- def dudu():
6
- pass
 
 
 
 
 
 
 
7
 
8
- def greet(name, intensity):
9
- return "Hello, " + name + "!" * int(intensity)
 
10
 
11
- demo = gr.Interface(
12
- fn=greet,
13
- inputs=["text", "slider"],
14
- outputs=["text"],
15
- api_name="predict"
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- demo.launch()
 
 
1
+ """
2
+ VOID VLM-Mask-Reasoner β€” Quadmask Generation Demo
3
+ Generates 4-level semantic masks for interaction-aware video inpainting.
4
+
5
+ Pipeline from https://github.com/Netflix/void-model:
6
+ Stage 1: SAM2 segmentation β†’ black mask (transformers Sam2Model)
7
+ Stage 2: Gemini VLM scene analysis β†’ affected objects JSON (repo code)
8
+ Stage 3: SAM3 text-prompted segmentation β†’ grey mask (transformers Sam3Model)
9
+ Stage 4: Combine black + grey β†’ quadmask (0/63/127/255) (repo code)
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import tempfile
16
+ import shutil
17
+ import subprocess
18
+ from pathlib import Path
19
+
20
+ import cv2
21
+ import numpy as np
22
+ import torch
23
  import gradio as gr
24
  import spaces
25
+ import imageio
26
+ from PIL import Image, ImageDraw
27
+ from huggingface_hub import hf_hub_download
28
+ import openai
29
+
30
+ # ── Add repo modules to path ─────────────────────────────────────────────────
31
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "VLM-MASK-REASONER"))
32
+
33
+ # ── Repo imports: Stage 2 (VLM) and Stage 4 (combine) ────────────────────────
34
+ from stage2_vlm_analysis import (
35
+ process_video as vlm_process_video,
36
+ calculate_square_grid,
37
+ )
38
+ from stage4_combine_masks import process_video as combine_process_video
39
+ # Stage 3 helpers (grid logic, mask combination β€” not the SegmentationModel)
40
+ from stage3a_generate_grey_masks_v2 import (
41
+ calculate_square_grid as calc_grid_3a,
42
+ gridify_masks,
43
+ filter_masks_by_proximity,
44
+ segment_object_all_frames as _repo_segment_all_frames,
45
+ process_video_grey_masks,
46
+ )
47
+
48
+ # ── Constants ─────────────────────────────────────────────────────────────────
49
+ SAM2_MODEL_ID = "facebook/sam2.1-hiera-large"
50
+ SAM3_MODEL_ID = "jetjodh/sam3"
51
+ DEFAULT_VLM_MODEL = "gemini-3-flash-preview"
52
+ MAX_FRAMES = 197
53
+ FPS_DEFAULT = 12
54
+ FRAME_STRIDE = 4 # Process every Nth frame for SAM2 tracking
55
+
56
+ # ── Load transformers SAM2 (video model with propagation support) ─────────────
57
+ print("Loading SAM2 video model (transformers)...")
58
+ from transformers import Sam2VideoModel, Sam2VideoProcessor
59
+ from transformers.models.sam2_video.modeling_sam2_video import Sam2VideoInferenceSession
60
+ sam2_model = Sam2VideoModel.from_pretrained(SAM2_MODEL_ID).to("cuda")
61
+ sam2_processor = Sam2VideoProcessor.from_pretrained(SAM2_MODEL_ID)
62
+ print("SAM2 video model ready.")
63
+
64
+ # ── Load transformers SAM3 ───────────────────────────────────────────────────
65
+ print("Loading SAM3 model (transformers)...")
66
+ from transformers import Sam3Model, Sam3Processor
67
+ sam3_model = Sam3Model.from_pretrained(SAM3_MODEL_ID).to("cuda")
68
+ sam3_processor = Sam3Processor.from_pretrained(SAM3_MODEL_ID)
69
+ print("SAM3 ready.")
70
+
71
+
72
+ # ══════════════════════════════════════════════════════════════════════════════
73
+ # STAGE 1: SAM2 VIDEO SEGMENTATION (transformers Sam2VideoModel)
74
+ # Uses proper video propagation with memory β€” matches repo's propagate_in_video
75
+ # ══════════════════════════════════════════════════════════════════════════════
76
+
77
+ def stage1_segment_video(frames: list, points: list, **kwargs) -> list:
78
+ """Segment primary object across all video frames using SAM2 video propagation.
79
+ Matches repo: point prompts + bounding box on frame 0, propagate through video.
80
+ Returns list of uint8 masks (0=object, 255=background)."""
81
+ total = len(frames)
82
+ h, w = frames[0].shape[:2]
83
+
84
+ # Preprocess all frames
85
+ pil_frames = [Image.fromarray(f) for f in frames]
86
+ inputs = sam2_processor(images=pil_frames, return_tensors="pt").to(sam2_model.device)
87
+
88
+ # Create inference session with all frames
89
+ session = Sam2VideoInferenceSession(
90
+ video=inputs["pixel_values"],
91
+ video_height=h,
92
+ video_width=w,
93
+ inference_device=sam2_model.device,
94
+ inference_state_device=sam2_model.device,
95
+ dtype=torch.float32,
96
+ )
97
+
98
+ # Add point prompts + bounding box on frame 0 via processor
99
+ # (handles normalization, object registration, and obj_with_new_inputs)
100
+ pts = np.array(points, dtype=np.float32)
101
+ x_min, x_max = pts[:, 0].min(), pts[:, 0].max()
102
+ y_min, y_max = pts[:, 1].min(), pts[:, 1].max()
103
+ x_margin = max((x_max - x_min) * 0.1, 10)
104
+ y_margin = max((y_max - y_min) * 0.1, 10)
105
+ box = [
106
+ max(0, x_min - x_margin),
107
+ max(0, y_min - y_margin),
108
+ min(w, x_max + x_margin),
109
+ min(h, y_max + y_margin),
110
+ ]
111
+
112
+ sam2_processor.process_new_points_or_boxes_for_video_frame(
113
+ inference_session=session,
114
+ frame_idx=0,
115
+ obj_ids=[1],
116
+ input_points=[[[[float(p[0]), float(p[1])] for p in points]]],
117
+ input_labels=[[[1] * len(points)]],
118
+ input_boxes=[[[float(box[0]), float(box[1]), float(box[2]), float(box[3])]]],
119
+ )
120
+
121
+ # Run forward on the prompted frame first (populates cond_frame_outputs)
122
+ with torch.no_grad():
123
+ sam2_model(session, frame_idx=0)
124
+
125
+ # Propagate through all frames (matches repo's propagate_in_video)
126
+ video_segments = {}
127
+ original_sizes = [[h, w]]
128
+ with torch.no_grad():
129
+ for output in sam2_model.propagate_in_video_iterator(session):
130
+ frame_idx = output.frame_idx
131
+ # pred_masks shape varies β€” get the raw logits and resize to original
132
+ mask_logits = output.pred_masks[0].cpu().float() # first object
133
+ # Ensure 4D for interpolation: (1, 1, H_model, W_model)
134
+ while mask_logits.dim() < 4:
135
+ mask_logits = mask_logits.unsqueeze(0)
136
+ mask_resized = torch.nn.functional.interpolate(
137
+ mask_logits, size=(h, w), mode="bilinear", align_corners=False
138
+ )
139
+ mask = (mask_resized.squeeze() > 0.0).numpy()
140
+ video_segments[frame_idx] = mask
141
+
142
+ # Convert to uint8 masks (0=object, 255=background)
143
+ all_masks = []
144
+ for idx in range(total):
145
+ if idx in video_segments:
146
+ mask_bool = video_segments[idx]
147
+ else:
148
+ nearest = min(video_segments.keys(), key=lambda k: abs(k - idx))
149
+ mask_bool = video_segments[nearest]
150
+ mask_uint8 = np.where(mask_bool, 0, 255).astype(np.uint8)
151
+ all_masks.append(mask_uint8)
152
+
153
+ return all_masks
154
+
155
+
156
+ def write_mask_video(masks: list, fps: float, output_path: str):
157
+ """Write list of uint8 grayscale masks to lossless MP4."""
158
+ h, w = masks[0].shape[:2]
159
+ temp_avi = str(Path(output_path).with_suffix('.avi'))
160
+ fourcc = cv2.VideoWriter_fourcc(*'FFV1')
161
+ out = cv2.VideoWriter(temp_avi, fourcc, fps, (w, h), isColor=False)
162
+ for mask in masks:
163
+ out.write(mask)
164
+ out.release()
165
+
166
+ cmd = [
167
+ 'ffmpeg', '-y', '-i', temp_avi,
168
+ '-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
169
+ '-pix_fmt', 'yuv444p', str(output_path),
170
+ ]
171
+ subprocess.run(cmd, capture_output=True)
172
+ if os.path.exists(temp_avi):
173
+ os.unlink(temp_avi)
174
+
175
+
176
+ # ══════════════════════════════════════════════════════════════════════════════
177
+ # STAGE 3: SAM3 TEXT-PROMPTED SEGMENTATION (transformers)
178
+ # β€” Drop-in replacement for repo's SegmentationModel.segment()
179
+ # ══════════════════════════════════════════════════════════════════════════════
180
+
181
+ class TransformersSam3Segmenter:
182
+ """Matches the interface of the repo's SegmentationModel for stage3a."""
183
+ model_type = "sam3"
184
+
185
+ def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
186
+ """Segment object by text prompt. Returns boolean mask."""
187
+ h, w = image_pil.height, image_pil.width
188
+ union = np.zeros((h, w), dtype=bool)
189
+
190
+ try:
191
+ inputs = sam3_processor(
192
+ images=image_pil, text=prompt, return_tensors="pt"
193
+ ).to(sam3_model.device)
194
+
195
+ with torch.no_grad():
196
+ outputs = sam3_model(**inputs)
197
+
198
+ results = sam3_processor.post_process_instance_segmentation(
199
+ outputs,
200
+ threshold=0.3,
201
+ mask_threshold=0.5,
202
+ target_sizes=inputs.get("original_sizes").tolist(),
203
+ )[0]
204
+
205
+ masks = results.get("masks")
206
+ if masks is not None and len(masks) > 0:
207
+ if torch.is_tensor(masks):
208
+ masks = masks.cpu().numpy()
209
+ if masks.ndim == 2:
210
+ union = masks.astype(bool)
211
+ elif masks.ndim == 3:
212
+ union = masks.any(axis=0).astype(bool)
213
+ elif masks.ndim == 4:
214
+ union = masks.any(axis=(0, 1)).astype(bool)
215
+ except Exception as e:
216
+ print(f" Warning: SAM3 segmentation failed for '{prompt}': {e}")
217
+
218
+ return union
219
+
220
+
221
+ seg_model = TransformersSam3Segmenter()
222
+
223
+
224
+ # ══════════════════════════════════════════════════════════════════════════════
225
+ # HELPERS
226
+ # ══════════════════════════════════════════════════════════════════════════════
227
+
228
+ def extract_frames(video_path: str, max_frames: int = MAX_FRAMES):
229
+ """Extract frames from video. Returns (frames_rgb_list, fps)."""
230
+ cap = cv2.VideoCapture(video_path)
231
+ fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
232
+ frames = []
233
+ while len(frames) < max_frames:
234
+ ret, frame = cap.read()
235
+ if not ret:
236
+ break
237
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
238
+ cap.release()
239
+ return frames, fps
240
+
241
+
242
+ def draw_points_on_image(image: np.ndarray, points: list, radius: int = 6) -> np.ndarray:
243
+ pil_img = Image.fromarray(image.copy())
244
+ draw = ImageDraw.Draw(pil_img)
245
+ for i, (x, y) in enumerate(points):
246
+ r = radius
247
+ draw.ellipse([x - r, y - r, x + r, y + r], fill="red", outline="white", width=2)
248
+ draw.text((x + r + 2, y - r), str(i + 1), fill="white")
249
+ return np.array(pil_img)
250
+
251
+
252
+ def frames_to_video(frames: list, fps: float) -> str:
253
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
254
+ tmp_path = tmp.name
255
+ tmp.close()
256
+ writer = imageio.get_writer(tmp_path, fps=fps, codec='libx264',
257
+ output_params=['-crf', '18', '-pix_fmt', 'yuv420p'])
258
+ for frame in frames:
259
+ writer.append_data(frame)
260
+ writer.close()
261
+ return tmp_path
262
+
263
+
264
+ def create_quadmask_visualization(video_path: str, quadmask_path: str) -> str:
265
+ cap_vid = cv2.VideoCapture(video_path)
266
+ cap_qm = cv2.VideoCapture(quadmask_path)
267
+ fps = cap_vid.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
268
+
269
+ vis_frames = []
270
+ while True:
271
+ ret_v, frame = cap_vid.read()
272
+ ret_q, qm_frame = cap_qm.read()
273
+ if not ret_v or not ret_q:
274
+ break
275
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
276
+ qm = cv2.cvtColor(qm_frame, cv2.COLOR_BGR2GRAY) if len(qm_frame.shape) == 3 else qm_frame
277
+
278
+ qm = np.where(qm <= 31, 0, qm)
279
+ qm = np.where((qm > 31) & (qm <= 95), 63, qm)
280
+ qm = np.where((qm > 95) & (qm <= 191), 127, qm)
281
+ qm = np.where(qm > 191, 255, qm)
282
+
283
+ overlay = frame_rgb.copy()
284
+ overlay[qm == 0] = [255, 50, 50]
285
+ overlay[qm == 63] = [255, 200, 0]
286
+ overlay[qm == 127] = [50, 255, 50]
287
+ result = cv2.addWeighted(frame_rgb, 0.5, overlay, 0.5, 0)
288
+ result[qm == 255] = frame_rgb[qm == 255]
289
+ vis_frames.append(result)
290
+
291
+ cap_vid.release()
292
+ cap_qm.release()
293
+ return frames_to_video(vis_frames, fps) if vis_frames else None
294
+
295
+
296
+ def create_quadmask_color_video(quadmask_path: str) -> str:
297
+ cap = cv2.VideoCapture(quadmask_path)
298
+ fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
299
+ color_frames = []
300
+ while True:
301
+ ret, frame = cap.read()
302
+ if not ret:
303
+ break
304
+ qm = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame
305
+ qm = np.where(qm <= 31, 0, qm)
306
+ qm = np.where((qm > 31) & (qm <= 95), 63, qm)
307
+ qm = np.where((qm > 95) & (qm <= 191), 127, qm)
308
+ qm = np.where(qm > 191, 255, qm)
309
+ h, w = qm.shape
310
+ color = np.full((h, w, 3), 255, dtype=np.uint8)
311
+ color[qm == 0] = [0, 0, 0]
312
+ color[qm == 63] = [80, 80, 80]
313
+ color[qm == 127] = [160, 160, 160]
314
+ color_frames.append(color)
315
+ cap.release()
316
+ return frames_to_video(color_frames, fps) if color_frames else None
317
+
318
+
319
+ # ══════════════════════════════════════════════════════════════════════════════
320
+ # MAIN PIPELINE
321
+ # ══════════════════════════════════════════════════════════════════════════════
322
 
323
  @spaces.GPU
324
+ def run_pipeline(video_path: str, points_json: str, instruction: str,
325
+ progress=gr.Progress(track_tqdm=False)):
326
+ """Run the full VLM-Mask-Reasoner pipeline."""
327
+ if not video_path:
328
+ raise gr.Error("Please upload a video.")
329
+ if not points_json or points_json == "[]":
330
+ raise gr.Error("Please click on the image to select at least one point on the primary object.")
331
+ if not instruction.strip():
332
+ raise gr.Error("Please enter an edit instruction.")
333
 
334
+ points = json.loads(points_json)
335
+ if len(points) == 0:
336
+ raise gr.Error("Please select at least one point on the primary object.")
337
 
338
+ api_key = os.environ.get("GEMINI_API_KEY", "")
339
+
340
+ # Create temp output directory
341
+ output_dir = Path(tempfile.mkdtemp(prefix="void_quadmask_"))
342
+ input_video_path = output_dir / "input_video.mp4"
343
+ shutil.copy2(video_path, input_video_path)
344
+
345
+ # ── Stage 1: SAM2 Segmentation ──────────────────────────────────────────
346
+ progress(0.05, desc="Stage 1: SAM2 segmentation...")
347
+ frames, fps = extract_frames(str(input_video_path))
348
+ if len(frames) < 2:
349
+ raise gr.Error("Video must have at least 2 frames.")
350
+
351
+ black_masks = stage1_segment_video(frames, points, stride=FRAME_STRIDE)
352
+ black_mask_path = output_dir / "black_mask.mp4"
353
+ write_mask_video(black_masks, fps, str(black_mask_path))
354
+
355
+ # Save first frame for VLM analysis
356
+ first_frame_path = output_dir / "first_frame.jpg"
357
+ cv2.imwrite(str(first_frame_path), cv2.cvtColor(frames[0], cv2.COLOR_RGB2BGR))
358
+
359
+ # Save segmentation metadata (Stage 2 expects this)
360
+ seg_info = {
361
+ "total_frames": len(frames),
362
+ "frame_width": frames[0].shape[1],
363
+ "frame_height": frames[0].shape[0],
364
+ "fps": fps,
365
+ "video_path": str(input_video_path),
366
+ "instruction": instruction,
367
+ "primary_points_by_frame": {"0": points},
368
+ "first_appears_frame": 0,
369
+ }
370
+ with open(output_dir / "segmentation_info.json", 'w') as f:
371
+ json.dump(seg_info, f, indent=2)
372
+
373
+ progress(0.3, desc="Stage 1 complete.")
374
+
375
+ # ── Stage 2: VLM Analysis (repo code) ───────────────────────────────────
376
+ analysis = None
377
+ if api_key:
378
+ progress(0.35, desc="Stage 2: VLM analysis (calling Gemini)...")
379
+ try:
380
+ video_info = {
381
+ "video_path": str(input_video_path),
382
+ "instruction": instruction,
383
+ "output_dir": str(output_dir),
384
+ "multi_frame_grids": True,
385
+ }
386
+ client = openai.OpenAI(
387
+ api_key=api_key,
388
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
389
+ )
390
+ analysis = vlm_process_video(video_info, client, DEFAULT_VLM_MODEL)
391
+ progress(0.55, desc="Stage 2 complete.")
392
+ except Exception as e:
393
+ gr.Warning(f"VLM analysis failed: {e}. Generating binary mask only.")
394
+ analysis = None
395
+ else:
396
+ gr.Warning("No GEMINI_API_KEY set. Generating binary mask only (no VLM analysis).")
397
+
398
+ # ── Stage 3: Grey Mask Generation (repo logic + transformers SAM3) ──────
399
+ grey_mask_path = output_dir / "grey_mask.mp4"
400
+ vlm_analysis_path = output_dir / "vlm_analysis.json"
401
+
402
+ if analysis and vlm_analysis_path.exists():
403
+ progress(0.6, desc="Stage 3: Generating grey masks (SAM3 segmentation)...")
404
+ try:
405
+ video_info_3 = {
406
+ "video_path": str(input_video_path),
407
+ "output_dir": str(output_dir),
408
+ "min_grid": 8,
409
+ }
410
+ # Uses the repo's process_video_grey_masks with our TransformersSam3Segmenter
411
+ process_video_grey_masks(video_info_3, seg_model)
412
+ progress(0.8, desc="Stage 3 complete.")
413
+ except Exception as e:
414
+ gr.Warning(f"Stage 3 failed: {e}. Generating binary mask only.")
415
+
416
+ # ── Stage 4: Combine into Quadmask (repo code) ─────────────────────────
417
+ quadmask_path = output_dir / "quadmask_0.mp4"
418
+ if grey_mask_path.exists():
419
+ progress(0.85, desc="Stage 4: Combining into quadmask...")
420
+ combine_process_video(black_mask_path, grey_mask_path, quadmask_path)
421
+ else:
422
+ shutil.copy2(black_mask_path, quadmask_path)
423
+
424
+ progress(0.9, desc="Creating visualizations...")
425
+
426
+ # ── Visualization outputs ───────────────────────────────────────────────
427
+ overlay_path = create_quadmask_visualization(str(input_video_path), str(quadmask_path))
428
+ color_path = create_quadmask_color_video(str(quadmask_path))
429
+
430
+ analysis_text = ""
431
+ if vlm_analysis_path.exists():
432
+ with open(vlm_analysis_path) as f:
433
+ analysis_text = f.read()
434
+ else:
435
+ analysis_text = "No VLM analysis available."
436
+
437
+ progress(1.0, desc="Done!")
438
+ return str(quadmask_path), overlay_path, color_path, analysis_text
439
+
440
+
441
+ # ══════════════════════════════════════════════════════════════════════════════
442
+ # GRADIO UI
443
+ # ══════════════════════════════════════════════════════════════════════════════
444
+
445
+ def on_video_upload(video_path):
446
+ if not video_path:
447
+ return None, None, "[]", gr.update(interactive=False)
448
+ frames, _ = extract_frames(video_path, max_frames=1)
449
+ if not frames:
450
+ return None, None, "[]", gr.update(interactive=False)
451
+ return frames[0], frames[0], "[]", gr.update(interactive=True)
452
+
453
+
454
+ def on_frame_select(clean_frame, points_json, evt: gr.SelectData):
455
+ if clean_frame is None:
456
+ return None, points_json
457
+ points = json.loads(points_json) if points_json else []
458
+ x, y = evt.index
459
+ points.append([int(x), int(y)])
460
+ annotated = draw_points_on_image(clean_frame, points)
461
+ return annotated, json.dumps(points)
462
+
463
+
464
+ def on_clear_points(clean_frame):
465
+ if clean_frame is not None:
466
+ return clean_frame, "[]"
467
+ return None, "[]"
468
+
469
+
470
+ DESCRIPTION = """
471
+ # VOID VLM-Mask-Reasoner β€” Quadmask Generation
472
+
473
+ Generate **4-level semantic masks** (quadmasks) for interaction-aware video inpainting with [VOID](https://github.com/Netflix/void-model).
474
+
475
+ **Pipeline:** Click points on object β†’ SAM2 segments it β†’ Gemini VLM reasons about interactions β†’ SAM3 segments affected objects β†’ Quadmask generated
476
+
477
+ Use the generated quadmask with the [VOID inpainting demo](https://huggingface.co/spaces/sam-motamed/VOID).
478
+ """
479
+
480
+ QUADMASK_EXPLAINER = """
481
+ ### Quadmask format
482
+
483
+ | Pixel Value | Color | Meaning |
484
+ |-------------|-------|---------|
485
+ | **0** (black) | Red overlay | Primary object to remove |
486
+ | **63** (dark grey) | Yellow overlay | Overlap of primary + affected zone |
487
+ | **127** (mid grey) | Green overlay | Affected region (shadows, reflections, physics) |
488
+ | **255** (white) | Original | Background β€” keep as-is |
489
+ """
490
+
491
+ with gr.Blocks(title="VOID VLM-Mask-Reasoner", theme=gr.themes.Default()) as demo:
492
+ gr.Markdown(DESCRIPTION)
493
+
494
+ points_state = gr.State("[]")
495
+ clean_frame_state = gr.State(None)
496
+
497
+ with gr.Row():
498
+ with gr.Column(scale=1):
499
+ video_input = gr.Video(label="Upload Video", sources=["upload"])
500
+ frame_display = gr.Image(
501
+ label="Click to select primary object points (click multiple spots on the object)",
502
+ interactive=True, type="numpy",
503
+ )
504
+ with gr.Row():
505
+ clear_btn = gr.Button("Clear Points", size="sm")
506
+ points_display = gr.Textbox(label="Selected Points", value="[]",
507
+ interactive=False, max_lines=2)
508
+ instruction_input = gr.Textbox(
509
+ label="Edit instruction β€” describe what to remove",
510
+ placeholder="e.g., remove the person", lines=1,
511
+ )
512
+ generate_btn = gr.Button("Generate Quadmask", variant="primary", size="lg",
513
+ interactive=False)
514
+
515
+ with gr.Column(scale=1):
516
+ output_quadmask_file = gr.File(label="Download lossless quadmask_0.mp4 (use this with VOID)")
517
+ with gr.Tabs():
518
+ with gr.TabItem("Quadmask Overlay"):
519
+ output_overlay = gr.Video(label="Quadmask overlay on original video")
520
+ with gr.TabItem("Raw Quadmask"):
521
+ output_color = gr.Video(label="Color-coded quadmask")
522
+ with gr.TabItem("VLM Analysis"):
523
+ output_analysis = gr.Code(label="VLM Analysis JSON", language="json")
524
+
525
+ video_input.change(
526
+ fn=on_video_upload, inputs=[video_input],
527
+ outputs=[frame_display, clean_frame_state, points_state, generate_btn],
528
+ )
529
+ points_state.change(lambda p: p, inputs=points_state, outputs=points_display)
530
+ frame_display.select(
531
+ fn=on_frame_select, inputs=[clean_frame_state, points_state],
532
+ outputs=[frame_display, points_state],
533
+ )
534
+ clear_btn.click(
535
+ fn=on_clear_points, inputs=[clean_frame_state],
536
+ outputs=[frame_display, points_state],
537
+ )
538
+ generate_btn.click(
539
+ fn=run_pipeline, inputs=[video_input, points_state, instruction_input],
540
+ outputs=[output_quadmask_file, output_overlay, output_color, output_analysis],
541
+ )
542
+
543
+ gr.Markdown(QUADMASK_EXPLAINER)
544
 
545
+ if __name__ == "__main__":
546
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchvision==0.24
2
+ transformers>=4.50.0
3
+ accelerate
4
+ gradio
5
+ numpy<2.0
6
+ opencv-python-headless
7
+ Pillow
8
+ imageio
9
+ imageio-ffmpeg
10
+ openai
11
+ huggingface_hub
12
+ spaces
13
+ tqdm