Spaces:
Running on Zero
Running on Zero
Commit Β·
535cf33
1
Parent(s): 0edb2a4
Add VOID VLM-Mask-Reasoner quadmask generation demo
Browse files4-stage pipeline matching Netflix VOID repo:
- Stage 1: SAM2 video segmentation (transformers Sam2VideoModel)
- Stage 2: Gemini VLM scene analysis (repo code directly)
- Stage 3: SAM3 text-prompted segmentation (transformers Sam3Model)
- Stage 4: Lossless quadmask combination (repo code directly)
Gradio UI with point-click selection, progress tracking,
overlay visualization, and lossless quadmask download.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- README.md +5 -5
- VLM-MASK-REASONER/README.md +116 -0
- VLM-MASK-REASONER/__pycache__/stage1_sam2_segmentation.cpython-312.pyc +0 -0
- VLM-MASK-REASONER/__pycache__/stage2_vlm_analysis.cpython-312.pyc +0 -0
- VLM-MASK-REASONER/__pycache__/stage3a_generate_grey_masks_v2.cpython-312.pyc +0 -0
- VLM-MASK-REASONER/__pycache__/stage4_combine_masks.cpython-312.pyc +0 -0
- VLM-MASK-REASONER/edit_quadmask.py +831 -0
- VLM-MASK-REASONER/point_selector_gui.py +601 -0
- VLM-MASK-REASONER/run_pipeline.sh +74 -0
- VLM-MASK-REASONER/stage1_sam2_segmentation.py +419 -0
- VLM-MASK-REASONER/stage2_vlm_analysis.py +1022 -0
- VLM-MASK-REASONER/stage2_vlm_analysis_cf.py +1024 -0
- VLM-MASK-REASONER/stage3a_generate_grey_masks.py +436 -0
- VLM-MASK-REASONER/stage3a_generate_grey_masks_v2.py +576 -0
- VLM-MASK-REASONER/stage3b_trajectory_gui.py +432 -0
- VLM-MASK-REASONER/stage4_combine_masks.py +241 -0
- VLM-MASK-REASONER/test_gemini_video.py +98 -0
- app.py +539 -11
- requirements.txt +13 -0
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
---
|
| 2 |
title: VOID Quadmask Reasoner
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.11.0
|
| 8 |
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: VOID Quadmask Reasoner
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.11.0
|
| 8 |
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
license: mit
|
| 12 |
+
short_description: 'VLM-Mask-Reasoner: Generate quadmasks for VOID video inpainting'
|
| 13 |
---
|
|
|
|
|
|
VLM-MASK-REASONER/README.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VLM Mask Reasoner β Mask Generation Pipeline
|
| 2 |
+
|
| 3 |
+
Generates quadmasks for video inpainting by combining user-guided SAM2 segmentation with VLM (Gemini) scene reasoning. The output `quadmask_0.mp4` encodes four semantic layers that the inpainting model uses to understand what to remove and what to preserve.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Quadmask Values
|
| 8 |
+
|
| 9 |
+
| Value | Meaning |
|
| 10 |
+
|-------|---------|
|
| 11 |
+
| `0` | Primary object (to be removed) |
|
| 12 |
+
| `63` | Overlap of primary and affected regions |
|
| 13 |
+
| `127` | Affected objects (shadows, reflections, held items) |
|
| 14 |
+
| `255` | Background β keep as-is |
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Step 1 β Select Points via GUI
|
| 19 |
+
|
| 20 |
+
Launch the point selector GUI:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
python point_selector_gui.py
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Use this GUI to place sparse click points on the object(s) you want removed.
|
| 27 |
+
|
| 28 |
+
**A few things to know:**
|
| 29 |
+
|
| 30 |
+
- Points can be placed on **any frame**, not just the first. If the object you want to remove only appears later in the video, navigate to that frame and click there.
|
| 31 |
+
- You can place points across **multiple frames** β useful when there are multiple distinct objects to remove, or when an object's position shifts significantly over time.
|
| 32 |
+
- The GUI saves your selections to a `config_points.json` file. Keep track of where this is saved β you'll pass it to the pipeline next.
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Step 2 β Run the Pipeline
|
| 37 |
+
|
| 38 |
+
Once you have your `config_points.json`, run all stages with a single command:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
bash run_pipeline.sh <config_points.json>
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Optional flags:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
bash run_pipeline.sh <config_points.json> \
|
| 48 |
+
--sam2-checkpoint ../sam2_hiera_large.pt \
|
| 49 |
+
--device cuda
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
This runs four stages automatically:
|
| 53 |
+
|
| 54 |
+
1. **Stage 1 β SAM2 Segmentation:** Propagates your point clicks into a per-frame black mask for the primary object.
|
| 55 |
+
2. **Stage 2 β VLM Analysis (Gemini):** Analyzes the scene to identify affected objects β things like shadows, reflections, or items the primary object is interacting with.
|
| 56 |
+
3. **Stage 3 β Grey Mask Generation:** Produces a grey mask track for the affected objects identified in Stage 2.
|
| 57 |
+
4. **Stage 4 β Combine into Quadmask:** Merges the black and grey masks into the final `quadmask_0.mp4`.
|
| 58 |
+
|
| 59 |
+
The output `quadmask_0.mp4` is written into each video's `output_dir` as specified in the config.
|
| 60 |
+
|
| 61 |
+
> **Note on grey values in frame 1:** The inpainting model was trained with grey-valued regions (`127`) starting from frame 1 onward β not on the very first frame. We find this convention improves inference quality, so the pipeline automatically clears any grey pixels from frame 0 of the final quadmask before saving.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## Step 3 (Optional) β Manual Mask Correction
|
| 66 |
+
|
| 67 |
+
If the generated quadmask needs refinement, you can correct it interactively:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python edit_quadmask.py
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Point the GUI to the folder containing `quadmask_0.mp4`. You can paint over regions frame-by-frame to fix any mask errors before running inference. The corrected mask is saved back to `quadmask_0.mp4` in the same folder.
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## Installation & Dependencies
|
| 78 |
+
|
| 79 |
+
### 1. Python dependencies
|
| 80 |
+
|
| 81 |
+
Install the main requirements from the repo root:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
pip install -r requirements.txt
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### 2. SAM2
|
| 88 |
+
|
| 89 |
+
SAM2 must be installed separately (it is not on PyPI):
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
pip install git+https://github.com/facebookresearch/segment-anything-2.git
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Then download the SAM2 checkpoint. The pipeline defaults to `sam2_hiera_large.pt` one level above this directory:
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
# from the repo root (or wherever you want to store it)
|
| 99 |
+
wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
If you place the checkpoint elsewhere, pass it explicitly:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
bash run_pipeline.sh config_points.json --sam2-checkpoint /path/to/sam2_hiera_large.pt
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
> SAM2 requires **Python β₯ 3.10** and **PyTorch β₯ 2.3.1** with CUDA. See the [SAM2 repo](https://github.com/facebookresearch/segment-anything-2) for full system requirements.
|
| 109 |
+
|
| 110 |
+
### 3. Gemini API key
|
| 111 |
+
|
| 112 |
+
Stage 2 uses the Gemini VLM. Export your API key before running the pipeline:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
export GEMINI_API_KEY="your_key_here"
|
| 116 |
+
```
|
VLM-MASK-REASONER/__pycache__/stage1_sam2_segmentation.cpython-312.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
VLM-MASK-REASONER/__pycache__/stage2_vlm_analysis.cpython-312.pyc
ADDED
|
Binary file (42.6 kB). View file
|
|
|
VLM-MASK-REASONER/__pycache__/stage3a_generate_grey_masks_v2.cpython-312.pyc
ADDED
|
Binary file (26 kB). View file
|
|
|
VLM-MASK-REASONER/__pycache__/stage4_combine_masks.cpython-312.pyc
ADDED
|
Binary file (9.67 kB). View file
|
|
|
VLM-MASK-REASONER/edit_quadmask.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Mask Editor GUI - Edit gridified video masks with grid toggling and brush tools
|
| 4 |
+
"""
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tkinter as tk
|
| 8 |
+
from tkinter import ttk, filedialog, messagebox
|
| 9 |
+
from PIL import Image, ImageTk
|
| 10 |
+
import subprocess
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import copy
|
| 13 |
+
import time
|
| 14 |
+
|
| 15 |
+
class MaskEditorGUI:
|
| 16 |
+
def __init__(self, root):
|
| 17 |
+
self.root = root
|
| 18 |
+
self.root.title("Mask Editor")
|
| 19 |
+
|
| 20 |
+
# Video data
|
| 21 |
+
self.rgb_frames = []
|
| 22 |
+
self.mask_frames = []
|
| 23 |
+
self.current_frame = 0
|
| 24 |
+
self.grid_rows = 0
|
| 25 |
+
self.grid_cols = 0
|
| 26 |
+
self.min_grid = 8
|
| 27 |
+
|
| 28 |
+
# Edit state
|
| 29 |
+
self.undo_stack = []
|
| 30 |
+
self.redo_stack = []
|
| 31 |
+
self.current_tool = "grid" # "grid" or "brush"
|
| 32 |
+
self.brush_size = 20
|
| 33 |
+
self.brush_mode = "add" # "add" or "erase"
|
| 34 |
+
|
| 35 |
+
# Display state
|
| 36 |
+
self.display_scale = 1.0
|
| 37 |
+
self.rgb_photo = None
|
| 38 |
+
self.mask_photo = None
|
| 39 |
+
self.dragging = False
|
| 40 |
+
self.last_brush_pos = None
|
| 41 |
+
self.last_update_time = 0
|
| 42 |
+
self.update_interval = 0.2 # Update every 200ms during dragging (5 FPS - less choppy)
|
| 43 |
+
self.cached_rgb_frame = None # Cache current RGB frame
|
| 44 |
+
self.cached_frame_idx = -1 # Track which frame is cached
|
| 45 |
+
self.pending_update = False # Track if update is needed after drag
|
| 46 |
+
self.brush_repeat_id = None # Timer for continuous brush application
|
| 47 |
+
|
| 48 |
+
# Paths
|
| 49 |
+
self.folder_path = None
|
| 50 |
+
self.mask_path = None
|
| 51 |
+
self.rgb_path = None
|
| 52 |
+
|
| 53 |
+
self.setup_ui()
|
| 54 |
+
|
| 55 |
+
def setup_ui(self):
|
| 56 |
+
"""Setup the GUI layout"""
|
| 57 |
+
# Menu bar
|
| 58 |
+
menubar = tk.Menu(self.root)
|
| 59 |
+
self.root.config(menu=menubar)
|
| 60 |
+
|
| 61 |
+
file_menu = tk.Menu(menubar, tearoff=0)
|
| 62 |
+
menubar.add_cascade(label="File", menu=file_menu)
|
| 63 |
+
file_menu.add_command(label="Open Folder", command=self.load_folder)
|
| 64 |
+
file_menu.add_command(label="Save Mask", command=self.save_mask)
|
| 65 |
+
file_menu.add_separator()
|
| 66 |
+
file_menu.add_command(label="Exit", command=self.root.quit)
|
| 67 |
+
|
| 68 |
+
edit_menu = tk.Menu(menubar, tearoff=0)
|
| 69 |
+
menubar.add_cascade(label="Edit", menu=edit_menu)
|
| 70 |
+
edit_menu.add_command(label="Undo", command=self.undo, accelerator="Ctrl+Z")
|
| 71 |
+
edit_menu.add_command(label="Redo", command=self.redo, accelerator="Ctrl+Y")
|
| 72 |
+
|
| 73 |
+
# Keyboard shortcuts
|
| 74 |
+
self.root.bind("<Control-z>", lambda e: self.undo())
|
| 75 |
+
self.root.bind("<Control-y>", lambda e: self.redo())
|
| 76 |
+
self.root.bind("<Left>", lambda e: self.prev_frame())
|
| 77 |
+
self.root.bind("<Right>", lambda e: self.next_frame())
|
| 78 |
+
|
| 79 |
+
# Top toolbar
|
| 80 |
+
toolbar = ttk.Frame(self.root)
|
| 81 |
+
toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
|
| 82 |
+
|
| 83 |
+
ttk.Label(toolbar, text="Folder:").pack(side=tk.LEFT)
|
| 84 |
+
self.folder_label = ttk.Label(toolbar, text="None", foreground="gray")
|
| 85 |
+
self.folder_label.pack(side=tk.LEFT, padx=5)
|
| 86 |
+
|
| 87 |
+
ttk.Button(toolbar, text="Open Folder", command=self.load_folder).pack(side=tk.LEFT, padx=5)
|
| 88 |
+
ttk.Button(toolbar, text="Save Mask", command=self.save_mask).pack(side=tk.LEFT, padx=5)
|
| 89 |
+
|
| 90 |
+
# Main content area
|
| 91 |
+
content = ttk.Frame(self.root)
|
| 92 |
+
content.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
|
| 93 |
+
|
| 94 |
+
# Left panel - Original video
|
| 95 |
+
left_panel = ttk.LabelFrame(content, text="Original Video")
|
| 96 |
+
left_panel.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
|
| 97 |
+
|
| 98 |
+
self.rgb_canvas = tk.Canvas(left_panel, width=640, height=480, bg='black')
|
| 99 |
+
self.rgb_canvas.pack(fill=tk.BOTH, expand=True)
|
| 100 |
+
|
| 101 |
+
# Right panel - Mask
|
| 102 |
+
right_panel = ttk.LabelFrame(content, text="Mask (Editable)")
|
| 103 |
+
right_panel.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
|
| 104 |
+
|
| 105 |
+
self.mask_canvas = tk.Canvas(right_panel, width=640, height=480, bg='black')
|
| 106 |
+
self.mask_canvas.pack(fill=tk.BOTH, expand=True)
|
| 107 |
+
self.mask_canvas.bind("<Button-1>", self.on_mask_click)
|
| 108 |
+
self.mask_canvas.bind("<B1-Motion>", self.on_mask_drag)
|
| 109 |
+
self.mask_canvas.bind("<ButtonRelease-1>", self.on_mask_release)
|
| 110 |
+
|
| 111 |
+
# Bottom controls
|
| 112 |
+
controls = ttk.Frame(self.root)
|
| 113 |
+
controls.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=5)
|
| 114 |
+
|
| 115 |
+
# Frame navigation
|
| 116 |
+
nav_frame = ttk.LabelFrame(controls, text="Frame Navigation")
|
| 117 |
+
nav_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
|
| 118 |
+
|
| 119 |
+
ttk.Button(nav_frame, text="<<", command=self.first_frame, width=5).pack(side=tk.LEFT, padx=2)
|
| 120 |
+
ttk.Button(nav_frame, text="<", command=self.prev_frame, width=5).pack(side=tk.LEFT, padx=2)
|
| 121 |
+
|
| 122 |
+
self.frame_label = ttk.Label(nav_frame, text="Frame: 0 / 0")
|
| 123 |
+
self.frame_label.pack(side=tk.LEFT, padx=10)
|
| 124 |
+
|
| 125 |
+
ttk.Button(nav_frame, text=">", command=self.next_frame, width=5).pack(side=tk.LEFT, padx=2)
|
| 126 |
+
ttk.Button(nav_frame, text=">>", command=self.last_frame, width=5).pack(side=tk.LEFT, padx=2)
|
| 127 |
+
|
| 128 |
+
self.frame_slider = ttk.Scale(nav_frame, from_=0, to=100, orient=tk.HORIZONTAL,
|
| 129 |
+
command=self.on_slider_change)
|
| 130 |
+
self.frame_slider.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=10)
|
| 131 |
+
|
| 132 |
+
# Tool selection
|
| 133 |
+
tool_frame = ttk.LabelFrame(controls, text="Tools")
|
| 134 |
+
tool_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
|
| 135 |
+
|
| 136 |
+
self.tool_var = tk.StringVar(value="grid")
|
| 137 |
+
ttk.Radiobutton(tool_frame, text="Grid Toggle", variable=self.tool_var,
|
| 138 |
+
value="grid", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
|
| 139 |
+
ttk.Radiobutton(tool_frame, text="Grid Black Toggle", variable=self.tool_var,
|
| 140 |
+
value="grid_black", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
|
| 141 |
+
ttk.Radiobutton(tool_frame, text="Brush (Add Black)", variable=self.tool_var,
|
| 142 |
+
value="brush_add", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
|
| 143 |
+
ttk.Radiobutton(tool_frame, text="Brush (Erase Black)", variable=self.tool_var,
|
| 144 |
+
value="brush_erase", command=self.on_tool_change).pack(side=tk.LEFT, padx=5)
|
| 145 |
+
|
| 146 |
+
ttk.Label(tool_frame, text="Brush Size:").pack(side=tk.LEFT, padx=10)
|
| 147 |
+
self.brush_slider = ttk.Scale(tool_frame, from_=5, to=100, orient=tk.HORIZONTAL,
|
| 148 |
+
command=self.on_brush_size_change)
|
| 149 |
+
self.brush_slider.set(20)
|
| 150 |
+
self.brush_slider.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5)
|
| 151 |
+
self.brush_size_label = ttk.Label(tool_frame, text="20")
|
| 152 |
+
self.brush_size_label.pack(side=tk.LEFT, padx=5)
|
| 153 |
+
|
| 154 |
+
# Copy from previous frame
|
| 155 |
+
copy_frame = ttk.LabelFrame(controls, text="Copy from Previous Frame")
|
| 156 |
+
copy_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
|
| 157 |
+
|
| 158 |
+
ttk.Button(copy_frame, text="Copy Black Mask",
|
| 159 |
+
command=self.copy_black_from_previous).pack(side=tk.LEFT, padx=5)
|
| 160 |
+
ttk.Button(copy_frame, text="Copy Grey Mask",
|
| 161 |
+
command=self.copy_grey_from_previous).pack(side=tk.LEFT, padx=5)
|
| 162 |
+
|
| 163 |
+
# Info panel
|
| 164 |
+
info_frame = ttk.Frame(controls)
|
| 165 |
+
info_frame.pack(side=tk.TOP, fill=tk.X, pady=5)
|
| 166 |
+
|
| 167 |
+
self.info_label = ttk.Label(info_frame, text="Load a folder to begin", foreground="blue")
|
| 168 |
+
self.info_label.pack(side=tk.LEFT, padx=5)
|
| 169 |
+
|
| 170 |
+
self.grid_info_label = ttk.Label(info_frame, text="Grid: N/A")
|
| 171 |
+
self.grid_info_label.pack(side=tk.RIGHT, padx=5)
|
| 172 |
+
|
| 173 |
+
def calculate_square_grid(self, width, height, min_grid=8):
|
| 174 |
+
"""Calculate grid dimensions to make square cells"""
|
| 175 |
+
aspect_ratio = width / height
|
| 176 |
+
|
| 177 |
+
if width >= height:
|
| 178 |
+
grid_rows = min_grid
|
| 179 |
+
grid_cols = max(min_grid, round(min_grid * aspect_ratio))
|
| 180 |
+
else:
|
| 181 |
+
grid_cols = min_grid
|
| 182 |
+
grid_rows = max(min_grid, round(min_grid / aspect_ratio))
|
| 183 |
+
|
| 184 |
+
return grid_rows, grid_cols
|
| 185 |
+
|
| 186 |
+
def load_folder(self):
|
| 187 |
+
"""Load a folder containing rgb_full.mp4/input_video.mp4 and quadmask_0.mp4"""
|
| 188 |
+
folder = filedialog.askdirectory(title="Select Folder")
|
| 189 |
+
if not folder:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
folder_path = Path(folder)
|
| 193 |
+
|
| 194 |
+
# Find RGB video
|
| 195 |
+
rgb_path = None
|
| 196 |
+
for name in ["rgb_full.mp4", "input_video.mp4"]:
|
| 197 |
+
candidate = folder_path / name
|
| 198 |
+
if candidate.exists():
|
| 199 |
+
rgb_path = candidate
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
mask_path = folder_path / "quadmask_0.mp4"
|
| 203 |
+
|
| 204 |
+
if not rgb_path or not mask_path.exists():
|
| 205 |
+
messagebox.showerror("Error", "Folder must contain quadmask_0.mp4 and rgb_full.mp4 or input_video.mp4")
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
self.folder_path = folder_path
|
| 209 |
+
self.rgb_path = rgb_path
|
| 210 |
+
self.mask_path = mask_path
|
| 211 |
+
|
| 212 |
+
# Load videos
|
| 213 |
+
self.load_videos()
|
| 214 |
+
|
| 215 |
+
def load_videos(self):
|
| 216 |
+
"""Load RGB and mask videos into memory"""
|
| 217 |
+
self.info_label.config(text="Loading videos...")
|
| 218 |
+
self.root.update()
|
| 219 |
+
|
| 220 |
+
# Load RGB frames
|
| 221 |
+
self.rgb_frames = self.read_video_frames(self.rgb_path)
|
| 222 |
+
|
| 223 |
+
# Load mask frames
|
| 224 |
+
self.mask_frames = self.read_video_frames(self.mask_path)
|
| 225 |
+
|
| 226 |
+
if len(self.rgb_frames) != len(self.mask_frames):
|
| 227 |
+
messagebox.showwarning("Warning",
|
| 228 |
+
f"Frame count mismatch: RGB={len(self.rgb_frames)}, Mask={len(self.mask_frames)}")
|
| 229 |
+
|
| 230 |
+
if len(self.mask_frames) == 0:
|
| 231 |
+
messagebox.showerror("Error", "No frames loaded")
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
# Calculate grid dimensions
|
| 235 |
+
height, width = self.mask_frames[0].shape[:2]
|
| 236 |
+
self.grid_rows, self.grid_cols = self.calculate_square_grid(width, height, self.min_grid)
|
| 237 |
+
|
| 238 |
+
# Calculate display scale
|
| 239 |
+
max_width = 600
|
| 240 |
+
max_height = 450
|
| 241 |
+
scale_w = max_width / width
|
| 242 |
+
scale_h = max_height / height
|
| 243 |
+
self.display_scale = min(scale_w, scale_h, 1.0)
|
| 244 |
+
|
| 245 |
+
# Update UI
|
| 246 |
+
self.folder_label.config(text=self.folder_path.name, foreground="black")
|
| 247 |
+
self.grid_info_label.config(text=f"Grid: {self.grid_rows}x{self.grid_cols}")
|
| 248 |
+
self.frame_slider.config(to=len(self.mask_frames)-1)
|
| 249 |
+
self.current_frame = 0
|
| 250 |
+
self.undo_stack = []
|
| 251 |
+
self.redo_stack = []
|
| 252 |
+
|
| 253 |
+
self.update_display()
|
| 254 |
+
self.info_label.config(text=f"Loaded {len(self.mask_frames)} frames", foreground="green")
|
| 255 |
+
|
| 256 |
+
def read_video_frames(self, video_path):
|
| 257 |
+
"""Read all frames from a video"""
|
| 258 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 259 |
+
frames = []
|
| 260 |
+
while True:
|
| 261 |
+
ret, frame = cap.read()
|
| 262 |
+
if not ret:
|
| 263 |
+
break
|
| 264 |
+
# Convert to grayscale if needed
|
| 265 |
+
if len(frame.shape) == 3:
|
| 266 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 267 |
+
frames.append(frame)
|
| 268 |
+
cap.release()
|
| 269 |
+
return frames
|
| 270 |
+
|
| 271 |
+
def write_video_frames(self, frames, output_path, fps=12):
|
| 272 |
+
"""Write frames to a video file using lossless H.264"""
|
| 273 |
+
if not frames:
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
height, width = frames[0].shape[:2]
|
| 277 |
+
|
| 278 |
+
# Write temp AVI first
|
| 279 |
+
temp_avi = output_path.with_suffix('.avi')
|
| 280 |
+
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
|
| 281 |
+
out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (width, height), isColor=False)
|
| 282 |
+
|
| 283 |
+
for frame in frames:
|
| 284 |
+
if len(frame.shape) == 3:
|
| 285 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 286 |
+
out.write(frame)
|
| 287 |
+
|
| 288 |
+
out.release()
|
| 289 |
+
|
| 290 |
+
# Convert to LOSSLESS H.264 (qp=0)
|
| 291 |
+
cmd = [
|
| 292 |
+
'ffmpeg', '-y', '-i', str(temp_avi),
|
| 293 |
+
'-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
|
| 294 |
+
'-pix_fmt', 'yuv444p', '-r', '12',
|
| 295 |
+
str(output_path)
|
| 296 |
+
]
|
| 297 |
+
subprocess.run(cmd, capture_output=True)
|
| 298 |
+
temp_avi.unlink()
|
| 299 |
+
|
| 300 |
+
def update_display(self, fast_mode=False):
|
| 301 |
+
"""Update both canvas displays (or just mask in fast mode)"""
|
| 302 |
+
if not self.mask_frames:
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
# Cache RGB frame if needed (only in full mode)
|
| 306 |
+
if not fast_mode and self.cached_frame_idx != self.current_frame:
|
| 307 |
+
if self.current_frame < len(self.rgb_frames):
|
| 308 |
+
rgb_frame = self.rgb_frames[self.current_frame]
|
| 309 |
+
self.cached_rgb_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_GRAY2RGB) if len(rgb_frame.shape) == 2 else rgb_frame.copy()
|
| 310 |
+
self.cached_frame_idx = self.current_frame
|
| 311 |
+
else:
|
| 312 |
+
self.cached_rgb_frame = None
|
| 313 |
+
|
| 314 |
+
if not fast_mode:
|
| 315 |
+
# Update frame label
|
| 316 |
+
self.frame_label.config(text=f"Frame: {self.current_frame + 1} / {len(self.mask_frames)}")
|
| 317 |
+
self.frame_slider.set(self.current_frame)
|
| 318 |
+
|
| 319 |
+
# Display RGB frame
|
| 320 |
+
if self.cached_rgb_frame is not None:
|
| 321 |
+
rgb_display = cv2.resize(self.cached_rgb_frame, None, fx=self.display_scale, fy=self.display_scale)
|
| 322 |
+
rgb_image = Image.fromarray(rgb_display)
|
| 323 |
+
self.rgb_photo = ImageTk.PhotoImage(rgb_image)
|
| 324 |
+
self.rgb_canvas.delete("all")
|
| 325 |
+
self.rgb_canvas.create_image(0, 0, anchor=tk.NW, image=self.rgb_photo)
|
| 326 |
+
|
| 327 |
+
# Display mask frame with grid overlay
|
| 328 |
+
mask_frame = self.mask_frames[self.current_frame]
|
| 329 |
+
# Use simple mode (no RGB blending) during fast updates for speed
|
| 330 |
+
mask_display = self.create_mask_visualization(mask_frame, self.cached_rgb_frame, simple_mode=fast_mode)
|
| 331 |
+
mask_display = cv2.resize(mask_display, None, fx=self.display_scale, fy=self.display_scale)
|
| 332 |
+
mask_image = Image.fromarray(mask_display)
|
| 333 |
+
self.mask_photo = ImageTk.PhotoImage(mask_image)
|
| 334 |
+
self.mask_canvas.delete("all")
|
| 335 |
+
self.mask_canvas.create_image(0, 0, anchor=tk.NW, image=self.mask_photo)
|
| 336 |
+
|
| 337 |
+
# Draw grid overlay (skip in fast mode for performance)
|
| 338 |
+
if not fast_mode:
|
| 339 |
+
self.draw_grid_overlay()
|
| 340 |
+
|
| 341 |
+
def create_mask_visualization(self, mask_frame, rgb_frame=None, simple_mode=False):
|
| 342 |
+
"""Create RGB visualization of mask with color coding and RGB background"""
|
| 343 |
+
height, width = mask_frame.shape
|
| 344 |
+
vis = np.zeros((height, width, 3), dtype=np.uint8)
|
| 345 |
+
|
| 346 |
+
if simple_mode:
|
| 347 |
+
# Fast simple mode - no blending, just solid colors
|
| 348 |
+
vis[mask_frame == 255] = [150, 150, 150] # Background - gray
|
| 349 |
+
vis[mask_frame == 127] = [0, 200, 0] # Gridified - green
|
| 350 |
+
vis[mask_frame == 63] = [200, 200, 0] # Overlap - yellow
|
| 351 |
+
vis[mask_frame == 0] = [200, 0, 0] # Black - red
|
| 352 |
+
return vis
|
| 353 |
+
|
| 354 |
+
# If RGB frame is provided, use it as background
|
| 355 |
+
if rgb_frame is not None:
|
| 356 |
+
if len(rgb_frame.shape) == 2:
|
| 357 |
+
# Convert grayscale to RGB
|
| 358 |
+
rgb_background = cv2.cvtColor(rgb_frame, cv2.COLOR_GRAY2RGB)
|
| 359 |
+
else:
|
| 360 |
+
rgb_background = rgb_frame.copy()
|
| 361 |
+
|
| 362 |
+
# Use it at 50% opacity as base
|
| 363 |
+
vis = (rgb_background * 0.5).astype(np.uint8)
|
| 364 |
+
|
| 365 |
+
# Color coding with transparency to show background:
|
| 366 |
+
# 0 (black) -> Red tint (to indicate removal area)
|
| 367 |
+
# 63 (overlap) -> Yellow tint
|
| 368 |
+
# 127 (gridified) -> Green tint
|
| 369 |
+
# 255 (background) -> Keep RGB background visible
|
| 370 |
+
|
| 371 |
+
# Background areas - show RGB at 60% brightness
|
| 372 |
+
bg_mask = mask_frame == 255
|
| 373 |
+
if rgb_frame is not None:
|
| 374 |
+
vis[bg_mask] = (rgb_background[bg_mask] * 0.6).astype(np.uint8)
|
| 375 |
+
else:
|
| 376 |
+
vis[bg_mask] = [150, 150, 150]
|
| 377 |
+
|
| 378 |
+
# Green overlay for gridified areas - blend 40% background + 60% green tint
|
| 379 |
+
green_mask = mask_frame == 127
|
| 380 |
+
if rgb_frame is not None:
|
| 381 |
+
vis[green_mask] = np.clip(rgb_background[green_mask] * 0.4 + np.array([0, 180, 0]) * 0.6, 0, 255).astype(np.uint8)
|
| 382 |
+
else:
|
| 383 |
+
vis[green_mask] = [0, 200, 0]
|
| 384 |
+
|
| 385 |
+
# Yellow overlay for overlap areas - blend 40% background + 60% yellow tint
|
| 386 |
+
yellow_mask = mask_frame == 63
|
| 387 |
+
if rgb_frame is not None:
|
| 388 |
+
vis[yellow_mask] = np.clip(rgb_background[yellow_mask] * 0.4 + np.array([180, 180, 0]) * 0.6, 0, 255).astype(np.uint8)
|
| 389 |
+
else:
|
| 390 |
+
vis[yellow_mask] = [200, 200, 0]
|
| 391 |
+
|
| 392 |
+
# Red tint for black areas (removal) - blend 30% background + 70% red tint
|
| 393 |
+
black_mask = mask_frame == 0
|
| 394 |
+
if rgb_frame is not None:
|
| 395 |
+
vis[black_mask] = np.clip(rgb_background[black_mask] * 0.3 + np.array([200, 0, 0]) * 0.7, 0, 255).astype(np.uint8)
|
| 396 |
+
else:
|
| 397 |
+
vis[black_mask] = [200, 0, 0]
|
| 398 |
+
|
| 399 |
+
return vis
|
| 400 |
+
|
| 401 |
+
def draw_grid_overlay(self):
|
| 402 |
+
"""Draw grid lines on mask canvas"""
|
| 403 |
+
if not self.mask_frames:
|
| 404 |
+
return
|
| 405 |
+
|
| 406 |
+
height, width = self.mask_frames[0].shape
|
| 407 |
+
|
| 408 |
+
scaled_width = int(width * self.display_scale)
|
| 409 |
+
scaled_height = int(height * self.display_scale)
|
| 410 |
+
|
| 411 |
+
cell_width = scaled_width / self.grid_cols
|
| 412 |
+
cell_height = scaled_height / self.grid_rows
|
| 413 |
+
|
| 414 |
+
# Draw vertical lines
|
| 415 |
+
for col in range(self.grid_cols + 1):
|
| 416 |
+
x = int(col * cell_width)
|
| 417 |
+
self.mask_canvas.create_line(x, 0, x, scaled_height, fill='red', width=1, tags='grid')
|
| 418 |
+
|
| 419 |
+
# Draw horizontal lines
|
| 420 |
+
for row in range(self.grid_rows + 1):
|
| 421 |
+
y = int(row * cell_height)
|
| 422 |
+
self.mask_canvas.create_line(0, y, scaled_width, y, fill='red', width=1, tags='grid')
|
| 423 |
+
|
| 424 |
+
def get_grid_from_pos(self, x, y):
|
| 425 |
+
"""Get grid row, col from canvas position"""
|
| 426 |
+
if not self.mask_frames:
|
| 427 |
+
return None, None
|
| 428 |
+
|
| 429 |
+
height, width = self.mask_frames[0].shape
|
| 430 |
+
|
| 431 |
+
# Convert to frame coordinates
|
| 432 |
+
frame_x = int(x / self.display_scale)
|
| 433 |
+
frame_y = int(y / self.display_scale)
|
| 434 |
+
|
| 435 |
+
if frame_x < 0 or frame_x >= width or frame_y < 0 or frame_y >= height:
|
| 436 |
+
return None, None
|
| 437 |
+
|
| 438 |
+
cell_width = width / self.grid_cols
|
| 439 |
+
cell_height = height / self.grid_rows
|
| 440 |
+
|
| 441 |
+
col = int(frame_x / cell_width)
|
| 442 |
+
row = int(frame_y / cell_height)
|
| 443 |
+
|
| 444 |
+
return row, col
|
| 445 |
+
|
| 446 |
+
def toggle_grid(self, row, col):
|
| 447 |
+
"""Toggle a grid cell between 127 and 255, handling 63 overlaps"""
|
| 448 |
+
if row is None or col is None:
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
if row < 0 or row >= self.grid_rows or col < 0 or col >= self.grid_cols:
|
| 452 |
+
return
|
| 453 |
+
|
| 454 |
+
# Save state for undo
|
| 455 |
+
self.save_state()
|
| 456 |
+
|
| 457 |
+
mask = self.mask_frames[self.current_frame]
|
| 458 |
+
height, width = mask.shape
|
| 459 |
+
|
| 460 |
+
cell_width = width / self.grid_cols
|
| 461 |
+
cell_height = height / self.grid_rows
|
| 462 |
+
|
| 463 |
+
y1 = int(row * cell_height)
|
| 464 |
+
y2 = int((row + 1) * cell_height)
|
| 465 |
+
x1 = int(col * cell_width)
|
| 466 |
+
x2 = int((col + 1) * cell_width)
|
| 467 |
+
|
| 468 |
+
grid_region = mask[y1:y2, x1:x2]
|
| 469 |
+
|
| 470 |
+
# Check if grid has any 127 or 63 values
|
| 471 |
+
has_active = np.any((grid_region == 127) | (grid_region == 63))
|
| 472 |
+
|
| 473 |
+
if has_active:
|
| 474 |
+
# Turn OFF: 127->255, 63->0, keep 0 and 255 as is
|
| 475 |
+
mask[y1:y2, x1:x2] = np.where(grid_region == 127, 255,
|
| 476 |
+
np.where(grid_region == 63, 0, grid_region))
|
| 477 |
+
else:
|
| 478 |
+
# Turn ON: 255->127, 0->63, keep others as is
|
| 479 |
+
mask[y1:y2, x1:x2] = np.where(grid_region == 255, 127,
|
| 480 |
+
np.where(grid_region == 0, 63, grid_region))
|
| 481 |
+
|
| 482 |
+
self.update_display()
|
| 483 |
+
|
| 484 |
+
def toggle_grid_black(self, row, col):
|
| 485 |
+
"""Toggle black mask in a grid cell"""
|
| 486 |
+
if row is None or col is None:
|
| 487 |
+
return
|
| 488 |
+
|
| 489 |
+
if row < 0 or row >= self.grid_rows or col < 0 or col >= self.grid_cols:
|
| 490 |
+
return
|
| 491 |
+
|
| 492 |
+
# Save state for undo
|
| 493 |
+
self.save_state()
|
| 494 |
+
|
| 495 |
+
mask = self.mask_frames[self.current_frame]
|
| 496 |
+
height, width = mask.shape
|
| 497 |
+
|
| 498 |
+
cell_width = width / self.grid_cols
|
| 499 |
+
cell_height = height / self.grid_rows
|
| 500 |
+
|
| 501 |
+
y1 = int(row * cell_height)
|
| 502 |
+
y2 = int((row + 1) * cell_height)
|
| 503 |
+
x1 = int(col * cell_width)
|
| 504 |
+
x2 = int((col + 1) * cell_width)
|
| 505 |
+
|
| 506 |
+
grid_region = mask[y1:y2, x1:x2]
|
| 507 |
+
|
| 508 |
+
# Check if grid has any black (0 or 63 values)
|
| 509 |
+
has_black = np.any((grid_region == 0) | (grid_region == 63))
|
| 510 |
+
|
| 511 |
+
if has_black:
|
| 512 |
+
# Turn OFF black: 0->255, 63->127, keep 127 and 255 as is
|
| 513 |
+
mask[y1:y2, x1:x2] = np.where(grid_region == 0, 255,
|
| 514 |
+
np.where(grid_region == 63, 127, grid_region))
|
| 515 |
+
else:
|
| 516 |
+
# Turn ON black: 255->0, 127->63, keep others as is
|
| 517 |
+
mask[y1:y2, x1:x2] = np.where(grid_region == 255, 0,
|
| 518 |
+
np.where(grid_region == 127, 63, grid_region))
|
| 519 |
+
|
| 520 |
+
self.update_display()
|
| 521 |
+
|
| 522 |
+
def apply_brush(self, x, y, mode="add"):
|
| 523 |
+
"""Apply brush to add/erase black mask (vectorized for speed)"""
|
| 524 |
+
if not self.mask_frames:
|
| 525 |
+
return
|
| 526 |
+
|
| 527 |
+
mask = self.mask_frames[self.current_frame]
|
| 528 |
+
height, width = mask.shape
|
| 529 |
+
|
| 530 |
+
# Convert to frame coordinates
|
| 531 |
+
frame_x = int(x / self.display_scale)
|
| 532 |
+
frame_y = int(y / self.display_scale)
|
| 533 |
+
|
| 534 |
+
if frame_x < 0 or frame_x >= width or frame_y < 0 or frame_y >= height:
|
| 535 |
+
return
|
| 536 |
+
|
| 537 |
+
# Create circular brush using vectorized operations
|
| 538 |
+
radius = int(self.brush_size / 2)
|
| 539 |
+
|
| 540 |
+
y1 = max(0, frame_y - radius)
|
| 541 |
+
y2 = min(height, frame_y + radius + 1)
|
| 542 |
+
x1 = max(0, frame_x - radius)
|
| 543 |
+
x2 = min(width, frame_x + radius + 1)
|
| 544 |
+
|
| 545 |
+
# Get the region
|
| 546 |
+
region = mask[y1:y2, x1:x2]
|
| 547 |
+
|
| 548 |
+
# Create coordinate grids for distance calculation
|
| 549 |
+
yy, xx = np.ogrid[y1:y2, x1:x2]
|
| 550 |
+
dist = np.sqrt((xx - frame_x)**2 + (yy - frame_y)**2)
|
| 551 |
+
brush_mask = dist <= radius
|
| 552 |
+
|
| 553 |
+
if mode == "add":
|
| 554 |
+
# Add black: 255->0, 127->63
|
| 555 |
+
region[brush_mask & (region == 255)] = 0
|
| 556 |
+
region[brush_mask & (region == 127)] = 63
|
| 557 |
+
else: # erase
|
| 558 |
+
# Erase black: 0->255, 63->127
|
| 559 |
+
region[brush_mask & (region == 0)] = 255
|
| 560 |
+
region[brush_mask & (region == 63)] = 127
|
| 561 |
+
|
| 562 |
+
def on_mask_click(self, event):
|
| 563 |
+
"""Handle click on mask canvas"""
|
| 564 |
+
if not self.mask_frames:
|
| 565 |
+
return
|
| 566 |
+
|
| 567 |
+
tool = self.tool_var.get()
|
| 568 |
+
|
| 569 |
+
if tool == "grid":
|
| 570 |
+
row, col = self.get_grid_from_pos(event.x, event.y)
|
| 571 |
+
self.toggle_grid(row, col)
|
| 572 |
+
elif tool == "grid_black":
|
| 573 |
+
row, col = self.get_grid_from_pos(event.x, event.y)
|
| 574 |
+
self.toggle_grid_black(row, col)
|
| 575 |
+
elif tool in ["brush_add", "brush_erase"]:
|
| 576 |
+
self.save_state()
|
| 577 |
+
mode = "add" if tool == "brush_add" else "erase"
|
| 578 |
+
self.apply_brush(event.x, event.y, mode)
|
| 579 |
+
self.dragging = True
|
| 580 |
+
self.last_brush_pos = (event.x, event.y)
|
| 581 |
+
self.last_update_time = time.time()
|
| 582 |
+
self.update_display(fast_mode=True)
|
| 583 |
+
# Start continuous brush application
|
| 584 |
+
self.schedule_brush_repeat()
|
| 585 |
+
|
| 586 |
+
def on_mask_drag(self, event):
|
| 587 |
+
"""Handle drag on mask canvas with throttled updates"""
|
| 588 |
+
if not self.dragging:
|
| 589 |
+
return
|
| 590 |
+
|
| 591 |
+
tool = self.tool_var.get()
|
| 592 |
+
if tool in ["brush_add", "brush_erase"]:
|
| 593 |
+
# Update brush position when moving
|
| 594 |
+
self.last_brush_pos = (event.x, event.y)
|
| 595 |
+
mode = "add" if tool == "brush_add" else "erase"
|
| 596 |
+
self.apply_brush(event.x, event.y, mode)
|
| 597 |
+
|
| 598 |
+
# Only update display if enough time has passed (fast mode - no grid)
|
| 599 |
+
current_time = time.time()
|
| 600 |
+
if current_time - self.last_update_time >= self.update_interval:
|
| 601 |
+
self.update_display(fast_mode=True)
|
| 602 |
+
self.last_update_time = current_time
|
| 603 |
+
|
| 604 |
+
def on_mask_release(self, event):
|
| 605 |
+
"""Handle release on mask canvas"""
|
| 606 |
+
self.dragging = False
|
| 607 |
+
self.last_brush_pos = None
|
| 608 |
+
# Cancel continuous brush application
|
| 609 |
+
if self.brush_repeat_id:
|
| 610 |
+
self.root.after_cancel(self.brush_repeat_id)
|
| 611 |
+
self.brush_repeat_id = None
|
| 612 |
+
# Final full update when releasing to show the complete result with blending
|
| 613 |
+
self.update_display(fast_mode=False)
|
| 614 |
+
|
| 615 |
+
def schedule_brush_repeat(self):
|
| 616 |
+
"""Schedule continuous brush application while mouse is held down"""
|
| 617 |
+
if self.dragging and self.last_brush_pos:
|
| 618 |
+
tool = self.tool_var.get()
|
| 619 |
+
if tool in ["brush_add", "brush_erase"]:
|
| 620 |
+
mode = "add" if tool == "brush_add" else "erase"
|
| 621 |
+
x, y = self.last_brush_pos
|
| 622 |
+
self.apply_brush(x, y, mode)
|
| 623 |
+
|
| 624 |
+
# Update display if enough time has passed
|
| 625 |
+
current_time = time.time()
|
| 626 |
+
if current_time - self.last_update_time >= self.update_interval:
|
| 627 |
+
self.update_display(fast_mode=True)
|
| 628 |
+
self.last_update_time = current_time
|
| 629 |
+
|
| 630 |
+
# Schedule next application (every 30ms for smooth continuous painting)
|
| 631 |
+
self.brush_repeat_id = self.root.after(30, self.schedule_brush_repeat)
|
| 632 |
+
|
| 633 |
+
def copy_black_from_previous(self):
|
| 634 |
+
"""Copy ONLY black component from previous frame, preserving grey in current frame"""
|
| 635 |
+
if not self.mask_frames:
|
| 636 |
+
messagebox.showwarning("Warning", "No mask loaded")
|
| 637 |
+
return
|
| 638 |
+
|
| 639 |
+
if self.current_frame == 0:
|
| 640 |
+
messagebox.showwarning("Warning", "Cannot copy from previous frame - already at first frame")
|
| 641 |
+
return
|
| 642 |
+
|
| 643 |
+
# Save state for undo
|
| 644 |
+
self.save_state()
|
| 645 |
+
|
| 646 |
+
prev_mask = self.mask_frames[self.current_frame - 1]
|
| 647 |
+
curr_mask = self.mask_frames[self.current_frame]
|
| 648 |
+
|
| 649 |
+
# Copy ONLY the black component from previous frame
|
| 650 |
+
# Where prev has black (0 or 63): add black to curr
|
| 651 |
+
# Where prev doesn't have black (127 or 255): remove black from curr
|
| 652 |
+
|
| 653 |
+
has_black_in_prev = (prev_mask == 0) | (prev_mask == 63)
|
| 654 |
+
no_black_in_prev = (prev_mask == 127) | (prev_mask == 255)
|
| 655 |
+
|
| 656 |
+
# Remove black where prev doesn't have it (preserve grey)
|
| 657 |
+
curr_mask[no_black_in_prev & (curr_mask == 0)] = 255 # 0 β 255
|
| 658 |
+
curr_mask[no_black_in_prev & (curr_mask == 63)] = 127 # 63 β 127 (keep grey)
|
| 659 |
+
|
| 660 |
+
# Add black where prev has it (preserve grey)
|
| 661 |
+
curr_mask[has_black_in_prev & (curr_mask == 255)] = 0 # 255 β 0
|
| 662 |
+
curr_mask[has_black_in_prev & (curr_mask == 127)] = 63 # 127 β 63 (keep grey, add black)
|
| 663 |
+
|
| 664 |
+
self.update_display()
|
| 665 |
+
self.info_label.config(text="Copied black mask from previous frame", foreground="green")
|
| 666 |
+
|
| 667 |
+
def copy_grey_from_previous(self):
|
| 668 |
+
"""Copy ONLY grey component from previous frame, preserving black in current frame"""
|
| 669 |
+
if not self.mask_frames:
|
| 670 |
+
messagebox.showwarning("Warning", "No mask loaded")
|
| 671 |
+
return
|
| 672 |
+
|
| 673 |
+
if self.current_frame == 0:
|
| 674 |
+
messagebox.showwarning("Warning", "Cannot copy from previous frame - already at first frame")
|
| 675 |
+
return
|
| 676 |
+
|
| 677 |
+
# Save state for undo
|
| 678 |
+
self.save_state()
|
| 679 |
+
|
| 680 |
+
prev_mask = self.mask_frames[self.current_frame - 1]
|
| 681 |
+
curr_mask = self.mask_frames[self.current_frame]
|
| 682 |
+
|
| 683 |
+
# Copy ONLY the grey component from previous frame
|
| 684 |
+
# Where prev has grey (127 or 63): add grey to curr
|
| 685 |
+
# Where prev doesn't have grey (0 or 255): remove grey from curr
|
| 686 |
+
|
| 687 |
+
has_grey_in_prev = (prev_mask == 127) | (prev_mask == 63)
|
| 688 |
+
no_grey_in_prev = (prev_mask == 0) | (prev_mask == 255)
|
| 689 |
+
|
| 690 |
+
# Remove grey where prev doesn't have it (preserve black)
|
| 691 |
+
curr_mask[no_grey_in_prev & (curr_mask == 127)] = 255 # 127 β 255
|
| 692 |
+
curr_mask[no_grey_in_prev & (curr_mask == 63)] = 0 # 63 β 0 (keep black)
|
| 693 |
+
|
| 694 |
+
# Add grey where prev has it (preserve black)
|
| 695 |
+
curr_mask[has_grey_in_prev & (curr_mask == 255)] = 127 # 255 β 127
|
| 696 |
+
curr_mask[has_grey_in_prev & (curr_mask == 0)] = 63 # 0 β 63 (keep black, add grey)
|
| 697 |
+
|
| 698 |
+
self.update_display()
|
| 699 |
+
self.info_label.config(text="Copied grey mask from previous frame", foreground="green")
|
| 700 |
+
|
| 701 |
+
def save_state(self):
|
| 702 |
+
"""Save current state for undo"""
|
| 703 |
+
if not self.mask_frames:
|
| 704 |
+
return
|
| 705 |
+
|
| 706 |
+
# Save deep copy of current frame
|
| 707 |
+
state = {
|
| 708 |
+
'frame': self.current_frame,
|
| 709 |
+
'mask': self.mask_frames[self.current_frame].copy()
|
| 710 |
+
}
|
| 711 |
+
self.undo_stack.append(state)
|
| 712 |
+
self.redo_stack.clear()
|
| 713 |
+
|
| 714 |
+
# Limit undo stack size
|
| 715 |
+
if len(self.undo_stack) > 50:
|
| 716 |
+
self.undo_stack.pop(0)
|
| 717 |
+
|
| 718 |
+
def undo(self):
|
| 719 |
+
"""Undo last edit"""
|
| 720 |
+
if not self.undo_stack:
|
| 721 |
+
return
|
| 722 |
+
|
| 723 |
+
# Save current state to redo
|
| 724 |
+
redo_state = {
|
| 725 |
+
'frame': self.current_frame,
|
| 726 |
+
'mask': self.mask_frames[self.current_frame].copy()
|
| 727 |
+
}
|
| 728 |
+
self.redo_stack.append(redo_state)
|
| 729 |
+
|
| 730 |
+
# Restore previous state
|
| 731 |
+
state = self.undo_stack.pop()
|
| 732 |
+
self.current_frame = state['frame']
|
| 733 |
+
self.mask_frames[self.current_frame] = state['mask']
|
| 734 |
+
|
| 735 |
+
self.update_display()
|
| 736 |
+
|
| 737 |
+
def redo(self):
|
| 738 |
+
"""Redo last undone edit"""
|
| 739 |
+
if not self.redo_stack:
|
| 740 |
+
return
|
| 741 |
+
|
| 742 |
+
# Save current state to undo
|
| 743 |
+
undo_state = {
|
| 744 |
+
'frame': self.current_frame,
|
| 745 |
+
'mask': self.mask_frames[self.current_frame].copy()
|
| 746 |
+
}
|
| 747 |
+
self.undo_stack.append(undo_state)
|
| 748 |
+
|
| 749 |
+
# Restore redo state
|
| 750 |
+
state = self.redo_stack.pop()
|
| 751 |
+
self.current_frame = state['frame']
|
| 752 |
+
self.mask_frames[self.current_frame] = state['mask']
|
| 753 |
+
|
| 754 |
+
self.update_display()
|
| 755 |
+
|
| 756 |
+
def save_mask(self):
|
| 757 |
+
"""Save edited mask back to quadmask_0.mp4"""
|
| 758 |
+
if not self.mask_frames or not self.mask_path:
|
| 759 |
+
messagebox.showwarning("Warning", "No mask loaded")
|
| 760 |
+
return
|
| 761 |
+
|
| 762 |
+
# Confirm save
|
| 763 |
+
result = messagebox.askyesno("Confirm Save",
|
| 764 |
+
f"Save mask to {self.mask_path.name}?\nThis will overwrite the existing file.")
|
| 765 |
+
if not result:
|
| 766 |
+
return
|
| 767 |
+
|
| 768 |
+
self.info_label.config(text="Saving mask...", foreground="blue")
|
| 769 |
+
self.root.update()
|
| 770 |
+
|
| 771 |
+
# Write video
|
| 772 |
+
self.write_video_frames(self.mask_frames, self.mask_path)
|
| 773 |
+
|
| 774 |
+
self.info_label.config(text="Mask saved successfully!", foreground="green")
|
| 775 |
+
messagebox.showinfo("Success", f"Mask saved to {self.mask_path.name}!")
|
| 776 |
+
|
| 777 |
+
def first_frame(self):
|
| 778 |
+
"""Go to first frame"""
|
| 779 |
+
self.current_frame = 0
|
| 780 |
+
self.update_display()
|
| 781 |
+
|
| 782 |
+
def last_frame(self):
|
| 783 |
+
"""Go to last frame"""
|
| 784 |
+
if self.mask_frames:
|
| 785 |
+
self.current_frame = len(self.mask_frames) - 1
|
| 786 |
+
self.update_display()
|
| 787 |
+
|
| 788 |
+
def prev_frame(self):
|
| 789 |
+
"""Go to previous frame"""
|
| 790 |
+
if self.current_frame > 0:
|
| 791 |
+
self.current_frame -= 1
|
| 792 |
+
self.update_display()
|
| 793 |
+
|
| 794 |
+
def next_frame(self):
|
| 795 |
+
"""Go to next frame"""
|
| 796 |
+
if self.mask_frames and self.current_frame < len(self.mask_frames) - 1:
|
| 797 |
+
self.current_frame += 1
|
| 798 |
+
self.update_display()
|
| 799 |
+
|
| 800 |
+
def on_slider_change(self, value):
|
| 801 |
+
"""Handle slider change"""
|
| 802 |
+
if not self.mask_frames:
|
| 803 |
+
return
|
| 804 |
+
|
| 805 |
+
new_frame = int(float(value))
|
| 806 |
+
if new_frame != self.current_frame:
|
| 807 |
+
self.current_frame = new_frame
|
| 808 |
+
self.update_display()
|
| 809 |
+
|
| 810 |
+
def on_tool_change(self):
|
| 811 |
+
"""Handle tool selection change"""
|
| 812 |
+
tool = self.tool_var.get()
|
| 813 |
+
if tool == "grid":
|
| 814 |
+
self.info_label.config(text="Grid Toggle: Click grids to toggle 127β255", foreground="blue")
|
| 815 |
+
elif tool == "grid_black":
|
| 816 |
+
self.info_label.config(text="Grid Black Toggle: Click grids to toggle black mask (0/63)", foreground="blue")
|
| 817 |
+
elif tool == "brush_add":
|
| 818 |
+
self.info_label.config(text="Brush (Add): Paint black mask areas", foreground="blue")
|
| 819 |
+
else: # brush_erase
|
| 820 |
+
self.info_label.config(text="Brush (Erase): Erase black mask areas", foreground="blue")
|
| 821 |
+
|
| 822 |
+
def on_brush_size_change(self, value):
|
| 823 |
+
"""Handle brush size change"""
|
| 824 |
+
self.brush_size = int(float(value))
|
| 825 |
+
self.brush_size_label.config(text=str(self.brush_size))
|
| 826 |
+
|
| 827 |
+
if __name__ == "__main__":
|
| 828 |
+
root = tk.Tk()
|
| 829 |
+
root.geometry("1400x800")
|
| 830 |
+
app = MaskEditorGUI(root)
|
| 831 |
+
root.mainloop()
|
VLM-MASK-REASONER/point_selector_gui.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Point Selector GUI - Multi-Frame Support
|
| 4 |
+
|
| 5 |
+
NEW: Support adding points across multiple frames for complex cases
|
| 6 |
+
Example: Car appears at frame 0, hand carrying it appears at frame 30
|
| 7 |
+
β Add points on car at frame 0, points on hand at frame 30
|
| 8 |
+
β Both get segmented together as "primary object to remove"
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python point_selector_gui_multiframe.py --config pexel_test_config.json
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
import tkinter as tk
|
| 17 |
+
from tkinter import ttk, filedialog, messagebox
|
| 18 |
+
from PIL import Image, ImageTk
|
| 19 |
+
import json
|
| 20 |
+
import argparse
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Dict, Tuple
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PointSelectorGUI:
|
| 26 |
+
def __init__(self, root, config_path=None):
|
| 27 |
+
self.root = root
|
| 28 |
+
self.root.title("Point Selector - Multi-Frame Support")
|
| 29 |
+
|
| 30 |
+
# Data
|
| 31 |
+
self.config_path = config_path
|
| 32 |
+
self.config_data = None
|
| 33 |
+
self.current_video_idx = 0
|
| 34 |
+
self.current_frame_idx = 0
|
| 35 |
+
self.video_captures = []
|
| 36 |
+
self.total_frames_list = []
|
| 37 |
+
|
| 38 |
+
# NEW: Points organized by frame
|
| 39 |
+
self.points_by_frame = {} # {frame_idx: [(x, y), ...]}
|
| 40 |
+
self.all_points_by_frame = [] # List of dicts for all videos
|
| 41 |
+
|
| 42 |
+
# Display
|
| 43 |
+
self.display_scale = 1.0
|
| 44 |
+
self.photo = None
|
| 45 |
+
self.point_radius = 8
|
| 46 |
+
|
| 47 |
+
self.setup_ui()
|
| 48 |
+
|
| 49 |
+
if config_path:
|
| 50 |
+
self.load_config_direct(config_path)
|
| 51 |
+
|
| 52 |
+
def setup_ui(self):
|
| 53 |
+
"""Setup the GUI layout"""
|
| 54 |
+
# Menu bar
|
| 55 |
+
menubar = tk.Menu(self.root)
|
| 56 |
+
self.root.config(menu=menubar)
|
| 57 |
+
|
| 58 |
+
file_menu = tk.Menu(menubar, tearoff=0)
|
| 59 |
+
menubar.add_cascade(label="File", menu=file_menu)
|
| 60 |
+
file_menu.add_command(label="Load Config", command=self.load_config)
|
| 61 |
+
file_menu.add_command(label="Save Points", command=self.save_points)
|
| 62 |
+
file_menu.add_separator()
|
| 63 |
+
file_menu.add_command(label="Exit", command=self.root.quit)
|
| 64 |
+
|
| 65 |
+
# Top toolbar
|
| 66 |
+
toolbar = ttk.Frame(self.root)
|
| 67 |
+
toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
|
| 68 |
+
|
| 69 |
+
ttk.Label(toolbar, text="Config:").pack(side=tk.LEFT)
|
| 70 |
+
self.config_label = ttk.Label(toolbar, text="None", foreground="gray")
|
| 71 |
+
self.config_label.pack(side=tk.LEFT, padx=5)
|
| 72 |
+
|
| 73 |
+
ttk.Button(toolbar, text="Load Config", command=self.load_config).pack(side=tk.LEFT, padx=5)
|
| 74 |
+
ttk.Button(toolbar, text="Save All Points", command=self.save_points).pack(side=tk.LEFT, padx=5)
|
| 75 |
+
|
| 76 |
+
# Video info
|
| 77 |
+
info_frame = ttk.Frame(self.root)
|
| 78 |
+
info_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
|
| 79 |
+
|
| 80 |
+
self.video_label = ttk.Label(info_frame, text="Video: None", font=("Arial", 10, "bold"))
|
| 81 |
+
self.video_label.pack(side=tk.LEFT, padx=5)
|
| 82 |
+
|
| 83 |
+
self.instruction_label = ttk.Label(info_frame, text="", foreground="blue")
|
| 84 |
+
self.instruction_label.pack(side=tk.LEFT, padx=10)
|
| 85 |
+
|
| 86 |
+
# Frame navigation controls - COMPACT (Ctrl+β/β shortcuts)
|
| 87 |
+
frame_nav = ttk.LabelFrame(self.root, text="Frame Navigation")
|
| 88 |
+
frame_nav.pack(side=tk.TOP, fill=tk.X, padx=5, pady=2)
|
| 89 |
+
|
| 90 |
+
btn_frame = ttk.Frame(frame_nav)
|
| 91 |
+
btn_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=2)
|
| 92 |
+
|
| 93 |
+
# Compact buttons
|
| 94 |
+
ttk.Button(btn_frame, text="<<", command=self.first_frame, width=3).pack(side=tk.LEFT, padx=1)
|
| 95 |
+
ttk.Button(btn_frame, text="<10", command=lambda: self.prev_frame(10), width=3).pack(side=tk.LEFT, padx=1)
|
| 96 |
+
ttk.Button(btn_frame, text="<", command=lambda: self.prev_frame(1), width=3).pack(side=tk.LEFT, padx=1)
|
| 97 |
+
|
| 98 |
+
self.frame_label = ttk.Label(btn_frame, text="F: 0/0", font=("Arial", 9))
|
| 99 |
+
self.frame_label.pack(side=tk.LEFT, padx=8)
|
| 100 |
+
|
| 101 |
+
ttk.Button(btn_frame, text=">", command=lambda: self.next_frame(1), width=3).pack(side=tk.LEFT, padx=1)
|
| 102 |
+
ttk.Button(btn_frame, text="10>", command=lambda: self.next_frame(10), width=3).pack(side=tk.LEFT, padx=1)
|
| 103 |
+
ttk.Button(btn_frame, text=">>", command=self.last_frame, width=3).pack(side=tk.LEFT, padx=1)
|
| 104 |
+
|
| 105 |
+
# Slider inline
|
| 106 |
+
self.frame_slider = ttk.Scale(btn_frame, from_=0, to=100, orient=tk.HORIZONTAL, command=self.on_slider_change, length=250)
|
| 107 |
+
self.frame_slider.pack(side=tk.LEFT, padx=5)
|
| 108 |
+
|
| 109 |
+
# Frames with points inline
|
| 110 |
+
ttk.Label(btn_frame, text="Points:", font=("Arial", 8)).pack(side=tk.LEFT, padx=3)
|
| 111 |
+
self.frames_with_points_label = ttk.Label(btn_frame, text="None", foreground="blue", font=("Arial", 8))
|
| 112 |
+
self.frames_with_points_label.pack(side=tk.LEFT)
|
| 113 |
+
|
| 114 |
+
# Main canvas - SMALLER to fit everything
|
| 115 |
+
canvas_frame = ttk.LabelFrame(self.root, text="Click to add points")
|
| 116 |
+
canvas_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=2)
|
| 117 |
+
|
| 118 |
+
self.canvas = tk.Canvas(canvas_frame, width=800, height=450, bg='black', cursor="crosshair")
|
| 119 |
+
self.canvas.pack(fill=tk.BOTH, expand=True)
|
| 120 |
+
self.canvas.bind("<Button-1>", self.on_canvas_click)
|
| 121 |
+
|
| 122 |
+
# Bottom controls - COMPACT
|
| 123 |
+
controls = ttk.Frame(self.root)
|
| 124 |
+
controls.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=2)
|
| 125 |
+
|
| 126 |
+
# Point info - compact
|
| 127 |
+
point_info = ttk.Frame(controls)
|
| 128 |
+
point_info.pack(side=tk.TOP, fill=tk.X, pady=2)
|
| 129 |
+
|
| 130 |
+
self.point_count_label = ttk.Label(point_info, text="Pts: 0", font=("Arial", 9))
|
| 131 |
+
self.point_count_label.pack(side=tk.LEFT, padx=5)
|
| 132 |
+
|
| 133 |
+
ttk.Button(point_info, text="Clear Frame", command=self.clear_current_frame, width=10).pack(side=tk.LEFT, padx=2)
|
| 134 |
+
ttk.Button(point_info, text="Clear ALL", command=self.clear_all_frames, width=9).pack(side=tk.LEFT, padx=2)
|
| 135 |
+
ttk.Button(point_info, text="Undo", command=self.undo_last_point, width=6).pack(side=tk.LEFT, padx=2)
|
| 136 |
+
|
| 137 |
+
# Video navigation - compact
|
| 138 |
+
nav_frame = ttk.Frame(controls)
|
| 139 |
+
nav_frame.pack(side=tk.TOP, fill=tk.X, pady=2)
|
| 140 |
+
|
| 141 |
+
ttk.Button(nav_frame, text="<< First", command=self.first_video, width=8).pack(side=tk.LEFT, padx=2)
|
| 142 |
+
ttk.Button(nav_frame, text="< Prev", command=self.prev_video, width=8).pack(side=tk.LEFT, padx=2)
|
| 143 |
+
|
| 144 |
+
self.nav_label = ttk.Label(nav_frame, text="Video: 0/0", font=("Arial", 10, "bold"))
|
| 145 |
+
self.nav_label.pack(side=tk.LEFT, padx=15)
|
| 146 |
+
|
| 147 |
+
ttk.Button(nav_frame, text="Save & Next >", command=self.save_and_next, width=12).pack(side=tk.LEFT, padx=2)
|
| 148 |
+
ttk.Button(nav_frame, text="Last >>", command=self.last_video, width=8).pack(side=tk.LEFT, padx=2)
|
| 149 |
+
|
| 150 |
+
# Status - compact
|
| 151 |
+
self.status_label = ttk.Label(controls, text="Load config", foreground="blue", font=("Arial", 8))
|
| 152 |
+
self.status_label.pack(side=tk.TOP, pady=2)
|
| 153 |
+
|
| 154 |
+
# Keyboard shortcuts
|
| 155 |
+
self.root.bind("<space>", lambda e: self.save_and_next())
|
| 156 |
+
self.root.bind("<Left>", lambda e: self.prev_video())
|
| 157 |
+
self.root.bind("<Right>", lambda e: self.save_and_next())
|
| 158 |
+
self.root.bind("<Control-z>", lambda e: self.undo_last_point())
|
| 159 |
+
self.root.bind("<Control-Left>", lambda e: self.prev_frame(1))
|
| 160 |
+
self.root.bind("<Control-Right>", lambda e: self.next_frame(1))
|
| 161 |
+
self.root.bind("<Control-Shift-Left>", lambda e: self.prev_frame(10))
|
| 162 |
+
self.root.bind("<Control-Shift-Right>", lambda e: self.next_frame(10))
|
| 163 |
+
|
| 164 |
+
def load_config_direct(self, config_path):
|
| 165 |
+
"""Load config from path (for command line usage)"""
|
| 166 |
+
self.config_path = Path(config_path)
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
with open(self.config_path, 'r') as f:
|
| 170 |
+
self.config_data = json.load(f)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
messagebox.showerror("Error", f"Failed to load config: {e}")
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
self.process_config()
|
| 176 |
+
|
| 177 |
+
def load_config(self):
|
| 178 |
+
"""Load JSON config file via dialog"""
|
| 179 |
+
filepath = filedialog.askopenfilename(
|
| 180 |
+
title="Select Config JSON",
|
| 181 |
+
filetypes=[("JSON files", "*.json"), ("All files", "*.*")]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not filepath:
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
self.config_path = Path(filepath)
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
with open(self.config_path, 'r') as f:
|
| 191 |
+
self.config_data = json.load(f)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
messagebox.showerror("Error", f"Failed to load config: {e}")
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
self.process_config()
|
| 197 |
+
|
| 198 |
+
def process_config(self):
|
| 199 |
+
"""Process loaded config"""
|
| 200 |
+
# Validate config
|
| 201 |
+
if isinstance(self.config_data, list):
|
| 202 |
+
videos = self.config_data
|
| 203 |
+
elif isinstance(self.config_data, dict) and "videos" in self.config_data:
|
| 204 |
+
videos = self.config_data["videos"]
|
| 205 |
+
else:
|
| 206 |
+
messagebox.showerror("Error", "Config must be a list or have 'videos' key")
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
if not isinstance(videos, list) or len(videos) == 0:
|
| 210 |
+
messagebox.showerror("Error", "No videos in config")
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
self.videos = videos
|
| 214 |
+
|
| 215 |
+
# Open video captures
|
| 216 |
+
self.status_label.config(text="Opening video files...", foreground="blue")
|
| 217 |
+
self.root.update()
|
| 218 |
+
|
| 219 |
+
self.open_videos()
|
| 220 |
+
|
| 221 |
+
# Initialize storage - now dict per video
|
| 222 |
+
self.all_points_by_frame = [{} for _ in range(len(self.videos))]
|
| 223 |
+
|
| 224 |
+
# Load existing points if available
|
| 225 |
+
self.load_existing_points()
|
| 226 |
+
|
| 227 |
+
# Update UI
|
| 228 |
+
self.config_label.config(text=self.config_path.name, foreground="black")
|
| 229 |
+
self.current_video_idx = 0
|
| 230 |
+
self.current_frame_idx = 0
|
| 231 |
+
self.display_current_video()
|
| 232 |
+
|
| 233 |
+
self.status_label.config(
|
| 234 |
+
text=f"Loaded {len(self.videos)} videos. Navigate frames and click points. Can add points on multiple frames!",
|
| 235 |
+
foreground="green"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def open_videos(self):
|
| 239 |
+
"""Open all videos for frame navigation"""
|
| 240 |
+
self.video_captures = []
|
| 241 |
+
self.total_frames_list = []
|
| 242 |
+
|
| 243 |
+
for i, video_info in enumerate(self.videos):
|
| 244 |
+
video_path = video_info.get("video_path", "")
|
| 245 |
+
|
| 246 |
+
if not video_path:
|
| 247 |
+
self.video_captures.append(None)
|
| 248 |
+
self.total_frames_list.append(0)
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
video_path = Path(video_path)
|
| 252 |
+
if not video_path.is_absolute():
|
| 253 |
+
video_path = self.config_path.parent / video_path
|
| 254 |
+
|
| 255 |
+
if not video_path.exists():
|
| 256 |
+
messagebox.showwarning("Warning", f"Video not found: {video_path}")
|
| 257 |
+
self.video_captures.append(None)
|
| 258 |
+
self.total_frames_list.append(0)
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 262 |
+
if not cap.isOpened():
|
| 263 |
+
messagebox.showwarning("Warning", f"Failed to open video: {video_path}")
|
| 264 |
+
self.video_captures.append(None)
|
| 265 |
+
self.total_frames_list.append(0)
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 269 |
+
self.video_captures.append(cap)
|
| 270 |
+
self.total_frames_list.append(total_frames)
|
| 271 |
+
|
| 272 |
+
self.status_label.config(text=f"Opened video {i+1}/{len(self.videos)}", foreground="blue")
|
| 273 |
+
self.root.update()
|
| 274 |
+
|
| 275 |
+
def load_existing_points(self):
|
| 276 |
+
"""Load existing points from output file if it exists"""
|
| 277 |
+
output_path = self.config_path.parent / f"{self.config_path.stem}_points.json"
|
| 278 |
+
|
| 279 |
+
if not output_path.exists():
|
| 280 |
+
return
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
with open(output_path, 'r') as f:
|
| 284 |
+
existing_data = json.load(f)
|
| 285 |
+
|
| 286 |
+
if isinstance(existing_data, list):
|
| 287 |
+
existing_videos = existing_data
|
| 288 |
+
elif isinstance(existing_data, dict) and "videos" in existing_data:
|
| 289 |
+
existing_videos = existing_data["videos"]
|
| 290 |
+
else:
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
for i, video_data in enumerate(existing_videos):
|
| 294 |
+
if i < len(self.all_points_by_frame):
|
| 295 |
+
# Load multi-frame format
|
| 296 |
+
points_by_frame = video_data.get("primary_points_by_frame", {})
|
| 297 |
+
# Convert string keys to int
|
| 298 |
+
self.all_points_by_frame[i] = {int(k): v for k, v in points_by_frame.items()}
|
| 299 |
+
|
| 300 |
+
self.status_label.config(text="Loaded existing points", foreground="green")
|
| 301 |
+
except Exception as e:
|
| 302 |
+
print(f"Warning: Could not load existing points: {e}")
|
| 303 |
+
|
| 304 |
+
def get_current_frame(self):
|
| 305 |
+
"""Get frame at current_frame_idx from current video"""
|
| 306 |
+
if self.current_video_idx >= len(self.video_captures):
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
cap = self.video_captures[self.current_video_idx]
|
| 310 |
+
if cap is None:
|
| 311 |
+
return None
|
| 312 |
+
|
| 313 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame_idx)
|
| 314 |
+
ret, frame = cap.read()
|
| 315 |
+
|
| 316 |
+
if not ret:
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 320 |
+
|
| 321 |
+
def display_current_video(self):
|
| 322 |
+
"""Display current video frame"""
|
| 323 |
+
if not self.video_captures:
|
| 324 |
+
return
|
| 325 |
+
|
| 326 |
+
video_info = self.videos[self.current_video_idx]
|
| 327 |
+
video_path = video_info.get("video_path", "")
|
| 328 |
+
|
| 329 |
+
# Update labels
|
| 330 |
+
self.video_label.config(text=f"Video: {Path(video_path).name}")
|
| 331 |
+
instruction = video_info.get("instruction", "")
|
| 332 |
+
if instruction:
|
| 333 |
+
self.instruction_label.config(text=f"Instruction: {instruction}")
|
| 334 |
+
|
| 335 |
+
self.nav_label.config(text=f"Video: {self.current_video_idx + 1}/{len(self.videos)}")
|
| 336 |
+
|
| 337 |
+
# Load points for this video
|
| 338 |
+
self.points_by_frame = self.all_points_by_frame[self.current_video_idx].copy()
|
| 339 |
+
|
| 340 |
+
# Update frame controls
|
| 341 |
+
total_frames = self.total_frames_list[self.current_video_idx]
|
| 342 |
+
self.frame_slider.config(to=max(1, total_frames - 1))
|
| 343 |
+
self.frame_slider.set(self.current_frame_idx)
|
| 344 |
+
self.frame_label.config(text=f"F: {self.current_frame_idx}/{total_frames - 1}")
|
| 345 |
+
|
| 346 |
+
# Update frames with points display
|
| 347 |
+
self.update_frames_display()
|
| 348 |
+
|
| 349 |
+
self.display_frame()
|
| 350 |
+
|
| 351 |
+
def update_frames_display(self):
|
| 352 |
+
"""Update display showing which frames have points"""
|
| 353 |
+
if not self.points_by_frame:
|
| 354 |
+
self.frames_with_points_label.config(text="None", foreground="gray")
|
| 355 |
+
else:
|
| 356 |
+
frames = sorted(self.points_by_frame.keys())
|
| 357 |
+
frames_str = ", ".join(f"F{f}" for f in frames)
|
| 358 |
+
total_points = sum(len(pts) for pts in self.points_by_frame.values())
|
| 359 |
+
self.frames_with_points_label.config(
|
| 360 |
+
text=f"{frames_str} ({total_points} total points)",
|
| 361 |
+
foreground="green"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
def display_frame(self):
|
| 365 |
+
"""Display current frame with points"""
|
| 366 |
+
frame = self.get_current_frame()
|
| 367 |
+
if frame is None:
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
# Draw points for CURRENT frame
|
| 371 |
+
vis = frame.copy()
|
| 372 |
+
current_points = self.points_by_frame.get(self.current_frame_idx, [])
|
| 373 |
+
|
| 374 |
+
for i, (x, y) in enumerate(current_points):
|
| 375 |
+
cv2.circle(vis, (x, y), self.point_radius, (255, 0, 0), -1)
|
| 376 |
+
cv2.circle(vis, (x, y), self.point_radius + 2, (255, 255, 255), 2)
|
| 377 |
+
cv2.putText(vis, str(i + 1), (x + 12, y + 12), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 378 |
+
|
| 379 |
+
# Show indicator if other frames have points
|
| 380 |
+
if len(self.points_by_frame) > 0:
|
| 381 |
+
other_frames = [f for f in self.points_by_frame.keys() if f != self.current_frame_idx]
|
| 382 |
+
if other_frames:
|
| 383 |
+
text = f"Other frames with points: {', '.join(map(str, sorted(other_frames)))}"
|
| 384 |
+
cv2.putText(vis, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
| 385 |
+
|
| 386 |
+
# Scale for display
|
| 387 |
+
h, w = vis.shape[:2]
|
| 388 |
+
max_width, max_height = 800, 450
|
| 389 |
+
scale_w = max_width / w
|
| 390 |
+
scale_h = max_height / h
|
| 391 |
+
self.display_scale = min(scale_w, scale_h, 1.0)
|
| 392 |
+
|
| 393 |
+
new_w = int(w * self.display_scale)
|
| 394 |
+
new_h = int(h * self.display_scale)
|
| 395 |
+
vis_resized = cv2.resize(vis, (new_w, new_h))
|
| 396 |
+
|
| 397 |
+
# Convert to PIL and display
|
| 398 |
+
pil_img = Image.fromarray(vis_resized)
|
| 399 |
+
self.photo = ImageTk.PhotoImage(pil_img)
|
| 400 |
+
self.canvas.delete("all")
|
| 401 |
+
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.photo)
|
| 402 |
+
|
| 403 |
+
self.point_count_label.config(text=f"Pts on F{self.current_frame_idx}: {len(current_points)}")
|
| 404 |
+
|
| 405 |
+
def on_canvas_click(self, event):
|
| 406 |
+
"""Handle click on canvas - add point to CURRENT frame"""
|
| 407 |
+
# Convert to frame coordinates
|
| 408 |
+
x = int(event.x / self.display_scale)
|
| 409 |
+
y = int(event.y / self.display_scale)
|
| 410 |
+
|
| 411 |
+
# Add to current frame
|
| 412 |
+
if self.current_frame_idx not in self.points_by_frame:
|
| 413 |
+
self.points_by_frame[self.current_frame_idx] = []
|
| 414 |
+
|
| 415 |
+
self.points_by_frame[self.current_frame_idx].append((x, y))
|
| 416 |
+
self.update_frames_display()
|
| 417 |
+
self.display_frame()
|
| 418 |
+
|
| 419 |
+
def clear_current_frame(self):
|
| 420 |
+
"""Clear points for current frame only"""
|
| 421 |
+
if self.current_frame_idx in self.points_by_frame:
|
| 422 |
+
del self.points_by_frame[self.current_frame_idx]
|
| 423 |
+
self.update_frames_display()
|
| 424 |
+
self.display_frame()
|
| 425 |
+
|
| 426 |
+
def clear_all_frames(self):
|
| 427 |
+
"""Clear all points for current video"""
|
| 428 |
+
result = messagebox.askyesno("Clear All", "Clear points from ALL frames?")
|
| 429 |
+
if result:
|
| 430 |
+
self.points_by_frame = {}
|
| 431 |
+
self.update_frames_display()
|
| 432 |
+
self.display_frame()
|
| 433 |
+
|
| 434 |
+
def undo_last_point(self):
|
| 435 |
+
"""Remove last point from current frame"""
|
| 436 |
+
if self.current_frame_idx in self.points_by_frame and self.points_by_frame[self.current_frame_idx]:
|
| 437 |
+
self.points_by_frame[self.current_frame_idx].pop()
|
| 438 |
+
if not self.points_by_frame[self.current_frame_idx]:
|
| 439 |
+
del self.points_by_frame[self.current_frame_idx]
|
| 440 |
+
self.update_frames_display()
|
| 441 |
+
self.display_frame()
|
| 442 |
+
|
| 443 |
+
# Frame navigation methods
|
| 444 |
+
def first_frame(self):
|
| 445 |
+
"""Jump to first frame"""
|
| 446 |
+
self.current_frame_idx = 0
|
| 447 |
+
self.frame_slider.set(self.current_frame_idx)
|
| 448 |
+
self.update_frame_display()
|
| 449 |
+
|
| 450 |
+
def last_frame(self):
|
| 451 |
+
"""Jump to last frame"""
|
| 452 |
+
total_frames = self.total_frames_list[self.current_video_idx]
|
| 453 |
+
self.current_frame_idx = max(0, total_frames - 1)
|
| 454 |
+
self.frame_slider.set(self.current_frame_idx)
|
| 455 |
+
self.update_frame_display()
|
| 456 |
+
|
| 457 |
+
def prev_frame(self, step=1):
|
| 458 |
+
"""Go to previous frame"""
|
| 459 |
+
self.current_frame_idx = max(0, self.current_frame_idx - step)
|
| 460 |
+
self.frame_slider.set(self.current_frame_idx)
|
| 461 |
+
self.update_frame_display()
|
| 462 |
+
|
| 463 |
+
def next_frame(self, step=1):
|
| 464 |
+
"""Go to next frame"""
|
| 465 |
+
total_frames = self.total_frames_list[self.current_video_idx]
|
| 466 |
+
self.current_frame_idx = min(total_frames - 1, self.current_frame_idx + step)
|
| 467 |
+
self.frame_slider.set(self.current_frame_idx)
|
| 468 |
+
self.update_frame_display()
|
| 469 |
+
|
| 470 |
+
def on_slider_change(self, value):
|
| 471 |
+
"""Handle slider change"""
|
| 472 |
+
self.current_frame_idx = int(float(value))
|
| 473 |
+
self.update_frame_display()
|
| 474 |
+
|
| 475 |
+
def update_frame_display(self):
|
| 476 |
+
"""Update frame label and display"""
|
| 477 |
+
total_frames = self.total_frames_list[self.current_video_idx]
|
| 478 |
+
self.frame_label.config(text=f"F: {self.current_frame_idx}/{total_frames - 1}")
|
| 479 |
+
self.display_frame()
|
| 480 |
+
|
| 481 |
+
# Video navigation
|
| 482 |
+
def first_video(self):
|
| 483 |
+
"""Jump to first video"""
|
| 484 |
+
self.save_current_points()
|
| 485 |
+
self.current_video_idx = 0
|
| 486 |
+
self.current_frame_idx = 0
|
| 487 |
+
self.display_current_video()
|
| 488 |
+
|
| 489 |
+
def last_video(self):
|
| 490 |
+
"""Jump to last video"""
|
| 491 |
+
self.save_current_points()
|
| 492 |
+
self.current_video_idx = len(self.videos) - 1
|
| 493 |
+
self.current_frame_idx = 0
|
| 494 |
+
self.display_current_video()
|
| 495 |
+
|
| 496 |
+
def prev_video(self):
|
| 497 |
+
"""Go to previous video"""
|
| 498 |
+
if self.current_video_idx > 0:
|
| 499 |
+
self.save_current_points()
|
| 500 |
+
self.current_video_idx -= 1
|
| 501 |
+
self.current_frame_idx = 0
|
| 502 |
+
self.display_current_video()
|
| 503 |
+
|
| 504 |
+
def save_and_next(self):
|
| 505 |
+
"""Save current points and move to next video"""
|
| 506 |
+
if len(self.points_by_frame) == 0:
|
| 507 |
+
result = messagebox.askyesno("No Points", "No points selected for any frame. Continue to next video?")
|
| 508 |
+
if not result:
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
self.save_current_points()
|
| 512 |
+
|
| 513 |
+
if self.current_video_idx < len(self.videos) - 1:
|
| 514 |
+
self.current_video_idx += 1
|
| 515 |
+
self.current_frame_idx = 0
|
| 516 |
+
self.display_current_video()
|
| 517 |
+
else:
|
| 518 |
+
messagebox.showinfo("Complete", "All videos processed!")
|
| 519 |
+
|
| 520 |
+
def save_current_points(self):
|
| 521 |
+
"""Save current video's points to storage"""
|
| 522 |
+
self.all_points_by_frame[self.current_video_idx] = self.points_by_frame.copy()
|
| 523 |
+
|
| 524 |
+
def save_points(self):
|
| 525 |
+
"""Save all points to JSON file"""
|
| 526 |
+
if not self.config_path:
|
| 527 |
+
messagebox.showerror("Error", "No config loaded")
|
| 528 |
+
return
|
| 529 |
+
|
| 530 |
+
# Save current video first
|
| 531 |
+
self.save_current_points()
|
| 532 |
+
|
| 533 |
+
# Build output
|
| 534 |
+
output_videos = []
|
| 535 |
+
for i, video_info in enumerate(self.videos):
|
| 536 |
+
video_data = video_info.copy()
|
| 537 |
+
|
| 538 |
+
points_by_frame = self.all_points_by_frame[i]
|
| 539 |
+
|
| 540 |
+
# Convert to serializable format (int keys β string keys for JSON)
|
| 541 |
+
video_data["primary_points_by_frame"] = {
|
| 542 |
+
str(frame_idx): points for frame_idx, points in points_by_frame.items()
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
# Also save list of frames for easy access
|
| 546 |
+
video_data["primary_frames"] = sorted(points_by_frame.keys())
|
| 547 |
+
|
| 548 |
+
# Backwards compatibility: if only one frame, save as before
|
| 549 |
+
if len(points_by_frame) == 1:
|
| 550 |
+
frame_idx = list(points_by_frame.keys())[0]
|
| 551 |
+
video_data["first_appears_frame"] = frame_idx
|
| 552 |
+
video_data["primary_points"] = points_by_frame[frame_idx]
|
| 553 |
+
elif len(points_by_frame) > 1:
|
| 554 |
+
# Multiple frames - use first frame as "first_appears_frame"
|
| 555 |
+
video_data["first_appears_frame"] = min(points_by_frame.keys())
|
| 556 |
+
# Flatten all points for backwards compat (not ideal but helps)
|
| 557 |
+
all_points = []
|
| 558 |
+
for frame_idx in sorted(points_by_frame.keys()):
|
| 559 |
+
all_points.extend(points_by_frame[frame_idx])
|
| 560 |
+
video_data["primary_points"] = all_points
|
| 561 |
+
|
| 562 |
+
output_videos.append(video_data)
|
| 563 |
+
|
| 564 |
+
# Match input format
|
| 565 |
+
if isinstance(self.config_data, list):
|
| 566 |
+
output_data = output_videos
|
| 567 |
+
else:
|
| 568 |
+
output_data = {"videos": output_videos}
|
| 569 |
+
|
| 570 |
+
# Save
|
| 571 |
+
output_path = self.config_path.parent / f"{self.config_path.stem}_points.json"
|
| 572 |
+
|
| 573 |
+
try:
|
| 574 |
+
with open(output_path, 'w') as f:
|
| 575 |
+
json.dump(output_data, f, indent=2)
|
| 576 |
+
|
| 577 |
+
self.status_label.config(text=f"Saved to {output_path.name}", foreground="green")
|
| 578 |
+
messagebox.showinfo("Success", f"Points saved to:\n{output_path}")
|
| 579 |
+
except Exception as e:
|
| 580 |
+
messagebox.showerror("Error", f"Failed to save: {e}")
|
| 581 |
+
|
| 582 |
+
def __del__(self):
|
| 583 |
+
"""Clean up video captures"""
|
| 584 |
+
for cap in self.video_captures:
|
| 585 |
+
if cap is not None:
|
| 586 |
+
cap.release()
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def main():
|
| 590 |
+
parser = argparse.ArgumentParser(description="Point Selector GUI - Multi-Frame Support")
|
| 591 |
+
parser.add_argument("--config", help="Config JSON file to load")
|
| 592 |
+
args = parser.parse_args()
|
| 593 |
+
|
| 594 |
+
root = tk.Tk()
|
| 595 |
+
root.geometry("900x750") # Compact height to fit on screen
|
| 596 |
+
gui = PointSelectorGUI(root, config_path=args.config)
|
| 597 |
+
root.mainloop()
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
if __name__ == "__main__":
|
| 601 |
+
main()
|
VLM-MASK-REASONER/run_pipeline.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# run_pipeline.sh
|
| 3 |
+
# Runs stages 1-4 given a points config JSON (output of point_selector_gui.py)
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash run_pipeline.sh <config_points.json> [--sam2-checkpoint PATH] [--device cuda]
|
| 7 |
+
#
|
| 8 |
+
# Example:
|
| 9 |
+
# bash run_pipeline.sh my_config_points.json
|
| 10 |
+
# bash run_pipeline.sh my_config_points.json --sam2-checkpoint ../sam2_hiera_large.pt
|
| 11 |
+
|
| 12 |
+
set -e
|
| 13 |
+
|
| 14 |
+
# ββ Arguments ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
CONFIG="$1"
|
| 16 |
+
if [ -z "$CONFIG" ]; then
|
| 17 |
+
echo "Usage: bash run_pipeline.sh <config_points.json> [--sam2-checkpoint PATH] [--device cuda]"
|
| 18 |
+
exit 1
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
SAM2_CHECKPOINT="../sam2_hiera_large.pt"
|
| 22 |
+
DEVICE="cuda"
|
| 23 |
+
|
| 24 |
+
# Parse optional flags
|
| 25 |
+
shift
|
| 26 |
+
while [[ $# -gt 0 ]]; do
|
| 27 |
+
case "$1" in
|
| 28 |
+
--sam2-checkpoint) SAM2_CHECKPOINT="$2"; shift 2 ;;
|
| 29 |
+
--device) DEVICE="$2"; shift 2 ;;
|
| 30 |
+
*) echo "Unknown argument: $1"; exit 1 ;;
|
| 31 |
+
esac
|
| 32 |
+
done
|
| 33 |
+
|
| 34 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 35 |
+
|
| 36 |
+
echo "=========================================="
|
| 37 |
+
echo "Void Mask Generation Pipeline"
|
| 38 |
+
echo "=========================================="
|
| 39 |
+
echo "Config: $CONFIG"
|
| 40 |
+
echo "SAM2 checkpoint: $SAM2_CHECKPOINT"
|
| 41 |
+
echo "Device: $DEVICE"
|
| 42 |
+
echo "=========================================="
|
| 43 |
+
|
| 44 |
+
# ββ Stage 1: SAM2 Segmentation βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
echo ""
|
| 46 |
+
echo "[1/4] SAM2 segmentation..."
|
| 47 |
+
python "$SCRIPT_DIR/stage1_sam2_segmentation.py" \
|
| 48 |
+
--config "$CONFIG" \
|
| 49 |
+
--sam2-checkpoint "$SAM2_CHECKPOINT" \
|
| 50 |
+
--device "$DEVICE"
|
| 51 |
+
|
| 52 |
+
# ββ Stage 2: VLM Analysis ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
echo ""
|
| 54 |
+
echo "[2/4] VLM analysis (Gemini)..."
|
| 55 |
+
python "$SCRIPT_DIR/stage2_vlm_analysis.py" \
|
| 56 |
+
--config "$CONFIG"
|
| 57 |
+
|
| 58 |
+
# ββ Stage 3a: Generate Grey Masks βββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
echo ""
|
| 60 |
+
echo "[3/4] Generating grey masks..."
|
| 61 |
+
python "$SCRIPT_DIR/stage3a_generate_grey_masks_v2.py" \
|
| 62 |
+
--config "$CONFIG"
|
| 63 |
+
|
| 64 |
+
# ββ Stage 4: Combine into Quadmask ββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
echo ""
|
| 66 |
+
echo "[4/4] Combining masks into quadmask_0.mp4..."
|
| 67 |
+
python "$SCRIPT_DIR/stage4_combine_masks.py" \
|
| 68 |
+
--config "$CONFIG"
|
| 69 |
+
|
| 70 |
+
echo ""
|
| 71 |
+
echo "=========================================="
|
| 72 |
+
echo "Pipeline complete!"
|
| 73 |
+
echo "Output: quadmask_0.mp4 in each video's output_dir"
|
| 74 |
+
echo "=========================================="
|
VLM-MASK-REASONER/stage1_sam2_segmentation.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 1: SAM2 Point-Prompted Segmentation
|
| 4 |
+
|
| 5 |
+
Takes user-selected points and generates pixel-perfect masks of primary objects
|
| 6 |
+
using SAM2 video tracking.
|
| 7 |
+
|
| 8 |
+
Input: <config>_points.json (with primary_points)
|
| 9 |
+
Output: For each video:
|
| 10 |
+
- black_mask.mp4: Primary object mask (0=object, 255=background)
|
| 11 |
+
- first_frame.jpg: First frame for VLM analysis
|
| 12 |
+
- segmentation_info.json: Metadata
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python stage1_sam2_segmentation.py --config more_dyn_2_config_points.json
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import json
|
| 21 |
+
import argparse
|
| 22 |
+
import cv2
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import tempfile
|
| 26 |
+
import shutil
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Dict, List, Tuple
|
| 29 |
+
import subprocess
|
| 30 |
+
|
| 31 |
+
# Check SAM2 availability
|
| 32 |
+
try:
|
| 33 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 34 |
+
SAM2_AVAILABLE = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
SAM2_AVAILABLE = False
|
| 37 |
+
print("β οΈ SAM2 not installed. Install with:")
|
| 38 |
+
print(" pip install git+https://github.com/facebookresearch/segment-anything-2.git")
|
| 39 |
+
sys.exit(1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SAM2PointSegmenter:
|
| 43 |
+
"""SAM2 video segmentation with point prompts"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, checkpoint_path: str, model_cfg: str = "sam2_hiera_l.yaml", device: str = "cuda"):
|
| 46 |
+
print(f" Loading SAM2 video predictor...")
|
| 47 |
+
self.device = device
|
| 48 |
+
self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=device)
|
| 49 |
+
print(f" β SAM2 loaded on {device}")
|
| 50 |
+
|
| 51 |
+
def segment_video(self, video_path: str, points: List[List[int]] = None,
|
| 52 |
+
output_mask_path: str = None, temp_dir: str = None,
|
| 53 |
+
first_appears_frame: int = 0,
|
| 54 |
+
points_by_frame: Dict[int, List[List[int]]] = None) -> Dict:
|
| 55 |
+
"""
|
| 56 |
+
Segment video using point prompts (single or multi-frame).
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
video_path: Path to input video
|
| 60 |
+
points: List of [x, y] points on object (single frame, legacy)
|
| 61 |
+
output_mask_path: Path to save mask video
|
| 62 |
+
temp_dir: Directory for temporary frames
|
| 63 |
+
first_appears_frame: Frame index where points were selected (for single frame)
|
| 64 |
+
points_by_frame: Dict mapping frame_idx β [[x, y], ...] (multi-frame support)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Dict with segmentation metadata
|
| 68 |
+
"""
|
| 69 |
+
# Handle both old and new formats
|
| 70 |
+
if points_by_frame is not None:
|
| 71 |
+
# Multi-frame format
|
| 72 |
+
if not points_by_frame or len(points_by_frame) == 0:
|
| 73 |
+
raise ValueError("No points provided")
|
| 74 |
+
elif points is not None:
|
| 75 |
+
# Single frame format (backwards compat)
|
| 76 |
+
if not points or len(points) == 0:
|
| 77 |
+
raise ValueError("No points provided")
|
| 78 |
+
points_by_frame = {first_appears_frame: points}
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError("Must provide either points or points_by_frame")
|
| 81 |
+
|
| 82 |
+
# Create temp directory for frames
|
| 83 |
+
if temp_dir is None:
|
| 84 |
+
temp_dir = tempfile.mkdtemp(prefix="sam2_frames_")
|
| 85 |
+
cleanup = True
|
| 86 |
+
else:
|
| 87 |
+
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
| 88 |
+
cleanup = False
|
| 89 |
+
|
| 90 |
+
print(f" Extracting frames to: {temp_dir}")
|
| 91 |
+
frame_files = self._extract_frames(video_path, temp_dir)
|
| 92 |
+
|
| 93 |
+
if len(frame_files) == 0:
|
| 94 |
+
raise RuntimeError(f"No frames extracted from {video_path}")
|
| 95 |
+
|
| 96 |
+
# Get video properties
|
| 97 |
+
cap = cv2.VideoCapture(video_path)
|
| 98 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 99 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 100 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 101 |
+
total_frames = len(frame_files)
|
| 102 |
+
cap.release()
|
| 103 |
+
|
| 104 |
+
# Count total points across all frames
|
| 105 |
+
total_points = sum(len(pts) for pts in points_by_frame.values())
|
| 106 |
+
print(f" Video: {frame_width}x{frame_height}, {total_frames} frames @ {fps} fps")
|
| 107 |
+
print(f" Using {total_points} points across {len(points_by_frame)} frame(s) for segmentation")
|
| 108 |
+
|
| 109 |
+
# Initialize SAM2
|
| 110 |
+
print(f" Initializing SAM2...")
|
| 111 |
+
inference_state = self.predictor.init_state(video_path=temp_dir)
|
| 112 |
+
|
| 113 |
+
# Add points for each frame (all with obj_id=1 to merge into single mask)
|
| 114 |
+
for frame_idx in sorted(points_by_frame.keys()):
|
| 115 |
+
frame_points = points_by_frame[frame_idx]
|
| 116 |
+
|
| 117 |
+
# Convert points to numpy array
|
| 118 |
+
points_np = np.array(frame_points, dtype=np.float32)
|
| 119 |
+
labels_np = np.ones(len(frame_points), dtype=np.int32) # All positive
|
| 120 |
+
|
| 121 |
+
# Calculate bounding box from points (with 10% margin for hair/clothes)
|
| 122 |
+
x_coords = points_np[:, 0]
|
| 123 |
+
y_coords = points_np[:, 1]
|
| 124 |
+
|
| 125 |
+
x_min, x_max = x_coords.min(), x_coords.max()
|
| 126 |
+
y_min, y_max = y_coords.min(), y_coords.max()
|
| 127 |
+
|
| 128 |
+
# Add 10% margin
|
| 129 |
+
x_margin = (x_max - x_min) * 0.1
|
| 130 |
+
y_margin = (y_max - y_min) * 0.1
|
| 131 |
+
|
| 132 |
+
box = np.array([
|
| 133 |
+
max(0, x_min - x_margin),
|
| 134 |
+
max(0, y_min - y_margin),
|
| 135 |
+
min(frame_width, x_max + x_margin),
|
| 136 |
+
min(frame_height, y_max + y_margin)
|
| 137 |
+
], dtype=np.float32)
|
| 138 |
+
|
| 139 |
+
print(f" Adding {len(frame_points)} points + box to frame {frame_idx}")
|
| 140 |
+
print(f" Points: {frame_points[:3]}..." if len(frame_points) > 3 else f" Points: {frame_points}")
|
| 141 |
+
print(f" Box: [{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}]")
|
| 142 |
+
|
| 143 |
+
# Add points + box to this frame (all use obj_id=1 to merge)
|
| 144 |
+
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
| 145 |
+
inference_state=inference_state,
|
| 146 |
+
frame_idx=frame_idx,
|
| 147 |
+
obj_id=1,
|
| 148 |
+
points=points_np,
|
| 149 |
+
labels=labels_np,
|
| 150 |
+
box=box,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
print(f" Propagating through video...")
|
| 154 |
+
|
| 155 |
+
# Propagate through video
|
| 156 |
+
video_segments = {}
|
| 157 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
| 158 |
+
# Get mask for object ID 1
|
| 159 |
+
mask_logits = out_mask_logits[out_obj_ids.index(1)]
|
| 160 |
+
mask = (mask_logits > 0.0).cpu().numpy().squeeze()
|
| 161 |
+
video_segments[out_frame_idx] = mask
|
| 162 |
+
|
| 163 |
+
print(f" β Segmented {len(video_segments)} frames")
|
| 164 |
+
|
| 165 |
+
# Write mask video
|
| 166 |
+
print(f" Writing mask video...")
|
| 167 |
+
self._write_mask_video(video_segments, output_mask_path, fps, frame_width, frame_height)
|
| 168 |
+
|
| 169 |
+
# Cleanup
|
| 170 |
+
if cleanup:
|
| 171 |
+
shutil.rmtree(temp_dir)
|
| 172 |
+
|
| 173 |
+
# Build metadata
|
| 174 |
+
metadata = {
|
| 175 |
+
"total_frames": total_frames,
|
| 176 |
+
"frame_width": frame_width,
|
| 177 |
+
"frame_height": frame_height,
|
| 178 |
+
"fps": fps,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# Add points info based on format
|
| 182 |
+
if points_by_frame:
|
| 183 |
+
total_points = sum(len(pts) for pts in points_by_frame.values())
|
| 184 |
+
metadata["num_points"] = total_points
|
| 185 |
+
metadata["points_by_frame"] = {str(k): v for k, v in points_by_frame.items()}
|
| 186 |
+
else:
|
| 187 |
+
metadata["num_points"] = len(points) if points else 0
|
| 188 |
+
metadata["points"] = points
|
| 189 |
+
|
| 190 |
+
return metadata
|
| 191 |
+
|
| 192 |
+
def _extract_frames(self, video_path: str, output_dir: str) -> List[str]:
|
| 193 |
+
"""Extract video frames as JPG files"""
|
| 194 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
cap = cv2.VideoCapture(video_path)
|
| 197 |
+
frame_idx = 0
|
| 198 |
+
frame_files = []
|
| 199 |
+
|
| 200 |
+
while True:
|
| 201 |
+
ret, frame = cap.read()
|
| 202 |
+
if not ret:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
# SAM2 expects frames named as frame_000000.jpg, frame_000001.jpg, etc.
|
| 206 |
+
frame_filename = f"{frame_idx:06d}.jpg"
|
| 207 |
+
frame_path = os.path.join(output_dir, frame_filename)
|
| 208 |
+
cv2.imwrite(frame_path, frame)
|
| 209 |
+
frame_files.append(frame_path)
|
| 210 |
+
frame_idx += 1
|
| 211 |
+
|
| 212 |
+
if frame_idx % 20 == 0:
|
| 213 |
+
print(f" Extracted {frame_idx} frames...", end='\r')
|
| 214 |
+
|
| 215 |
+
cap.release()
|
| 216 |
+
print(f" Extracted {frame_idx} frames")
|
| 217 |
+
|
| 218 |
+
return frame_files
|
| 219 |
+
|
| 220 |
+
def _write_mask_video(self, masks: Dict[int, np.ndarray], output_path: str,
|
| 221 |
+
fps: float, width: int, height: int):
|
| 222 |
+
"""Write masks to video file"""
|
| 223 |
+
# Write temp AVI first
|
| 224 |
+
temp_avi = Path(output_path).with_suffix('.avi')
|
| 225 |
+
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
|
| 226 |
+
out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (width, height), isColor=False)
|
| 227 |
+
|
| 228 |
+
for frame_idx in sorted(masks.keys()):
|
| 229 |
+
mask = masks[frame_idx]
|
| 230 |
+
# Convert boolean mask to 0/255
|
| 231 |
+
mask_uint8 = np.where(mask, 0, 255).astype(np.uint8)
|
| 232 |
+
out.write(mask_uint8)
|
| 233 |
+
|
| 234 |
+
out.release()
|
| 235 |
+
|
| 236 |
+
# Convert to lossless MP4
|
| 237 |
+
cmd = [
|
| 238 |
+
'ffmpeg', '-y', '-i', str(temp_avi),
|
| 239 |
+
'-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
|
| 240 |
+
'-pix_fmt', 'yuv444p',
|
| 241 |
+
str(output_path)
|
| 242 |
+
]
|
| 243 |
+
subprocess.run(cmd, capture_output=True)
|
| 244 |
+
temp_avi.unlink()
|
| 245 |
+
|
| 246 |
+
print(f" β Saved mask video: {output_path}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def process_config(config_path: str, sam2_checkpoint: str, device: str = "cuda"):
|
| 250 |
+
"""Process all videos in config"""
|
| 251 |
+
config_path = Path(config_path)
|
| 252 |
+
|
| 253 |
+
# Load config
|
| 254 |
+
with open(config_path, 'r') as f:
|
| 255 |
+
config_data = json.load(f)
|
| 256 |
+
|
| 257 |
+
# Handle both formats
|
| 258 |
+
if isinstance(config_data, list):
|
| 259 |
+
videos = config_data
|
| 260 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 261 |
+
videos = config_data["videos"]
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError("Config must be a list or have 'videos' key")
|
| 264 |
+
|
| 265 |
+
print(f"\n{'='*70}")
|
| 266 |
+
print(f"Stage 1: SAM2 Point-Prompted Segmentation")
|
| 267 |
+
print(f"{'='*70}")
|
| 268 |
+
print(f"Config: {config_path.name}")
|
| 269 |
+
print(f"Videos: {len(videos)}")
|
| 270 |
+
print(f"Device: {device}")
|
| 271 |
+
print(f"{'='*70}\n")
|
| 272 |
+
|
| 273 |
+
# Initialize SAM2
|
| 274 |
+
segmenter = SAM2PointSegmenter(sam2_checkpoint, device=device)
|
| 275 |
+
|
| 276 |
+
# Process each video
|
| 277 |
+
for i, video_info in enumerate(videos):
|
| 278 |
+
video_path = video_info.get("video_path", "")
|
| 279 |
+
instruction = video_info.get("instruction", "")
|
| 280 |
+
output_dir = video_info.get("output_dir", "")
|
| 281 |
+
|
| 282 |
+
# Read points - support both single-frame and multi-frame formats
|
| 283 |
+
points_by_frame_raw = video_info.get("primary_points_by_frame", None)
|
| 284 |
+
points = video_info.get("primary_points", [])
|
| 285 |
+
first_appears_frame = video_info.get("first_appears_frame", 0)
|
| 286 |
+
|
| 287 |
+
# Convert points_by_frame from string keys to int keys
|
| 288 |
+
points_by_frame = None
|
| 289 |
+
if points_by_frame_raw:
|
| 290 |
+
points_by_frame = {int(k): v for k, v in points_by_frame_raw.items()}
|
| 291 |
+
|
| 292 |
+
if not video_path:
|
| 293 |
+
print(f"\nβ οΈ Video {i+1}: No video_path, skipping")
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
if not points and not points_by_frame:
|
| 297 |
+
print(f"\nβ οΈ Video {i+1}: No primary_points selected, skipping")
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
video_path = Path(video_path)
|
| 301 |
+
if not video_path.exists():
|
| 302 |
+
print(f"\nβ οΈ Video {i+1}: File not found: {video_path}, skipping")
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
print(f"\n{'β'*70}")
|
| 306 |
+
print(f"Video {i+1}/{len(videos)}: {video_path.name}")
|
| 307 |
+
print(f"{'β'*70}")
|
| 308 |
+
print(f"Instruction: {instruction}")
|
| 309 |
+
|
| 310 |
+
if points_by_frame:
|
| 311 |
+
total_points = sum(len(pts) for pts in points_by_frame.values())
|
| 312 |
+
print(f"Points: {total_points} across {len(points_by_frame)} frame(s)")
|
| 313 |
+
print(f"Frames: {sorted(points_by_frame.keys())}")
|
| 314 |
+
else:
|
| 315 |
+
print(f"Points: {len(points)}")
|
| 316 |
+
print(f"First appears frame: {first_appears_frame}")
|
| 317 |
+
|
| 318 |
+
# Setup output directory
|
| 319 |
+
if output_dir:
|
| 320 |
+
output_dir = Path(output_dir)
|
| 321 |
+
else:
|
| 322 |
+
# Create unique output directory per video
|
| 323 |
+
video_name = video_path.stem # Get video name without extension
|
| 324 |
+
output_dir = video_path.parent / f"{video_name}_masks_output"
|
| 325 |
+
|
| 326 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 327 |
+
print(f"Output: {output_dir}")
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
# Segment video - use multi-frame or single-frame format
|
| 331 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 332 |
+
if points_by_frame:
|
| 333 |
+
metadata = segmenter.segment_video(
|
| 334 |
+
str(video_path),
|
| 335 |
+
output_mask_path=str(black_mask_path),
|
| 336 |
+
points_by_frame=points_by_frame
|
| 337 |
+
)
|
| 338 |
+
# Use first frame for VLM analysis
|
| 339 |
+
first_frame_for_vlm = min(points_by_frame.keys())
|
| 340 |
+
else:
|
| 341 |
+
metadata = segmenter.segment_video(
|
| 342 |
+
str(video_path),
|
| 343 |
+
points=points,
|
| 344 |
+
output_mask_path=str(black_mask_path),
|
| 345 |
+
first_appears_frame=first_appears_frame
|
| 346 |
+
)
|
| 347 |
+
first_frame_for_vlm = first_appears_frame
|
| 348 |
+
|
| 349 |
+
# Save frame where object appears for VLM analysis
|
| 350 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 351 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, first_frame_for_vlm)
|
| 352 |
+
ret, first_frame = cap.read()
|
| 353 |
+
cap.release()
|
| 354 |
+
|
| 355 |
+
if ret:
|
| 356 |
+
first_frame_path = output_dir / "first_frame.jpg"
|
| 357 |
+
cv2.imwrite(str(first_frame_path), first_frame)
|
| 358 |
+
print(f" β Saved first frame (frame {first_frame_for_vlm}): {first_frame_path.name}")
|
| 359 |
+
|
| 360 |
+
# Copy input video
|
| 361 |
+
input_copy_path = output_dir / "input_video.mp4"
|
| 362 |
+
if not input_copy_path.exists():
|
| 363 |
+
shutil.copy2(video_path, input_copy_path)
|
| 364 |
+
print(f" β Copied input video")
|
| 365 |
+
|
| 366 |
+
# Save metadata
|
| 367 |
+
metadata["video_path"] = str(video_path)
|
| 368 |
+
metadata["instruction"] = instruction
|
| 369 |
+
if points_by_frame:
|
| 370 |
+
metadata["primary_points_by_frame"] = {str(k): v for k, v in points_by_frame.items()}
|
| 371 |
+
metadata["primary_frames"] = sorted(points_by_frame.keys())
|
| 372 |
+
metadata["first_appears_frame"] = min(points_by_frame.keys())
|
| 373 |
+
else:
|
| 374 |
+
metadata["primary_points"] = points
|
| 375 |
+
metadata["first_appears_frame"] = first_appears_frame
|
| 376 |
+
|
| 377 |
+
metadata_path = output_dir / "segmentation_info.json"
|
| 378 |
+
with open(metadata_path, 'w') as f:
|
| 379 |
+
json.dump(metadata, f, indent=2)
|
| 380 |
+
print(f" β Saved metadata: {metadata_path.name}")
|
| 381 |
+
|
| 382 |
+
print(f"\nβ
Video {i+1} complete!")
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
print(f"\nβ Error processing video {i+1}: {e}")
|
| 386 |
+
import traceback
|
| 387 |
+
traceback.print_exc()
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
print(f"\n{'='*70}")
|
| 391 |
+
print(f"Stage 1 Complete!")
|
| 392 |
+
print(f"{'='*70}\n")
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def main():
|
| 396 |
+
parser = argparse.ArgumentParser(description="Stage 1: SAM2 Point-Prompted Segmentation")
|
| 397 |
+
parser.add_argument("--config", required=True, help="Config JSON with primary_points")
|
| 398 |
+
parser.add_argument("--sam2-checkpoint", default="../sam2_hiera_large.pt",
|
| 399 |
+
help="Path to SAM2 checkpoint")
|
| 400 |
+
parser.add_argument("--device", default="cuda", help="Device (cuda/cpu)")
|
| 401 |
+
args = parser.parse_args()
|
| 402 |
+
|
| 403 |
+
if not SAM2_AVAILABLE:
|
| 404 |
+
print("β SAM2 not available")
|
| 405 |
+
sys.exit(1)
|
| 406 |
+
|
| 407 |
+
# Check checkpoint exists
|
| 408 |
+
checkpoint_path = Path(args.sam2_checkpoint)
|
| 409 |
+
if not checkpoint_path.exists():
|
| 410 |
+
print(f"β Checkpoint not found: {checkpoint_path}")
|
| 411 |
+
print(f" Download with:")
|
| 412 |
+
print(f" wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt")
|
| 413 |
+
sys.exit(1)
|
| 414 |
+
|
| 415 |
+
process_config(args.config, str(checkpoint_path), args.device)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if __name__ == "__main__":
|
| 419 |
+
main()
|
VLM-MASK-REASONER/stage2_vlm_analysis.py
ADDED
|
@@ -0,0 +1,1022 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 2: VLM Analysis - Identify Affected Objects & Physics
|
| 4 |
+
|
| 5 |
+
Analyzes videos with primary masks to identify:
|
| 6 |
+
- Integral belongings (to add to black mask)
|
| 7 |
+
- Affected objects (shadows, reflections, held items)
|
| 8 |
+
- Physics behavior (will_move, needs_trajectory)
|
| 9 |
+
|
| 10 |
+
Input: Config from Stage 1 (with output_dir containing black_mask.mp4, first_frame.jpg)
|
| 11 |
+
Output: For each video:
|
| 12 |
+
- vlm_analysis.json: Identified objects and physics reasoning
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python stage2_vlm_analysis.py --config more_dyn_2_config_points_absolute.json
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import json
|
| 21 |
+
import argparse
|
| 22 |
+
import cv2
|
| 23 |
+
import numpy as np
|
| 24 |
+
import base64
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Dict, List
|
| 27 |
+
from PIL import Image, ImageDraw
|
| 28 |
+
|
| 29 |
+
import openai
|
| 30 |
+
|
| 31 |
+
DEFAULT_MODEL = "gemini-3-pro-preview"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def image_to_data_url(image_path: str) -> str:
|
| 35 |
+
"""Convert image file to base64 data URL"""
|
| 36 |
+
with open(image_path, 'rb') as f:
|
| 37 |
+
img_data = base64.b64encode(f.read()).decode('utf-8')
|
| 38 |
+
|
| 39 |
+
# Detect format
|
| 40 |
+
ext = Path(image_path).suffix.lower()
|
| 41 |
+
if ext == '.png':
|
| 42 |
+
mime = 'image/png'
|
| 43 |
+
elif ext in ['.jpg', '.jpeg']:
|
| 44 |
+
mime = 'image/jpeg'
|
| 45 |
+
else:
|
| 46 |
+
mime = 'image/jpeg'
|
| 47 |
+
|
| 48 |
+
return f"data:{mime};base64,{img_data}"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def video_to_data_url(video_path: str) -> str:
|
| 52 |
+
"""Convert video file to base64 data URL"""
|
| 53 |
+
with open(video_path, 'rb') as f:
|
| 54 |
+
video_data = base64.b64encode(f.read()).decode('utf-8')
|
| 55 |
+
return f"data:video/mp4;base64,{video_data}"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> tuple:
|
| 59 |
+
"""Calculate grid dimensions matching stage3a logic"""
|
| 60 |
+
aspect_ratio = width / height
|
| 61 |
+
if width >= height:
|
| 62 |
+
grid_rows = min_grid
|
| 63 |
+
grid_cols = max(min_grid, round(min_grid * aspect_ratio))
|
| 64 |
+
else:
|
| 65 |
+
grid_cols = min_grid
|
| 66 |
+
grid_rows = max(min_grid, round(min_grid / aspect_ratio))
|
| 67 |
+
return grid_rows, grid_cols
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_first_frame_with_mask_overlay(first_frame_path: str, black_mask_path: str,
|
| 71 |
+
output_path: str, frame_idx: int = 0) -> str:
|
| 72 |
+
"""Create visualization of first frame with red overlay on primary object
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
first_frame_path: Path to first_frame.jpg
|
| 76 |
+
black_mask_path: Path to black_mask.mp4
|
| 77 |
+
output_path: Where to save overlay
|
| 78 |
+
frame_idx: Which frame to extract from black_mask.mp4 (default: 0)
|
| 79 |
+
"""
|
| 80 |
+
# Load first frame
|
| 81 |
+
frame = cv2.imread(first_frame_path)
|
| 82 |
+
if frame is None:
|
| 83 |
+
raise ValueError(f"Failed to load first frame: {first_frame_path}")
|
| 84 |
+
|
| 85 |
+
# Load black mask video and get the specified frame
|
| 86 |
+
cap = cv2.VideoCapture(black_mask_path)
|
| 87 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 88 |
+
ret, mask_frame = cap.read()
|
| 89 |
+
cap.release()
|
| 90 |
+
|
| 91 |
+
if not ret:
|
| 92 |
+
raise ValueError(f"Failed to load black mask frame {frame_idx}: {black_mask_path}")
|
| 93 |
+
|
| 94 |
+
# Convert mask to binary (0 = object, 255 = background)
|
| 95 |
+
if len(mask_frame.shape) == 3:
|
| 96 |
+
mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
|
| 97 |
+
|
| 98 |
+
object_mask = (mask_frame == 0)
|
| 99 |
+
|
| 100 |
+
# Create red overlay on object
|
| 101 |
+
overlay = frame.copy()
|
| 102 |
+
overlay[object_mask] = [0, 0, 255] # Red in BGR
|
| 103 |
+
|
| 104 |
+
# Blend: 60% original + 40% red overlay
|
| 105 |
+
result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
|
| 106 |
+
|
| 107 |
+
# Save
|
| 108 |
+
cv2.imwrite(output_path, result)
|
| 109 |
+
return output_path
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_gridded_frame_overlay(first_frame_path: str, black_mask_path: str,
|
| 113 |
+
output_path: str, min_grid: int = 8) -> tuple:
|
| 114 |
+
"""Create first frame with BOTH red mask overlay AND grid lines
|
| 115 |
+
|
| 116 |
+
Returns: (output_path, grid_rows, grid_cols)
|
| 117 |
+
"""
|
| 118 |
+
# Load first frame
|
| 119 |
+
frame = cv2.imread(first_frame_path)
|
| 120 |
+
if frame is None:
|
| 121 |
+
raise ValueError(f"Failed to load first frame: {first_frame_path}")
|
| 122 |
+
|
| 123 |
+
h, w = frame.shape[:2]
|
| 124 |
+
|
| 125 |
+
# Load black mask
|
| 126 |
+
cap = cv2.VideoCapture(black_mask_path)
|
| 127 |
+
ret, mask_frame = cap.read()
|
| 128 |
+
cap.release()
|
| 129 |
+
|
| 130 |
+
if not ret:
|
| 131 |
+
raise ValueError(f"Failed to load black mask: {black_mask_path}")
|
| 132 |
+
|
| 133 |
+
if len(mask_frame.shape) == 3:
|
| 134 |
+
mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
|
| 135 |
+
|
| 136 |
+
object_mask = (mask_frame == 0)
|
| 137 |
+
|
| 138 |
+
# Create red overlay
|
| 139 |
+
overlay = frame.copy()
|
| 140 |
+
overlay[object_mask] = [0, 0, 255]
|
| 141 |
+
result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
|
| 142 |
+
|
| 143 |
+
# Calculate grid
|
| 144 |
+
grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
|
| 145 |
+
|
| 146 |
+
# Draw grid lines
|
| 147 |
+
cell_width = w / grid_cols
|
| 148 |
+
cell_height = h / grid_rows
|
| 149 |
+
|
| 150 |
+
# Vertical lines
|
| 151 |
+
for col in range(1, grid_cols):
|
| 152 |
+
x = int(col * cell_width)
|
| 153 |
+
cv2.line(result, (x, 0), (x, h), (255, 255, 0), 1) # Yellow lines
|
| 154 |
+
|
| 155 |
+
# Horizontal lines
|
| 156 |
+
for row in range(1, grid_rows):
|
| 157 |
+
y = int(row * cell_height)
|
| 158 |
+
cv2.line(result, (0, y), (w, y), (255, 255, 0), 1)
|
| 159 |
+
|
| 160 |
+
# Add grid labels
|
| 161 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 162 |
+
font_scale = 0.3
|
| 163 |
+
thickness = 1
|
| 164 |
+
|
| 165 |
+
# Label columns at top
|
| 166 |
+
for col in range(grid_cols):
|
| 167 |
+
x = int((col + 0.5) * cell_width)
|
| 168 |
+
cv2.putText(result, str(col), (x-5, 15), font, font_scale, (255, 255, 0), thickness)
|
| 169 |
+
|
| 170 |
+
# Label rows on left
|
| 171 |
+
for row in range(grid_rows):
|
| 172 |
+
y = int((row + 0.5) * cell_height)
|
| 173 |
+
cv2.putText(result, str(row), (5, y+5), font, font_scale, (255, 255, 0), thickness)
|
| 174 |
+
|
| 175 |
+
cv2.imwrite(output_path, result)
|
| 176 |
+
return output_path, grid_rows, grid_cols
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def create_multi_frame_grid_samples(video_path: str, output_dir: Path,
|
| 180 |
+
min_grid: int = 8,
|
| 181 |
+
sample_points: list = [0.0, 0.11, 0.22, 0.33, 0.44, 0.56, 0.67, 0.78, 0.89, 1.0]) -> tuple:
|
| 182 |
+
"""
|
| 183 |
+
Create gridded frame samples at multiple time points in video.
|
| 184 |
+
Helps VLM see objects that appear mid-video with grid reference.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
video_path: Path to video
|
| 188 |
+
output_dir: Where to save samples
|
| 189 |
+
min_grid: Minimum grid size
|
| 190 |
+
sample_points: List of normalized positions [0.0-1.0] to sample
|
| 191 |
+
|
| 192 |
+
Returns: (sample_paths, grid_rows, grid_cols)
|
| 193 |
+
"""
|
| 194 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 195 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 196 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 197 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 198 |
+
|
| 199 |
+
# Calculate grid (same for all frames)
|
| 200 |
+
grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
|
| 201 |
+
cell_width = w / grid_cols
|
| 202 |
+
cell_height = h / grid_rows
|
| 203 |
+
|
| 204 |
+
sample_paths = []
|
| 205 |
+
|
| 206 |
+
for i, t in enumerate(sample_points):
|
| 207 |
+
frame_idx = int(t * (total_frames - 1))
|
| 208 |
+
frame_idx = max(0, min(frame_idx, total_frames - 1))
|
| 209 |
+
|
| 210 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 211 |
+
ret, frame = cap.read()
|
| 212 |
+
if not ret:
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
# Draw grid
|
| 216 |
+
result = frame.copy()
|
| 217 |
+
|
| 218 |
+
# Vertical lines
|
| 219 |
+
for col in range(1, grid_cols):
|
| 220 |
+
x = int(col * cell_width)
|
| 221 |
+
cv2.line(result, (x, 0), (x, h), (255, 255, 0), 2)
|
| 222 |
+
|
| 223 |
+
# Horizontal lines
|
| 224 |
+
for row in range(1, grid_rows):
|
| 225 |
+
y = int(row * cell_height)
|
| 226 |
+
cv2.line(result, (0, y), (w, y), (255, 255, 0), 2)
|
| 227 |
+
|
| 228 |
+
# Add grid labels
|
| 229 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 230 |
+
font_scale = 0.4
|
| 231 |
+
thickness = 1
|
| 232 |
+
|
| 233 |
+
# Label columns
|
| 234 |
+
for col in range(grid_cols):
|
| 235 |
+
x = int((col + 0.5) * cell_width)
|
| 236 |
+
cv2.putText(result, str(col), (x-8, 20), font, font_scale, (255, 255, 0), thickness)
|
| 237 |
+
|
| 238 |
+
# Label rows
|
| 239 |
+
for row in range(grid_rows):
|
| 240 |
+
y = int((row + 0.5) * cell_height)
|
| 241 |
+
cv2.putText(result, str(row), (10, y+8), font, font_scale, (255, 255, 0), thickness)
|
| 242 |
+
|
| 243 |
+
# Add frame number and percentage
|
| 244 |
+
label = f"Frame {frame_idx} ({int(t*100)}%)"
|
| 245 |
+
cv2.putText(result, label, (10, h-10), font, 0.5, (255, 255, 0), 2)
|
| 246 |
+
|
| 247 |
+
# Save
|
| 248 |
+
output_path = output_dir / f"grid_sample_frame_{frame_idx:04d}.jpg"
|
| 249 |
+
cv2.imwrite(str(output_path), result)
|
| 250 |
+
sample_paths.append(output_path)
|
| 251 |
+
|
| 252 |
+
cap.release()
|
| 253 |
+
return sample_paths, grid_rows, grid_cols
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def make_vlm_analysis_prompt(instruction: str, grid_rows: int, grid_cols: int,
|
| 257 |
+
has_multi_frame_grids: bool = False) -> str:
|
| 258 |
+
"""Create VLM prompt for analyzing video with primary mask"""
|
| 259 |
+
|
| 260 |
+
grid_context = ""
|
| 261 |
+
if has_multi_frame_grids:
|
| 262 |
+
grid_context = f"""
|
| 263 |
+
1. **Multiple Grid Reference Frames**: Sampled frames at 0%, 11%, 22%, 33%, 44%, 56%, 67%, 78%, 89%, 100% of video
|
| 264 |
+
- Each frame shows YELLOW GRID with {grid_rows} rows Γ {grid_cols} columns
|
| 265 |
+
- Grid cells labeled (row, col) starting from (0, 0) at top-left
|
| 266 |
+
- Frame number shown at bottom
|
| 267 |
+
- Use these to locate objects that appear MID-VIDEO and track object positions across time
|
| 268 |
+
2. **First Frame with RED mask**: Shows what will be REMOVED (primary object)
|
| 269 |
+
3. **Full Video**: Complete action and interactions"""
|
| 270 |
+
else:
|
| 271 |
+
grid_context = f"""
|
| 272 |
+
1. **First Frame with Grid**: PRIMARY OBJECT highlighted in RED + GRID OVERLAY
|
| 273 |
+
- The red overlay shows what will be REMOVED (already masked)
|
| 274 |
+
- Yellow grid with {grid_rows} rows Γ {grid_cols} columns
|
| 275 |
+
- Grid cells are labeled (row, col) starting from (0, 0) at top-left
|
| 276 |
+
2. **Full Video**: Complete scene and action"""
|
| 277 |
+
|
| 278 |
+
return f"""
|
| 279 |
+
You are an expert video analyst specializing in physics and object interactions.
|
| 280 |
+
|
| 281 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 282 |
+
CONTEXT
|
| 283 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 284 |
+
|
| 285 |
+
You will see MULTIPLE inputs:
|
| 286 |
+
{grid_context}
|
| 287 |
+
|
| 288 |
+
Edit instruction: "{instruction}"
|
| 289 |
+
|
| 290 |
+
IMPORTANT: Some objects may NOT appear in first frame. They may enter later.
|
| 291 |
+
Watch the ENTIRE video and note when each object first appears.
|
| 292 |
+
|
| 293 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 294 |
+
YOUR TASK
|
| 295 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 296 |
+
|
| 297 |
+
Analyze what would happen if the PRIMARY OBJECT (shown in red) is removed.
|
| 298 |
+
Watch the ENTIRE video to see all interactions and movements.
|
| 299 |
+
|
| 300 |
+
STEP 1: IDENTIFY INTEGRAL BELONGINGS (0-3 items)
|
| 301 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 302 |
+
Items that should be ADDED to the primary removal mask (removed WITH primary object):
|
| 303 |
+
|
| 304 |
+
β INCLUDE:
|
| 305 |
+
β’ Distinct wearable items: hat, backpack, jacket (if separate/visible)
|
| 306 |
+
β’ Vehicles/equipment being ridden: bike, skateboard, surfboard, scooter
|
| 307 |
+
β’ Large carried items that are part of the subject
|
| 308 |
+
|
| 309 |
+
β DO NOT INCLUDE:
|
| 310 |
+
β’ Generic clothing (shirt, pants, shoes) - already captured with person
|
| 311 |
+
β’ Held items that could be set down: guitar, cup, phone, tools
|
| 312 |
+
β’ Objects they're interacting with but not wearing/riding
|
| 313 |
+
|
| 314 |
+
Examples:
|
| 315 |
+
β’ Person on bike β integral: "bike"
|
| 316 |
+
β’ Person with guitar β integral: none (guitar is affected, not integral)
|
| 317 |
+
β’ Surfer β integral: "surfboard"
|
| 318 |
+
β’ Boxer β integral: "boxing gloves" (wearable equipment)
|
| 319 |
+
|
| 320 |
+
STEP 2: IDENTIFY AFFECTED OBJECTS (0-5 objects)
|
| 321 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 322 |
+
Objects/effects that are SEPARATE from primary but affected by its removal.
|
| 323 |
+
|
| 324 |
+
CRITICAL: Do NOT include integral belongings from Step 1.
|
| 325 |
+
|
| 326 |
+
Two categories:
|
| 327 |
+
|
| 328 |
+
A) VISUAL ARTIFACTS (disappear when primary removed):
|
| 329 |
+
β’ shadow, reflection, wake, ripples, splash, footprints
|
| 330 |
+
β’ These vanish completely - no physics needed
|
| 331 |
+
|
| 332 |
+
**CRITICAL FOR VISUAL ARTIFACTS:**
|
| 333 |
+
You MUST provide GRID LOCALIZATIONS across the reference frames.
|
| 334 |
+
Keyword segmentation fails to isolate specific shadows/reflections.
|
| 335 |
+
|
| 336 |
+
For each visual artifact:
|
| 337 |
+
- Look at each grid reference frame you were shown
|
| 338 |
+
- Identify which grid cells the artifact occupies in EACH frame
|
| 339 |
+
- List all grid cells (row, col) that contain any part of it
|
| 340 |
+
- Be thorough - include ALL touched cells (over-mask is better than under-mask)
|
| 341 |
+
|
| 342 |
+
Format:
|
| 343 |
+
{{
|
| 344 |
+
"noun": "shadow",
|
| 345 |
+
"category": "visual_artifact",
|
| 346 |
+
"grid_localizations": [
|
| 347 |
+
{{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, ...]}},
|
| 348 |
+
{{"frame": 5, "grid_regions": [{{"row": 6, "col": 4}}, ...]}},
|
| 349 |
+
// ... for each reference frame shown
|
| 350 |
+
]
|
| 351 |
+
}}
|
| 352 |
+
|
| 353 |
+
B) PHYSICAL OBJECTS (may move, fall, or stay):
|
| 354 |
+
|
| 355 |
+
CRITICAL - Understand the difference:
|
| 356 |
+
|
| 357 |
+
**SUPPORTING vs ACTING ON:**
|
| 358 |
+
β’ SUPPORTING = holding UP against gravity β object WILL FALL when removed
|
| 359 |
+
Examples: holding guitar, carrying cup, person sitting on chair
|
| 360 |
+
β will_move: TRUE
|
| 361 |
+
|
| 362 |
+
β’ ACTING ON = touching/manipulating but object rests on stable surface β object STAYS
|
| 363 |
+
Examples: hand crushing can (can on table), hand opening can (can on counter),
|
| 364 |
+
hand pushing object (object on floor)
|
| 365 |
+
β will_move: FALSE
|
| 366 |
+
|
| 367 |
+
**Key Questions:**
|
| 368 |
+
1. Is the primary object HOLDING THIS UP against gravity?
|
| 369 |
+
- YES β will_move: true, needs_trajectory: true
|
| 370 |
+
- NO β Check next question
|
| 371 |
+
|
| 372 |
+
2. Is this object RESTING ON a stable surface (table, floor, counter)?
|
| 373 |
+
- YES β will_move: false (stays on surface when primary removed)
|
| 374 |
+
- NO β will_move: true
|
| 375 |
+
|
| 376 |
+
3. Is the primary object DOING an action TO this object?
|
| 377 |
+
- Opening can, crushing can, pushing button, turning knob
|
| 378 |
+
- When primary removed β action STOPS, object stays in current state
|
| 379 |
+
- will_move: false
|
| 380 |
+
|
| 381 |
+
**SPECIAL CASE - Object Currently Moving But Should Have Stayed:**
|
| 382 |
+
If primary object CAUSES another object to move (hitting, kicking, throwing):
|
| 383 |
+
- The object is currently moving in the video
|
| 384 |
+
- But WITHOUT primary, it would have stayed at its original position
|
| 385 |
+
- You MUST provide:
|
| 386 |
+
β’ "currently_moving": true
|
| 387 |
+
β’ "should_have_stayed": true
|
| 388 |
+
β’ "original_position_grid": {{"row": R, "col": C}} - Where it started
|
| 389 |
+
|
| 390 |
+
Examples:
|
| 391 |
+
- Golf club hits ball β Ball at tee, then flies (mark original tee position)
|
| 392 |
+
- Person kicks soccer ball β Ball on ground, then rolls (mark original ground position)
|
| 393 |
+
- Hand throws object β Object held, then flies (mark original held position)
|
| 394 |
+
|
| 395 |
+
Format:
|
| 396 |
+
{{
|
| 397 |
+
"noun": "golf ball",
|
| 398 |
+
"category": "physical",
|
| 399 |
+
"currently_moving": true,
|
| 400 |
+
"should_have_stayed": true,
|
| 401 |
+
"original_position_grid": {{"row": 6, "col": 7}},
|
| 402 |
+
"why": "ball was stationary until club hit it"
|
| 403 |
+
}}
|
| 404 |
+
|
| 405 |
+
For each physical object, determine:
|
| 406 |
+
- **will_move**: true ONLY if object will fall/move when support removed
|
| 407 |
+
- **first_appears_frame**: frame number object first appears (0 if from start)
|
| 408 |
+
- **why**: Brief explanation of relationship to primary object
|
| 409 |
+
|
| 410 |
+
IF will_move=TRUE, also provide GRID-BASED TRAJECTORY:
|
| 411 |
+
- **object_size_grids**: {{"rows": R, "cols": C}} - How many grid cells object occupies
|
| 412 |
+
IMPORTANT: Add 1 extra cell padding for safety (better to over-mask than under-mask)
|
| 413 |
+
Example: Object looks 2Γ1 β report as 3Γ2
|
| 414 |
+
|
| 415 |
+
- **trajectory_path**: List of keyframe positions as grid coordinates
|
| 416 |
+
Format: [{{"frame": N, "grid_row": R, "grid_col": C}}, ...]
|
| 417 |
+
- IMPORTANT: First keyframe should be at first_appears_frame (not frame 0 if object appears later!)
|
| 418 |
+
- Provide 3-5 keyframes spanning from first appearance to end
|
| 419 |
+
- (grid_row, grid_col) is the CENTER position of object at that frame
|
| 420 |
+
- Use the yellow grid reference frames to determine positions
|
| 421 |
+
- For objects appearing mid-video: use the grid samples to locate them
|
| 422 |
+
- Example: Object appears at frame 15, falls to bottom
|
| 423 |
+
[{{"frame": 15, "grid_row": 3, "grid_col": 5}}, β First appearance
|
| 424 |
+
{{"frame": 25, "grid_row": 6, "grid_col": 5}}, β Mid-fall
|
| 425 |
+
{{"frame": 35, "grid_row": 9, "grid_col": 5}}] β On ground
|
| 426 |
+
|
| 427 |
+
β Objects held/carried at ANY point in video
|
| 428 |
+
β Objects the primary supports or interacts with
|
| 429 |
+
β Visual effects visible at any time
|
| 430 |
+
|
| 431 |
+
β Background objects never touched
|
| 432 |
+
β Other people/animals with no contact
|
| 433 |
+
β Integral belongings (already in Step 1)
|
| 434 |
+
|
| 435 |
+
STEP 3: SCENE DESCRIPTION
|
| 436 |
+
ββββββββββββββββββββββββββ
|
| 437 |
+
Describe scene WITHOUT the primary object (1-2 sentences).
|
| 438 |
+
Focus on what remains and any dynamic changes (falling objects, etc).
|
| 439 |
+
|
| 440 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 441 |
+
OUTPUT FORMAT (STRICT JSON ONLY)
|
| 442 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 443 |
+
|
| 444 |
+
EXAMPLES TO LEARN FROM:
|
| 445 |
+
|
| 446 |
+
Example 1: Person holding guitar
|
| 447 |
+
{{
|
| 448 |
+
"affected_objects": [
|
| 449 |
+
{{
|
| 450 |
+
"noun": "guitar",
|
| 451 |
+
"will_move": true,
|
| 452 |
+
"why": "person is SUPPORTING guitar against gravity by holding it",
|
| 453 |
+
"object_size_grids": {{"rows": 3, "cols": 2}},
|
| 454 |
+
"trajectory_path": [
|
| 455 |
+
{{"frame": 0, "grid_row": 4, "grid_col": 5}},
|
| 456 |
+
{{"frame": 15, "grid_row": 6, "grid_col": 5}},
|
| 457 |
+
{{"frame": 30, "grid_row": 8, "grid_col": 6}}
|
| 458 |
+
]
|
| 459 |
+
}}
|
| 460 |
+
]
|
| 461 |
+
}}
|
| 462 |
+
|
| 463 |
+
Example 2: Hand crushing can on table
|
| 464 |
+
{{
|
| 465 |
+
"affected_objects": [
|
| 466 |
+
{{
|
| 467 |
+
"noun": "can",
|
| 468 |
+
"will_move": false,
|
| 469 |
+
"why": "can RESTS ON TABLE - hand is just acting on it. When hand removed, can stays on table (uncrushed)"
|
| 470 |
+
}}
|
| 471 |
+
]
|
| 472 |
+
}}
|
| 473 |
+
|
| 474 |
+
Example 3: Hands opening can on counter
|
| 475 |
+
{{
|
| 476 |
+
"affected_objects": [
|
| 477 |
+
{{
|
| 478 |
+
"noun": "can",
|
| 479 |
+
"will_move": false,
|
| 480 |
+
"why": "can RESTS ON COUNTER - hands are doing opening action. When hands removed, can stays closed on counter"
|
| 481 |
+
}}
|
| 482 |
+
]
|
| 483 |
+
}}
|
| 484 |
+
|
| 485 |
+
Example 4: Person sitting on chair
|
| 486 |
+
{{
|
| 487 |
+
"affected_objects": [
|
| 488 |
+
{{
|
| 489 |
+
"noun": "chair",
|
| 490 |
+
"will_move": false,
|
| 491 |
+
"why": "chair RESTS ON FLOOR - person sitting on it doesn't make it fall. Chair stays on floor when person removed"
|
| 492 |
+
}}
|
| 493 |
+
]
|
| 494 |
+
}}
|
| 495 |
+
|
| 496 |
+
Example 5: Person throws ball (ball appears at frame 12)
|
| 497 |
+
{{
|
| 498 |
+
"affected_objects": [
|
| 499 |
+
{{
|
| 500 |
+
"noun": "ball",
|
| 501 |
+
"category": "physical",
|
| 502 |
+
"will_move": true,
|
| 503 |
+
"first_appears_frame": 12,
|
| 504 |
+
"why": "ball is SUPPORTED by person's hand, then thrown",
|
| 505 |
+
"object_size_grids": {{"rows": 2, "cols": 2}},
|
| 506 |
+
"trajectory_path": [
|
| 507 |
+
{{"frame": 12, "grid_row": 4, "grid_col": 3}},
|
| 508 |
+
{{"frame": 20, "grid_row": 2, "grid_col": 6}},
|
| 509 |
+
{{"frame": 28, "grid_row": 5, "grid_col": 8}}
|
| 510 |
+
]
|
| 511 |
+
}}
|
| 512 |
+
]
|
| 513 |
+
}}
|
| 514 |
+
|
| 515 |
+
Example 6: Person with shadow (shadow needs grid localization)
|
| 516 |
+
{{
|
| 517 |
+
"affected_objects": [
|
| 518 |
+
{{
|
| 519 |
+
"noun": "shadow",
|
| 520 |
+
"category": "visual_artifact",
|
| 521 |
+
"why": "cast by person on the floor",
|
| 522 |
+
"will_move": false,
|
| 523 |
+
"first_appears_frame": 0,
|
| 524 |
+
"movement_description": "Disappears entirely as visual artifact",
|
| 525 |
+
"grid_localizations": [
|
| 526 |
+
{{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}},
|
| 527 |
+
{{"frame": 12, "grid_regions": [{{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}, {{"row": 7, "col": 4}}]}},
|
| 528 |
+
{{"frame": 23, "grid_regions": [{{"row": 5, "col": 4}}, {{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}]}},
|
| 529 |
+
{{"frame": 35, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}]}},
|
| 530 |
+
{{"frame": 47, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}}
|
| 531 |
+
]
|
| 532 |
+
}}
|
| 533 |
+
]
|
| 534 |
+
}}
|
| 535 |
+
|
| 536 |
+
Example 7: Golf club hits ball (Case 4 - currently moving but should stay)
|
| 537 |
+
{{
|
| 538 |
+
"affected_objects": [
|
| 539 |
+
{{
|
| 540 |
+
"noun": "golf ball",
|
| 541 |
+
"category": "physical",
|
| 542 |
+
"currently_moving": true,
|
| 543 |
+
"should_have_stayed": true,
|
| 544 |
+
"original_position_grid": {{"row": 6, "col": 7}},
|
| 545 |
+
"first_appears_frame": 0,
|
| 546 |
+
"why": "ball was stationary on tee until club hit it. Without club, ball would remain at original position."
|
| 547 |
+
}}
|
| 548 |
+
]
|
| 549 |
+
}}
|
| 550 |
+
|
| 551 |
+
YOUR OUTPUT FORMAT:
|
| 552 |
+
{{
|
| 553 |
+
"edit_instruction": "{instruction}",
|
| 554 |
+
"integral_belongings": [
|
| 555 |
+
{{
|
| 556 |
+
"noun": "bike",
|
| 557 |
+
"why": "person is riding the bike throughout the video"
|
| 558 |
+
}}
|
| 559 |
+
],
|
| 560 |
+
"affected_objects": [
|
| 561 |
+
{{
|
| 562 |
+
"noun": "guitar",
|
| 563 |
+
"category": "physical",
|
| 564 |
+
"why": "person is SUPPORTING guitar against gravity by holding it",
|
| 565 |
+
"will_move": true,
|
| 566 |
+
"first_appears_frame": 0,
|
| 567 |
+
"movement_description": "Will fall from held position to the ground",
|
| 568 |
+
"object_size_grids": {{"rows": 3, "cols": 2}},
|
| 569 |
+
"trajectory_path": [
|
| 570 |
+
{{"frame": 0, "grid_row": 3, "grid_col": 6}},
|
| 571 |
+
{{"frame": 20, "grid_row": 6, "grid_col": 6}},
|
| 572 |
+
{{"frame": 40, "grid_row": 9, "grid_col": 7}}
|
| 573 |
+
]
|
| 574 |
+
}},
|
| 575 |
+
{{
|
| 576 |
+
"noun": "shadow",
|
| 577 |
+
"category": "visual_artifact",
|
| 578 |
+
"why": "cast by person on floor",
|
| 579 |
+
"will_move": false,
|
| 580 |
+
"first_appears_frame": 0,
|
| 581 |
+
"movement_description": "Disappears entirely as visual artifact"
|
| 582 |
+
}}
|
| 583 |
+
],
|
| 584 |
+
"scene_description": "An acoustic guitar falling to the ground in an empty room. Natural window lighting.",
|
| 585 |
+
"confidence": 0.85
|
| 586 |
+
}}
|
| 587 |
+
|
| 588 |
+
CRITICAL REMINDERS:
|
| 589 |
+
β’ Watch ENTIRE video before answering
|
| 590 |
+
β’ SUPPORTING vs ACTING ON:
|
| 591 |
+
- Primary HOLDS UP object against gravity β will_move=TRUE (provide grid trajectory)
|
| 592 |
+
- Primary ACTS ON object (crushing, opening) but object on stable surface β will_move=FALSE
|
| 593 |
+
- Object RESTS ON stable surface (table, floor) β will_move=FALSE
|
| 594 |
+
β’ For visual artifacts (shadow, reflection): will_move=false (no trajectory needed)
|
| 595 |
+
β’ For held objects (guitar, cup): will_move=true (MUST provide object_size_grids + trajectory_path)
|
| 596 |
+
β’ For objects on surfaces being acted on (can being crushed, can being opened): will_move=false
|
| 597 |
+
β’ Grid trajectory: Add +1 cell padding to object size (over-mask is better than under-mask)
|
| 598 |
+
β’ Grid trajectory: Use the yellow grid overlay to determine (row, col) positions
|
| 599 |
+
β’ Be conservative - when in doubt, DON'T include
|
| 600 |
+
β’ Output MUST be valid JSON only
|
| 601 |
+
|
| 602 |
+
GRID INFO: {grid_rows} rows Γ {grid_cols} columns
|
| 603 |
+
EDIT INSTRUCTION: {instruction}
|
| 604 |
+
""".strip()
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def call_vlm_with_images_and_video(client, model: str, image_data_urls: list,
|
| 608 |
+
video_data_url: str, prompt: str) -> str:
|
| 609 |
+
"""Call VLM with multiple images and video"""
|
| 610 |
+
content = []
|
| 611 |
+
|
| 612 |
+
# Add all images first
|
| 613 |
+
for img_url in image_data_urls:
|
| 614 |
+
content.append({"type": "image_url", "image_url": {"url": img_url}})
|
| 615 |
+
|
| 616 |
+
# Add video
|
| 617 |
+
content.append({"type": "image_url", "image_url": {"url": video_data_url}})
|
| 618 |
+
|
| 619 |
+
# Add prompt
|
| 620 |
+
content.append({"type": "text", "text": prompt})
|
| 621 |
+
|
| 622 |
+
resp = client.chat.completions.create(
|
| 623 |
+
model=model,
|
| 624 |
+
messages=[
|
| 625 |
+
{
|
| 626 |
+
"role": "system",
|
| 627 |
+
"content": "You are an expert video analyst with deep understanding of physics and object interactions. Always output valid JSON only."
|
| 628 |
+
},
|
| 629 |
+
{
|
| 630 |
+
"role": "user",
|
| 631 |
+
"content": content
|
| 632 |
+
},
|
| 633 |
+
],
|
| 634 |
+
)
|
| 635 |
+
return resp.choices[0].message.content
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def parse_vlm_response(raw: str) -> Dict:
|
| 639 |
+
"""Parse VLM JSON response"""
|
| 640 |
+
# Strip markdown code blocks
|
| 641 |
+
cleaned = raw.strip()
|
| 642 |
+
if cleaned.startswith("```"):
|
| 643 |
+
lines = cleaned.split('\n')
|
| 644 |
+
if lines[0].startswith("```"):
|
| 645 |
+
lines = lines[1:]
|
| 646 |
+
if lines and lines[-1].strip() == "```":
|
| 647 |
+
lines = lines[:-1]
|
| 648 |
+
cleaned = '\n'.join(lines)
|
| 649 |
+
|
| 650 |
+
try:
|
| 651 |
+
parsed = json.loads(cleaned)
|
| 652 |
+
except json.JSONDecodeError:
|
| 653 |
+
# Try to find JSON in response
|
| 654 |
+
start = cleaned.find("{")
|
| 655 |
+
end = cleaned.rfind("}")
|
| 656 |
+
if start != -1 and end != -1 and end > start:
|
| 657 |
+
parsed = json.loads(cleaned[start:end+1])
|
| 658 |
+
else:
|
| 659 |
+
raise ValueError("Failed to parse VLM response as JSON")
|
| 660 |
+
|
| 661 |
+
# Validate structure
|
| 662 |
+
result = {
|
| 663 |
+
"edit_instruction": parsed.get("edit_instruction", ""),
|
| 664 |
+
"integral_belongings": [],
|
| 665 |
+
"affected_objects": [],
|
| 666 |
+
"scene_description": parsed.get("scene_description", ""),
|
| 667 |
+
"confidence": float(parsed.get("confidence", 0.0))
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
# Parse integral belongings
|
| 671 |
+
for item in parsed.get("integral_belongings", [])[:3]:
|
| 672 |
+
obj = {
|
| 673 |
+
"noun": str(item.get("noun", "")).strip().lower(),
|
| 674 |
+
"why": str(item.get("why", "")).strip()[:200]
|
| 675 |
+
}
|
| 676 |
+
if obj["noun"]:
|
| 677 |
+
result["integral_belongings"].append(obj)
|
| 678 |
+
|
| 679 |
+
# Parse affected objects
|
| 680 |
+
for item in parsed.get("affected_objects", [])[:5]:
|
| 681 |
+
obj = {
|
| 682 |
+
"noun": str(item.get("noun", "")).strip().lower(),
|
| 683 |
+
"category": str(item.get("category", "physical")).strip().lower(),
|
| 684 |
+
"why": str(item.get("why", "")).strip()[:200],
|
| 685 |
+
"will_move": bool(item.get("will_move", False)),
|
| 686 |
+
"first_appears_frame": int(item.get("first_appears_frame", 0)),
|
| 687 |
+
"movement_description": str(item.get("movement_description", "")).strip()[:300]
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
# Parse Case 4: currently moving but should have stayed
|
| 691 |
+
if "currently_moving" in item:
|
| 692 |
+
obj["currently_moving"] = bool(item.get("currently_moving", False))
|
| 693 |
+
if "should_have_stayed" in item:
|
| 694 |
+
obj["should_have_stayed"] = bool(item.get("should_have_stayed", False))
|
| 695 |
+
if "original_position_grid" in item:
|
| 696 |
+
orig_grid = item.get("original_position_grid", {})
|
| 697 |
+
obj["original_position_grid"] = {
|
| 698 |
+
"row": int(orig_grid.get("row", 0)),
|
| 699 |
+
"col": int(orig_grid.get("col", 0))
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
# Parse grid localizations for visual artifacts
|
| 703 |
+
if "grid_localizations" in item:
|
| 704 |
+
grid_locs = []
|
| 705 |
+
for loc in item.get("grid_localizations", []):
|
| 706 |
+
frame_loc = {
|
| 707 |
+
"frame": int(loc.get("frame", 0)),
|
| 708 |
+
"grid_regions": []
|
| 709 |
+
}
|
| 710 |
+
for region in loc.get("grid_regions", []):
|
| 711 |
+
frame_loc["grid_regions"].append({
|
| 712 |
+
"row": int(region.get("row", 0)),
|
| 713 |
+
"col": int(region.get("col", 0))
|
| 714 |
+
})
|
| 715 |
+
if frame_loc["grid_regions"]: # Only add if has regions
|
| 716 |
+
grid_locs.append(frame_loc)
|
| 717 |
+
if grid_locs:
|
| 718 |
+
obj["grid_localizations"] = grid_locs
|
| 719 |
+
|
| 720 |
+
# Parse grid trajectory if will_move=true
|
| 721 |
+
if obj["will_move"] and "object_size_grids" in item and "trajectory_path" in item:
|
| 722 |
+
size_grids = item.get("object_size_grids", {})
|
| 723 |
+
obj["object_size_grids"] = {
|
| 724 |
+
"rows": int(size_grids.get("rows", 2)),
|
| 725 |
+
"cols": int(size_grids.get("cols", 2))
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
trajectory = []
|
| 729 |
+
for point in item.get("trajectory_path", []):
|
| 730 |
+
trajectory.append({
|
| 731 |
+
"frame": int(point.get("frame", 0)),
|
| 732 |
+
"grid_row": int(point.get("grid_row", 0)),
|
| 733 |
+
"grid_col": int(point.get("grid_col", 0))
|
| 734 |
+
})
|
| 735 |
+
|
| 736 |
+
if trajectory: # Only add if we have valid trajectory points
|
| 737 |
+
obj["trajectory_path"] = trajectory
|
| 738 |
+
|
| 739 |
+
if obj["noun"]:
|
| 740 |
+
result["affected_objects"].append(obj)
|
| 741 |
+
|
| 742 |
+
return result
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def process_video(video_info: Dict, client, model: str):
|
| 746 |
+
"""Process a single video with VLM analysis"""
|
| 747 |
+
video_path = video_info.get("video_path", "")
|
| 748 |
+
instruction = video_info.get("instruction", "")
|
| 749 |
+
output_dir = video_info.get("output_dir", "")
|
| 750 |
+
|
| 751 |
+
if not output_dir:
|
| 752 |
+
print(f" β οΈ No output_dir specified, skipping")
|
| 753 |
+
return None
|
| 754 |
+
|
| 755 |
+
output_dir = Path(output_dir)
|
| 756 |
+
if not output_dir.exists():
|
| 757 |
+
print(f" β οΈ Output directory not found: {output_dir}")
|
| 758 |
+
print(f" Run Stage 1 first to create black masks")
|
| 759 |
+
return None
|
| 760 |
+
|
| 761 |
+
# Check required files from Stage 1
|
| 762 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 763 |
+
first_frame_path = output_dir / "first_frame.jpg"
|
| 764 |
+
input_video_path = output_dir / "input_video.mp4"
|
| 765 |
+
segmentation_info_path = output_dir / "segmentation_info.json"
|
| 766 |
+
|
| 767 |
+
if not black_mask_path.exists():
|
| 768 |
+
print(f" β οΈ black_mask.mp4 not found in {output_dir}")
|
| 769 |
+
print(f" Run Stage 1 first")
|
| 770 |
+
return None
|
| 771 |
+
|
| 772 |
+
if not first_frame_path.exists():
|
| 773 |
+
print(f" β οΈ first_frame.jpg not found in {output_dir}")
|
| 774 |
+
return None
|
| 775 |
+
|
| 776 |
+
if not input_video_path.exists():
|
| 777 |
+
# Try original video path
|
| 778 |
+
if Path(video_path).exists():
|
| 779 |
+
input_video_path = Path(video_path)
|
| 780 |
+
else:
|
| 781 |
+
print(f" β οΈ Video not found: {video_path}")
|
| 782 |
+
return None
|
| 783 |
+
|
| 784 |
+
# Read segmentation metadata to get correct frame index
|
| 785 |
+
frame_idx = 0 # Default
|
| 786 |
+
if segmentation_info_path.exists():
|
| 787 |
+
try:
|
| 788 |
+
with open(segmentation_info_path, 'r') as f:
|
| 789 |
+
seg_info = json.load(f)
|
| 790 |
+
frame_idx = seg_info.get("first_appears_frame", 0)
|
| 791 |
+
print(f" Using frame {frame_idx} from segmentation metadata")
|
| 792 |
+
except Exception as e:
|
| 793 |
+
print(f" Warning: Could not read segmentation_info.json: {e}")
|
| 794 |
+
print(f" Using frame 0 as fallback")
|
| 795 |
+
|
| 796 |
+
# Get min_grid for grid calculation
|
| 797 |
+
min_grid = video_info.get('min_grid', 8)
|
| 798 |
+
use_multi_frame_grids = video_info.get('multi_frame_grids', True) # Default: use multi-frame
|
| 799 |
+
max_video_size_mb = video_info.get('max_video_size_for_multiframe', 25) # Default: 25MB limit
|
| 800 |
+
|
| 801 |
+
# Check video size and auto-disable multi-frame for large videos
|
| 802 |
+
if use_multi_frame_grids:
|
| 803 |
+
video_size_mb = input_video_path.stat().st_size / (1024 * 1024)
|
| 804 |
+
if video_size_mb > max_video_size_mb:
|
| 805 |
+
print(f" β οΈ Video size ({video_size_mb:.1f} MB) exceeds {max_video_size_mb} MB")
|
| 806 |
+
print(f" Auto-disabling multi-frame grids to avoid API errors")
|
| 807 |
+
use_multi_frame_grids = False
|
| 808 |
+
|
| 809 |
+
print(f" Creating frame overlays and grids...")
|
| 810 |
+
overlay_path = output_dir / "first_frame_with_mask.jpg"
|
| 811 |
+
gridded_path = output_dir / "first_frame_with_grid.jpg"
|
| 812 |
+
|
| 813 |
+
# Create regular overlay (for backwards compatibility)
|
| 814 |
+
create_first_frame_with_mask_overlay(
|
| 815 |
+
str(first_frame_path),
|
| 816 |
+
str(black_mask_path),
|
| 817 |
+
str(overlay_path),
|
| 818 |
+
frame_idx=frame_idx
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
image_data_urls = []
|
| 822 |
+
|
| 823 |
+
if use_multi_frame_grids:
|
| 824 |
+
# Create multi-frame grid samples for objects appearing mid-video
|
| 825 |
+
print(f" Creating multi-frame grid samples (0%, 25%, 50%, 75%, 100%)...")
|
| 826 |
+
sample_paths, grid_rows, grid_cols = create_multi_frame_grid_samples(
|
| 827 |
+
str(input_video_path),
|
| 828 |
+
output_dir,
|
| 829 |
+
min_grid=min_grid
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
# Encode all grid samples
|
| 833 |
+
for sample_path in sample_paths:
|
| 834 |
+
image_data_urls.append(image_to_data_url(str(sample_path)))
|
| 835 |
+
|
| 836 |
+
# Also add the first frame with mask overlay
|
| 837 |
+
_, _, _ = create_gridded_frame_overlay(
|
| 838 |
+
str(first_frame_path),
|
| 839 |
+
str(black_mask_path),
|
| 840 |
+
str(gridded_path),
|
| 841 |
+
min_grid=min_grid
|
| 842 |
+
)
|
| 843 |
+
image_data_urls.append(image_to_data_url(str(gridded_path)))
|
| 844 |
+
|
| 845 |
+
print(f" Grid: {grid_rows}x{grid_cols}, {len(sample_paths)} sample frames + masked frame")
|
| 846 |
+
|
| 847 |
+
else:
|
| 848 |
+
# Single gridded first frame (old approach)
|
| 849 |
+
_, grid_rows, grid_cols = create_gridded_frame_overlay(
|
| 850 |
+
str(first_frame_path),
|
| 851 |
+
str(black_mask_path),
|
| 852 |
+
str(gridded_path),
|
| 853 |
+
min_grid=min_grid
|
| 854 |
+
)
|
| 855 |
+
image_data_urls.append(image_to_data_url(str(gridded_path)))
|
| 856 |
+
print(f" Grid: {grid_rows}x{grid_cols} (single frame)")
|
| 857 |
+
|
| 858 |
+
print(f" Encoding video for VLM...")
|
| 859 |
+
|
| 860 |
+
# Check video size
|
| 861 |
+
video_size_mb = input_video_path.stat().st_size / (1024 * 1024)
|
| 862 |
+
print(f" Video size: {video_size_mb:.1f} MB")
|
| 863 |
+
|
| 864 |
+
if video_size_mb > 20:
|
| 865 |
+
print(f" β οΈ Warning: Large video may cause API errors")
|
| 866 |
+
if use_multi_frame_grids:
|
| 867 |
+
print(f" Consider setting multi_frame_grids=false for large videos")
|
| 868 |
+
|
| 869 |
+
video_data_url = video_to_data_url(str(input_video_path))
|
| 870 |
+
|
| 871 |
+
print(f" Calling {model}...")
|
| 872 |
+
prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
|
| 873 |
+
has_multi_frame_grids=use_multi_frame_grids)
|
| 874 |
+
|
| 875 |
+
try:
|
| 876 |
+
try:
|
| 877 |
+
raw_response = call_vlm_with_images_and_video(
|
| 878 |
+
client, model, image_data_urls, video_data_url, prompt
|
| 879 |
+
)
|
| 880 |
+
except Exception as e:
|
| 881 |
+
# If multi-frame fails (likely payload size issue), fall back to single frame
|
| 882 |
+
if use_multi_frame_grids and "400" in str(e):
|
| 883 |
+
print(f" β οΈ Multi-frame request failed (payload too large?)")
|
| 884 |
+
print(f" Falling back to single-frame grid mode...")
|
| 885 |
+
|
| 886 |
+
# Retry with just the gridded first frame
|
| 887 |
+
image_data_urls = [image_to_data_url(str(gridded_path))]
|
| 888 |
+
prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
|
| 889 |
+
has_multi_frame_grids=False)
|
| 890 |
+
|
| 891 |
+
try:
|
| 892 |
+
raw_response = call_vlm_with_images_and_video(
|
| 893 |
+
client, model, image_data_urls, video_data_url, prompt
|
| 894 |
+
)
|
| 895 |
+
print(f" β Single-frame fallback succeeded")
|
| 896 |
+
except Exception as e2:
|
| 897 |
+
raise e2 # Re-raise if fallback also fails
|
| 898 |
+
else:
|
| 899 |
+
raise # Re-raise if not a 400 or not multi-frame mode
|
| 900 |
+
|
| 901 |
+
# Parse and save results (runs whether first call succeeded or fallback succeeded)
|
| 902 |
+
print(f" Parsing VLM response...")
|
| 903 |
+
analysis = parse_vlm_response(raw_response)
|
| 904 |
+
|
| 905 |
+
# Save results
|
| 906 |
+
output_path = output_dir / "vlm_analysis.json"
|
| 907 |
+
with open(output_path, 'w') as f:
|
| 908 |
+
json.dump(analysis, f, indent=2)
|
| 909 |
+
|
| 910 |
+
print(f" β Saved VLM analysis: {output_path.name}")
|
| 911 |
+
|
| 912 |
+
# Print summary
|
| 913 |
+
print(f"\n Summary:")
|
| 914 |
+
print(f" - Integral belongings: {len(analysis['integral_belongings'])}")
|
| 915 |
+
for obj in analysis['integral_belongings']:
|
| 916 |
+
print(f" β’ {obj['noun']}: {obj['why']}")
|
| 917 |
+
|
| 918 |
+
print(f" - Affected objects: {len(analysis['affected_objects'])}")
|
| 919 |
+
for obj in analysis['affected_objects']:
|
| 920 |
+
move_str = "WILL MOVE" if obj['will_move'] else "STAYS/DISAPPEARS"
|
| 921 |
+
traj_str = ""
|
| 922 |
+
if obj.get('will_move') and 'trajectory_path' in obj:
|
| 923 |
+
num_points = len(obj['trajectory_path'])
|
| 924 |
+
size = obj.get('object_size_grids', {})
|
| 925 |
+
traj_str = f" (trajectory: {num_points} keyframes, size: {size.get('rows')}Γ{size.get('cols')} grids)"
|
| 926 |
+
print(f" β’ {obj['noun']}: {move_str}{traj_str}")
|
| 927 |
+
|
| 928 |
+
return analysis
|
| 929 |
+
|
| 930 |
+
except Exception as e:
|
| 931 |
+
print(f" β VLM analysis failed: {e}")
|
| 932 |
+
import traceback
|
| 933 |
+
traceback.print_exc()
|
| 934 |
+
return None
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def process_config(config_path: str, model: str = DEFAULT_MODEL):
|
| 938 |
+
"""Process all videos in config"""
|
| 939 |
+
config_path = Path(config_path)
|
| 940 |
+
|
| 941 |
+
# Load config
|
| 942 |
+
with open(config_path, 'r') as f:
|
| 943 |
+
config_data = json.load(f)
|
| 944 |
+
|
| 945 |
+
# Handle both formats
|
| 946 |
+
if isinstance(config_data, list):
|
| 947 |
+
videos = config_data
|
| 948 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 949 |
+
videos = config_data["videos"]
|
| 950 |
+
else:
|
| 951 |
+
raise ValueError("Config must be a list or have 'videos' key")
|
| 952 |
+
|
| 953 |
+
print(f"\n{'='*70}")
|
| 954 |
+
print(f"Stage 2: VLM Analysis - Identify Affected Objects")
|
| 955 |
+
print(f"{'='*70}")
|
| 956 |
+
print(f"Config: {config_path.name}")
|
| 957 |
+
print(f"Videos: {len(videos)}")
|
| 958 |
+
print(f"Model: {model}")
|
| 959 |
+
print(f"{'='*70}\n")
|
| 960 |
+
|
| 961 |
+
# Initialize VLM client
|
| 962 |
+
api_key = os.environ.get("GEMINI_API_KEY")
|
| 963 |
+
if not api_key:
|
| 964 |
+
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
| 965 |
+
client = openai.OpenAI(
|
| 966 |
+
api_key=api_key,
|
| 967 |
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
# Process each video
|
| 971 |
+
results = []
|
| 972 |
+
for i, video_info in enumerate(videos):
|
| 973 |
+
video_path = video_info.get("video_path", "")
|
| 974 |
+
instruction = video_info.get("instruction", "")
|
| 975 |
+
|
| 976 |
+
print(f"\n{'β'*70}")
|
| 977 |
+
print(f"Video {i+1}/{len(videos)}: {Path(video_path).name}")
|
| 978 |
+
print(f"{'β'*70}")
|
| 979 |
+
print(f"Instruction: {instruction}")
|
| 980 |
+
|
| 981 |
+
try:
|
| 982 |
+
analysis = process_video(video_info, client, model)
|
| 983 |
+
results.append({
|
| 984 |
+
"video": video_path,
|
| 985 |
+
"success": analysis is not None,
|
| 986 |
+
"analysis": analysis
|
| 987 |
+
})
|
| 988 |
+
|
| 989 |
+
if analysis:
|
| 990 |
+
print(f"\nβ
Video {i+1} complete!")
|
| 991 |
+
else:
|
| 992 |
+
print(f"\nβ οΈ Video {i+1} skipped")
|
| 993 |
+
|
| 994 |
+
except Exception as e:
|
| 995 |
+
print(f"\nβ Error processing video {i+1}: {e}")
|
| 996 |
+
results.append({
|
| 997 |
+
"video": video_path,
|
| 998 |
+
"success": False,
|
| 999 |
+
"error": str(e)
|
| 1000 |
+
})
|
| 1001 |
+
continue
|
| 1002 |
+
|
| 1003 |
+
# Summary
|
| 1004 |
+
print(f"\n{'='*70}")
|
| 1005 |
+
print(f"Stage 2 Complete!")
|
| 1006 |
+
print(f"{'='*70}")
|
| 1007 |
+
successful = sum(1 for r in results if r["success"])
|
| 1008 |
+
print(f"Successful: {successful}/{len(videos)}")
|
| 1009 |
+
print(f"{'='*70}\n")
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def main():
|
| 1013 |
+
parser = argparse.ArgumentParser(description="Stage 2: VLM Analysis")
|
| 1014 |
+
parser.add_argument("--config", required=True, help="Config JSON from Stage 1")
|
| 1015 |
+
parser.add_argument("--model", default=DEFAULT_MODEL, help="VLM model name")
|
| 1016 |
+
args = parser.parse_args()
|
| 1017 |
+
|
| 1018 |
+
process_config(args.config, args.model)
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
if __name__ == "__main__":
|
| 1022 |
+
main()
|
VLM-MASK-REASONER/stage2_vlm_analysis_cf.py
ADDED
|
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 2: VLM Analysis - Identify Affected Objects & Physics
|
| 4 |
+
(Cloudflare AI Gateway variant)
|
| 5 |
+
|
| 6 |
+
Identical to stage2_vlm_analysis.py but routes through the internal CF AI Gateway
|
| 7 |
+
instead of calling the Gemini API directly. Video is sent as sampled frames rather
|
| 8 |
+
than a raw video data URL (not supported by the OpenAI-compat endpoint).
|
| 9 |
+
|
| 10 |
+
Required environment variables:
|
| 11 |
+
CF_PROJECT_ID - Cloudflare AI Gateway project ID
|
| 12 |
+
CF_USER_ID - Cloudflare AI Gateway user ID
|
| 13 |
+
MODEL_ID - Model identifier to use (e.g. "gemini-3-pro-preview")
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python stage2_vlm_analysis_cf.py --config my_config_points.json
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import json
|
| 22 |
+
import argparse
|
| 23 |
+
import cv2
|
| 24 |
+
import numpy as np
|
| 25 |
+
import base64
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Dict, List
|
| 28 |
+
from PIL import Image, ImageDraw
|
| 29 |
+
|
| 30 |
+
import openai
|
| 31 |
+
|
| 32 |
+
DEFAULT_MODEL = "gemini-3-pro-preview"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def image_to_data_url(image_path: str) -> str:
|
| 36 |
+
"""Convert image file to base64 data URL"""
|
| 37 |
+
with open(image_path, 'rb') as f:
|
| 38 |
+
img_data = base64.b64encode(f.read()).decode('utf-8')
|
| 39 |
+
|
| 40 |
+
# Detect format
|
| 41 |
+
ext = Path(image_path).suffix.lower()
|
| 42 |
+
if ext == '.png':
|
| 43 |
+
mime = 'image/png'
|
| 44 |
+
elif ext in ['.jpg', '.jpeg']:
|
| 45 |
+
mime = 'image/jpeg'
|
| 46 |
+
else:
|
| 47 |
+
mime = 'image/jpeg'
|
| 48 |
+
|
| 49 |
+
return f"data:{mime};base64,{img_data}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def video_to_data_url(video_path: str) -> str:
|
| 53 |
+
"""Convert video file to base64 data URL"""
|
| 54 |
+
with open(video_path, 'rb') as f:
|
| 55 |
+
video_data = base64.b64encode(f.read()).decode('utf-8')
|
| 56 |
+
return f"data:video/mp4;base64,{video_data}"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> tuple:
|
| 60 |
+
"""Calculate grid dimensions matching stage3a logic"""
|
| 61 |
+
aspect_ratio = width / height
|
| 62 |
+
if width >= height:
|
| 63 |
+
grid_rows = min_grid
|
| 64 |
+
grid_cols = max(min_grid, round(min_grid * aspect_ratio))
|
| 65 |
+
else:
|
| 66 |
+
grid_cols = min_grid
|
| 67 |
+
grid_rows = max(min_grid, round(min_grid / aspect_ratio))
|
| 68 |
+
return grid_rows, grid_cols
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def create_first_frame_with_mask_overlay(first_frame_path: str, black_mask_path: str,
|
| 72 |
+
output_path: str, frame_idx: int = 0) -> str:
|
| 73 |
+
"""Create visualization of first frame with red overlay on primary object
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
first_frame_path: Path to first_frame.jpg
|
| 77 |
+
black_mask_path: Path to black_mask.mp4
|
| 78 |
+
output_path: Where to save overlay
|
| 79 |
+
frame_idx: Which frame to extract from black_mask.mp4 (default: 0)
|
| 80 |
+
"""
|
| 81 |
+
# Load first frame
|
| 82 |
+
frame = cv2.imread(first_frame_path)
|
| 83 |
+
if frame is None:
|
| 84 |
+
raise ValueError(f"Failed to load first frame: {first_frame_path}")
|
| 85 |
+
|
| 86 |
+
# Load black mask video and get the specified frame
|
| 87 |
+
cap = cv2.VideoCapture(black_mask_path)
|
| 88 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 89 |
+
ret, mask_frame = cap.read()
|
| 90 |
+
cap.release()
|
| 91 |
+
|
| 92 |
+
if not ret:
|
| 93 |
+
raise ValueError(f"Failed to load black mask frame {frame_idx}: {black_mask_path}")
|
| 94 |
+
|
| 95 |
+
# Convert mask to binary (0 = object, 255 = background)
|
| 96 |
+
if len(mask_frame.shape) == 3:
|
| 97 |
+
mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
|
| 98 |
+
|
| 99 |
+
object_mask = (mask_frame == 0)
|
| 100 |
+
|
| 101 |
+
# Create red overlay on object
|
| 102 |
+
overlay = frame.copy()
|
| 103 |
+
overlay[object_mask] = [0, 0, 255] # Red in BGR
|
| 104 |
+
|
| 105 |
+
# Blend: 60% original + 40% red overlay
|
| 106 |
+
result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
|
| 107 |
+
|
| 108 |
+
# Save
|
| 109 |
+
cv2.imwrite(output_path, result)
|
| 110 |
+
return output_path
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def create_gridded_frame_overlay(first_frame_path: str, black_mask_path: str,
|
| 114 |
+
output_path: str, min_grid: int = 8) -> tuple:
|
| 115 |
+
"""Create first frame with BOTH red mask overlay AND grid lines
|
| 116 |
+
|
| 117 |
+
Returns: (output_path, grid_rows, grid_cols)
|
| 118 |
+
"""
|
| 119 |
+
# Load first frame
|
| 120 |
+
frame = cv2.imread(first_frame_path)
|
| 121 |
+
if frame is None:
|
| 122 |
+
raise ValueError(f"Failed to load first frame: {first_frame_path}")
|
| 123 |
+
|
| 124 |
+
h, w = frame.shape[:2]
|
| 125 |
+
|
| 126 |
+
# Load black mask
|
| 127 |
+
cap = cv2.VideoCapture(black_mask_path)
|
| 128 |
+
ret, mask_frame = cap.read()
|
| 129 |
+
cap.release()
|
| 130 |
+
|
| 131 |
+
if not ret:
|
| 132 |
+
raise ValueError(f"Failed to load black mask: {black_mask_path}")
|
| 133 |
+
|
| 134 |
+
if len(mask_frame.shape) == 3:
|
| 135 |
+
mask_frame = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY)
|
| 136 |
+
|
| 137 |
+
object_mask = (mask_frame == 0)
|
| 138 |
+
|
| 139 |
+
# Create red overlay
|
| 140 |
+
overlay = frame.copy()
|
| 141 |
+
overlay[object_mask] = [0, 0, 255]
|
| 142 |
+
result = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
|
| 143 |
+
|
| 144 |
+
# Calculate grid
|
| 145 |
+
grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
|
| 146 |
+
|
| 147 |
+
# Draw grid lines
|
| 148 |
+
cell_width = w / grid_cols
|
| 149 |
+
cell_height = h / grid_rows
|
| 150 |
+
|
| 151 |
+
# Vertical lines
|
| 152 |
+
for col in range(1, grid_cols):
|
| 153 |
+
x = int(col * cell_width)
|
| 154 |
+
cv2.line(result, (x, 0), (x, h), (255, 255, 0), 1) # Yellow lines
|
| 155 |
+
|
| 156 |
+
# Horizontal lines
|
| 157 |
+
for row in range(1, grid_rows):
|
| 158 |
+
y = int(row * cell_height)
|
| 159 |
+
cv2.line(result, (0, y), (w, y), (255, 255, 0), 1)
|
| 160 |
+
|
| 161 |
+
# Add grid labels
|
| 162 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 163 |
+
font_scale = 0.3
|
| 164 |
+
thickness = 1
|
| 165 |
+
|
| 166 |
+
# Label columns at top
|
| 167 |
+
for col in range(grid_cols):
|
| 168 |
+
x = int((col + 0.5) * cell_width)
|
| 169 |
+
cv2.putText(result, str(col), (x-5, 15), font, font_scale, (255, 255, 0), thickness)
|
| 170 |
+
|
| 171 |
+
# Label rows on left
|
| 172 |
+
for row in range(grid_rows):
|
| 173 |
+
y = int((row + 0.5) * cell_height)
|
| 174 |
+
cv2.putText(result, str(row), (5, y+5), font, font_scale, (255, 255, 0), thickness)
|
| 175 |
+
|
| 176 |
+
cv2.imwrite(output_path, result)
|
| 177 |
+
return output_path, grid_rows, grid_cols
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def create_multi_frame_grid_samples(video_path: str, output_dir: Path,
|
| 181 |
+
min_grid: int = 8,
|
| 182 |
+
sample_points: list = [0.0, 0.11, 0.22, 0.33, 0.44, 0.56, 0.67, 0.78, 0.89, 1.0]) -> tuple:
|
| 183 |
+
"""
|
| 184 |
+
Create gridded frame samples at multiple time points in video.
|
| 185 |
+
Helps VLM see objects that appear mid-video with grid reference.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
video_path: Path to video
|
| 189 |
+
output_dir: Where to save samples
|
| 190 |
+
min_grid: Minimum grid size
|
| 191 |
+
sample_points: List of normalized positions [0.0-1.0] to sample
|
| 192 |
+
|
| 193 |
+
Returns: (sample_paths, grid_rows, grid_cols)
|
| 194 |
+
"""
|
| 195 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 196 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 197 |
+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 198 |
+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 199 |
+
|
| 200 |
+
# Calculate grid (same for all frames)
|
| 201 |
+
grid_rows, grid_cols = calculate_square_grid(w, h, min_grid)
|
| 202 |
+
cell_width = w / grid_cols
|
| 203 |
+
cell_height = h / grid_rows
|
| 204 |
+
|
| 205 |
+
sample_paths = []
|
| 206 |
+
|
| 207 |
+
for i, t in enumerate(sample_points):
|
| 208 |
+
frame_idx = int(t * (total_frames - 1))
|
| 209 |
+
frame_idx = max(0, min(frame_idx, total_frames - 1))
|
| 210 |
+
|
| 211 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 212 |
+
ret, frame = cap.read()
|
| 213 |
+
if not ret:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Draw grid
|
| 217 |
+
result = frame.copy()
|
| 218 |
+
|
| 219 |
+
# Vertical lines
|
| 220 |
+
for col in range(1, grid_cols):
|
| 221 |
+
x = int(col * cell_width)
|
| 222 |
+
cv2.line(result, (x, 0), (x, h), (255, 255, 0), 2)
|
| 223 |
+
|
| 224 |
+
# Horizontal lines
|
| 225 |
+
for row in range(1, grid_rows):
|
| 226 |
+
y = int(row * cell_height)
|
| 227 |
+
cv2.line(result, (0, y), (w, y), (255, 255, 0), 2)
|
| 228 |
+
|
| 229 |
+
# Add grid labels
|
| 230 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 231 |
+
font_scale = 0.4
|
| 232 |
+
thickness = 1
|
| 233 |
+
|
| 234 |
+
# Label columns
|
| 235 |
+
for col in range(grid_cols):
|
| 236 |
+
x = int((col + 0.5) * cell_width)
|
| 237 |
+
cv2.putText(result, str(col), (x-8, 20), font, font_scale, (255, 255, 0), thickness)
|
| 238 |
+
|
| 239 |
+
# Label rows
|
| 240 |
+
for row in range(grid_rows):
|
| 241 |
+
y = int((row + 0.5) * cell_height)
|
| 242 |
+
cv2.putText(result, str(row), (10, y+8), font, font_scale, (255, 255, 0), thickness)
|
| 243 |
+
|
| 244 |
+
# Add frame number and percentage
|
| 245 |
+
label = f"Frame {frame_idx} ({int(t*100)}%)"
|
| 246 |
+
cv2.putText(result, label, (10, h-10), font, 0.5, (255, 255, 0), 2)
|
| 247 |
+
|
| 248 |
+
# Save
|
| 249 |
+
output_path = output_dir / f"grid_sample_frame_{frame_idx:04d}.jpg"
|
| 250 |
+
cv2.imwrite(str(output_path), result)
|
| 251 |
+
sample_paths.append(output_path)
|
| 252 |
+
|
| 253 |
+
cap.release()
|
| 254 |
+
return sample_paths, grid_rows, grid_cols
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def make_vlm_analysis_prompt(instruction: str, grid_rows: int, grid_cols: int,
|
| 258 |
+
has_multi_frame_grids: bool = False) -> str:
|
| 259 |
+
"""Create VLM prompt for analyzing video with primary mask"""
|
| 260 |
+
|
| 261 |
+
grid_context = ""
|
| 262 |
+
if has_multi_frame_grids:
|
| 263 |
+
grid_context = f"""
|
| 264 |
+
1. **Multiple Grid Reference Frames**: Sampled frames at 0%, 11%, 22%, 33%, 44%, 56%, 67%, 78%, 89%, 100% of video
|
| 265 |
+
- Each frame shows YELLOW GRID with {grid_rows} rows Γ {grid_cols} columns
|
| 266 |
+
- Grid cells labeled (row, col) starting from (0, 0) at top-left
|
| 267 |
+
- Frame number shown at bottom
|
| 268 |
+
- Use these to locate objects that appear MID-VIDEO and track object positions across time
|
| 269 |
+
2. **First Frame with RED mask**: Shows what will be REMOVED (primary object)
|
| 270 |
+
3. **Full Video**: Complete action and interactions"""
|
| 271 |
+
else:
|
| 272 |
+
grid_context = f"""
|
| 273 |
+
1. **First Frame with Grid**: PRIMARY OBJECT highlighted in RED + GRID OVERLAY
|
| 274 |
+
- The red overlay shows what will be REMOVED (already masked)
|
| 275 |
+
- Yellow grid with {grid_rows} rows Γ {grid_cols} columns
|
| 276 |
+
- Grid cells are labeled (row, col) starting from (0, 0) at top-left
|
| 277 |
+
2. **Full Video**: Complete scene and action"""
|
| 278 |
+
|
| 279 |
+
return f"""
|
| 280 |
+
You are an expert video analyst specializing in physics and object interactions.
|
| 281 |
+
|
| 282 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
+
CONTEXT
|
| 284 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 285 |
+
|
| 286 |
+
You will see MULTIPLE inputs:
|
| 287 |
+
{grid_context}
|
| 288 |
+
|
| 289 |
+
Edit instruction: "{instruction}"
|
| 290 |
+
|
| 291 |
+
IMPORTANT: Some objects may NOT appear in first frame. They may enter later.
|
| 292 |
+
Watch the ENTIRE video and note when each object first appears.
|
| 293 |
+
|
| 294 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
+
YOUR TASK
|
| 296 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 297 |
+
|
| 298 |
+
Analyze what would happen if the PRIMARY OBJECT (shown in red) is removed.
|
| 299 |
+
Watch the ENTIRE video to see all interactions and movements.
|
| 300 |
+
|
| 301 |
+
STEP 1: IDENTIFY INTEGRAL BELONGINGS (0-3 items)
|
| 302 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 303 |
+
Items that should be ADDED to the primary removal mask (removed WITH primary object):
|
| 304 |
+
|
| 305 |
+
β INCLUDE:
|
| 306 |
+
β’ Distinct wearable items: hat, backpack, jacket (if separate/visible)
|
| 307 |
+
β’ Vehicles/equipment being ridden: bike, skateboard, surfboard, scooter
|
| 308 |
+
β’ Large carried items that are part of the subject
|
| 309 |
+
|
| 310 |
+
β DO NOT INCLUDE:
|
| 311 |
+
β’ Generic clothing (shirt, pants, shoes) - already captured with person
|
| 312 |
+
β’ Held items that could be set down: guitar, cup, phone, tools
|
| 313 |
+
β’ Objects they're interacting with but not wearing/riding
|
| 314 |
+
|
| 315 |
+
Examples:
|
| 316 |
+
β’ Person on bike β integral: "bike"
|
| 317 |
+
β’ Person with guitar β integral: none (guitar is affected, not integral)
|
| 318 |
+
β’ Surfer β integral: "surfboard"
|
| 319 |
+
β’ Boxer β integral: "boxing gloves" (wearable equipment)
|
| 320 |
+
|
| 321 |
+
STEP 2: IDENTIFY AFFECTED OBJECTS (0-5 objects)
|
| 322 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 323 |
+
Objects/effects that are SEPARATE from primary but affected by its removal.
|
| 324 |
+
|
| 325 |
+
CRITICAL: Do NOT include integral belongings from Step 1.
|
| 326 |
+
|
| 327 |
+
Two categories:
|
| 328 |
+
|
| 329 |
+
A) VISUAL ARTIFACTS (disappear when primary removed):
|
| 330 |
+
β’ shadow, reflection, wake, ripples, splash, footprints
|
| 331 |
+
β’ These vanish completely - no physics needed
|
| 332 |
+
|
| 333 |
+
**CRITICAL FOR VISUAL ARTIFACTS:**
|
| 334 |
+
You MUST provide GRID LOCALIZATIONS across the reference frames.
|
| 335 |
+
Keyword segmentation fails to isolate specific shadows/reflections.
|
| 336 |
+
|
| 337 |
+
For each visual artifact:
|
| 338 |
+
- Look at each grid reference frame you were shown
|
| 339 |
+
- Identify which grid cells the artifact occupies in EACH frame
|
| 340 |
+
- List all grid cells (row, col) that contain any part of it
|
| 341 |
+
- Be thorough - include ALL touched cells (over-mask is better than under-mask)
|
| 342 |
+
|
| 343 |
+
Format:
|
| 344 |
+
{{
|
| 345 |
+
"noun": "shadow",
|
| 346 |
+
"category": "visual_artifact",
|
| 347 |
+
"grid_localizations": [
|
| 348 |
+
{{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, ...]}},
|
| 349 |
+
{{"frame": 5, "grid_regions": [{{"row": 6, "col": 4}}, ...]}},
|
| 350 |
+
// ... for each reference frame shown
|
| 351 |
+
]
|
| 352 |
+
}}
|
| 353 |
+
|
| 354 |
+
B) PHYSICAL OBJECTS (may move, fall, or stay):
|
| 355 |
+
|
| 356 |
+
CRITICAL - Understand the difference:
|
| 357 |
+
|
| 358 |
+
**SUPPORTING vs ACTING ON:**
|
| 359 |
+
β’ SUPPORTING = holding UP against gravity β object WILL FALL when removed
|
| 360 |
+
Examples: holding guitar, carrying cup, person sitting on chair
|
| 361 |
+
β will_move: TRUE
|
| 362 |
+
|
| 363 |
+
β’ ACTING ON = touching/manipulating but object rests on stable surface β object STAYS
|
| 364 |
+
Examples: hand crushing can (can on table), hand opening can (can on counter),
|
| 365 |
+
hand pushing object (object on floor)
|
| 366 |
+
β will_move: FALSE
|
| 367 |
+
|
| 368 |
+
**Key Questions:**
|
| 369 |
+
1. Is the primary object HOLDING THIS UP against gravity?
|
| 370 |
+
- YES β will_move: true, needs_trajectory: true
|
| 371 |
+
- NO β Check next question
|
| 372 |
+
|
| 373 |
+
2. Is this object RESTING ON a stable surface (table, floor, counter)?
|
| 374 |
+
- YES β will_move: false (stays on surface when primary removed)
|
| 375 |
+
- NO β will_move: true
|
| 376 |
+
|
| 377 |
+
3. Is the primary object DOING an action TO this object?
|
| 378 |
+
- Opening can, crushing can, pushing button, turning knob
|
| 379 |
+
- When primary removed β action STOPS, object stays in current state
|
| 380 |
+
- will_move: false
|
| 381 |
+
|
| 382 |
+
**SPECIAL CASE - Object Currently Moving But Should Have Stayed:**
|
| 383 |
+
If primary object CAUSES another object to move (hitting, kicking, throwing):
|
| 384 |
+
- The object is currently moving in the video
|
| 385 |
+
- But WITHOUT primary, it would have stayed at its original position
|
| 386 |
+
- You MUST provide:
|
| 387 |
+
β’ "currently_moving": true
|
| 388 |
+
β’ "should_have_stayed": true
|
| 389 |
+
β’ "original_position_grid": {{"row": R, "col": C}} - Where it started
|
| 390 |
+
|
| 391 |
+
Examples:
|
| 392 |
+
- Golf club hits ball β Ball at tee, then flies (mark original tee position)
|
| 393 |
+
- Person kicks soccer ball β Ball on ground, then rolls (mark original ground position)
|
| 394 |
+
- Hand throws object β Object held, then flies (mark original held position)
|
| 395 |
+
|
| 396 |
+
Format:
|
| 397 |
+
{{
|
| 398 |
+
"noun": "golf ball",
|
| 399 |
+
"category": "physical",
|
| 400 |
+
"currently_moving": true,
|
| 401 |
+
"should_have_stayed": true,
|
| 402 |
+
"original_position_grid": {{"row": 6, "col": 7}},
|
| 403 |
+
"why": "ball was stationary until club hit it"
|
| 404 |
+
}}
|
| 405 |
+
|
| 406 |
+
For each physical object, determine:
|
| 407 |
+
- **will_move**: true ONLY if object will fall/move when support removed
|
| 408 |
+
- **first_appears_frame**: frame number object first appears (0 if from start)
|
| 409 |
+
- **why**: Brief explanation of relationship to primary object
|
| 410 |
+
|
| 411 |
+
IF will_move=TRUE, also provide GRID-BASED TRAJECTORY:
|
| 412 |
+
- **object_size_grids**: {{"rows": R, "cols": C}} - How many grid cells object occupies
|
| 413 |
+
IMPORTANT: Add 1 extra cell padding for safety (better to over-mask than under-mask)
|
| 414 |
+
Example: Object looks 2Γ1 β report as 3Γ2
|
| 415 |
+
|
| 416 |
+
- **trajectory_path**: List of keyframe positions as grid coordinates
|
| 417 |
+
Format: [{{"frame": N, "grid_row": R, "grid_col": C}}, ...]
|
| 418 |
+
- IMPORTANT: First keyframe should be at first_appears_frame (not frame 0 if object appears later!)
|
| 419 |
+
- Provide 3-5 keyframes spanning from first appearance to end
|
| 420 |
+
- (grid_row, grid_col) is the CENTER position of object at that frame
|
| 421 |
+
- Use the yellow grid reference frames to determine positions
|
| 422 |
+
- For objects appearing mid-video: use the grid samples to locate them
|
| 423 |
+
- Example: Object appears at frame 15, falls to bottom
|
| 424 |
+
[{{"frame": 15, "grid_row": 3, "grid_col": 5}}, β First appearance
|
| 425 |
+
{{"frame": 25, "grid_row": 6, "grid_col": 5}}, β Mid-fall
|
| 426 |
+
{{"frame": 35, "grid_row": 9, "grid_col": 5}}] β On ground
|
| 427 |
+
|
| 428 |
+
β Objects held/carried at ANY point in video
|
| 429 |
+
β Objects the primary supports or interacts with
|
| 430 |
+
β Visual effects visible at any time
|
| 431 |
+
|
| 432 |
+
β Background objects never touched
|
| 433 |
+
β Other people/animals with no contact
|
| 434 |
+
β Integral belongings (already in Step 1)
|
| 435 |
+
|
| 436 |
+
STEP 3: SCENE DESCRIPTION
|
| 437 |
+
ββββββββββββββββββββββββββ
|
| 438 |
+
Describe scene WITHOUT the primary object (1-2 sentences).
|
| 439 |
+
Focus on what remains and any dynamic changes (falling objects, etc).
|
| 440 |
+
|
| 441 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 442 |
+
OUTPUT FORMAT (STRICT JSON ONLY)
|
| 443 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 444 |
+
|
| 445 |
+
EXAMPLES TO LEARN FROM:
|
| 446 |
+
|
| 447 |
+
Example 1: Person holding guitar
|
| 448 |
+
{{
|
| 449 |
+
"affected_objects": [
|
| 450 |
+
{{
|
| 451 |
+
"noun": "guitar",
|
| 452 |
+
"will_move": true,
|
| 453 |
+
"why": "person is SUPPORTING guitar against gravity by holding it",
|
| 454 |
+
"object_size_grids": {{"rows": 3, "cols": 2}},
|
| 455 |
+
"trajectory_path": [
|
| 456 |
+
{{"frame": 0, "grid_row": 4, "grid_col": 5}},
|
| 457 |
+
{{"frame": 15, "grid_row": 6, "grid_col": 5}},
|
| 458 |
+
{{"frame": 30, "grid_row": 8, "grid_col": 6}}
|
| 459 |
+
]
|
| 460 |
+
}}
|
| 461 |
+
]
|
| 462 |
+
}}
|
| 463 |
+
|
| 464 |
+
Example 2: Hand crushing can on table
|
| 465 |
+
{{
|
| 466 |
+
"affected_objects": [
|
| 467 |
+
{{
|
| 468 |
+
"noun": "can",
|
| 469 |
+
"will_move": false,
|
| 470 |
+
"why": "can RESTS ON TABLE - hand is just acting on it. When hand removed, can stays on table (uncrushed)"
|
| 471 |
+
}}
|
| 472 |
+
]
|
| 473 |
+
}}
|
| 474 |
+
|
| 475 |
+
Example 3: Hands opening can on counter
|
| 476 |
+
{{
|
| 477 |
+
"affected_objects": [
|
| 478 |
+
{{
|
| 479 |
+
"noun": "can",
|
| 480 |
+
"will_move": false,
|
| 481 |
+
"why": "can RESTS ON COUNTER - hands are doing opening action. When hands removed, can stays closed on counter"
|
| 482 |
+
}}
|
| 483 |
+
]
|
| 484 |
+
}}
|
| 485 |
+
|
| 486 |
+
Example 4: Person sitting on chair
|
| 487 |
+
{{
|
| 488 |
+
"affected_objects": [
|
| 489 |
+
{{
|
| 490 |
+
"noun": "chair",
|
| 491 |
+
"will_move": false,
|
| 492 |
+
"why": "chair RESTS ON FLOOR - person sitting on it doesn't make it fall. Chair stays on floor when person removed"
|
| 493 |
+
}}
|
| 494 |
+
]
|
| 495 |
+
}}
|
| 496 |
+
|
| 497 |
+
Example 5: Person throws ball (ball appears at frame 12)
|
| 498 |
+
{{
|
| 499 |
+
"affected_objects": [
|
| 500 |
+
{{
|
| 501 |
+
"noun": "ball",
|
| 502 |
+
"category": "physical",
|
| 503 |
+
"will_move": true,
|
| 504 |
+
"first_appears_frame": 12,
|
| 505 |
+
"why": "ball is SUPPORTED by person's hand, then thrown",
|
| 506 |
+
"object_size_grids": {{"rows": 2, "cols": 2}},
|
| 507 |
+
"trajectory_path": [
|
| 508 |
+
{{"frame": 12, "grid_row": 4, "grid_col": 3}},
|
| 509 |
+
{{"frame": 20, "grid_row": 2, "grid_col": 6}},
|
| 510 |
+
{{"frame": 28, "grid_row": 5, "grid_col": 8}}
|
| 511 |
+
]
|
| 512 |
+
}}
|
| 513 |
+
]
|
| 514 |
+
}}
|
| 515 |
+
|
| 516 |
+
Example 6: Person with shadow (shadow needs grid localization)
|
| 517 |
+
{{
|
| 518 |
+
"affected_objects": [
|
| 519 |
+
{{
|
| 520 |
+
"noun": "shadow",
|
| 521 |
+
"category": "visual_artifact",
|
| 522 |
+
"why": "cast by person on the floor",
|
| 523 |
+
"will_move": false,
|
| 524 |
+
"first_appears_frame": 0,
|
| 525 |
+
"movement_description": "Disappears entirely as visual artifact",
|
| 526 |
+
"grid_localizations": [
|
| 527 |
+
{{"frame": 0, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}},
|
| 528 |
+
{{"frame": 12, "grid_regions": [{{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}, {{"row": 7, "col": 4}}]}},
|
| 529 |
+
{{"frame": 23, "grid_regions": [{{"row": 5, "col": 4}}, {{"row": 6, "col": 4}}, {{"row": 6, "col": 5}}]}},
|
| 530 |
+
{{"frame": 35, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 6, "col": 4}}, {{"row": 7, "col": 3}}]}},
|
| 531 |
+
{{"frame": 47, "grid_regions": [{{"row": 6, "col": 3}}, {{"row": 7, "col": 3}}, {{"row": 7, "col": 4}}]}}
|
| 532 |
+
]
|
| 533 |
+
}}
|
| 534 |
+
]
|
| 535 |
+
}}
|
| 536 |
+
|
| 537 |
+
Example 7: Golf club hits ball (Case 4 - currently moving but should stay)
|
| 538 |
+
{{
|
| 539 |
+
"affected_objects": [
|
| 540 |
+
{{
|
| 541 |
+
"noun": "golf ball",
|
| 542 |
+
"category": "physical",
|
| 543 |
+
"currently_moving": true,
|
| 544 |
+
"should_have_stayed": true,
|
| 545 |
+
"original_position_grid": {{"row": 6, "col": 7}},
|
| 546 |
+
"first_appears_frame": 0,
|
| 547 |
+
"why": "ball was stationary on tee until club hit it. Without club, ball would remain at original position."
|
| 548 |
+
}}
|
| 549 |
+
]
|
| 550 |
+
}}
|
| 551 |
+
|
| 552 |
+
YOUR OUTPUT FORMAT:
|
| 553 |
+
{{
|
| 554 |
+
"edit_instruction": "{instruction}",
|
| 555 |
+
"integral_belongings": [
|
| 556 |
+
{{
|
| 557 |
+
"noun": "bike",
|
| 558 |
+
"why": "person is riding the bike throughout the video"
|
| 559 |
+
}}
|
| 560 |
+
],
|
| 561 |
+
"affected_objects": [
|
| 562 |
+
{{
|
| 563 |
+
"noun": "guitar",
|
| 564 |
+
"category": "physical",
|
| 565 |
+
"why": "person is SUPPORTING guitar against gravity by holding it",
|
| 566 |
+
"will_move": true,
|
| 567 |
+
"first_appears_frame": 0,
|
| 568 |
+
"movement_description": "Will fall from held position to the ground",
|
| 569 |
+
"object_size_grids": {{"rows": 3, "cols": 2}},
|
| 570 |
+
"trajectory_path": [
|
| 571 |
+
{{"frame": 0, "grid_row": 3, "grid_col": 6}},
|
| 572 |
+
{{"frame": 20, "grid_row": 6, "grid_col": 6}},
|
| 573 |
+
{{"frame": 40, "grid_row": 9, "grid_col": 7}}
|
| 574 |
+
]
|
| 575 |
+
}},
|
| 576 |
+
{{
|
| 577 |
+
"noun": "shadow",
|
| 578 |
+
"category": "visual_artifact",
|
| 579 |
+
"why": "cast by person on floor",
|
| 580 |
+
"will_move": false,
|
| 581 |
+
"first_appears_frame": 0,
|
| 582 |
+
"movement_description": "Disappears entirely as visual artifact"
|
| 583 |
+
}}
|
| 584 |
+
],
|
| 585 |
+
"scene_description": "An acoustic guitar falling to the ground in an empty room. Natural window lighting.",
|
| 586 |
+
"confidence": 0.85
|
| 587 |
+
}}
|
| 588 |
+
|
| 589 |
+
CRITICAL REMINDERS:
|
| 590 |
+
β’ Watch ENTIRE video before answering
|
| 591 |
+
β’ SUPPORTING vs ACTING ON:
|
| 592 |
+
- Primary HOLDS UP object against gravity β will_move=TRUE (provide grid trajectory)
|
| 593 |
+
- Primary ACTS ON object (crushing, opening) but object on stable surface β will_move=FALSE
|
| 594 |
+
- Object RESTS ON stable surface (table, floor) β will_move=FALSE
|
| 595 |
+
β’ For visual artifacts (shadow, reflection): will_move=false (no trajectory needed)
|
| 596 |
+
β’ For held objects (guitar, cup): will_move=true (MUST provide object_size_grids + trajectory_path)
|
| 597 |
+
β’ For objects on surfaces being acted on (can being crushed, can being opened): will_move=false
|
| 598 |
+
β’ Grid trajectory: Add +1 cell padding to object size (over-mask is better than under-mask)
|
| 599 |
+
β’ Grid trajectory: Use the yellow grid overlay to determine (row, col) positions
|
| 600 |
+
β’ Be conservative - when in doubt, DON'T include
|
| 601 |
+
β’ Output MUST be valid JSON only
|
| 602 |
+
|
| 603 |
+
GRID INFO: {grid_rows} rows Γ {grid_cols} columns
|
| 604 |
+
EDIT INSTRUCTION: {instruction}
|
| 605 |
+
""".strip()
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def call_vlm_with_images_and_video(client, model: str, image_data_urls: list,
|
| 609 |
+
video_data_url: str, prompt: str) -> str:
|
| 610 |
+
"""Call VLM with sampled frame images.
|
| 611 |
+
|
| 612 |
+
The CF AI Gateway OpenAI-compat endpoint does not support video/mp4 base64
|
| 613 |
+
data URLs, so we rely solely on the sampled grid frames already in
|
| 614 |
+
image_data_urls. video_data_url is accepted for signature compatibility but
|
| 615 |
+
intentionally not sent.
|
| 616 |
+
"""
|
| 617 |
+
content = []
|
| 618 |
+
|
| 619 |
+
# Add all sampled frame images
|
| 620 |
+
for img_url in image_data_urls:
|
| 621 |
+
content.append({"type": "image_url", "image_url": {"url": img_url}})
|
| 622 |
+
|
| 623 |
+
# Add prompt
|
| 624 |
+
content.append({"type": "text", "text": prompt})
|
| 625 |
+
|
| 626 |
+
resp = client.chat.completions.create(
|
| 627 |
+
model=model,
|
| 628 |
+
messages=[
|
| 629 |
+
{
|
| 630 |
+
"role": "system",
|
| 631 |
+
"content": "You are an expert video analyst with deep understanding of physics and object interactions. Always output valid JSON only."
|
| 632 |
+
},
|
| 633 |
+
{
|
| 634 |
+
"role": "user",
|
| 635 |
+
"content": content
|
| 636 |
+
},
|
| 637 |
+
],
|
| 638 |
+
)
|
| 639 |
+
return resp.choices[0].message.content
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def parse_vlm_response(raw: str) -> Dict:
|
| 643 |
+
"""Parse VLM JSON response"""
|
| 644 |
+
# Strip markdown code blocks
|
| 645 |
+
cleaned = raw.strip()
|
| 646 |
+
if cleaned.startswith("```"):
|
| 647 |
+
lines = cleaned.split('\n')
|
| 648 |
+
if lines[0].startswith("```"):
|
| 649 |
+
lines = lines[1:]
|
| 650 |
+
if lines and lines[-1].strip() == "```":
|
| 651 |
+
lines = lines[:-1]
|
| 652 |
+
cleaned = '\n'.join(lines)
|
| 653 |
+
|
| 654 |
+
try:
|
| 655 |
+
parsed = json.loads(cleaned)
|
| 656 |
+
except json.JSONDecodeError:
|
| 657 |
+
# Try to find JSON in response
|
| 658 |
+
start = cleaned.find("{")
|
| 659 |
+
end = cleaned.rfind("}")
|
| 660 |
+
if start != -1 and end != -1 and end > start:
|
| 661 |
+
parsed = json.loads(cleaned[start:end+1])
|
| 662 |
+
else:
|
| 663 |
+
raise ValueError("Failed to parse VLM response as JSON")
|
| 664 |
+
|
| 665 |
+
# Validate structure
|
| 666 |
+
result = {
|
| 667 |
+
"edit_instruction": parsed.get("edit_instruction", ""),
|
| 668 |
+
"integral_belongings": [],
|
| 669 |
+
"affected_objects": [],
|
| 670 |
+
"scene_description": parsed.get("scene_description", ""),
|
| 671 |
+
"confidence": float(parsed.get("confidence", 0.0))
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
# Parse integral belongings
|
| 675 |
+
for item in parsed.get("integral_belongings", [])[:3]:
|
| 676 |
+
obj = {
|
| 677 |
+
"noun": str(item.get("noun", "")).strip().lower(),
|
| 678 |
+
"why": str(item.get("why", "")).strip()[:200]
|
| 679 |
+
}
|
| 680 |
+
if obj["noun"]:
|
| 681 |
+
result["integral_belongings"].append(obj)
|
| 682 |
+
|
| 683 |
+
# Parse affected objects
|
| 684 |
+
for item in parsed.get("affected_objects", [])[:5]:
|
| 685 |
+
obj = {
|
| 686 |
+
"noun": str(item.get("noun", "")).strip().lower(),
|
| 687 |
+
"category": str(item.get("category", "physical")).strip().lower(),
|
| 688 |
+
"why": str(item.get("why", "")).strip()[:200],
|
| 689 |
+
"will_move": bool(item.get("will_move", False)),
|
| 690 |
+
"first_appears_frame": int(item.get("first_appears_frame", 0)),
|
| 691 |
+
"movement_description": str(item.get("movement_description", "")).strip()[:300]
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
# Parse Case 4: currently moving but should have stayed
|
| 695 |
+
if "currently_moving" in item:
|
| 696 |
+
obj["currently_moving"] = bool(item.get("currently_moving", False))
|
| 697 |
+
if "should_have_stayed" in item:
|
| 698 |
+
obj["should_have_stayed"] = bool(item.get("should_have_stayed", False))
|
| 699 |
+
if "original_position_grid" in item:
|
| 700 |
+
orig_grid = item.get("original_position_grid", {})
|
| 701 |
+
obj["original_position_grid"] = {
|
| 702 |
+
"row": int(orig_grid.get("row", 0)),
|
| 703 |
+
"col": int(orig_grid.get("col", 0))
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
# Parse grid localizations for visual artifacts
|
| 707 |
+
if "grid_localizations" in item:
|
| 708 |
+
grid_locs = []
|
| 709 |
+
for loc in item.get("grid_localizations", []):
|
| 710 |
+
frame_loc = {
|
| 711 |
+
"frame": int(loc.get("frame", 0)),
|
| 712 |
+
"grid_regions": []
|
| 713 |
+
}
|
| 714 |
+
for region in loc.get("grid_regions", []):
|
| 715 |
+
frame_loc["grid_regions"].append({
|
| 716 |
+
"row": int(region.get("row", 0)),
|
| 717 |
+
"col": int(region.get("col", 0))
|
| 718 |
+
})
|
| 719 |
+
if frame_loc["grid_regions"]: # Only add if has regions
|
| 720 |
+
grid_locs.append(frame_loc)
|
| 721 |
+
if grid_locs:
|
| 722 |
+
obj["grid_localizations"] = grid_locs
|
| 723 |
+
|
| 724 |
+
# Parse grid trajectory if will_move=true
|
| 725 |
+
if obj["will_move"] and "object_size_grids" in item and "trajectory_path" in item:
|
| 726 |
+
size_grids = item.get("object_size_grids", {})
|
| 727 |
+
obj["object_size_grids"] = {
|
| 728 |
+
"rows": int(size_grids.get("rows", 2)),
|
| 729 |
+
"cols": int(size_grids.get("cols", 2))
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
trajectory = []
|
| 733 |
+
for point in item.get("trajectory_path", []):
|
| 734 |
+
trajectory.append({
|
| 735 |
+
"frame": int(point.get("frame", 0)),
|
| 736 |
+
"grid_row": int(point.get("grid_row", 0)),
|
| 737 |
+
"grid_col": int(point.get("grid_col", 0))
|
| 738 |
+
})
|
| 739 |
+
|
| 740 |
+
if trajectory: # Only add if we have valid trajectory points
|
| 741 |
+
obj["trajectory_path"] = trajectory
|
| 742 |
+
|
| 743 |
+
if obj["noun"]:
|
| 744 |
+
result["affected_objects"].append(obj)
|
| 745 |
+
|
| 746 |
+
return result
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def process_video(video_info: Dict, client, model: str):
|
| 750 |
+
"""Process a single video with VLM analysis"""
|
| 751 |
+
video_path = video_info.get("video_path", "")
|
| 752 |
+
instruction = video_info.get("instruction", "")
|
| 753 |
+
output_dir = video_info.get("output_dir", "")
|
| 754 |
+
|
| 755 |
+
if not output_dir:
|
| 756 |
+
print(f" β οΈ No output_dir specified, skipping")
|
| 757 |
+
return None
|
| 758 |
+
|
| 759 |
+
output_dir = Path(output_dir)
|
| 760 |
+
if not output_dir.exists():
|
| 761 |
+
print(f" β οΈ Output directory not found: {output_dir}")
|
| 762 |
+
print(f" Run Stage 1 first to create black masks")
|
| 763 |
+
return None
|
| 764 |
+
|
| 765 |
+
# Check required files from Stage 1
|
| 766 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 767 |
+
first_frame_path = output_dir / "first_frame.jpg"
|
| 768 |
+
input_video_path = output_dir / "input_video.mp4"
|
| 769 |
+
segmentation_info_path = output_dir / "segmentation_info.json"
|
| 770 |
+
|
| 771 |
+
if not black_mask_path.exists():
|
| 772 |
+
print(f" β οΈ black_mask.mp4 not found in {output_dir}")
|
| 773 |
+
print(f" Run Stage 1 first")
|
| 774 |
+
return None
|
| 775 |
+
|
| 776 |
+
if not first_frame_path.exists():
|
| 777 |
+
print(f" β οΈ first_frame.jpg not found in {output_dir}")
|
| 778 |
+
return None
|
| 779 |
+
|
| 780 |
+
if not input_video_path.exists():
|
| 781 |
+
# Try original video path
|
| 782 |
+
if Path(video_path).exists():
|
| 783 |
+
input_video_path = Path(video_path)
|
| 784 |
+
else:
|
| 785 |
+
print(f" β οΈ Video not found: {video_path}")
|
| 786 |
+
return None
|
| 787 |
+
|
| 788 |
+
# Read segmentation metadata to get correct frame index
|
| 789 |
+
frame_idx = 0 # Default
|
| 790 |
+
if segmentation_info_path.exists():
|
| 791 |
+
try:
|
| 792 |
+
with open(segmentation_info_path, 'r') as f:
|
| 793 |
+
seg_info = json.load(f)
|
| 794 |
+
frame_idx = seg_info.get("first_appears_frame", 0)
|
| 795 |
+
print(f" Using frame {frame_idx} from segmentation metadata")
|
| 796 |
+
except Exception as e:
|
| 797 |
+
print(f" Warning: Could not read segmentation_info.json: {e}")
|
| 798 |
+
print(f" Using frame 0 as fallback")
|
| 799 |
+
|
| 800 |
+
# Get min_grid for grid calculation
|
| 801 |
+
min_grid = video_info.get('min_grid', 8)
|
| 802 |
+
use_multi_frame_grids = video_info.get('multi_frame_grids', True) # Default: use multi-frame
|
| 803 |
+
max_video_size_mb = video_info.get('max_video_size_for_multiframe', 25) # Default: 25MB limit
|
| 804 |
+
|
| 805 |
+
# Check video size and auto-disable multi-frame for large videos
|
| 806 |
+
if use_multi_frame_grids:
|
| 807 |
+
video_size_mb = input_video_path.stat().st_size / (1024 * 1024)
|
| 808 |
+
if video_size_mb > max_video_size_mb:
|
| 809 |
+
print(f" β οΈ Video size ({video_size_mb:.1f} MB) exceeds {max_video_size_mb} MB")
|
| 810 |
+
print(f" Auto-disabling multi-frame grids to avoid API errors")
|
| 811 |
+
use_multi_frame_grids = False
|
| 812 |
+
|
| 813 |
+
print(f" Creating frame overlays and grids...")
|
| 814 |
+
overlay_path = output_dir / "first_frame_with_mask.jpg"
|
| 815 |
+
gridded_path = output_dir / "first_frame_with_grid.jpg"
|
| 816 |
+
|
| 817 |
+
# Create regular overlay (for backwards compatibility)
|
| 818 |
+
create_first_frame_with_mask_overlay(
|
| 819 |
+
str(first_frame_path),
|
| 820 |
+
str(black_mask_path),
|
| 821 |
+
str(overlay_path),
|
| 822 |
+
frame_idx=frame_idx
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
image_data_urls = []
|
| 826 |
+
|
| 827 |
+
if use_multi_frame_grids:
|
| 828 |
+
# Create multi-frame grid samples for objects appearing mid-video
|
| 829 |
+
print(f" Creating multi-frame grid samples (0%, 25%, 50%, 75%, 100%)...")
|
| 830 |
+
sample_paths, grid_rows, grid_cols = create_multi_frame_grid_samples(
|
| 831 |
+
str(input_video_path),
|
| 832 |
+
output_dir,
|
| 833 |
+
min_grid=min_grid
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
# Encode all grid samples
|
| 837 |
+
for sample_path in sample_paths:
|
| 838 |
+
image_data_urls.append(image_to_data_url(str(sample_path)))
|
| 839 |
+
|
| 840 |
+
# Also add the first frame with mask overlay
|
| 841 |
+
_, _, _ = create_gridded_frame_overlay(
|
| 842 |
+
str(first_frame_path),
|
| 843 |
+
str(black_mask_path),
|
| 844 |
+
str(gridded_path),
|
| 845 |
+
min_grid=min_grid
|
| 846 |
+
)
|
| 847 |
+
image_data_urls.append(image_to_data_url(str(gridded_path)))
|
| 848 |
+
|
| 849 |
+
print(f" Grid: {grid_rows}x{grid_cols}, {len(sample_paths)} sample frames + masked frame")
|
| 850 |
+
|
| 851 |
+
else:
|
| 852 |
+
# Single gridded first frame (old approach)
|
| 853 |
+
_, grid_rows, grid_cols = create_gridded_frame_overlay(
|
| 854 |
+
str(first_frame_path),
|
| 855 |
+
str(black_mask_path),
|
| 856 |
+
str(gridded_path),
|
| 857 |
+
min_grid=min_grid
|
| 858 |
+
)
|
| 859 |
+
image_data_urls.append(image_to_data_url(str(gridded_path)))
|
| 860 |
+
print(f" Grid: {grid_rows}x{grid_cols} (single frame)")
|
| 861 |
+
|
| 862 |
+
# CF gateway does not support video/mp4 base64 β pass None; frames already
|
| 863 |
+
# captured in image_data_urls above.
|
| 864 |
+
video_data_url = None
|
| 865 |
+
|
| 866 |
+
print(f" Calling {model}...")
|
| 867 |
+
prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
|
| 868 |
+
has_multi_frame_grids=use_multi_frame_grids)
|
| 869 |
+
|
| 870 |
+
try:
|
| 871 |
+
try:
|
| 872 |
+
raw_response = call_vlm_with_images_and_video(
|
| 873 |
+
client, model, image_data_urls, video_data_url, prompt
|
| 874 |
+
)
|
| 875 |
+
except Exception as e:
|
| 876 |
+
# If multi-frame fails (likely payload size issue), fall back to single frame
|
| 877 |
+
if use_multi_frame_grids and "400" in str(e):
|
| 878 |
+
print(f" β οΈ Multi-frame request failed (payload too large?)")
|
| 879 |
+
print(f" Falling back to single-frame grid mode...")
|
| 880 |
+
|
| 881 |
+
# Retry with just the gridded first frame
|
| 882 |
+
image_data_urls = [image_to_data_url(str(gridded_path))]
|
| 883 |
+
prompt = make_vlm_analysis_prompt(instruction, grid_rows, grid_cols,
|
| 884 |
+
has_multi_frame_grids=False)
|
| 885 |
+
|
| 886 |
+
try:
|
| 887 |
+
raw_response = call_vlm_with_images_and_video(
|
| 888 |
+
client, model, image_data_urls, video_data_url, prompt
|
| 889 |
+
)
|
| 890 |
+
print(f" β Single-frame fallback succeeded")
|
| 891 |
+
except Exception as e2:
|
| 892 |
+
raise e2 # Re-raise if fallback also fails
|
| 893 |
+
else:
|
| 894 |
+
raise # Re-raise if not a 400 or not multi-frame mode
|
| 895 |
+
|
| 896 |
+
# Parse and save results (runs whether first call succeeded or fallback succeeded)
|
| 897 |
+
print(f" Parsing VLM response...")
|
| 898 |
+
analysis = parse_vlm_response(raw_response)
|
| 899 |
+
|
| 900 |
+
# Save results
|
| 901 |
+
output_path = output_dir / "vlm_analysis.json"
|
| 902 |
+
with open(output_path, 'w') as f:
|
| 903 |
+
json.dump(analysis, f, indent=2)
|
| 904 |
+
|
| 905 |
+
print(f" β Saved VLM analysis: {output_path.name}")
|
| 906 |
+
|
| 907 |
+
# Print summary
|
| 908 |
+
print(f"\n Summary:")
|
| 909 |
+
print(f" - Integral belongings: {len(analysis['integral_belongings'])}")
|
| 910 |
+
for obj in analysis['integral_belongings']:
|
| 911 |
+
print(f" β’ {obj['noun']}: {obj['why']}")
|
| 912 |
+
|
| 913 |
+
print(f" - Affected objects: {len(analysis['affected_objects'])}")
|
| 914 |
+
for obj in analysis['affected_objects']:
|
| 915 |
+
move_str = "WILL MOVE" if obj['will_move'] else "STAYS/DISAPPEARS"
|
| 916 |
+
traj_str = ""
|
| 917 |
+
if obj.get('will_move') and 'trajectory_path' in obj:
|
| 918 |
+
num_points = len(obj['trajectory_path'])
|
| 919 |
+
size = obj.get('object_size_grids', {})
|
| 920 |
+
traj_str = f" (trajectory: {num_points} keyframes, size: {size.get('rows')}Γ{size.get('cols')} grids)"
|
| 921 |
+
print(f" β’ {obj['noun']}: {move_str}{traj_str}")
|
| 922 |
+
|
| 923 |
+
return analysis
|
| 924 |
+
|
| 925 |
+
except Exception as e:
|
| 926 |
+
print(f" β VLM analysis failed: {e}")
|
| 927 |
+
import traceback
|
| 928 |
+
traceback.print_exc()
|
| 929 |
+
return None
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
def process_config(config_path: str, model: str = DEFAULT_MODEL):
|
| 933 |
+
"""Process all videos in config"""
|
| 934 |
+
config_path = Path(config_path)
|
| 935 |
+
|
| 936 |
+
# Load config
|
| 937 |
+
with open(config_path, 'r') as f:
|
| 938 |
+
config_data = json.load(f)
|
| 939 |
+
|
| 940 |
+
# Handle both formats
|
| 941 |
+
if isinstance(config_data, list):
|
| 942 |
+
videos = config_data
|
| 943 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 944 |
+
videos = config_data["videos"]
|
| 945 |
+
else:
|
| 946 |
+
raise ValueError("Config must be a list or have 'videos' key")
|
| 947 |
+
|
| 948 |
+
print(f"\n{'='*70}")
|
| 949 |
+
print(f"Stage 2: VLM Analysis - Identify Affected Objects")
|
| 950 |
+
print(f"{'='*70}")
|
| 951 |
+
print(f"Config: {config_path.name}")
|
| 952 |
+
print(f"Videos: {len(videos)}")
|
| 953 |
+
print(f"Model: {model}")
|
| 954 |
+
print(f"{'='*70}\n")
|
| 955 |
+
|
| 956 |
+
# Initialize VLM client (CF AI Gateway)
|
| 957 |
+
cf_project_id = os.environ.get("CF_PROJECT_ID")
|
| 958 |
+
cf_user_id = os.environ.get("CF_USER_ID")
|
| 959 |
+
if not cf_project_id or not cf_user_id:
|
| 960 |
+
raise RuntimeError("CF_PROJECT_ID and CF_USER_ID environment variables must be set")
|
| 961 |
+
|
| 962 |
+
metadata = json.dumps({"project_id": cf_project_id, "user_id": cf_user_id})
|
| 963 |
+
client = openai.OpenAI(
|
| 964 |
+
api_key=os.environ.get("GEMINI_API_KEY", "placeholder"),
|
| 965 |
+
base_url="https://ai-gateway.plain-flower-4887.workers.dev/compat",
|
| 966 |
+
default_headers={"cf-aig-metadata": metadata},
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
# Model comes from MODEL_ID env var; fall back to --model arg
|
| 970 |
+
model = os.environ.get("MODEL_ID", model)
|
| 971 |
+
|
| 972 |
+
# Process each video
|
| 973 |
+
results = []
|
| 974 |
+
for i, video_info in enumerate(videos):
|
| 975 |
+
video_path = video_info.get("video_path", "")
|
| 976 |
+
instruction = video_info.get("instruction", "")
|
| 977 |
+
|
| 978 |
+
print(f"\n{'β'*70}")
|
| 979 |
+
print(f"Video {i+1}/{len(videos)}: {Path(video_path).name}")
|
| 980 |
+
print(f"{'β'*70}")
|
| 981 |
+
print(f"Instruction: {instruction}")
|
| 982 |
+
|
| 983 |
+
try:
|
| 984 |
+
analysis = process_video(video_info, client, model)
|
| 985 |
+
results.append({
|
| 986 |
+
"video": video_path,
|
| 987 |
+
"success": analysis is not None,
|
| 988 |
+
"analysis": analysis
|
| 989 |
+
})
|
| 990 |
+
|
| 991 |
+
if analysis:
|
| 992 |
+
print(f"\nβ
Video {i+1} complete!")
|
| 993 |
+
else:
|
| 994 |
+
print(f"\nβ οΈ Video {i+1} skipped")
|
| 995 |
+
|
| 996 |
+
except Exception as e:
|
| 997 |
+
print(f"\nβ Error processing video {i+1}: {e}")
|
| 998 |
+
results.append({
|
| 999 |
+
"video": video_path,
|
| 1000 |
+
"success": False,
|
| 1001 |
+
"error": str(e)
|
| 1002 |
+
})
|
| 1003 |
+
continue
|
| 1004 |
+
|
| 1005 |
+
# Summary
|
| 1006 |
+
print(f"\n{'='*70}")
|
| 1007 |
+
print(f"Stage 2 Complete!")
|
| 1008 |
+
print(f"{'='*70}")
|
| 1009 |
+
successful = sum(1 for r in results if r["success"])
|
| 1010 |
+
print(f"Successful: {successful}/{len(videos)}")
|
| 1011 |
+
print(f"{'='*70}\n")
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
def main():
|
| 1015 |
+
parser = argparse.ArgumentParser(description="Stage 2: VLM Analysis")
|
| 1016 |
+
parser.add_argument("--config", required=True, help="Config JSON from Stage 1")
|
| 1017 |
+
parser.add_argument("--model", default=DEFAULT_MODEL, help="VLM model name")
|
| 1018 |
+
args = parser.parse_args()
|
| 1019 |
+
|
| 1020 |
+
process_config(args.config, args.model)
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
if __name__ == "__main__":
|
| 1024 |
+
main()
|
VLM-MASK-REASONER/stage3a_generate_grey_masks.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 3a: Generate Grey Masks - Combine VLM Logic + User Trajectories
|
| 4 |
+
|
| 5 |
+
Generates grey masks (127=affected regions) by combining:
|
| 6 |
+
1. VLM-identified affected objects (segmented + gridified)
|
| 7 |
+
2. User-drawn trajectories (from Stage 3b)
|
| 8 |
+
3. Proximity filtering (only mask near primary object)
|
| 9 |
+
|
| 10 |
+
Input: - vlm_analysis.json (Stage 2)
|
| 11 |
+
- black_mask.mp4 (Stage 1)
|
| 12 |
+
- trajectories.json (Stage 3b, optional)
|
| 13 |
+
Output: - grey_mask.mp4 (127=affected, 255=background)
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python stage3a_generate_grey_masks.py --config more_dyn_2_config_points_absolute.json
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import json
|
| 22 |
+
import argparse
|
| 23 |
+
import cv2
|
| 24 |
+
import numpy as np
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Dict, List, Tuple
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import subprocess
|
| 29 |
+
|
| 30 |
+
# Segmentation model
|
| 31 |
+
try:
|
| 32 |
+
from sam3.model_builder import build_sam3_image_model
|
| 33 |
+
from sam3.model.sam3_image_processor import Sam3Processor
|
| 34 |
+
SAM3_AVAILABLE = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
SAM3_AVAILABLE = False
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from lang_sam import LangSAM
|
| 40 |
+
LANGSAM_AVAILABLE = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
LANGSAM_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SegmentationModel:
|
| 46 |
+
"""Wrapper for segmentation"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, model_type: str = "sam3"):
|
| 49 |
+
self.model_type = model_type.lower()
|
| 50 |
+
|
| 51 |
+
if self.model_type == "sam3":
|
| 52 |
+
if not SAM3_AVAILABLE:
|
| 53 |
+
raise ImportError("SAM3 not available")
|
| 54 |
+
print(f" Loading SAM3...")
|
| 55 |
+
model = build_sam3_image_model()
|
| 56 |
+
self.processor = Sam3Processor(model)
|
| 57 |
+
self.model = model
|
| 58 |
+
elif self.model_type == "langsam":
|
| 59 |
+
if not LANGSAM_AVAILABLE:
|
| 60 |
+
raise ImportError("LangSAM not available")
|
| 61 |
+
print(f" Loading LangSAM...")
|
| 62 |
+
self.model = LangSAM()
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unknown model: {model_type}")
|
| 65 |
+
|
| 66 |
+
def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 67 |
+
"""Segment object using text prompt"""
|
| 68 |
+
if self.model_type == "sam3":
|
| 69 |
+
return self._segment_sam3(image_pil, prompt)
|
| 70 |
+
else:
|
| 71 |
+
return self._segment_langsam(image_pil, prompt)
|
| 72 |
+
|
| 73 |
+
def _segment_sam3(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 74 |
+
import torch
|
| 75 |
+
h, w = image_pil.height, image_pil.width
|
| 76 |
+
union = np.zeros((h, w), dtype=bool)
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
inference_state = self.processor.set_image(image_pil)
|
| 80 |
+
output = self.processor.set_text_prompt(state=inference_state, prompt=prompt)
|
| 81 |
+
masks = output.get("masks")
|
| 82 |
+
|
| 83 |
+
if masks is None or len(masks) == 0:
|
| 84 |
+
return union
|
| 85 |
+
|
| 86 |
+
if torch.is_tensor(masks):
|
| 87 |
+
masks = masks.cpu().numpy()
|
| 88 |
+
|
| 89 |
+
if masks.ndim == 2:
|
| 90 |
+
union = masks.astype(bool)
|
| 91 |
+
elif masks.ndim == 3:
|
| 92 |
+
union = masks.any(axis=0).astype(bool)
|
| 93 |
+
elif masks.ndim == 4:
|
| 94 |
+
union = masks.any(axis=(0, 1)).astype(bool)
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f" Warning: SAM3 segmentation failed for '{prompt}': {e}")
|
| 98 |
+
|
| 99 |
+
return union
|
| 100 |
+
|
| 101 |
+
def _segment_langsam(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 102 |
+
h, w = image_pil.height, image_pil.width
|
| 103 |
+
union = np.zeros((h, w), dtype=bool)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
results = self.model.predict([image_pil], [prompt])
|
| 107 |
+
if not results:
|
| 108 |
+
return union
|
| 109 |
+
|
| 110 |
+
r0 = results[0]
|
| 111 |
+
if isinstance(r0, dict) and "masks" in r0:
|
| 112 |
+
masks = r0["masks"]
|
| 113 |
+
if masks.ndim == 4 and masks.shape[0] == 1:
|
| 114 |
+
masks = masks[0]
|
| 115 |
+
if masks.ndim == 3:
|
| 116 |
+
union = masks.any(axis=0).astype(bool)
|
| 117 |
+
elif masks.ndim == 2:
|
| 118 |
+
union = masks.astype(bool)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f" Warning: LangSAM segmentation failed for '{prompt}': {e}")
|
| 122 |
+
|
| 123 |
+
return union
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> Tuple[int, int]:
|
| 127 |
+
"""Calculate grid dimensions for square cells"""
|
| 128 |
+
aspect_ratio = width / height
|
| 129 |
+
if width >= height:
|
| 130 |
+
grid_rows = min_grid
|
| 131 |
+
grid_cols = max(min_grid, round(min_grid * aspect_ratio))
|
| 132 |
+
else:
|
| 133 |
+
grid_cols = min_grid
|
| 134 |
+
grid_rows = max(min_grid, round(min_grid / aspect_ratio))
|
| 135 |
+
return grid_rows, grid_cols
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def gridify_mask(mask: np.ndarray, grid_rows: int, grid_cols: int) -> np.ndarray:
|
| 139 |
+
"""Convert pixel mask to gridified mask"""
|
| 140 |
+
h, w = mask.shape
|
| 141 |
+
gridified = np.zeros((h, w), dtype=bool)
|
| 142 |
+
|
| 143 |
+
cell_width = w / grid_cols
|
| 144 |
+
cell_height = h / grid_rows
|
| 145 |
+
|
| 146 |
+
for row in range(grid_rows):
|
| 147 |
+
for col in range(grid_cols):
|
| 148 |
+
y1 = int(row * cell_height)
|
| 149 |
+
y2 = int((row + 1) * cell_height)
|
| 150 |
+
x1 = int(col * cell_width)
|
| 151 |
+
x2 = int((col + 1) * cell_width)
|
| 152 |
+
|
| 153 |
+
cell_region = mask[y1:y2, x1:x2]
|
| 154 |
+
if cell_region.any():
|
| 155 |
+
gridified[y1:y2, x1:x2] = True
|
| 156 |
+
|
| 157 |
+
return gridified
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def grid_cells_to_mask(grid_cells: List[List[int]], grid_rows: int, grid_cols: int,
|
| 161 |
+
frame_width: int, frame_height: int) -> np.ndarray:
|
| 162 |
+
"""Convert grid cells to mask"""
|
| 163 |
+
mask = np.zeros((frame_height, frame_width), dtype=bool)
|
| 164 |
+
|
| 165 |
+
cell_width = frame_width / grid_cols
|
| 166 |
+
cell_height = frame_height / grid_rows
|
| 167 |
+
|
| 168 |
+
for row, col in grid_cells:
|
| 169 |
+
y1 = int(row * cell_height)
|
| 170 |
+
y2 = int((row + 1) * cell_height)
|
| 171 |
+
x1 = int(col * cell_width)
|
| 172 |
+
x2 = int((col + 1) * cell_width)
|
| 173 |
+
mask[y1:y2, x1:x2] = True
|
| 174 |
+
|
| 175 |
+
return mask
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def dilate_mask(mask: np.ndarray, kernel_size: int = 15) -> np.ndarray:
|
| 179 |
+
"""Dilate mask to create proximity region"""
|
| 180 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 181 |
+
return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def filter_by_proximity(mask: np.ndarray, primary_mask: np.ndarray, dilation: int = 15) -> np.ndarray:
|
| 185 |
+
"""Filter mask to only include regions near primary mask"""
|
| 186 |
+
# Dilate primary mask to create proximity region
|
| 187 |
+
proximity_region = dilate_mask(primary_mask, dilation)
|
| 188 |
+
|
| 189 |
+
# Only keep mask where it overlaps with proximity region
|
| 190 |
+
filtered = mask & proximity_region
|
| 191 |
+
|
| 192 |
+
return filtered
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def process_video_grey_masks(video_info: Dict, segmenter: SegmentationModel,
|
| 196 |
+
trajectory_data: Dict = None):
|
| 197 |
+
"""Generate grey masks for a single video"""
|
| 198 |
+
video_path = video_info.get("video_path", "")
|
| 199 |
+
output_dir = Path(video_info.get("output_dir", ""))
|
| 200 |
+
|
| 201 |
+
if not output_dir.exists():
|
| 202 |
+
print(f" β οΈ Output directory not found: {output_dir}")
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
# Load required files
|
| 206 |
+
vlm_analysis_path = output_dir / "vlm_analysis.json"
|
| 207 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 208 |
+
input_video_path = output_dir / "input_video.mp4"
|
| 209 |
+
|
| 210 |
+
if not vlm_analysis_path.exists():
|
| 211 |
+
print(f" β οΈ vlm_analysis.json not found")
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
if not black_mask_path.exists():
|
| 215 |
+
print(f" β οΈ black_mask.mp4 not found")
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
if not input_video_path.exists():
|
| 219 |
+
input_video_path = Path(video_path)
|
| 220 |
+
if not input_video_path.exists():
|
| 221 |
+
print(f" β οΈ Video not found")
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
# Load VLM analysis
|
| 225 |
+
with open(vlm_analysis_path, 'r') as f:
|
| 226 |
+
analysis = json.load(f)
|
| 227 |
+
|
| 228 |
+
# Get video properties
|
| 229 |
+
cap = cv2.VideoCapture(str(input_video_path))
|
| 230 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 231 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 232 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 233 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 234 |
+
cap.release()
|
| 235 |
+
|
| 236 |
+
# Calculate grid
|
| 237 |
+
min_grid = video_info.get('min_grid', 8)
|
| 238 |
+
grid_rows, grid_cols = calculate_square_grid(frame_width, frame_height, min_grid)
|
| 239 |
+
|
| 240 |
+
print(f" Video: {frame_width}x{frame_height}, {total_frames} frames, grid: {grid_rows}x{grid_cols}")
|
| 241 |
+
|
| 242 |
+
# Load first frame
|
| 243 |
+
cap = cv2.VideoCapture(str(input_video_path))
|
| 244 |
+
ret, first_frame = cap.read()
|
| 245 |
+
cap.release()
|
| 246 |
+
|
| 247 |
+
if not ret:
|
| 248 |
+
print(f" β οΈ Failed to read first frame")
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 252 |
+
first_frame_pil = Image.fromarray(first_frame_rgb)
|
| 253 |
+
|
| 254 |
+
# Load black mask (first frame for proximity filtering)
|
| 255 |
+
black_cap = cv2.VideoCapture(str(black_mask_path))
|
| 256 |
+
ret, black_mask_frame = black_cap.read()
|
| 257 |
+
black_cap.release()
|
| 258 |
+
|
| 259 |
+
if not ret:
|
| 260 |
+
print(f" β οΈ Failed to read black mask")
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
+
if len(black_mask_frame.shape) == 3:
|
| 264 |
+
black_mask_frame = cv2.cvtColor(black_mask_frame, cv2.COLOR_BGR2GRAY)
|
| 265 |
+
|
| 266 |
+
primary_mask = (black_mask_frame == 0) # 0 = primary object
|
| 267 |
+
|
| 268 |
+
# Initialize grey mask
|
| 269 |
+
grey_mask_combined = np.zeros((frame_height, frame_width), dtype=bool)
|
| 270 |
+
|
| 271 |
+
# Process affected objects from VLM
|
| 272 |
+
affected_objects = analysis.get('affected_objects', [])
|
| 273 |
+
|
| 274 |
+
print(f" Processing {len(affected_objects)} affected object(s)...")
|
| 275 |
+
|
| 276 |
+
for obj in affected_objects:
|
| 277 |
+
noun = obj.get('noun', '')
|
| 278 |
+
category = obj.get('category', 'physical')
|
| 279 |
+
will_move = obj.get('will_move', False)
|
| 280 |
+
needs_trajectory = obj.get('needs_trajectory', False)
|
| 281 |
+
|
| 282 |
+
if not noun:
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
print(f" β’ {noun} ({category})")
|
| 286 |
+
|
| 287 |
+
# Check if we have trajectory data for this object
|
| 288 |
+
has_trajectory = False
|
| 289 |
+
if needs_trajectory and trajectory_data:
|
| 290 |
+
for traj in trajectory_data:
|
| 291 |
+
if traj.get('object_noun', '') == noun:
|
| 292 |
+
# Use trajectory grid cells
|
| 293 |
+
print(f" Using user-drawn trajectory ({len(traj['trajectory_grid_cells'])} cells)")
|
| 294 |
+
traj_mask = grid_cells_to_mask(
|
| 295 |
+
traj['trajectory_grid_cells'],
|
| 296 |
+
grid_rows, grid_cols,
|
| 297 |
+
frame_width, frame_height
|
| 298 |
+
)
|
| 299 |
+
grey_mask_combined |= traj_mask
|
| 300 |
+
has_trajectory = True
|
| 301 |
+
break
|
| 302 |
+
|
| 303 |
+
# If no trajectory or doesn't need one, segment normally
|
| 304 |
+
if not has_trajectory:
|
| 305 |
+
# Segment object
|
| 306 |
+
obj_mask = segmenter.segment(first_frame_pil, noun)
|
| 307 |
+
|
| 308 |
+
if obj_mask.any():
|
| 309 |
+
print(f" Segmented {obj_mask.sum()} pixels")
|
| 310 |
+
|
| 311 |
+
# Filter by proximity to primary mask
|
| 312 |
+
obj_mask_filtered = filter_by_proximity(obj_mask, primary_mask, dilation=50)
|
| 313 |
+
|
| 314 |
+
if obj_mask_filtered.any():
|
| 315 |
+
print(f" After proximity filter: {obj_mask_filtered.sum()} pixels")
|
| 316 |
+
|
| 317 |
+
# Gridify
|
| 318 |
+
obj_mask_gridified = gridify_mask(obj_mask_filtered, grid_rows, grid_cols)
|
| 319 |
+
|
| 320 |
+
# Add to combined grey mask
|
| 321 |
+
grey_mask_combined |= obj_mask_gridified
|
| 322 |
+
|
| 323 |
+
print(f" β Added to grey mask")
|
| 324 |
+
else:
|
| 325 |
+
print(f" β οΈ No pixels near primary object, skipping")
|
| 326 |
+
else:
|
| 327 |
+
print(f" β οΈ Segmentation failed")
|
| 328 |
+
|
| 329 |
+
# Generate grey mask video
|
| 330 |
+
print(f" Generating grey mask video...")
|
| 331 |
+
|
| 332 |
+
# For simplicity, use same mask for all frames
|
| 333 |
+
# (In future, could track objects through video)
|
| 334 |
+
grey_mask_uint8 = np.where(grey_mask_combined, 127, 255).astype(np.uint8)
|
| 335 |
+
|
| 336 |
+
# Write temp AVI
|
| 337 |
+
temp_avi = output_dir / "grey_mask_temp.avi"
|
| 338 |
+
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
|
| 339 |
+
out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (frame_width, frame_height), isColor=False)
|
| 340 |
+
|
| 341 |
+
for _ in range(total_frames):
|
| 342 |
+
out.write(grey_mask_uint8)
|
| 343 |
+
|
| 344 |
+
out.release()
|
| 345 |
+
|
| 346 |
+
# Convert to MP4
|
| 347 |
+
grey_mask_mp4 = output_dir / "grey_mask.mp4"
|
| 348 |
+
cmd = [
|
| 349 |
+
'ffmpeg', '-y', '-i', str(temp_avi),
|
| 350 |
+
'-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
|
| 351 |
+
'-pix_fmt', 'yuv444p',
|
| 352 |
+
str(grey_mask_mp4)
|
| 353 |
+
]
|
| 354 |
+
subprocess.run(cmd, capture_output=True)
|
| 355 |
+
temp_avi.unlink()
|
| 356 |
+
|
| 357 |
+
print(f" β Saved grey_mask.mp4")
|
| 358 |
+
|
| 359 |
+
# Save debug visualization
|
| 360 |
+
debug_vis = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
|
| 361 |
+
debug_vis[grey_mask_combined] = [0, 255, 0] # Green for affected regions
|
| 362 |
+
debug_vis[primary_mask] = [255, 0, 0] # Red for primary
|
| 363 |
+
debug_path = output_dir / "debug_grey_mask.jpg"
|
| 364 |
+
cv2.imwrite(str(debug_path), debug_vis)
|
| 365 |
+
print(f" β Saved debug visualization")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def main():
|
| 369 |
+
parser = argparse.ArgumentParser(description="Stage 3a: Generate Grey Masks")
|
| 370 |
+
parser.add_argument("--config", required=True, help="Config JSON")
|
| 371 |
+
parser.add_argument("--segmentation-model", default="sam3", choices=["langsam", "sam3"],
|
| 372 |
+
help="Segmentation model")
|
| 373 |
+
args = parser.parse_args()
|
| 374 |
+
|
| 375 |
+
config_path = Path(args.config)
|
| 376 |
+
|
| 377 |
+
# Load config
|
| 378 |
+
with open(config_path, 'r') as f:
|
| 379 |
+
config_data = json.load(f)
|
| 380 |
+
|
| 381 |
+
if isinstance(config_data, list):
|
| 382 |
+
videos = config_data
|
| 383 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 384 |
+
videos = config_data["videos"]
|
| 385 |
+
else:
|
| 386 |
+
raise ValueError("Invalid config format")
|
| 387 |
+
|
| 388 |
+
# Load trajectory data if exists
|
| 389 |
+
trajectory_path = config_path.parent / f"{config_path.stem}_trajectories.json"
|
| 390 |
+
trajectory_data = None
|
| 391 |
+
|
| 392 |
+
if trajectory_path.exists():
|
| 393 |
+
print(f"Loading trajectory data from: {trajectory_path.name}")
|
| 394 |
+
with open(trajectory_path, 'r') as f:
|
| 395 |
+
trajectory_data = json.load(f)
|
| 396 |
+
print(f" Loaded {len(trajectory_data)} trajectory(s)")
|
| 397 |
+
else:
|
| 398 |
+
print(f"No trajectory data found (Stage 3b not run or no objects needed trajectories)")
|
| 399 |
+
|
| 400 |
+
print(f"\n{'='*70}")
|
| 401 |
+
print(f"Stage 3a: Generate Grey Masks")
|
| 402 |
+
print(f"{'='*70}")
|
| 403 |
+
print(f"Videos: {len(videos)}")
|
| 404 |
+
print(f"Segmentation: {args.segmentation_model.upper()}")
|
| 405 |
+
print(f"{'='*70}\n")
|
| 406 |
+
|
| 407 |
+
# Load segmentation model
|
| 408 |
+
segmenter = SegmentationModel(args.segmentation_model)
|
| 409 |
+
|
| 410 |
+
# Process each video
|
| 411 |
+
for i, video_info in enumerate(videos):
|
| 412 |
+
video_path = video_info.get('video_path', '')
|
| 413 |
+
print(f"\n{'β'*70}")
|
| 414 |
+
print(f"Video {i+1}/{len(videos)}: {Path(video_path).parent.name}")
|
| 415 |
+
print(f"{'β'*70}")
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
process_video_grey_masks(video_info, segmenter, trajectory_data)
|
| 419 |
+
print(f"\nβ
Video {i+1} complete!")
|
| 420 |
+
|
| 421 |
+
except Exception as e:
|
| 422 |
+
print(f"\nβ Error processing video {i+1}: {e}")
|
| 423 |
+
import traceback
|
| 424 |
+
traceback.print_exc()
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
print(f"\n{'='*70}")
|
| 428 |
+
print(f"β
Stage 3a Complete!")
|
| 429 |
+
print(f"{'='*70}")
|
| 430 |
+
print(f"Generated grey_mask.mp4 for all videos")
|
| 431 |
+
print(f"Next: Run Stage 4 to combine black + grey masks")
|
| 432 |
+
print(f"{'='*70}\n")
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
if __name__ == "__main__":
|
| 436 |
+
main()
|
VLM-MASK-REASONER/stage3a_generate_grey_masks_v2.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 3a: Generate Grey Masks (CORRECTED)
|
| 4 |
+
|
| 5 |
+
Correct pipeline:
|
| 6 |
+
1. For EACH affected object (from VLM analysis):
|
| 7 |
+
a) IF user drew trajectory (Stage 3b):
|
| 8 |
+
- Segment object in first_appears_frame β get SIZE
|
| 9 |
+
- Apply object SIZE along trajectory path across all frames
|
| 10 |
+
b) ELSE (no user trajectory):
|
| 11 |
+
- Segment object through ALL frames (captures any movement/changes)
|
| 12 |
+
- This handles:
|
| 13 |
+
* Static objects (can, chair)
|
| 14 |
+
* Objects that move during video (golf ball)
|
| 15 |
+
* Dynamic effects (paint strokes, shadows)
|
| 16 |
+
- Filter by proximity to primary object
|
| 17 |
+
|
| 18 |
+
2. Accumulate all masks (one combined mask per frame)
|
| 19 |
+
|
| 20 |
+
3. Gridify ALL accumulated masks:
|
| 21 |
+
- If ANY pixel in grid cell β ENTIRE cell = 127
|
| 22 |
+
|
| 23 |
+
4. Write grey_mask.mp4
|
| 24 |
+
|
| 25 |
+
Key insight: will_move / needs_trajectory are ONLY for Stage 3b (user input).
|
| 26 |
+
In Stage 3a, we segment ALL affected objects through ALL frames.
|
| 27 |
+
|
| 28 |
+
Input: - vlm_analysis.json (Stage 2)
|
| 29 |
+
- black_mask.mp4 (Stage 1)
|
| 30 |
+
- trajectories.json (Stage 3b, optional)
|
| 31 |
+
Output: - grey_mask.mp4 (127=affected, 255=background)
|
| 32 |
+
|
| 33 |
+
Usage:
|
| 34 |
+
python stage3a_generate_grey_masks_v2.py --config more_dyn_2_config_points_absolute.json
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import os
|
| 38 |
+
import sys
|
| 39 |
+
import json
|
| 40 |
+
import argparse
|
| 41 |
+
import cv2
|
| 42 |
+
import numpy as np
|
| 43 |
+
from pathlib import Path
|
| 44 |
+
from typing import Dict, List, Tuple, Optional
|
| 45 |
+
from PIL import Image
|
| 46 |
+
import subprocess
|
| 47 |
+
|
| 48 |
+
# SAM2 for video tracking
|
| 49 |
+
try:
|
| 50 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 51 |
+
SAM2_AVAILABLE = True
|
| 52 |
+
except ImportError:
|
| 53 |
+
SAM2_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
# SAM3 for single-frame segmentation
|
| 56 |
+
try:
|
| 57 |
+
from sam3.model_builder import build_sam3_image_model
|
| 58 |
+
from sam3.model.sam3_image_processor import Sam3Processor
|
| 59 |
+
SAM3_AVAILABLE = True
|
| 60 |
+
except ImportError:
|
| 61 |
+
SAM3_AVAILABLE = False
|
| 62 |
+
|
| 63 |
+
# LangSAM
|
| 64 |
+
try:
|
| 65 |
+
from lang_sam import LangSAM
|
| 66 |
+
LANGSAM_AVAILABLE = True
|
| 67 |
+
except ImportError:
|
| 68 |
+
LANGSAM_AVAILABLE = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SegmentationModel:
|
| 72 |
+
"""Wrapper for segmentation"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, model_type: str = "sam3"):
|
| 75 |
+
self.model_type = model_type.lower()
|
| 76 |
+
|
| 77 |
+
if self.model_type == "sam3":
|
| 78 |
+
if not SAM3_AVAILABLE:
|
| 79 |
+
raise ImportError("SAM3 not available")
|
| 80 |
+
print(f" Loading SAM3...")
|
| 81 |
+
model = build_sam3_image_model()
|
| 82 |
+
self.processor = Sam3Processor(model)
|
| 83 |
+
self.model = model
|
| 84 |
+
elif self.model_type == "langsam":
|
| 85 |
+
if not LANGSAM_AVAILABLE:
|
| 86 |
+
raise ImportError("LangSAM not available")
|
| 87 |
+
print(f" Loading LangSAM...")
|
| 88 |
+
self.model = LangSAM()
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"Unknown model: {model_type}")
|
| 91 |
+
|
| 92 |
+
def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 93 |
+
"""Segment object using text prompt - returns boolean mask"""
|
| 94 |
+
if self.model_type == "sam3":
|
| 95 |
+
return self._segment_sam3(image_pil, prompt)
|
| 96 |
+
else:
|
| 97 |
+
return self._segment_langsam(image_pil, prompt)
|
| 98 |
+
|
| 99 |
+
def _segment_sam3(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 100 |
+
import torch
|
| 101 |
+
h, w = image_pil.height, image_pil.width
|
| 102 |
+
union = np.zeros((h, w), dtype=bool)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
inference_state = self.processor.set_image(image_pil)
|
| 106 |
+
output = self.processor.set_text_prompt(state=inference_state, prompt=prompt)
|
| 107 |
+
masks = output.get("masks")
|
| 108 |
+
|
| 109 |
+
if masks is None or len(masks) == 0:
|
| 110 |
+
return union
|
| 111 |
+
|
| 112 |
+
if torch.is_tensor(masks):
|
| 113 |
+
masks = masks.cpu().numpy()
|
| 114 |
+
|
| 115 |
+
if masks.ndim == 2:
|
| 116 |
+
union = masks.astype(bool)
|
| 117 |
+
elif masks.ndim == 3:
|
| 118 |
+
union = masks.any(axis=0).astype(bool)
|
| 119 |
+
elif masks.ndim == 4:
|
| 120 |
+
union = masks.any(axis=(0, 1)).astype(bool)
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f" Warning: SAM3 failed: {e}")
|
| 124 |
+
|
| 125 |
+
return union
|
| 126 |
+
|
| 127 |
+
def _segment_langsam(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 128 |
+
h, w = image_pil.height, image_pil.width
|
| 129 |
+
union = np.zeros((h, w), dtype=bool)
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
results = self.model.predict([image_pil], [prompt])
|
| 133 |
+
if not results:
|
| 134 |
+
return union
|
| 135 |
+
|
| 136 |
+
r0 = results[0]
|
| 137 |
+
if isinstance(r0, dict) and "masks" in r0:
|
| 138 |
+
masks = r0["masks"]
|
| 139 |
+
if masks.ndim == 4 and masks.shape[0] == 1:
|
| 140 |
+
masks = masks[0]
|
| 141 |
+
if masks.ndim == 3:
|
| 142 |
+
union = masks.any(axis=0).astype(bool)
|
| 143 |
+
elif masks.ndim == 2:
|
| 144 |
+
union = masks.astype(bool)
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f" Warning: LangSAM failed: {e}")
|
| 148 |
+
|
| 149 |
+
return union
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> Tuple[int, int]:
|
| 153 |
+
"""Calculate grid dimensions for square cells"""
|
| 154 |
+
aspect_ratio = width / height
|
| 155 |
+
if width >= height:
|
| 156 |
+
grid_rows = min_grid
|
| 157 |
+
grid_cols = max(min_grid, round(min_grid * aspect_ratio))
|
| 158 |
+
else:
|
| 159 |
+
grid_cols = min_grid
|
| 160 |
+
grid_rows = max(min_grid, round(min_grid / aspect_ratio))
|
| 161 |
+
return grid_rows, grid_cols
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def gridify_masks(masks: List[np.ndarray], grid_rows: int, grid_cols: int) -> List[np.ndarray]:
|
| 165 |
+
"""
|
| 166 |
+
Gridify masks: if ANY pixel in grid cell β ENTIRE cell = True
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
masks: List of boolean masks (one per frame)
|
| 170 |
+
grid_rows, grid_cols: Grid dimensions
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
List of gridified boolean masks
|
| 174 |
+
"""
|
| 175 |
+
gridified_masks = []
|
| 176 |
+
|
| 177 |
+
for mask in masks:
|
| 178 |
+
h, w = mask.shape
|
| 179 |
+
gridified = np.zeros((h, w), dtype=bool)
|
| 180 |
+
|
| 181 |
+
cell_width = w / grid_cols
|
| 182 |
+
cell_height = h / grid_rows
|
| 183 |
+
|
| 184 |
+
for row in range(grid_rows):
|
| 185 |
+
for col in range(grid_cols):
|
| 186 |
+
y1 = int(row * cell_height)
|
| 187 |
+
y2 = int((row + 1) * cell_height)
|
| 188 |
+
x1 = int(col * cell_width)
|
| 189 |
+
x2 = int((col + 1) * cell_width)
|
| 190 |
+
|
| 191 |
+
cell_region = mask[y1:y2, x1:x2]
|
| 192 |
+
# If ANY pixel in cell β ENTIRE cell
|
| 193 |
+
if cell_region.any():
|
| 194 |
+
gridified[y1:y2, x1:x2] = True
|
| 195 |
+
|
| 196 |
+
gridified_masks.append(gridified)
|
| 197 |
+
|
| 198 |
+
return gridified_masks
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_object_size(mask: np.ndarray) -> Tuple[int, int]:
|
| 202 |
+
"""Get bounding box size of object"""
|
| 203 |
+
rows = np.any(mask, axis=1)
|
| 204 |
+
cols = np.any(mask, axis=0)
|
| 205 |
+
|
| 206 |
+
if not rows.any() or not cols.any():
|
| 207 |
+
return 0, 0
|
| 208 |
+
|
| 209 |
+
y1, y2 = np.where(rows)[0][[0, -1]]
|
| 210 |
+
x1, x2 = np.where(cols)[0][[0, -1]]
|
| 211 |
+
|
| 212 |
+
width = x2 - x1 + 1
|
| 213 |
+
height = y2 - y1 + 1
|
| 214 |
+
|
| 215 |
+
return width, height
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def apply_object_along_trajectory(obj_mask: np.ndarray, trajectory_points: List[Tuple[int, int]],
|
| 219 |
+
total_frames: int, frame_shape: Tuple[int, int]) -> List[np.ndarray]:
|
| 220 |
+
"""
|
| 221 |
+
Apply object along trajectory path across frames.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
obj_mask: Object mask from first_appears_frame
|
| 225 |
+
trajectory_points: List of (x, y) points defining path
|
| 226 |
+
total_frames: Total number of frames in video
|
| 227 |
+
frame_shape: (height, width)
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
List of masks (one per frame) with object placed along trajectory
|
| 231 |
+
"""
|
| 232 |
+
h, w = frame_shape
|
| 233 |
+
masks = [np.zeros((h, w), dtype=bool) for _ in range(total_frames)]
|
| 234 |
+
|
| 235 |
+
if len(trajectory_points) < 2:
|
| 236 |
+
return masks
|
| 237 |
+
|
| 238 |
+
# Get object size
|
| 239 |
+
obj_width, obj_height = get_object_size(obj_mask)
|
| 240 |
+
|
| 241 |
+
if obj_width == 0 or obj_height == 0:
|
| 242 |
+
return masks
|
| 243 |
+
|
| 244 |
+
# Interpolate trajectory across frames
|
| 245 |
+
num_traj_points = len(trajectory_points)
|
| 246 |
+
|
| 247 |
+
for frame_idx in range(total_frames):
|
| 248 |
+
# Map frame index to trajectory point
|
| 249 |
+
t = frame_idx / max(total_frames - 1, 1) # 0.0 to 1.0
|
| 250 |
+
traj_idx = int(t * (num_traj_points - 1))
|
| 251 |
+
traj_idx = min(traj_idx, num_traj_points - 1)
|
| 252 |
+
|
| 253 |
+
# Get position on trajectory
|
| 254 |
+
x_center, y_center = trajectory_points[traj_idx]
|
| 255 |
+
|
| 256 |
+
# Place object at this position
|
| 257 |
+
x1 = max(0, int(x_center - obj_width // 2))
|
| 258 |
+
y1 = max(0, int(y_center - obj_height // 2))
|
| 259 |
+
x2 = min(w, x1 + obj_width)
|
| 260 |
+
y2 = min(h, y1 + obj_height)
|
| 261 |
+
|
| 262 |
+
# Place object mask
|
| 263 |
+
masks[frame_idx][y1:y2, x1:x2] = True
|
| 264 |
+
|
| 265 |
+
return masks
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def segment_object_all_frames(video_path: str, obj_noun: str, segmenter: SegmentationModel,
|
| 269 |
+
frame_stride: int = 1) -> List[np.ndarray]:
|
| 270 |
+
"""
|
| 271 |
+
Segment object through all frames.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
video_path: Path to video
|
| 275 |
+
obj_noun: Object to segment
|
| 276 |
+
segmenter: Segmentation model
|
| 277 |
+
frame_stride: Process every Nth frame (for speed)
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
List of boolean masks (one per frame)
|
| 281 |
+
"""
|
| 282 |
+
cap = cv2.VideoCapture(video_path)
|
| 283 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 284 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 285 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 286 |
+
|
| 287 |
+
masks = []
|
| 288 |
+
frame_idx = 0
|
| 289 |
+
|
| 290 |
+
while True:
|
| 291 |
+
ret, frame = cap.read()
|
| 292 |
+
if not ret:
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
if frame_idx % frame_stride == 0:
|
| 296 |
+
# Segment this frame
|
| 297 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 298 |
+
frame_pil = Image.fromarray(frame_rgb)
|
| 299 |
+
mask = segmenter.segment(frame_pil, obj_noun)
|
| 300 |
+
masks.append(mask)
|
| 301 |
+
|
| 302 |
+
if (frame_idx + 1) % 10 == 0:
|
| 303 |
+
print(f" Frame {frame_idx + 1}/{total_frames}...", end='\r')
|
| 304 |
+
else:
|
| 305 |
+
# Reuse previous mask
|
| 306 |
+
if masks:
|
| 307 |
+
masks.append(masks[-1])
|
| 308 |
+
else:
|
| 309 |
+
masks.append(np.zeros((frame_height, frame_width), dtype=bool))
|
| 310 |
+
|
| 311 |
+
frame_idx += 1
|
| 312 |
+
|
| 313 |
+
cap.release()
|
| 314 |
+
print(f" Segmented {total_frames} frames")
|
| 315 |
+
|
| 316 |
+
return masks
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def dilate_mask(mask: np.ndarray, kernel_size: int = 15) -> np.ndarray:
|
| 320 |
+
"""Dilate mask for proximity checking"""
|
| 321 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 322 |
+
return cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def filter_masks_by_proximity(masks: List[np.ndarray], primary_mask: np.ndarray,
|
| 326 |
+
dilation: int = 50) -> List[np.ndarray]:
|
| 327 |
+
"""Filter masks to only include regions near primary mask"""
|
| 328 |
+
proximity_region = dilate_mask(primary_mask, dilation)
|
| 329 |
+
|
| 330 |
+
filtered = []
|
| 331 |
+
for mask in masks:
|
| 332 |
+
filtered_mask = mask & proximity_region
|
| 333 |
+
filtered.append(filtered_mask)
|
| 334 |
+
|
| 335 |
+
return filtered
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def process_video_grey_masks(video_info: Dict, segmenter: SegmentationModel,
|
| 339 |
+
trajectory_data: List[Dict] = None):
|
| 340 |
+
"""Generate grey masks for a single video"""
|
| 341 |
+
video_path = video_info.get("video_path", "")
|
| 342 |
+
output_dir = Path(video_info.get("output_dir", ""))
|
| 343 |
+
|
| 344 |
+
if not output_dir.exists():
|
| 345 |
+
print(f" β οΈ Output directory not found")
|
| 346 |
+
return
|
| 347 |
+
|
| 348 |
+
# Load required files
|
| 349 |
+
vlm_analysis_path = output_dir / "vlm_analysis.json"
|
| 350 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 351 |
+
input_video_path = output_dir / "input_video.mp4"
|
| 352 |
+
|
| 353 |
+
if not vlm_analysis_path.exists():
|
| 354 |
+
print(f" β οΈ vlm_analysis.json not found")
|
| 355 |
+
return
|
| 356 |
+
|
| 357 |
+
if not black_mask_path.exists():
|
| 358 |
+
print(f" β οΈ black_mask.mp4 not found")
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
if not input_video_path.exists():
|
| 362 |
+
input_video_path = Path(video_path)
|
| 363 |
+
|
| 364 |
+
# Load VLM analysis
|
| 365 |
+
with open(vlm_analysis_path, 'r') as f:
|
| 366 |
+
analysis = json.load(f)
|
| 367 |
+
|
| 368 |
+
# Get video properties
|
| 369 |
+
cap = cv2.VideoCapture(str(input_video_path))
|
| 370 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 371 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 372 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 373 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 374 |
+
cap.release()
|
| 375 |
+
|
| 376 |
+
# Calculate grid
|
| 377 |
+
min_grid = video_info.get('min_grid', 8)
|
| 378 |
+
grid_rows, grid_cols = calculate_square_grid(frame_width, frame_height, min_grid)
|
| 379 |
+
|
| 380 |
+
print(f" Video: {frame_width}x{frame_height}, {total_frames} frames, grid: {grid_rows}x{grid_cols}")
|
| 381 |
+
|
| 382 |
+
# Load black mask (first frame for proximity filtering)
|
| 383 |
+
black_cap = cv2.VideoCapture(str(black_mask_path))
|
| 384 |
+
ret, black_mask_frame = black_cap.read()
|
| 385 |
+
black_cap.release()
|
| 386 |
+
|
| 387 |
+
if len(black_mask_frame.shape) == 3:
|
| 388 |
+
black_mask_frame = cv2.cvtColor(black_mask_frame, cv2.COLOR_BGR2GRAY)
|
| 389 |
+
|
| 390 |
+
primary_mask = (black_mask_frame == 0) # 0 = primary object
|
| 391 |
+
|
| 392 |
+
# Initialize accumulated masks (one per frame)
|
| 393 |
+
accumulated_masks = [np.zeros((frame_height, frame_width), dtype=bool) for _ in range(total_frames)]
|
| 394 |
+
|
| 395 |
+
# Process affected objects
|
| 396 |
+
affected_objects = analysis.get('affected_objects', [])
|
| 397 |
+
print(f" Processing {len(affected_objects)} affected object(s)...")
|
| 398 |
+
|
| 399 |
+
for obj in affected_objects:
|
| 400 |
+
noun = obj.get('noun', '')
|
| 401 |
+
|
| 402 |
+
if not noun:
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
print(f" β’ {noun}")
|
| 406 |
+
|
| 407 |
+
# Check if we have USER TRAJECTORY for this object
|
| 408 |
+
has_trajectory = False
|
| 409 |
+
if trajectory_data:
|
| 410 |
+
for traj in trajectory_data:
|
| 411 |
+
if traj.get('object_noun', '') == noun and not traj.get('skipped', False):
|
| 412 |
+
has_trajectory = True
|
| 413 |
+
traj_points = traj.get('trajectory_points', [])
|
| 414 |
+
|
| 415 |
+
print(f" Using user-drawn trajectory ({len(traj_points)} points)")
|
| 416 |
+
|
| 417 |
+
# Segment object in first_appears_frame to get SIZE
|
| 418 |
+
first_frame_idx = obj.get('first_appears_frame', 0)
|
| 419 |
+
cap = cv2.VideoCapture(str(input_video_path))
|
| 420 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, first_frame_idx)
|
| 421 |
+
ret, frame = cap.read()
|
| 422 |
+
cap.release()
|
| 423 |
+
|
| 424 |
+
if ret:
|
| 425 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 426 |
+
frame_pil = Image.fromarray(frame_rgb)
|
| 427 |
+
obj_mask = segmenter.segment(frame_pil, noun)
|
| 428 |
+
|
| 429 |
+
if obj_mask.any():
|
| 430 |
+
obj_width, obj_height = get_object_size(obj_mask)
|
| 431 |
+
print(f" Segmented object (size: {obj_width}x{obj_height} px)")
|
| 432 |
+
|
| 433 |
+
# Apply object SIZE along trajectory
|
| 434 |
+
traj_masks = apply_object_along_trajectory(
|
| 435 |
+
obj_mask, traj_points, total_frames, (frame_height, frame_width)
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Accumulate
|
| 439 |
+
for i in range(total_frames):
|
| 440 |
+
accumulated_masks[i] |= traj_masks[i]
|
| 441 |
+
|
| 442 |
+
print(f" β Applied object along trajectory across {total_frames} frames")
|
| 443 |
+
else:
|
| 444 |
+
print(f" β οΈ Segmentation failed, using trajectory grid cells only")
|
| 445 |
+
# Fallback: just use trajectory grid cells
|
| 446 |
+
grid_cells = traj.get('trajectory_grid_cells', [])
|
| 447 |
+
for row, col in grid_cells:
|
| 448 |
+
y1 = int(row * frame_height / grid_rows)
|
| 449 |
+
y2 = int((row + 1) * frame_height / grid_rows)
|
| 450 |
+
x1 = int(col * frame_width / grid_cols)
|
| 451 |
+
x2 = int((col + 1) * frame_width / grid_cols)
|
| 452 |
+
for i in range(total_frames):
|
| 453 |
+
accumulated_masks[i][y1:y2, x1:x2] = True
|
| 454 |
+
|
| 455 |
+
break
|
| 456 |
+
|
| 457 |
+
# If NO user trajectory, segment through ALL frames
|
| 458 |
+
# This captures: static objects, objects that move during video, dynamic effects
|
| 459 |
+
if not has_trajectory:
|
| 460 |
+
print(f" Segmenting through ALL frames (captures any movement/changes)...")
|
| 461 |
+
obj_masks = segment_object_all_frames(str(input_video_path), noun, segmenter, frame_stride=5)
|
| 462 |
+
|
| 463 |
+
# Filter by proximity to primary mask
|
| 464 |
+
obj_masks_filtered = filter_masks_by_proximity(obj_masks, primary_mask, dilation=50)
|
| 465 |
+
|
| 466 |
+
# Accumulate
|
| 467 |
+
for i in range(len(obj_masks_filtered)):
|
| 468 |
+
if i < len(accumulated_masks):
|
| 469 |
+
accumulated_masks[i] |= obj_masks_filtered[i]
|
| 470 |
+
|
| 471 |
+
pixel_count = sum(mask.sum() for mask in obj_masks_filtered)
|
| 472 |
+
print(f" β Segmented across {len(obj_masks_filtered)} frames ({pixel_count} total pixels)")
|
| 473 |
+
|
| 474 |
+
# GRIDIFY all accumulated masks
|
| 475 |
+
print(f" Gridifying masks...")
|
| 476 |
+
gridified_masks = gridify_masks(accumulated_masks, grid_rows, grid_cols)
|
| 477 |
+
|
| 478 |
+
# Convert to uint8 (127 = grey, 255 = background)
|
| 479 |
+
grey_masks_uint8 = [np.where(mask, 127, 255).astype(np.uint8) for mask in gridified_masks]
|
| 480 |
+
|
| 481 |
+
# Write video
|
| 482 |
+
print(f" Writing grey_mask.mp4...")
|
| 483 |
+
temp_avi = output_dir / "grey_mask_temp.avi"
|
| 484 |
+
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
|
| 485 |
+
out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (frame_width, frame_height), isColor=False)
|
| 486 |
+
|
| 487 |
+
for mask in grey_masks_uint8:
|
| 488 |
+
out.write(mask)
|
| 489 |
+
|
| 490 |
+
out.release()
|
| 491 |
+
|
| 492 |
+
# Convert to MP4
|
| 493 |
+
grey_mask_mp4 = output_dir / "grey_mask.mp4"
|
| 494 |
+
cmd = [
|
| 495 |
+
'ffmpeg', '-y', '-i', str(temp_avi),
|
| 496 |
+
'-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
|
| 497 |
+
'-pix_fmt', 'yuv444p',
|
| 498 |
+
str(grey_mask_mp4)
|
| 499 |
+
]
|
| 500 |
+
subprocess.run(cmd, capture_output=True)
|
| 501 |
+
temp_avi.unlink()
|
| 502 |
+
|
| 503 |
+
print(f" β Saved grey_mask.mp4")
|
| 504 |
+
|
| 505 |
+
# Save debug visualization (first frame)
|
| 506 |
+
debug_vis = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
|
| 507 |
+
debug_vis[gridified_masks[0]] = [0, 255, 0] # Green
|
| 508 |
+
debug_vis[primary_mask] = [255, 0, 0] # Red
|
| 509 |
+
debug_path = output_dir / "debug_grey_mask.jpg"
|
| 510 |
+
cv2.imwrite(str(debug_path), debug_vis)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def main():
|
| 514 |
+
parser = argparse.ArgumentParser(description="Stage 3a: Generate Grey Masks (Corrected)")
|
| 515 |
+
parser.add_argument("--config", required=True, help="Config JSON")
|
| 516 |
+
parser.add_argument("--segmentation-model", default="sam3", choices=["langsam", "sam3"],
|
| 517 |
+
help="Segmentation model")
|
| 518 |
+
args = parser.parse_args()
|
| 519 |
+
|
| 520 |
+
config_path = Path(args.config)
|
| 521 |
+
|
| 522 |
+
# Load config
|
| 523 |
+
with open(config_path, 'r') as f:
|
| 524 |
+
config_data = json.load(f)
|
| 525 |
+
|
| 526 |
+
if isinstance(config_data, list):
|
| 527 |
+
videos = config_data
|
| 528 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 529 |
+
videos = config_data["videos"]
|
| 530 |
+
else:
|
| 531 |
+
raise ValueError("Invalid config format")
|
| 532 |
+
|
| 533 |
+
# Load trajectory data
|
| 534 |
+
trajectory_path = config_path.parent / f"{config_path.stem}_trajectories.json"
|
| 535 |
+
trajectory_data = None
|
| 536 |
+
|
| 537 |
+
if trajectory_path.exists():
|
| 538 |
+
print(f"Loading trajectory data: {trajectory_path.name}")
|
| 539 |
+
with open(trajectory_path, 'r') as f:
|
| 540 |
+
trajectory_data = json.load(f)
|
| 541 |
+
print(f" Loaded {len(trajectory_data)} trajectory(s)")
|
| 542 |
+
|
| 543 |
+
print(f"\n{'='*70}")
|
| 544 |
+
print(f"Stage 3a: Generate Grey Masks (CORRECTED)")
|
| 545 |
+
print(f"{'='*70}")
|
| 546 |
+
print(f"Videos: {len(videos)}")
|
| 547 |
+
print(f"Segmentation: {args.segmentation_model.upper()}")
|
| 548 |
+
print(f"{'='*70}\n")
|
| 549 |
+
|
| 550 |
+
# Load segmentation model
|
| 551 |
+
segmenter = SegmentationModel(args.segmentation_model)
|
| 552 |
+
|
| 553 |
+
# Process each video
|
| 554 |
+
for i, video_info in enumerate(videos):
|
| 555 |
+
video_path = video_info.get('video_path', '')
|
| 556 |
+
print(f"\n{'β'*70}")
|
| 557 |
+
print(f"Video {i+1}/{len(videos)}: {Path(video_path).parent.name}")
|
| 558 |
+
print(f"{'β'*70}")
|
| 559 |
+
|
| 560 |
+
try:
|
| 561 |
+
process_video_grey_masks(video_info, segmenter, trajectory_data)
|
| 562 |
+
print(f"\nβ
Video {i+1} complete!")
|
| 563 |
+
|
| 564 |
+
except Exception as e:
|
| 565 |
+
print(f"\nβ Error: {e}")
|
| 566 |
+
import traceback
|
| 567 |
+
traceback.print_exc()
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
print(f"\n{'='*70}")
|
| 571 |
+
print(f"β
Stage 3a Complete!")
|
| 572 |
+
print(f"{'='*70}\n")
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
if __name__ == "__main__":
|
| 576 |
+
main()
|
VLM-MASK-REASONER/stage3b_trajectory_gui.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 3b: Trajectory Drawing GUI (Simplified - No Segmentation)
|
| 4 |
+
|
| 5 |
+
For objects with needs_trajectory=true, user draws movement paths.
|
| 6 |
+
|
| 7 |
+
Input: Config with vlm_analysis.json in output_dir
|
| 8 |
+
Output: trajectory_data.json with user-drawn paths as grid cells
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python stage3b_trajectory_gui.py --config more_dyn_2_config_points_absolute.json
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import json
|
| 17 |
+
import argparse
|
| 18 |
+
import cv2
|
| 19 |
+
import numpy as np
|
| 20 |
+
import tkinter as tk
|
| 21 |
+
from tkinter import ttk, messagebox
|
| 22 |
+
from PIL import Image, ImageTk, ImageDraw, ImageFont
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Dict, List, Tuple
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def calculate_square_grid(width: int, height: int, min_grid: int = 8) -> Tuple[int, int]:
|
| 28 |
+
"""Calculate grid dimensions for square cells"""
|
| 29 |
+
aspect_ratio = width / height
|
| 30 |
+
if width >= height:
|
| 31 |
+
grid_rows = min_grid
|
| 32 |
+
grid_cols = max(min_grid, round(min_grid * aspect_ratio))
|
| 33 |
+
else:
|
| 34 |
+
grid_cols = min_grid
|
| 35 |
+
grid_rows = max(min_grid, round(min_grid / aspect_ratio))
|
| 36 |
+
return grid_rows, grid_cols
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def points_to_grid_cells(points: List[Tuple[int, int]], grid_rows: int, grid_cols: int,
|
| 40 |
+
frame_width: int, frame_height: int) -> List[List[int]]:
|
| 41 |
+
"""Convert trajectory points to grid cells"""
|
| 42 |
+
cell_width = frame_width / grid_cols
|
| 43 |
+
cell_height = frame_height / grid_rows
|
| 44 |
+
|
| 45 |
+
grid_cells = set()
|
| 46 |
+
for x, y in points:
|
| 47 |
+
col = int(x / cell_width)
|
| 48 |
+
row = int(y / cell_height)
|
| 49 |
+
if 0 <= row < grid_rows and 0 <= col < grid_cols:
|
| 50 |
+
grid_cells.add((row, col))
|
| 51 |
+
|
| 52 |
+
# Sort by row, then col
|
| 53 |
+
return sorted([[r, c] for r, c in grid_cells])
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TrajectoryGUI:
|
| 57 |
+
def __init__(self, root, objects_data: List[Dict]):
|
| 58 |
+
self.root = root
|
| 59 |
+
self.root.title("Stage 3b: Trajectory Drawing")
|
| 60 |
+
|
| 61 |
+
self.objects_data = objects_data # List of {video_info, objects_needing_trajectory}
|
| 62 |
+
self.current_video_idx = 0
|
| 63 |
+
self.current_object_idx = 0
|
| 64 |
+
|
| 65 |
+
# Current state
|
| 66 |
+
self.frame = None
|
| 67 |
+
self.trajectory_points = []
|
| 68 |
+
self.drawing = False
|
| 69 |
+
|
| 70 |
+
# Display
|
| 71 |
+
self.display_scale = 1.0
|
| 72 |
+
self.photo = None
|
| 73 |
+
|
| 74 |
+
# Results storage
|
| 75 |
+
self.all_trajectories = [] # List of trajectories for all videos
|
| 76 |
+
|
| 77 |
+
self.setup_ui()
|
| 78 |
+
self.load_current_object()
|
| 79 |
+
|
| 80 |
+
def setup_ui(self):
|
| 81 |
+
"""Setup GUI layout"""
|
| 82 |
+
# Top info
|
| 83 |
+
info_frame = ttk.Frame(self.root)
|
| 84 |
+
info_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
|
| 85 |
+
|
| 86 |
+
self.video_label = ttk.Label(info_frame, text="Video: ", font=("Arial", 10, "bold"))
|
| 87 |
+
self.video_label.pack(side=tk.LEFT, padx=5)
|
| 88 |
+
|
| 89 |
+
self.object_label = ttk.Label(info_frame, text="Object: ", foreground="blue")
|
| 90 |
+
self.object_label.pack(side=tk.LEFT, padx=10)
|
| 91 |
+
|
| 92 |
+
# Instructions
|
| 93 |
+
inst_frame = ttk.LabelFrame(self.root, text="Instructions")
|
| 94 |
+
inst_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
|
| 95 |
+
|
| 96 |
+
ttk.Label(inst_frame, text="1. See the frame where object is visible", foreground="blue").pack(anchor=tk.W, padx=5)
|
| 97 |
+
ttk.Label(inst_frame, text="2. Click and drag to draw trajectory path (RED line)", foreground="red").pack(anchor=tk.W, padx=5)
|
| 98 |
+
ttk.Label(inst_frame, text="3. Draw from object's current position to where it should end up", foreground="orange").pack(anchor=tk.W, padx=5)
|
| 99 |
+
ttk.Label(inst_frame, text="4. Click 'Clear' to restart, 'Save & Next' when done", foreground="green").pack(anchor=tk.W, padx=5)
|
| 100 |
+
|
| 101 |
+
# Canvas
|
| 102 |
+
canvas_frame = ttk.LabelFrame(self.root, text="Draw Trajectory Path")
|
| 103 |
+
canvas_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
|
| 104 |
+
|
| 105 |
+
self.canvas = tk.Canvas(canvas_frame, width=800, height=600, bg='black', cursor="crosshair")
|
| 106 |
+
self.canvas.pack(fill=tk.BOTH, expand=True)
|
| 107 |
+
self.canvas.bind("<Button-1>", self.on_canvas_click)
|
| 108 |
+
self.canvas.bind("<B1-Motion>", self.on_canvas_drag)
|
| 109 |
+
self.canvas.bind("<ButtonRelease-1>", self.on_canvas_release)
|
| 110 |
+
|
| 111 |
+
# Controls
|
| 112 |
+
controls = ttk.Frame(self.root)
|
| 113 |
+
controls.pack(side=tk.BOTTOM, fill=tk.X, padx=5, pady=5)
|
| 114 |
+
|
| 115 |
+
self.status_label = ttk.Label(controls, text="Draw trajectory path for object", foreground="blue")
|
| 116 |
+
self.status_label.pack(side=tk.TOP, pady=5)
|
| 117 |
+
|
| 118 |
+
button_frame = ttk.Frame(controls)
|
| 119 |
+
button_frame.pack(side=tk.TOP)
|
| 120 |
+
|
| 121 |
+
ttk.Button(button_frame, text="Clear Trajectory", command=self.clear_trajectory).pack(side=tk.LEFT, padx=5)
|
| 122 |
+
ttk.Button(button_frame, text="Skip Object", command=self.skip_object).pack(side=tk.LEFT, padx=5)
|
| 123 |
+
ttk.Button(button_frame, text="Save & Next", command=self.save_and_next).pack(side=tk.LEFT, padx=5)
|
| 124 |
+
|
| 125 |
+
self.progress_label = ttk.Label(controls, text="", font=("Arial", 9))
|
| 126 |
+
self.progress_label.pack(side=tk.TOP, pady=5)
|
| 127 |
+
|
| 128 |
+
def load_current_object(self):
|
| 129 |
+
"""Load current object for trajectory drawing"""
|
| 130 |
+
if self.current_video_idx >= len(self.objects_data):
|
| 131 |
+
# All done
|
| 132 |
+
self.finish()
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
data = self.objects_data[self.current_video_idx]
|
| 136 |
+
video_info = data['video_info']
|
| 137 |
+
objects_needing_traj = data['objects']
|
| 138 |
+
|
| 139 |
+
if self.current_object_idx >= len(objects_needing_traj):
|
| 140 |
+
# Done with this video, move to next
|
| 141 |
+
self.current_video_idx += 1
|
| 142 |
+
self.current_object_idx = 0
|
| 143 |
+
self.load_current_object()
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
obj = objects_needing_traj[self.current_object_idx]
|
| 147 |
+
video_path = video_info.get('video_path', '')
|
| 148 |
+
output_dir = Path(video_info.get('output_dir', ''))
|
| 149 |
+
|
| 150 |
+
# Update labels
|
| 151 |
+
self.video_label.config(text=f"Video: {Path(video_path).parent.name}/{Path(video_path).name}")
|
| 152 |
+
self.object_label.config(text=f"Object: {obj['noun']} (will fall/move)")
|
| 153 |
+
|
| 154 |
+
total_objects = sum(len(d['objects']) for d in self.objects_data)
|
| 155 |
+
current_obj_num = sum(len(self.objects_data[i]['objects']) for i in range(self.current_video_idx)) + self.current_object_idx + 1
|
| 156 |
+
self.progress_label.config(text=f"Object {current_obj_num}/{total_objects} across {len(self.objects_data)} video(s)")
|
| 157 |
+
|
| 158 |
+
# Extract frame
|
| 159 |
+
frame_idx = obj.get('first_appears_frame', 0)
|
| 160 |
+
input_video = output_dir / "input_video.mp4"
|
| 161 |
+
|
| 162 |
+
if not input_video.exists():
|
| 163 |
+
input_video = Path(video_path)
|
| 164 |
+
|
| 165 |
+
cap = cv2.VideoCapture(str(input_video))
|
| 166 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 167 |
+
ret, frame = cap.read()
|
| 168 |
+
cap.release()
|
| 169 |
+
|
| 170 |
+
if not ret:
|
| 171 |
+
messagebox.showerror("Error", f"Failed to read frame {frame_idx} from video")
|
| 172 |
+
self.skip_object()
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
self.frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 176 |
+
|
| 177 |
+
print(f"\n Loaded frame {frame_idx} for '{obj['noun']}'")
|
| 178 |
+
|
| 179 |
+
# Clear trajectory
|
| 180 |
+
self.trajectory_points = []
|
| 181 |
+
|
| 182 |
+
# Calculate grid for this video
|
| 183 |
+
h, w = self.frame.shape[:2]
|
| 184 |
+
min_grid = video_info.get('min_grid', 8)
|
| 185 |
+
self.grid_rows, self.grid_cols = calculate_square_grid(w, h, min_grid)
|
| 186 |
+
|
| 187 |
+
# Display
|
| 188 |
+
self.status_label.config(text="Draw trajectory path (click and drag)", foreground="blue")
|
| 189 |
+
self.display_frame()
|
| 190 |
+
|
| 191 |
+
def display_frame(self):
|
| 192 |
+
"""Display frame with trajectory"""
|
| 193 |
+
if self.frame is None:
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
# Create visualization
|
| 197 |
+
vis = self.frame.copy()
|
| 198 |
+
h, w = vis.shape[:2]
|
| 199 |
+
|
| 200 |
+
# Draw trajectory
|
| 201 |
+
if len(self.trajectory_points) > 1:
|
| 202 |
+
for i in range(len(self.trajectory_points) - 1):
|
| 203 |
+
pt1 = self.trajectory_points[i]
|
| 204 |
+
pt2 = self.trajectory_points[i + 1]
|
| 205 |
+
cv2.line(vis, pt1, pt2, (255, 0, 0), 5) # Thicker line for visibility
|
| 206 |
+
|
| 207 |
+
# Draw start point (green) and end point (red)
|
| 208 |
+
if len(self.trajectory_points) > 0:
|
| 209 |
+
start_pt = self.trajectory_points[0]
|
| 210 |
+
end_pt = self.trajectory_points[-1]
|
| 211 |
+
cv2.circle(vis, start_pt, 8, (0, 255, 0), -1) # Green start
|
| 212 |
+
cv2.circle(vis, end_pt, 8, (255, 0, 0), -1) # Red end
|
| 213 |
+
|
| 214 |
+
# Scale for display
|
| 215 |
+
max_width, max_height = 800, 600
|
| 216 |
+
scale_w = max_width / w
|
| 217 |
+
scale_h = max_height / h
|
| 218 |
+
self.display_scale = min(scale_w, scale_h, 1.0)
|
| 219 |
+
|
| 220 |
+
new_w = int(w * self.display_scale)
|
| 221 |
+
new_h = int(h * self.display_scale)
|
| 222 |
+
vis_resized = cv2.resize(vis, (new_w, new_h))
|
| 223 |
+
|
| 224 |
+
# Convert to PIL and display
|
| 225 |
+
pil_img = Image.fromarray(vis_resized)
|
| 226 |
+
self.photo = ImageTk.PhotoImage(pil_img)
|
| 227 |
+
self.canvas.delete("all")
|
| 228 |
+
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.photo)
|
| 229 |
+
|
| 230 |
+
def on_canvas_click(self, event):
|
| 231 |
+
"""Start drawing trajectory"""
|
| 232 |
+
# Convert to frame coordinates
|
| 233 |
+
x = int(event.x / self.display_scale)
|
| 234 |
+
y = int(event.y / self.display_scale)
|
| 235 |
+
|
| 236 |
+
self.trajectory_points = [(x, y)]
|
| 237 |
+
self.drawing = True
|
| 238 |
+
|
| 239 |
+
def on_canvas_drag(self, event):
|
| 240 |
+
"""Continue drawing trajectory"""
|
| 241 |
+
if not self.drawing:
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
x = int(event.x / self.display_scale)
|
| 245 |
+
y = int(event.y / self.display_scale)
|
| 246 |
+
|
| 247 |
+
# Add point if far enough from last point
|
| 248 |
+
if len(self.trajectory_points) > 0:
|
| 249 |
+
last_x, last_y = self.trajectory_points[-1]
|
| 250 |
+
dist = np.sqrt((x - last_x)**2 + (y - last_y)**2)
|
| 251 |
+
if dist > 5: # Minimum distance between points
|
| 252 |
+
self.trajectory_points.append((x, y))
|
| 253 |
+
self.display_frame()
|
| 254 |
+
|
| 255 |
+
def on_canvas_release(self, event):
|
| 256 |
+
"""Finish drawing trajectory"""
|
| 257 |
+
self.drawing = False
|
| 258 |
+
if len(self.trajectory_points) > 0:
|
| 259 |
+
x = int(event.x / self.display_scale)
|
| 260 |
+
y = int(event.y / self.display_scale)
|
| 261 |
+
self.trajectory_points.append((x, y))
|
| 262 |
+
self.display_frame()
|
| 263 |
+
|
| 264 |
+
def clear_trajectory(self):
|
| 265 |
+
"""Clear drawn trajectory"""
|
| 266 |
+
self.trajectory_points = []
|
| 267 |
+
self.display_frame()
|
| 268 |
+
self.status_label.config(text="Trajectory cleared. Draw again.", foreground="blue")
|
| 269 |
+
|
| 270 |
+
def skip_object(self):
|
| 271 |
+
"""Skip current object without saving trajectory"""
|
| 272 |
+
result = messagebox.askyesno("Skip Object", "Skip this object without drawing trajectory?")
|
| 273 |
+
if not result:
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
# Save empty trajectory
|
| 277 |
+
data = self.objects_data[self.current_video_idx]
|
| 278 |
+
obj = data['objects'][self.current_object_idx]
|
| 279 |
+
|
| 280 |
+
self.all_trajectories.append({
|
| 281 |
+
'video_path': data['video_info']['video_path'],
|
| 282 |
+
'object_noun': obj['noun'],
|
| 283 |
+
'trajectory_points': [],
|
| 284 |
+
'trajectory_grid_cells': [],
|
| 285 |
+
'skipped': True
|
| 286 |
+
})
|
| 287 |
+
|
| 288 |
+
self.current_object_idx += 1
|
| 289 |
+
self.load_current_object()
|
| 290 |
+
|
| 291 |
+
def save_and_next(self):
|
| 292 |
+
"""Save trajectory and move to next object"""
|
| 293 |
+
if len(self.trajectory_points) < 2:
|
| 294 |
+
messagebox.showwarning("Warning", "Draw a trajectory path first (at least 2 points)")
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
# Convert to grid cells
|
| 298 |
+
data = self.objects_data[self.current_video_idx]
|
| 299 |
+
obj = data['objects'][self.current_object_idx]
|
| 300 |
+
|
| 301 |
+
grid_cells = points_to_grid_cells(
|
| 302 |
+
self.trajectory_points,
|
| 303 |
+
self.grid_rows,
|
| 304 |
+
self.grid_cols,
|
| 305 |
+
self.frame.shape[1],
|
| 306 |
+
self.frame.shape[0]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Save
|
| 310 |
+
self.all_trajectories.append({
|
| 311 |
+
'video_path': data['video_info']['video_path'],
|
| 312 |
+
'object_noun': obj['noun'],
|
| 313 |
+
'first_appears_frame': obj.get('first_appears_frame', 0),
|
| 314 |
+
'trajectory_points': self.trajectory_points,
|
| 315 |
+
'trajectory_grid_cells': grid_cells,
|
| 316 |
+
'grid_rows': self.grid_rows,
|
| 317 |
+
'grid_cols': self.grid_cols,
|
| 318 |
+
'skipped': False
|
| 319 |
+
})
|
| 320 |
+
|
| 321 |
+
print(f" β Saved trajectory for '{obj['noun']}': {len(grid_cells)} grid cells")
|
| 322 |
+
|
| 323 |
+
self.current_object_idx += 1
|
| 324 |
+
self.load_current_object()
|
| 325 |
+
|
| 326 |
+
def finish(self):
|
| 327 |
+
"""All objects done"""
|
| 328 |
+
self.status_label.config(text="All trajectories complete!", foreground="green")
|
| 329 |
+
messagebox.showinfo("Complete", "All trajectory drawings complete!\n\nSaving results...")
|
| 330 |
+
self.root.quit()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def find_objects_needing_trajectory(config_path: str) -> List[Dict]:
|
| 334 |
+
"""Find all objects that need trajectory input"""
|
| 335 |
+
config_path = Path(config_path)
|
| 336 |
+
|
| 337 |
+
with open(config_path, 'r') as f:
|
| 338 |
+
config_data = json.load(f)
|
| 339 |
+
|
| 340 |
+
if isinstance(config_data, list):
|
| 341 |
+
videos = config_data
|
| 342 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 343 |
+
videos = config_data["videos"]
|
| 344 |
+
else:
|
| 345 |
+
raise ValueError("Invalid config format")
|
| 346 |
+
|
| 347 |
+
objects_data = []
|
| 348 |
+
|
| 349 |
+
for video_info in videos:
|
| 350 |
+
output_dir = Path(video_info.get('output_dir', ''))
|
| 351 |
+
vlm_analysis_path = output_dir / "vlm_analysis.json"
|
| 352 |
+
|
| 353 |
+
if not vlm_analysis_path.exists():
|
| 354 |
+
print(f" Skipping {output_dir.parent.name}: no vlm_analysis.json")
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
with open(vlm_analysis_path, 'r') as f:
|
| 358 |
+
analysis = json.load(f)
|
| 359 |
+
|
| 360 |
+
# Find objects with needs_trajectory=true
|
| 361 |
+
objects_needing_traj = [
|
| 362 |
+
obj for obj in analysis.get('affected_objects', [])
|
| 363 |
+
if obj.get('needs_trajectory', False)
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
+
if objects_needing_traj:
|
| 367 |
+
objects_data.append({
|
| 368 |
+
'video_info': video_info,
|
| 369 |
+
'objects': objects_needing_traj,
|
| 370 |
+
'output_dir': output_dir
|
| 371 |
+
})
|
| 372 |
+
|
| 373 |
+
return objects_data
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def main():
|
| 377 |
+
parser = argparse.ArgumentParser(description="Stage 3b: Trajectory Drawing GUI")
|
| 378 |
+
parser.add_argument("--config", required=True, help="Config JSON")
|
| 379 |
+
args = parser.parse_args()
|
| 380 |
+
|
| 381 |
+
print(f"\n{'='*70}")
|
| 382 |
+
print(f"Stage 3b: Trajectory Drawing GUI")
|
| 383 |
+
print(f"{'='*70}\n")
|
| 384 |
+
|
| 385 |
+
# Find objects needing trajectories
|
| 386 |
+
print("Finding objects that need trajectory input...")
|
| 387 |
+
objects_data = find_objects_needing_trajectory(args.config)
|
| 388 |
+
|
| 389 |
+
if not objects_data:
|
| 390 |
+
print("\nβ
No objects need trajectory input!")
|
| 391 |
+
print("All objects are either stationary or visual artifacts.")
|
| 392 |
+
print("Proceeding to Stage 3a for mask generation...")
|
| 393 |
+
return
|
| 394 |
+
|
| 395 |
+
total_objects = sum(len(d['objects']) for d in objects_data)
|
| 396 |
+
print(f"\nFound {total_objects} object(s) needing trajectories across {len(objects_data)} video(s):")
|
| 397 |
+
for d in objects_data:
|
| 398 |
+
video_name = Path(d['video_info']['video_path']).parent.name
|
| 399 |
+
print(f" β’ {video_name}: {', '.join(obj['noun'] for obj in d['objects'])}")
|
| 400 |
+
|
| 401 |
+
# Launch GUI
|
| 402 |
+
print("\nLaunching trajectory drawing GUI...")
|
| 403 |
+
print("Instructions:")
|
| 404 |
+
print(" 1. See the frame where the object is visible")
|
| 405 |
+
print(" 2. Click and drag to draw trajectory path (RED line)")
|
| 406 |
+
print(" 3. Draw from object's current position to where it should end up")
|
| 407 |
+
print(" 4. Click 'Save & Next' when done with each object")
|
| 408 |
+
print("")
|
| 409 |
+
|
| 410 |
+
root = tk.Tk()
|
| 411 |
+
root.geometry("900x800")
|
| 412 |
+
gui = TrajectoryGUI(root, objects_data)
|
| 413 |
+
root.mainloop()
|
| 414 |
+
|
| 415 |
+
# Save trajectories
|
| 416 |
+
config_path = Path(args.config)
|
| 417 |
+
output_path = config_path.parent / f"{config_path.stem}_trajectories.json"
|
| 418 |
+
|
| 419 |
+
with open(output_path, 'w') as f:
|
| 420 |
+
json.dump(gui.all_trajectories, f, indent=2)
|
| 421 |
+
|
| 422 |
+
print(f"\n{'='*70}")
|
| 423 |
+
print(f"β
Stage 3b Complete!")
|
| 424 |
+
print(f"{'='*70}")
|
| 425 |
+
print(f"Saved trajectories to: {output_path}")
|
| 426 |
+
print(f"Total trajectories: {len(gui.all_trajectories)}")
|
| 427 |
+
print(f"\nNext: Run Stage 3a to generate grey masks (includes trajectories)")
|
| 428 |
+
print(f"{'='*70}\n")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
main()
|
VLM-MASK-REASONER/stage4_combine_masks.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 4: Combine Black and Grey Masks into Tri/Quad Mask
|
| 4 |
+
|
| 5 |
+
Combines the black mask (primary object) and grey masks (affected objects)
|
| 6 |
+
into a single tri-mask or quad-mask video.
|
| 7 |
+
|
| 8 |
+
Mask values:
|
| 9 |
+
- 0: Primary object (from black mask)
|
| 10 |
+
- 63: Overlap of primary and affected objects
|
| 11 |
+
- 127: Affected objects only (from grey masks)
|
| 12 |
+
- 255: Background (keep)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import cv2
|
| 19 |
+
import numpy as np
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def combine_masks(black_frame, grey_frame):
|
| 24 |
+
"""
|
| 25 |
+
Combine black and grey mask frames.
|
| 26 |
+
|
| 27 |
+
Rules:
|
| 28 |
+
- black=0, grey=255 β 0 (primary object only)
|
| 29 |
+
- black=255, grey=127 β 127 (affected object only)
|
| 30 |
+
- black=0, grey=127 β 63 (overlap)
|
| 31 |
+
- black=255, grey=255 β 255 (background)
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
black_frame: Frame from black_mask.mp4 (0=object, 255=background)
|
| 35 |
+
grey_frame: Frame from grey_mask.mp4 (127=object, 255=background)
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Combined mask frame
|
| 39 |
+
"""
|
| 40 |
+
# Initialize with background (255)
|
| 41 |
+
combined = np.full_like(black_frame, 255, dtype=np.uint8)
|
| 42 |
+
|
| 43 |
+
# Primary object only (black=0, grey=255)
|
| 44 |
+
primary_only = (black_frame == 0) & (grey_frame == 255)
|
| 45 |
+
combined[primary_only] = 0
|
| 46 |
+
|
| 47 |
+
# Affected object only (black=255, grey=127)
|
| 48 |
+
affected_only = (black_frame == 255) & (grey_frame == 127)
|
| 49 |
+
combined[affected_only] = 127
|
| 50 |
+
|
| 51 |
+
# Overlap (black=0, grey=127)
|
| 52 |
+
overlap = (black_frame == 0) & (grey_frame == 127)
|
| 53 |
+
combined[overlap] = 63
|
| 54 |
+
|
| 55 |
+
return combined
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def process_video(black_mask_path: Path, grey_mask_path: Path, output_path: Path):
|
| 59 |
+
"""Combine black and grey mask videos into trimask/quadmask"""
|
| 60 |
+
import subprocess
|
| 61 |
+
|
| 62 |
+
print(f" Loading black mask: {black_mask_path.name}")
|
| 63 |
+
black_cap = cv2.VideoCapture(str(black_mask_path))
|
| 64 |
+
|
| 65 |
+
print(f" Loading grey mask: {grey_mask_path.name}")
|
| 66 |
+
grey_cap = cv2.VideoCapture(str(grey_mask_path))
|
| 67 |
+
|
| 68 |
+
# Get video properties
|
| 69 |
+
fps = black_cap.get(cv2.CAP_PROP_FPS)
|
| 70 |
+
width = int(black_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 71 |
+
height = int(black_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 72 |
+
total_frames = int(black_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 73 |
+
|
| 74 |
+
# Check grey mask has same properties
|
| 75 |
+
grey_total_frames = int(grey_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 76 |
+
if total_frames != grey_total_frames:
|
| 77 |
+
print(f" β οΈ Warning: Frame count mismatch (black: {total_frames}, grey: {grey_total_frames})")
|
| 78 |
+
total_frames = min(total_frames, grey_total_frames)
|
| 79 |
+
|
| 80 |
+
print(f" Video: {width}x{height} @ {fps:.2f}fps, {total_frames} frames")
|
| 81 |
+
print(f" Combining masks...")
|
| 82 |
+
|
| 83 |
+
# Collect all frames first
|
| 84 |
+
combined_frames = []
|
| 85 |
+
|
| 86 |
+
# Process frames
|
| 87 |
+
for frame_idx in tqdm(range(total_frames), desc=" Combining"):
|
| 88 |
+
ret_black, black_frame = black_cap.read()
|
| 89 |
+
ret_grey, grey_frame = grey_cap.read()
|
| 90 |
+
|
| 91 |
+
if not ret_black or not ret_grey:
|
| 92 |
+
print(f" β οΈ Warning: Could not read frame {frame_idx}")
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
# Convert to grayscale if needed
|
| 96 |
+
if len(black_frame.shape) == 3:
|
| 97 |
+
black_frame = cv2.cvtColor(black_frame, cv2.COLOR_BGR2GRAY)
|
| 98 |
+
if len(grey_frame.shape) == 3:
|
| 99 |
+
grey_frame = cv2.cvtColor(grey_frame, cv2.COLOR_BGR2GRAY)
|
| 100 |
+
|
| 101 |
+
# Combine
|
| 102 |
+
combined_frame = combine_masks(black_frame, grey_frame)
|
| 103 |
+
combined_frames.append(combined_frame)
|
| 104 |
+
|
| 105 |
+
# Cleanup
|
| 106 |
+
black_cap.release()
|
| 107 |
+
grey_cap.release()
|
| 108 |
+
|
| 109 |
+
# On the first frame, clamp near-grey values (100β135) to 255 (background).
|
| 110 |
+
# Video codecs can introduce slight luma drift around 127; this ensures no
|
| 111 |
+
# grey pixels survive into the final quadmask on frame 0.
|
| 112 |
+
if combined_frames:
|
| 113 |
+
f0 = combined_frames[0]
|
| 114 |
+
grey_pixels = (f0 > 100) & (f0 < 135)
|
| 115 |
+
f0[grey_pixels] = 255
|
| 116 |
+
combined_frames[0] = f0
|
| 117 |
+
|
| 118 |
+
# Write using LOSSLESS encoding to preserve exact mask values
|
| 119 |
+
print(f" Writing lossless video...")
|
| 120 |
+
|
| 121 |
+
# Write temp AVI with FFV1 codec (lossless)
|
| 122 |
+
temp_avi = output_path.with_suffix('.avi')
|
| 123 |
+
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
|
| 124 |
+
out = cv2.VideoWriter(str(temp_avi), fourcc, fps, (width, height), isColor=False)
|
| 125 |
+
|
| 126 |
+
for frame in combined_frames:
|
| 127 |
+
out.write(frame)
|
| 128 |
+
out.release()
|
| 129 |
+
|
| 130 |
+
# Convert to LOSSLESS H.264 (qp=0, yuv444p to preserve all luma values)
|
| 131 |
+
cmd = [
|
| 132 |
+
'ffmpeg', '-y', '-i', str(temp_avi),
|
| 133 |
+
'-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
|
| 134 |
+
'-pix_fmt', 'yuv444p',
|
| 135 |
+
'-r', str(fps),
|
| 136 |
+
str(output_path)
|
| 137 |
+
]
|
| 138 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 139 |
+
if result.returncode != 0:
|
| 140 |
+
print(f" β οΈ Warning: ffmpeg conversion had issues")
|
| 141 |
+
print(result.stderr)
|
| 142 |
+
|
| 143 |
+
# Clean up temp file
|
| 144 |
+
temp_avi.unlink()
|
| 145 |
+
|
| 146 |
+
print(f" β Saved: {output_path.name}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def process_config(config_path: str):
|
| 150 |
+
"""Process all videos in config"""
|
| 151 |
+
config_path = Path(config_path)
|
| 152 |
+
|
| 153 |
+
# Load config
|
| 154 |
+
with open(config_path, 'r') as f:
|
| 155 |
+
config_data = json.load(f)
|
| 156 |
+
|
| 157 |
+
# Handle both formats
|
| 158 |
+
if isinstance(config_data, list):
|
| 159 |
+
videos = config_data
|
| 160 |
+
elif isinstance(config_data, dict) and "videos" in config_data:
|
| 161 |
+
videos = config_data["videos"]
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError("Config must be a list or have 'videos' key")
|
| 164 |
+
|
| 165 |
+
print(f"\n{'='*70}")
|
| 166 |
+
print(f"Stage 4: Combine Masks into Tri/Quad Mask")
|
| 167 |
+
print(f"{'='*70}")
|
| 168 |
+
print(f"Config: {config_path.name}")
|
| 169 |
+
print(f"Videos: {len(videos)}")
|
| 170 |
+
print(f"{'='*70}\n")
|
| 171 |
+
|
| 172 |
+
# Process each video
|
| 173 |
+
success_count = 0
|
| 174 |
+
for i, video_info in enumerate(videos):
|
| 175 |
+
video_path = video_info.get("video_path", "")
|
| 176 |
+
output_dir = video_info.get("output_dir", "")
|
| 177 |
+
|
| 178 |
+
print(f"\n{'β'*70}")
|
| 179 |
+
print(f"Video {i+1}/{len(videos)}: {Path(video_path).name}")
|
| 180 |
+
print(f"{'β'*70}")
|
| 181 |
+
|
| 182 |
+
if not output_dir:
|
| 183 |
+
print(f" β οΈ No output_dir specified, skipping")
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
output_dir = Path(output_dir)
|
| 187 |
+
if not output_dir.exists():
|
| 188 |
+
print(f" β οΈ Output directory not found: {output_dir}")
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
# Check for required masks
|
| 192 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 193 |
+
grey_mask_path = output_dir / "grey_mask.mp4"
|
| 194 |
+
|
| 195 |
+
if not black_mask_path.exists():
|
| 196 |
+
print(f" β οΈ black_mask.mp4 not found, skipping")
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
if not grey_mask_path.exists():
|
| 200 |
+
print(f" β οΈ grey_mask.mp4 not found, skipping")
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
# Output path
|
| 204 |
+
output_path = output_dir / "quadmask_0.mp4"
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
process_video(black_mask_path, grey_mask_path, output_path)
|
| 208 |
+
success_count += 1
|
| 209 |
+
print(f"\nβ
Video {i+1} complete!")
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"\nβ Error processing video {i+1}: {e}")
|
| 212 |
+
import traceback
|
| 213 |
+
traceback.print_exc()
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Summary
|
| 217 |
+
print(f"\n{'='*70}")
|
| 218 |
+
print(f"Stage 4 Complete!")
|
| 219 |
+
print(f"{'='*70}")
|
| 220 |
+
print(f"Successful: {success_count}/{len(videos)}")
|
| 221 |
+
print(f"Failed: {len(videos) - success_count}/{len(videos)}")
|
| 222 |
+
print(f"{'='*70}\n")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def main():
|
| 226 |
+
parser = argparse.ArgumentParser(
|
| 227 |
+
description="Stage 4: Combine black and grey masks into tri/quad mask"
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--config",
|
| 231 |
+
type=str,
|
| 232 |
+
required=True,
|
| 233 |
+
help="Path to config JSON (with output_dir for each video)"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
args = parser.parse_args()
|
| 237 |
+
process_config(args.config)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
main()
|
VLM-MASK-REASONER/test_gemini_video.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick Gemini API smoke test β samples a few frames from a video and asks a
|
| 4 |
+
simple question. Use this to verify your API key works before running the
|
| 5 |
+
full pipeline.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
export GEMINI_API_KEY="your_aistudio_key"
|
| 9 |
+
python test_gemini_video.py --video path/to/video.mp4
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import base64
|
| 15 |
+
import argparse
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import openai
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
FREE_TIER_MODEL = "gemini-2.0-flash"
|
| 24 |
+
NUM_FRAMES = 4 # keep low for free tier rate limits
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def sample_frames(video_path: str, n: int = NUM_FRAMES):
|
| 28 |
+
"""Sample n evenly-spaced frames from the video, return as base64 data URLs."""
|
| 29 |
+
cap = cv2.VideoCapture(video_path)
|
| 30 |
+
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 31 |
+
indices = [int(i * (total - 1) / (n - 1)) for i in range(n)]
|
| 32 |
+
|
| 33 |
+
data_urls = []
|
| 34 |
+
for idx in indices:
|
| 35 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 36 |
+
ret, frame = cap.read()
|
| 37 |
+
if not ret:
|
| 38 |
+
continue
|
| 39 |
+
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
| 40 |
+
b64 = base64.b64encode(buf).decode("utf-8")
|
| 41 |
+
data_urls.append(f"data:image/jpeg;base64,{b64}")
|
| 42 |
+
|
| 43 |
+
cap.release()
|
| 44 |
+
return data_urls
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
parser = argparse.ArgumentParser(description="Gemini API smoke test with video frames")
|
| 49 |
+
parser.add_argument("--video", required=True, help="Path to a video file")
|
| 50 |
+
parser.add_argument("--model", default=FREE_TIER_MODEL, help="Gemini model to use")
|
| 51 |
+
parser.add_argument("--frames", type=int, default=NUM_FRAMES, help="Number of frames to sample")
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
api_key = os.environ.get("GEMINI_API_KEY")
|
| 55 |
+
if not api_key:
|
| 56 |
+
print("ERROR: GEMINI_API_KEY environment variable not set")
|
| 57 |
+
sys.exit(1)
|
| 58 |
+
|
| 59 |
+
video_path = Path(args.video)
|
| 60 |
+
if not video_path.exists():
|
| 61 |
+
print(f"ERROR: Video not found: {video_path}")
|
| 62 |
+
sys.exit(1)
|
| 63 |
+
|
| 64 |
+
print(f"Video: {video_path.name}")
|
| 65 |
+
print(f"Model: {args.model}")
|
| 66 |
+
print(f"Frames: {args.frames}")
|
| 67 |
+
print()
|
| 68 |
+
|
| 69 |
+
print(f"Sampling {args.frames} frames...")
|
| 70 |
+
data_urls = sample_frames(str(video_path), args.frames)
|
| 71 |
+
print(f"Got {len(data_urls)} frames. Sending to Gemini...")
|
| 72 |
+
|
| 73 |
+
client = openai.OpenAI(
|
| 74 |
+
api_key=api_key,
|
| 75 |
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
content = [
|
| 79 |
+
{"type": "image_url", "image_url": {"url": url}} for url in data_urls
|
| 80 |
+
]
|
| 81 |
+
content.append({
|
| 82 |
+
"type": "text",
|
| 83 |
+
"text": "These are evenly-spaced frames from a short video. In one sentence, describe what is happening in the video."
|
| 84 |
+
})
|
| 85 |
+
|
| 86 |
+
response = client.chat.completions.create(
|
| 87 |
+
model=args.model,
|
| 88 |
+
messages=[{"role": "user", "content": content}],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
print("\n--- Gemini response ---")
|
| 92 |
+
print(response.choices[0].message.content)
|
| 93 |
+
print("-----------------------")
|
| 94 |
+
print("\nβ
API key works!")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
main()
|
app.py
CHANGED
|
@@ -1,18 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
@spaces.GPU
|
| 5 |
-
def
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VOID VLM-Mask-Reasoner β Quadmask Generation Demo
|
| 3 |
+
Generates 4-level semantic masks for interaction-aware video inpainting.
|
| 4 |
+
|
| 5 |
+
Pipeline from https://github.com/Netflix/void-model:
|
| 6 |
+
Stage 1: SAM2 segmentation β black mask (transformers Sam2Model)
|
| 7 |
+
Stage 2: Gemini VLM scene analysis β affected objects JSON (repo code)
|
| 8 |
+
Stage 3: SAM3 text-prompted segmentation β grey mask (transformers Sam3Model)
|
| 9 |
+
Stage 4: Combine black + grey β quadmask (0/63/127/255) (repo code)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import tempfile
|
| 16 |
+
import shutil
|
| 17 |
+
import subprocess
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
import gradio as gr
|
| 24 |
import spaces
|
| 25 |
+
import imageio
|
| 26 |
+
from PIL import Image, ImageDraw
|
| 27 |
+
from huggingface_hub import hf_hub_download
|
| 28 |
+
import openai
|
| 29 |
+
|
| 30 |
+
# ββ Add repo modules to path βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "VLM-MASK-REASONER"))
|
| 32 |
+
|
| 33 |
+
# ββ Repo imports: Stage 2 (VLM) and Stage 4 (combine) ββββββββββββββββββββββββ
|
| 34 |
+
from stage2_vlm_analysis import (
|
| 35 |
+
process_video as vlm_process_video,
|
| 36 |
+
calculate_square_grid,
|
| 37 |
+
)
|
| 38 |
+
from stage4_combine_masks import process_video as combine_process_video
|
| 39 |
+
# Stage 3 helpers (grid logic, mask combination β not the SegmentationModel)
|
| 40 |
+
from stage3a_generate_grey_masks_v2 import (
|
| 41 |
+
calculate_square_grid as calc_grid_3a,
|
| 42 |
+
gridify_masks,
|
| 43 |
+
filter_masks_by_proximity,
|
| 44 |
+
segment_object_all_frames as _repo_segment_all_frames,
|
| 45 |
+
process_video_grey_masks,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
SAM2_MODEL_ID = "facebook/sam2.1-hiera-large"
|
| 50 |
+
SAM3_MODEL_ID = "jetjodh/sam3"
|
| 51 |
+
DEFAULT_VLM_MODEL = "gemini-3-flash-preview"
|
| 52 |
+
MAX_FRAMES = 197
|
| 53 |
+
FPS_DEFAULT = 12
|
| 54 |
+
FRAME_STRIDE = 4 # Process every Nth frame for SAM2 tracking
|
| 55 |
+
|
| 56 |
+
# ββ Load transformers SAM2 (video model with propagation support) βββββββββββββ
|
| 57 |
+
print("Loading SAM2 video model (transformers)...")
|
| 58 |
+
from transformers import Sam2VideoModel, Sam2VideoProcessor
|
| 59 |
+
from transformers.models.sam2_video.modeling_sam2_video import Sam2VideoInferenceSession
|
| 60 |
+
sam2_model = Sam2VideoModel.from_pretrained(SAM2_MODEL_ID).to("cuda")
|
| 61 |
+
sam2_processor = Sam2VideoProcessor.from_pretrained(SAM2_MODEL_ID)
|
| 62 |
+
print("SAM2 video model ready.")
|
| 63 |
+
|
| 64 |
+
# ββ Load transformers SAM3 βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
print("Loading SAM3 model (transformers)...")
|
| 66 |
+
from transformers import Sam3Model, Sam3Processor
|
| 67 |
+
sam3_model = Sam3Model.from_pretrained(SAM3_MODEL_ID).to("cuda")
|
| 68 |
+
sam3_processor = Sam3Processor.from_pretrained(SAM3_MODEL_ID)
|
| 69 |
+
print("SAM3 ready.")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
# STAGE 1: SAM2 VIDEO SEGMENTATION (transformers Sam2VideoModel)
|
| 74 |
+
# Uses proper video propagation with memory β matches repo's propagate_in_video
|
| 75 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
+
|
| 77 |
+
def stage1_segment_video(frames: list, points: list, **kwargs) -> list:
|
| 78 |
+
"""Segment primary object across all video frames using SAM2 video propagation.
|
| 79 |
+
Matches repo: point prompts + bounding box on frame 0, propagate through video.
|
| 80 |
+
Returns list of uint8 masks (0=object, 255=background)."""
|
| 81 |
+
total = len(frames)
|
| 82 |
+
h, w = frames[0].shape[:2]
|
| 83 |
+
|
| 84 |
+
# Preprocess all frames
|
| 85 |
+
pil_frames = [Image.fromarray(f) for f in frames]
|
| 86 |
+
inputs = sam2_processor(images=pil_frames, return_tensors="pt").to(sam2_model.device)
|
| 87 |
+
|
| 88 |
+
# Create inference session with all frames
|
| 89 |
+
session = Sam2VideoInferenceSession(
|
| 90 |
+
video=inputs["pixel_values"],
|
| 91 |
+
video_height=h,
|
| 92 |
+
video_width=w,
|
| 93 |
+
inference_device=sam2_model.device,
|
| 94 |
+
inference_state_device=sam2_model.device,
|
| 95 |
+
dtype=torch.float32,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Add point prompts + bounding box on frame 0 via processor
|
| 99 |
+
# (handles normalization, object registration, and obj_with_new_inputs)
|
| 100 |
+
pts = np.array(points, dtype=np.float32)
|
| 101 |
+
x_min, x_max = pts[:, 0].min(), pts[:, 0].max()
|
| 102 |
+
y_min, y_max = pts[:, 1].min(), pts[:, 1].max()
|
| 103 |
+
x_margin = max((x_max - x_min) * 0.1, 10)
|
| 104 |
+
y_margin = max((y_max - y_min) * 0.1, 10)
|
| 105 |
+
box = [
|
| 106 |
+
max(0, x_min - x_margin),
|
| 107 |
+
max(0, y_min - y_margin),
|
| 108 |
+
min(w, x_max + x_margin),
|
| 109 |
+
min(h, y_max + y_margin),
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
sam2_processor.process_new_points_or_boxes_for_video_frame(
|
| 113 |
+
inference_session=session,
|
| 114 |
+
frame_idx=0,
|
| 115 |
+
obj_ids=[1],
|
| 116 |
+
input_points=[[[[float(p[0]), float(p[1])] for p in points]]],
|
| 117 |
+
input_labels=[[[1] * len(points)]],
|
| 118 |
+
input_boxes=[[[float(box[0]), float(box[1]), float(box[2]), float(box[3])]]],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Run forward on the prompted frame first (populates cond_frame_outputs)
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
sam2_model(session, frame_idx=0)
|
| 124 |
+
|
| 125 |
+
# Propagate through all frames (matches repo's propagate_in_video)
|
| 126 |
+
video_segments = {}
|
| 127 |
+
original_sizes = [[h, w]]
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
for output in sam2_model.propagate_in_video_iterator(session):
|
| 130 |
+
frame_idx = output.frame_idx
|
| 131 |
+
# pred_masks shape varies β get the raw logits and resize to original
|
| 132 |
+
mask_logits = output.pred_masks[0].cpu().float() # first object
|
| 133 |
+
# Ensure 4D for interpolation: (1, 1, H_model, W_model)
|
| 134 |
+
while mask_logits.dim() < 4:
|
| 135 |
+
mask_logits = mask_logits.unsqueeze(0)
|
| 136 |
+
mask_resized = torch.nn.functional.interpolate(
|
| 137 |
+
mask_logits, size=(h, w), mode="bilinear", align_corners=False
|
| 138 |
+
)
|
| 139 |
+
mask = (mask_resized.squeeze() > 0.0).numpy()
|
| 140 |
+
video_segments[frame_idx] = mask
|
| 141 |
+
|
| 142 |
+
# Convert to uint8 masks (0=object, 255=background)
|
| 143 |
+
all_masks = []
|
| 144 |
+
for idx in range(total):
|
| 145 |
+
if idx in video_segments:
|
| 146 |
+
mask_bool = video_segments[idx]
|
| 147 |
+
else:
|
| 148 |
+
nearest = min(video_segments.keys(), key=lambda k: abs(k - idx))
|
| 149 |
+
mask_bool = video_segments[nearest]
|
| 150 |
+
mask_uint8 = np.where(mask_bool, 0, 255).astype(np.uint8)
|
| 151 |
+
all_masks.append(mask_uint8)
|
| 152 |
+
|
| 153 |
+
return all_masks
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def write_mask_video(masks: list, fps: float, output_path: str):
|
| 157 |
+
"""Write list of uint8 grayscale masks to lossless MP4."""
|
| 158 |
+
h, w = masks[0].shape[:2]
|
| 159 |
+
temp_avi = str(Path(output_path).with_suffix('.avi'))
|
| 160 |
+
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
|
| 161 |
+
out = cv2.VideoWriter(temp_avi, fourcc, fps, (w, h), isColor=False)
|
| 162 |
+
for mask in masks:
|
| 163 |
+
out.write(mask)
|
| 164 |
+
out.release()
|
| 165 |
+
|
| 166 |
+
cmd = [
|
| 167 |
+
'ffmpeg', '-y', '-i', temp_avi,
|
| 168 |
+
'-c:v', 'libx264', '-qp', '0', '-preset', 'ultrafast',
|
| 169 |
+
'-pix_fmt', 'yuv444p', str(output_path),
|
| 170 |
+
]
|
| 171 |
+
subprocess.run(cmd, capture_output=True)
|
| 172 |
+
if os.path.exists(temp_avi):
|
| 173 |
+
os.unlink(temp_avi)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
+
# STAGE 3: SAM3 TEXT-PROMPTED SEGMENTATION (transformers)
|
| 178 |
+
# β Drop-in replacement for repo's SegmentationModel.segment()
|
| 179 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
+
|
| 181 |
+
class TransformersSam3Segmenter:
|
| 182 |
+
"""Matches the interface of the repo's SegmentationModel for stage3a."""
|
| 183 |
+
model_type = "sam3"
|
| 184 |
+
|
| 185 |
+
def segment(self, image_pil: Image.Image, prompt: str) -> np.ndarray:
|
| 186 |
+
"""Segment object by text prompt. Returns boolean mask."""
|
| 187 |
+
h, w = image_pil.height, image_pil.width
|
| 188 |
+
union = np.zeros((h, w), dtype=bool)
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
inputs = sam3_processor(
|
| 192 |
+
images=image_pil, text=prompt, return_tensors="pt"
|
| 193 |
+
).to(sam3_model.device)
|
| 194 |
+
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
outputs = sam3_model(**inputs)
|
| 197 |
+
|
| 198 |
+
results = sam3_processor.post_process_instance_segmentation(
|
| 199 |
+
outputs,
|
| 200 |
+
threshold=0.3,
|
| 201 |
+
mask_threshold=0.5,
|
| 202 |
+
target_sizes=inputs.get("original_sizes").tolist(),
|
| 203 |
+
)[0]
|
| 204 |
+
|
| 205 |
+
masks = results.get("masks")
|
| 206 |
+
if masks is not None and len(masks) > 0:
|
| 207 |
+
if torch.is_tensor(masks):
|
| 208 |
+
masks = masks.cpu().numpy()
|
| 209 |
+
if masks.ndim == 2:
|
| 210 |
+
union = masks.astype(bool)
|
| 211 |
+
elif masks.ndim == 3:
|
| 212 |
+
union = masks.any(axis=0).astype(bool)
|
| 213 |
+
elif masks.ndim == 4:
|
| 214 |
+
union = masks.any(axis=(0, 1)).astype(bool)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f" Warning: SAM3 segmentation failed for '{prompt}': {e}")
|
| 217 |
+
|
| 218 |
+
return union
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
seg_model = TransformersSam3Segmenter()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
# HELPERS
|
| 226 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 227 |
+
|
| 228 |
+
def extract_frames(video_path: str, max_frames: int = MAX_FRAMES):
|
| 229 |
+
"""Extract frames from video. Returns (frames_rgb_list, fps)."""
|
| 230 |
+
cap = cv2.VideoCapture(video_path)
|
| 231 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
|
| 232 |
+
frames = []
|
| 233 |
+
while len(frames) < max_frames:
|
| 234 |
+
ret, frame = cap.read()
|
| 235 |
+
if not ret:
|
| 236 |
+
break
|
| 237 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 238 |
+
cap.release()
|
| 239 |
+
return frames, fps
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def draw_points_on_image(image: np.ndarray, points: list, radius: int = 6) -> np.ndarray:
|
| 243 |
+
pil_img = Image.fromarray(image.copy())
|
| 244 |
+
draw = ImageDraw.Draw(pil_img)
|
| 245 |
+
for i, (x, y) in enumerate(points):
|
| 246 |
+
r = radius
|
| 247 |
+
draw.ellipse([x - r, y - r, x + r, y + r], fill="red", outline="white", width=2)
|
| 248 |
+
draw.text((x + r + 2, y - r), str(i + 1), fill="white")
|
| 249 |
+
return np.array(pil_img)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def frames_to_video(frames: list, fps: float) -> str:
|
| 253 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
| 254 |
+
tmp_path = tmp.name
|
| 255 |
+
tmp.close()
|
| 256 |
+
writer = imageio.get_writer(tmp_path, fps=fps, codec='libx264',
|
| 257 |
+
output_params=['-crf', '18', '-pix_fmt', 'yuv420p'])
|
| 258 |
+
for frame in frames:
|
| 259 |
+
writer.append_data(frame)
|
| 260 |
+
writer.close()
|
| 261 |
+
return tmp_path
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def create_quadmask_visualization(video_path: str, quadmask_path: str) -> str:
|
| 265 |
+
cap_vid = cv2.VideoCapture(video_path)
|
| 266 |
+
cap_qm = cv2.VideoCapture(quadmask_path)
|
| 267 |
+
fps = cap_vid.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
|
| 268 |
+
|
| 269 |
+
vis_frames = []
|
| 270 |
+
while True:
|
| 271 |
+
ret_v, frame = cap_vid.read()
|
| 272 |
+
ret_q, qm_frame = cap_qm.read()
|
| 273 |
+
if not ret_v or not ret_q:
|
| 274 |
+
break
|
| 275 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 276 |
+
qm = cv2.cvtColor(qm_frame, cv2.COLOR_BGR2GRAY) if len(qm_frame.shape) == 3 else qm_frame
|
| 277 |
+
|
| 278 |
+
qm = np.where(qm <= 31, 0, qm)
|
| 279 |
+
qm = np.where((qm > 31) & (qm <= 95), 63, qm)
|
| 280 |
+
qm = np.where((qm > 95) & (qm <= 191), 127, qm)
|
| 281 |
+
qm = np.where(qm > 191, 255, qm)
|
| 282 |
+
|
| 283 |
+
overlay = frame_rgb.copy()
|
| 284 |
+
overlay[qm == 0] = [255, 50, 50]
|
| 285 |
+
overlay[qm == 63] = [255, 200, 0]
|
| 286 |
+
overlay[qm == 127] = [50, 255, 50]
|
| 287 |
+
result = cv2.addWeighted(frame_rgb, 0.5, overlay, 0.5, 0)
|
| 288 |
+
result[qm == 255] = frame_rgb[qm == 255]
|
| 289 |
+
vis_frames.append(result)
|
| 290 |
+
|
| 291 |
+
cap_vid.release()
|
| 292 |
+
cap_qm.release()
|
| 293 |
+
return frames_to_video(vis_frames, fps) if vis_frames else None
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def create_quadmask_color_video(quadmask_path: str) -> str:
|
| 297 |
+
cap = cv2.VideoCapture(quadmask_path)
|
| 298 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or FPS_DEFAULT
|
| 299 |
+
color_frames = []
|
| 300 |
+
while True:
|
| 301 |
+
ret, frame = cap.read()
|
| 302 |
+
if not ret:
|
| 303 |
+
break
|
| 304 |
+
qm = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame
|
| 305 |
+
qm = np.where(qm <= 31, 0, qm)
|
| 306 |
+
qm = np.where((qm > 31) & (qm <= 95), 63, qm)
|
| 307 |
+
qm = np.where((qm > 95) & (qm <= 191), 127, qm)
|
| 308 |
+
qm = np.where(qm > 191, 255, qm)
|
| 309 |
+
h, w = qm.shape
|
| 310 |
+
color = np.full((h, w, 3), 255, dtype=np.uint8)
|
| 311 |
+
color[qm == 0] = [0, 0, 0]
|
| 312 |
+
color[qm == 63] = [80, 80, 80]
|
| 313 |
+
color[qm == 127] = [160, 160, 160]
|
| 314 |
+
color_frames.append(color)
|
| 315 |
+
cap.release()
|
| 316 |
+
return frames_to_video(color_frames, fps) if color_frames else None
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 320 |
+
# MAIN PIPELINE
|
| 321 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 322 |
|
| 323 |
@spaces.GPU
|
| 324 |
+
def run_pipeline(video_path: str, points_json: str, instruction: str,
|
| 325 |
+
progress=gr.Progress(track_tqdm=False)):
|
| 326 |
+
"""Run the full VLM-Mask-Reasoner pipeline."""
|
| 327 |
+
if not video_path:
|
| 328 |
+
raise gr.Error("Please upload a video.")
|
| 329 |
+
if not points_json or points_json == "[]":
|
| 330 |
+
raise gr.Error("Please click on the image to select at least one point on the primary object.")
|
| 331 |
+
if not instruction.strip():
|
| 332 |
+
raise gr.Error("Please enter an edit instruction.")
|
| 333 |
|
| 334 |
+
points = json.loads(points_json)
|
| 335 |
+
if len(points) == 0:
|
| 336 |
+
raise gr.Error("Please select at least one point on the primary object.")
|
| 337 |
|
| 338 |
+
api_key = os.environ.get("GEMINI_API_KEY", "")
|
| 339 |
+
|
| 340 |
+
# Create temp output directory
|
| 341 |
+
output_dir = Path(tempfile.mkdtemp(prefix="void_quadmask_"))
|
| 342 |
+
input_video_path = output_dir / "input_video.mp4"
|
| 343 |
+
shutil.copy2(video_path, input_video_path)
|
| 344 |
+
|
| 345 |
+
# ββ Stage 1: SAM2 Segmentation ββββββββββββββββββββββββββββββββββββββββββ
|
| 346 |
+
progress(0.05, desc="Stage 1: SAM2 segmentation...")
|
| 347 |
+
frames, fps = extract_frames(str(input_video_path))
|
| 348 |
+
if len(frames) < 2:
|
| 349 |
+
raise gr.Error("Video must have at least 2 frames.")
|
| 350 |
+
|
| 351 |
+
black_masks = stage1_segment_video(frames, points, stride=FRAME_STRIDE)
|
| 352 |
+
black_mask_path = output_dir / "black_mask.mp4"
|
| 353 |
+
write_mask_video(black_masks, fps, str(black_mask_path))
|
| 354 |
+
|
| 355 |
+
# Save first frame for VLM analysis
|
| 356 |
+
first_frame_path = output_dir / "first_frame.jpg"
|
| 357 |
+
cv2.imwrite(str(first_frame_path), cv2.cvtColor(frames[0], cv2.COLOR_RGB2BGR))
|
| 358 |
+
|
| 359 |
+
# Save segmentation metadata (Stage 2 expects this)
|
| 360 |
+
seg_info = {
|
| 361 |
+
"total_frames": len(frames),
|
| 362 |
+
"frame_width": frames[0].shape[1],
|
| 363 |
+
"frame_height": frames[0].shape[0],
|
| 364 |
+
"fps": fps,
|
| 365 |
+
"video_path": str(input_video_path),
|
| 366 |
+
"instruction": instruction,
|
| 367 |
+
"primary_points_by_frame": {"0": points},
|
| 368 |
+
"first_appears_frame": 0,
|
| 369 |
+
}
|
| 370 |
+
with open(output_dir / "segmentation_info.json", 'w') as f:
|
| 371 |
+
json.dump(seg_info, f, indent=2)
|
| 372 |
+
|
| 373 |
+
progress(0.3, desc="Stage 1 complete.")
|
| 374 |
+
|
| 375 |
+
# ββ Stage 2: VLM Analysis (repo code) βββββββββββββββββββββββββββββββββββ
|
| 376 |
+
analysis = None
|
| 377 |
+
if api_key:
|
| 378 |
+
progress(0.35, desc="Stage 2: VLM analysis (calling Gemini)...")
|
| 379 |
+
try:
|
| 380 |
+
video_info = {
|
| 381 |
+
"video_path": str(input_video_path),
|
| 382 |
+
"instruction": instruction,
|
| 383 |
+
"output_dir": str(output_dir),
|
| 384 |
+
"multi_frame_grids": True,
|
| 385 |
+
}
|
| 386 |
+
client = openai.OpenAI(
|
| 387 |
+
api_key=api_key,
|
| 388 |
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
| 389 |
+
)
|
| 390 |
+
analysis = vlm_process_video(video_info, client, DEFAULT_VLM_MODEL)
|
| 391 |
+
progress(0.55, desc="Stage 2 complete.")
|
| 392 |
+
except Exception as e:
|
| 393 |
+
gr.Warning(f"VLM analysis failed: {e}. Generating binary mask only.")
|
| 394 |
+
analysis = None
|
| 395 |
+
else:
|
| 396 |
+
gr.Warning("No GEMINI_API_KEY set. Generating binary mask only (no VLM analysis).")
|
| 397 |
+
|
| 398 |
+
# ββ Stage 3: Grey Mask Generation (repo logic + transformers SAM3) ββββββ
|
| 399 |
+
grey_mask_path = output_dir / "grey_mask.mp4"
|
| 400 |
+
vlm_analysis_path = output_dir / "vlm_analysis.json"
|
| 401 |
+
|
| 402 |
+
if analysis and vlm_analysis_path.exists():
|
| 403 |
+
progress(0.6, desc="Stage 3: Generating grey masks (SAM3 segmentation)...")
|
| 404 |
+
try:
|
| 405 |
+
video_info_3 = {
|
| 406 |
+
"video_path": str(input_video_path),
|
| 407 |
+
"output_dir": str(output_dir),
|
| 408 |
+
"min_grid": 8,
|
| 409 |
+
}
|
| 410 |
+
# Uses the repo's process_video_grey_masks with our TransformersSam3Segmenter
|
| 411 |
+
process_video_grey_masks(video_info_3, seg_model)
|
| 412 |
+
progress(0.8, desc="Stage 3 complete.")
|
| 413 |
+
except Exception as e:
|
| 414 |
+
gr.Warning(f"Stage 3 failed: {e}. Generating binary mask only.")
|
| 415 |
+
|
| 416 |
+
# ββ Stage 4: Combine into Quadmask (repo code) βββββββββββββββββββββββββ
|
| 417 |
+
quadmask_path = output_dir / "quadmask_0.mp4"
|
| 418 |
+
if grey_mask_path.exists():
|
| 419 |
+
progress(0.85, desc="Stage 4: Combining into quadmask...")
|
| 420 |
+
combine_process_video(black_mask_path, grey_mask_path, quadmask_path)
|
| 421 |
+
else:
|
| 422 |
+
shutil.copy2(black_mask_path, quadmask_path)
|
| 423 |
+
|
| 424 |
+
progress(0.9, desc="Creating visualizations...")
|
| 425 |
+
|
| 426 |
+
# ββ Visualization outputs βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 427 |
+
overlay_path = create_quadmask_visualization(str(input_video_path), str(quadmask_path))
|
| 428 |
+
color_path = create_quadmask_color_video(str(quadmask_path))
|
| 429 |
+
|
| 430 |
+
analysis_text = ""
|
| 431 |
+
if vlm_analysis_path.exists():
|
| 432 |
+
with open(vlm_analysis_path) as f:
|
| 433 |
+
analysis_text = f.read()
|
| 434 |
+
else:
|
| 435 |
+
analysis_text = "No VLM analysis available."
|
| 436 |
+
|
| 437 |
+
progress(1.0, desc="Done!")
|
| 438 |
+
return str(quadmask_path), overlay_path, color_path, analysis_text
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 442 |
+
# GRADIO UI
|
| 443 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 444 |
+
|
| 445 |
+
def on_video_upload(video_path):
|
| 446 |
+
if not video_path:
|
| 447 |
+
return None, None, "[]", gr.update(interactive=False)
|
| 448 |
+
frames, _ = extract_frames(video_path, max_frames=1)
|
| 449 |
+
if not frames:
|
| 450 |
+
return None, None, "[]", gr.update(interactive=False)
|
| 451 |
+
return frames[0], frames[0], "[]", gr.update(interactive=True)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def on_frame_select(clean_frame, points_json, evt: gr.SelectData):
|
| 455 |
+
if clean_frame is None:
|
| 456 |
+
return None, points_json
|
| 457 |
+
points = json.loads(points_json) if points_json else []
|
| 458 |
+
x, y = evt.index
|
| 459 |
+
points.append([int(x), int(y)])
|
| 460 |
+
annotated = draw_points_on_image(clean_frame, points)
|
| 461 |
+
return annotated, json.dumps(points)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def on_clear_points(clean_frame):
|
| 465 |
+
if clean_frame is not None:
|
| 466 |
+
return clean_frame, "[]"
|
| 467 |
+
return None, "[]"
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
DESCRIPTION = """
|
| 471 |
+
# VOID VLM-Mask-Reasoner β Quadmask Generation
|
| 472 |
+
|
| 473 |
+
Generate **4-level semantic masks** (quadmasks) for interaction-aware video inpainting with [VOID](https://github.com/Netflix/void-model).
|
| 474 |
+
|
| 475 |
+
**Pipeline:** Click points on object β SAM2 segments it β Gemini VLM reasons about interactions β SAM3 segments affected objects β Quadmask generated
|
| 476 |
+
|
| 477 |
+
Use the generated quadmask with the [VOID inpainting demo](https://huggingface.co/spaces/sam-motamed/VOID).
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
QUADMASK_EXPLAINER = """
|
| 481 |
+
### Quadmask format
|
| 482 |
+
|
| 483 |
+
| Pixel Value | Color | Meaning |
|
| 484 |
+
|-------------|-------|---------|
|
| 485 |
+
| **0** (black) | Red overlay | Primary object to remove |
|
| 486 |
+
| **63** (dark grey) | Yellow overlay | Overlap of primary + affected zone |
|
| 487 |
+
| **127** (mid grey) | Green overlay | Affected region (shadows, reflections, physics) |
|
| 488 |
+
| **255** (white) | Original | Background β keep as-is |
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
with gr.Blocks(title="VOID VLM-Mask-Reasoner", theme=gr.themes.Default()) as demo:
|
| 492 |
+
gr.Markdown(DESCRIPTION)
|
| 493 |
+
|
| 494 |
+
points_state = gr.State("[]")
|
| 495 |
+
clean_frame_state = gr.State(None)
|
| 496 |
+
|
| 497 |
+
with gr.Row():
|
| 498 |
+
with gr.Column(scale=1):
|
| 499 |
+
video_input = gr.Video(label="Upload Video", sources=["upload"])
|
| 500 |
+
frame_display = gr.Image(
|
| 501 |
+
label="Click to select primary object points (click multiple spots on the object)",
|
| 502 |
+
interactive=True, type="numpy",
|
| 503 |
+
)
|
| 504 |
+
with gr.Row():
|
| 505 |
+
clear_btn = gr.Button("Clear Points", size="sm")
|
| 506 |
+
points_display = gr.Textbox(label="Selected Points", value="[]",
|
| 507 |
+
interactive=False, max_lines=2)
|
| 508 |
+
instruction_input = gr.Textbox(
|
| 509 |
+
label="Edit instruction β describe what to remove",
|
| 510 |
+
placeholder="e.g., remove the person", lines=1,
|
| 511 |
+
)
|
| 512 |
+
generate_btn = gr.Button("Generate Quadmask", variant="primary", size="lg",
|
| 513 |
+
interactive=False)
|
| 514 |
+
|
| 515 |
+
with gr.Column(scale=1):
|
| 516 |
+
output_quadmask_file = gr.File(label="Download lossless quadmask_0.mp4 (use this with VOID)")
|
| 517 |
+
with gr.Tabs():
|
| 518 |
+
with gr.TabItem("Quadmask Overlay"):
|
| 519 |
+
output_overlay = gr.Video(label="Quadmask overlay on original video")
|
| 520 |
+
with gr.TabItem("Raw Quadmask"):
|
| 521 |
+
output_color = gr.Video(label="Color-coded quadmask")
|
| 522 |
+
with gr.TabItem("VLM Analysis"):
|
| 523 |
+
output_analysis = gr.Code(label="VLM Analysis JSON", language="json")
|
| 524 |
+
|
| 525 |
+
video_input.change(
|
| 526 |
+
fn=on_video_upload, inputs=[video_input],
|
| 527 |
+
outputs=[frame_display, clean_frame_state, points_state, generate_btn],
|
| 528 |
+
)
|
| 529 |
+
points_state.change(lambda p: p, inputs=points_state, outputs=points_display)
|
| 530 |
+
frame_display.select(
|
| 531 |
+
fn=on_frame_select, inputs=[clean_frame_state, points_state],
|
| 532 |
+
outputs=[frame_display, points_state],
|
| 533 |
+
)
|
| 534 |
+
clear_btn.click(
|
| 535 |
+
fn=on_clear_points, inputs=[clean_frame_state],
|
| 536 |
+
outputs=[frame_display, points_state],
|
| 537 |
+
)
|
| 538 |
+
generate_btn.click(
|
| 539 |
+
fn=run_pipeline, inputs=[video_input, points_state, instruction_input],
|
| 540 |
+
outputs=[output_quadmask_file, output_overlay, output_color, output_analysis],
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
gr.Markdown(QUADMASK_EXPLAINER)
|
| 544 |
|
| 545 |
+
if __name__ == "__main__":
|
| 546 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchvision==0.24
|
| 2 |
+
transformers>=4.50.0
|
| 3 |
+
accelerate
|
| 4 |
+
gradio
|
| 5 |
+
numpy<2.0
|
| 6 |
+
opencv-python-headless
|
| 7 |
+
Pillow
|
| 8 |
+
imageio
|
| 9 |
+
imageio-ffmpeg
|
| 10 |
+
openai
|
| 11 |
+
huggingface_hub
|
| 12 |
+
spaces
|
| 13 |
+
tqdm
|